Skip to content

Commit 4d89913

Browse files
committed
move assertion check cleanup back to stock export (#93)
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 1ec1448 commit 4d89913

File tree

4 files changed

+25
-42
lines changed

4 files changed

+25
-42
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,5 @@ transforms:
1717
stage: post_export
1818
cleanup_noop_add:
1919
stage: post_export
20-
cleanup_checks:
21-
stage: post_export
2220
cleanup_input_constraints:
2321
stage: post_export

tensorrt_llm/_torch/auto_deploy/export/export.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
tree_to,
1818
)
1919
from ..utils.logger import ad_logger
20+
from ..utils.node_utils import is_op
2021
from .interface import ExportPatchRegistry, apply_export_patches
2122

2223
try:
@@ -176,6 +177,24 @@ def aliasing_load_pre_hook(state_dict: Dict[str, torch.Tensor], prefix: str, *ar
176177
gm._register_load_state_dict_pre_hook(aliasing_load_pre_hook)
177178

178179

180+
def _clean_up_assertions(gm: fx.GraphModule):
181+
"""This transformations removes shape checks and assertions from the graph."""
182+
check_ops = {
183+
torch.ops.aten._assert_scalar,
184+
torch.ops.aten.sym_constrain_range,
185+
torch.ops.aten.sym_constrain_range_for_size,
186+
torch.ops.aten._assert_tensor_metadata,
187+
# torch.ops.aten._functional_sym_constrain_range,
188+
# torch.ops.aten._functional_sym_constrain_range_for_size
189+
}
190+
graph: fx.Graph = gm.graph
191+
for node in reversed(graph.nodes):
192+
if len(node.users) > 0 or not is_op(node, check_ops):
193+
continue
194+
graph.erase_node(node)
195+
canonicalize_graph(gm)
196+
197+
179198
def torch_export_to_gm(
180199
model: nn.Module,
181200
args: Tuple[Any, ...],
@@ -196,6 +215,7 @@ def torch_export_to_gm(
196215
3. Automatically extract the GraphModule from the exported program.
197216
4. Retain load hooks for state_dict loading from the original module.
198217
5. Manage parameter aliasing in the model.
218+
6. Remove assertions from the graph.
199219
200220
Args:
201221
model: The model to export
@@ -255,6 +275,9 @@ def torch_export_to_gm(
255275
# This is a consequence of lifting to meta during export.
256276
_clean_up_device_info(egm)
257277

278+
# clean up checks --> generally the sanity checks are overly conservative and we can remove them
279+
_clean_up_assertions(egm)
280+
258281
# show exported graph
259282
ad_logger.debug("exported graph: " + str(egm))
260283

tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_checks.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from ..interface import BaseTransform, TransformInfo, TransformRegistry
1111

1212

13+
# TODO (lucaslie): consider reconfiguring this transform to run before we switch to flattened
14+
# sequences which is done in update_in_out_nodes at the moment.
1315
@TransformRegistry.register("cleanup_input_constraints")
1416
class CleanupInputConstraints(BaseTransform):
1517
"""Cleanup input constraints from the graph.

0 commit comments

Comments
 (0)