1717 tree_to ,
1818)
1919from ..utils .logger import ad_logger
20+ from ..utils .node_utils import is_op
2021from .interface import ExportPatchRegistry , apply_export_patches
2122
2223try :
@@ -176,6 +177,24 @@ def aliasing_load_pre_hook(state_dict: Dict[str, torch.Tensor], prefix: str, *ar
176177 gm ._register_load_state_dict_pre_hook (aliasing_load_pre_hook )
177178
178179
180+ def _clean_up_assertions (gm : fx .GraphModule ):
181+ """This transformations removes shape checks and assertions from the graph."""
182+ check_ops = {
183+ torch .ops .aten ._assert_scalar ,
184+ torch .ops .aten .sym_constrain_range ,
185+ torch .ops .aten .sym_constrain_range_for_size ,
186+ torch .ops .aten ._assert_tensor_metadata ,
187+ # torch.ops.aten._functional_sym_constrain_range,
188+ # torch.ops.aten._functional_sym_constrain_range_for_size
189+ }
190+ graph : fx .Graph = gm .graph
191+ for node in reversed (graph .nodes ):
192+ if len (node .users ) > 0 or not is_op (node , check_ops ):
193+ continue
194+ graph .erase_node (node )
195+ canonicalize_graph (gm )
196+
197+
179198def torch_export_to_gm (
180199 model : nn .Module ,
181200 args : Tuple [Any , ...],
@@ -196,6 +215,7 @@ def torch_export_to_gm(
196215 3. Automatically extract the GraphModule from the exported program.
197216 4. Retain load hooks for state_dict loading from the original module.
198217 5. Manage parameter aliasing in the model.
218+ 6. Remove assertions from the graph.
199219
200220 Args:
201221 model: The model to export
@@ -255,6 +275,9 @@ def torch_export_to_gm(
255275 # This is a consequence of lifting to meta during export.
256276 _clean_up_device_info (egm )
257277
278+ # clean up checks --> generally the sanity checks are overly conservative and we can remove them
279+ _clean_up_assertions (egm )
280+
258281 # show exported graph
259282 ad_logger .debug ("exported graph: " + str (egm ))
260283
0 commit comments