Skip to content

Commit fb7f983

Browse files
authored
[NVIDIA#8924][fix] Fix AutoDeploy pattern matcher for torch 2.9 (NVIDIA#8920)
Signed-off-by: Fridah-nv <[email protected]>
1 parent b181568 commit fb7f983

File tree

1 file changed

+18
-3
lines changed
  • tensorrt_llm/_torch/auto_deploy/export

1 file changed

+18
-3
lines changed

tensorrt_llm/_torch/auto_deploy/export/export.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def aliasing_load_pre_hook(state_dict: Dict[str, torch.Tensor], prefix: str, *ar
172172
gm._register_load_state_dict_pre_hook(aliasing_load_pre_hook)
173173

174174

175-
def _clean_up_assertions(gm: fx.GraphModule):
175+
def _clean_up_assertions_and_guards(gm: fx.GraphModule):
176176
"""This transformations removes shape checks and assertions from the graph."""
177177
check_ops = {
178178
torch.ops.aten._assert_scalar,
@@ -183,11 +183,26 @@ def _clean_up_assertions(gm: fx.GraphModule):
183183
# torch.ops.aten._functional_sym_constrain_range_for_size
184184
}
185185
graph: fx.Graph = gm.graph
186+
removed = False
186187
for node in reversed(graph.nodes):
187188
if len(node.users) > 0 or not is_op(node, check_ops):
188189
continue
189190
graph.erase_node(node)
190-
canonicalize_graph(gm)
191+
removed = True
192+
for node in reversed(graph.nodes):
193+
if node.op == "call_module" and (
194+
str(node.target) == "_guards_fn" or str(node.target).startswith("_guards")
195+
):
196+
# there's typically no users of the guards, but if there are, we route through the first arg
197+
if len(node.users) > 0 and len(node.args) >= 1:
198+
node.replace_all_uses_with(node.args[0])
199+
graph.erase_node(node)
200+
removed = True
201+
202+
if removed and hasattr(gm, "_guards_fn"):
203+
delattr(gm, "_guards_fn")
204+
if removed:
205+
canonicalize_graph(gm)
191206

192207

193208
def run_forward_for_capture(
@@ -308,7 +323,7 @@ def _capture_fn(model, args, kwargs):
308323
_clean_up_device_info(egm)
309324

310325
# clean up checks --> generally the sanity checks are overly conservative and we can remove them
311-
_clean_up_assertions(egm)
326+
_clean_up_assertions_and_guards(egm)
312327

313328
# show exported graph
314329
ad_logger.debug("exported graph: " + str(egm))

0 commit comments

Comments
 (0)