Skip to content

Commit c2d2065

Browse files
committed
[AutoDeploy] Modular export patches + registry; fixes NVIDIA#5728 (#91)
* Modular export patches + registry; fixes NVIDIA#5728 Signed-off-by: Lucas Liebenwein <[email protected]> * patch library for models Signed-off-by: Lucas Liebenwein <[email protected]> * unit test fixes Signed-off-by: Lucas Liebenwein <[email protected]> * addressing reviewer feedback Signed-off-by: Lucas Liebenwein <[email protected]> --------- Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 5b69d2c commit c2d2065

37 files changed

+966
-415
lines changed

examples/auto_deploy/build_and_run_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from diffusers import DiffusionPipeline
77

88
from tensorrt_llm._torch.auto_deploy.compile import compile_and_capture
9-
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export_to_gm
9+
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
1010
from tensorrt_llm._torch.auto_deploy.transformations.library.fusion import fuse_gemms
1111
from tensorrt_llm._torch.auto_deploy.transformations.library.quantization import quantize
1212
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger

tensorrt_llm/_torch/auto_deploy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# import submodules that require registration process
2-
from . import compile, custom_ops, models, shim # noqa: F401
2+
from . import compile, custom_ops, export, models, shim # noqa: F401
33

44
# import AutoDeploy LLM and LlmArgs
55
from .llm import *
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""AutoDeploy's modular export patch system."""
2+
3+
from . import library # ensure all patches are registered
4+
from .export import *
5+
from .interface import *
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""Main export functionality with utilities for torch.export."""
2+
3+
from collections import defaultdict
4+
from contextlib import nullcontext
5+
from functools import partial
6+
from typing import Any, Dict, List, Optional, Tuple, Union
7+
8+
import torch
9+
import torch.export as te
10+
import torch.nn as nn
11+
from torch import fx
12+
13+
from ..transformations._graph import (
14+
canonicalize_graph,
15+
lift_to_meta,
16+
load_buffers_and_params,
17+
tree_to,
18+
)
19+
from ..utils.logger import ad_logger
20+
from .interface import ExportPatchRegistry, apply_export_patches
21+
22+
try:
23+
from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context
24+
except ImportError:
25+
torch_export_context = nullcontext
26+
27+
28+
def _clean_up_device_info(gm: fx.GraphModule) -> None:
29+
"""Correct device information in the graph."""
30+
devices = {t.device for _, t in gm.named_parameters()}
31+
if len(devices) == 0:
32+
return
33+
elif len(devices) > 1:
34+
raise AssertionError("All parameters should be on the same device.")
35+
device = devices.pop()
36+
meta_device = torch.device("meta")
37+
38+
for node in gm.graph.nodes:
39+
if any(a == meta_device for a in node.args):
40+
new_args = list(node.args)
41+
new_args = [a if a != meta_device else device for a in new_args]
42+
node.args = tuple(new_args)
43+
if any(a == meta_device for a in node.kwargs.values()):
44+
new_kwargs = dict(node.kwargs)
45+
new_kwargs = {k: v if v != meta_device else device for k, v in new_kwargs.items()}
46+
node.kwargs = new_kwargs
47+
48+
canonicalize_graph(gm)
49+
50+
51+
def _load_hook_for_deduplication(
52+
state_dict, prefix, *args, param_key_remaining: str, param_key_removed: str
53+
):
54+
"""Check for removed param key and and put it into the key that is remaining."""
55+
ad_logger.debug(f"Loading hook for deduplication: {param_key_remaining} <- {param_key_removed}")
56+
k_remaining = prefix + param_key_remaining
57+
k_removed = prefix + param_key_removed
58+
if k_removed in state_dict:
59+
state_dict[k_remaining] = state_dict.pop(k_removed)
60+
61+
62+
def _deduplicate_params_and_buffers(gm: fx.GraphModule) -> None:
63+
"""This will de-duplicate params and buffers that share the same tensor."""
64+
# get all get_attr nodes
65+
get_attr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
66+
67+
# sort by id of target
68+
targets: Dict[int, List[fx.Node]] = defaultdict(list)
69+
for n in get_attr_nodes:
70+
submod, _, name = n.target.rpartition(".")
71+
t_target = getattr(gm.get_submodule(submod), name)
72+
targets[id(t_target)].append(n)
73+
# now replace all instances of the same tensor with the same get_attr node (idx 0 in the list)
74+
for nodes in targets.values():
75+
node_kept = nodes[0]
76+
for n in nodes[1:]:
77+
n.replace_all_uses_with(node_kept)
78+
gm.graph.erase_node(n)
79+
80+
# remove the param/buffer from the submodule
81+
submod, _, name = n.target.rpartition(".")
82+
delattr(gm.get_submodule(submod), name)
83+
84+
# add load hooks to also load the weights correctly
85+
gm._register_load_state_dict_pre_hook(
86+
partial(
87+
_load_hook_for_deduplication,
88+
param_key_remaining=str(node_kept.target),
89+
param_key_removed=str(n.target),
90+
)
91+
)
92+
93+
ad_logger.debug(f"Deduplicated: {n.target} --> {node_kept.target}")
94+
95+
canonicalize_graph(gm)
96+
97+
98+
def _add_missing_load_hooks(gm: fx.GraphModule, model: nn.Module) -> None:
99+
"""Adds back the state dict load hooks stripped away during export."""
100+
hooks = {
101+
k: mod._load_state_dict_pre_hooks
102+
for k, mod in model.named_modules()
103+
if mod._load_state_dict_pre_hooks
104+
}
105+
106+
for mod_name, mod in gm.named_modules():
107+
if mod_name in hooks:
108+
for hook in hooks.pop(mod_name).values():
109+
mod._register_load_state_dict_pre_hook(hook.hook, with_module=hook.with_module)
110+
assert not (bool(hooks)), f"""Mismatch in names of exported and source modules with hooks.
111+
The following module names were not found in exported module {list(hooks.keys())}"""
112+
113+
114+
def _add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module) -> None:
115+
"""
116+
Add a load hook to handle aliased parameters in the model.
117+
118+
When parameters are aliased (multiple parameter names point to the same tensor),
119+
we need to ensure all aliases get the same value during loading. This hook:
120+
1. Identifies groups of aliased parameters
121+
2. For each group, finds a valid parameter value from the state dict
122+
3. Applies that value to all aliases in the group
123+
124+
Args:
125+
gm: The graph module to add the hook to
126+
model: The source model containing the original parameter aliases
127+
"""
128+
129+
def find_valid_param_value(
130+
state_dict: Dict[str, torch.Tensor], param_names: List[str]
131+
) -> Optional[torch.Tensor]:
132+
"""Find a valid parameter value from state dict for a group of aliased parameters.
133+
134+
Args:
135+
state_dict: The state dict being loaded
136+
param_names: List of parameter names that are aliases of each other
137+
138+
Returns:
139+
A valid tensor value if found, None otherwise
140+
"""
141+
# First try to find a non-meta tensor value
142+
value = None
143+
for name in param_names:
144+
if name in state_dict:
145+
value = state_dict[name]
146+
if value.device.type != "meta":
147+
return value
148+
149+
return value
150+
151+
def aliasing_load_pre_hook(state_dict: Dict[str, torch.Tensor], prefix: str, *args, **kwargs):
152+
"""Load hook that ensures aliased parameters get the same value."""
153+
for group in aliased_groups:
154+
# Find a valid value for this group of aliases
155+
value = find_valid_param_value(state_dict, group)
156+
157+
if value is not None:
158+
# Apply the value to all aliases
159+
for name in group:
160+
state_dict[name] = value
161+
162+
ad_logger.debug(f"Applied value from {group[0]} to aliased parameters: {group}")
163+
164+
# Find all parameter aliases in the source model
165+
param_to_names = defaultdict(list)
166+
for name, param in model.named_parameters(remove_duplicate=False):
167+
param_to_names[id(param)].append(name)
168+
169+
# Filter to only groups with multiple aliases
170+
aliased_groups = [names for names in param_to_names.values() if len(names) > 1]
171+
172+
if not aliased_groups:
173+
return
174+
175+
# Register the hook
176+
gm._register_load_state_dict_pre_hook(aliasing_load_pre_hook)
177+
178+
179+
def torch_export_to_gm(
180+
model: nn.Module,
181+
args: Tuple[Any, ...],
182+
kwargs: Optional[Dict[str, Any]] = None,
183+
clone: bool = False, # clone or don't clone the model state_dict
184+
*,
185+
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
186+
strict: bool = False,
187+
patch_configs: Optional[Dict[str, Union[dict, Any]]] = None,
188+
patch_list: Optional[List[str]] = None,
189+
) -> fx.GraphModule:
190+
"""torch's export with wrapping into GraphModule + useful additions to the resulting module.
191+
192+
This utility improves over stock torch.export.export in the following aspects:
193+
194+
1. Provide patches for certain corner cases that torch.export does not support.
195+
2. Standardize the export process to strictly run on the meta device.
196+
3. Automatically extract the GraphModule from the exported program.
197+
4. Retain load hooks for state_dict loading from the original module.
198+
5. Manage parameter aliasing in the model.
199+
200+
Args:
201+
model: The model to export
202+
args: Arguments for the model
203+
kwargs: Keyword arguments for the model
204+
clone: Whether to clone the model state_dict
205+
dynamic_shapes: Dynamic shapes for the export
206+
strict: Whether to use strict mode for export
207+
patch_configs: Optional patch configurations. If None, all registered patches
208+
will be applied with default settings.
209+
patch_list: Optional list of patch names to apply with default settings.
210+
Cannot be used together with patch_configs.
211+
"""
212+
# Validate that both patch_configs and patch_list are not provided simultaneously
213+
if patch_configs is not None and patch_list is not None:
214+
raise ValueError("Cannot specify both patch_configs and patch_list. Use only one.")
215+
216+
# Handle patch configuration
217+
if patch_list is not None:
218+
# Convert patch_list to patch_configs format
219+
patch_configs = {patch_name: {} for patch_name in patch_list}
220+
elif patch_configs is None:
221+
# Default patch configurations - apply all registered patches with default settings
222+
patch_configs = {patch_name: {} for patch_name in ExportPatchRegistry.list_patches()}
223+
224+
# run export with patches and lifted to meta
225+
with apply_export_patches(patch_configs), lift_to_meta(model) as state_dict:
226+
# clean up args, kwargs and move to correct device
227+
args, kwargs = tree_to((args, kwargs or {}), device="meta")
228+
229+
# NOTE (lucaslie): export is VERY sensitive to the location of the inference_mode
230+
# context manager. Do NOT move it unless absolutely necessary.
231+
with torch.inference_mode():
232+
ep = te.export(model, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict)
233+
egm = ep.module()
234+
assert isinstance(egm, fx.GraphModule)
235+
236+
# load state_dict into egm
237+
# NOTE: export might have removed unused params/buffers (hence we allow unexpected keys)
238+
load_buffers_and_params(
239+
egm, state_dict, strict_missing=True, strict_unexpected=False, clone=clone
240+
)
241+
242+
# Export strips away all methods not traced during forward. The model could have
243+
# load hooks that contain logic for correct state_dict loading. We need to add those
244+
# hooks back to the exported graph module.
245+
_add_missing_load_hooks(egm, model)
246+
247+
# Add load hook to correctly load parameters that are aliased in the source model.
248+
# deduplicate params and buffers
249+
# TODO (lucaslie, suyoggupta): seems there is some overlap here. I believe we should just have
250+
# the deduplicate function and extend it to handle reading from state dict for any name.
251+
_add_load_hook_for_aliased_params(egm, model)
252+
_deduplicate_params_and_buffers(egm)
253+
254+
# clean up devices in the graph
255+
# This is a consequence of lifting to meta during export.
256+
_clean_up_device_info(egm)
257+
258+
# show exported graph
259+
ad_logger.debug("exported graph: " + str(egm))
260+
261+
return egm

0 commit comments

Comments
 (0)