@@ -172,7 +172,7 @@ def aliasing_load_pre_hook(state_dict: Dict[str, torch.Tensor], prefix: str, *ar
172172 gm ._register_load_state_dict_pre_hook (aliasing_load_pre_hook )
173173
174174
175- def _clean_up_assertions (gm : fx .GraphModule ):
175+ def _clean_up_assertions_and_guards (gm : fx .GraphModule ):
176176 """This transformations removes shape checks and assertions from the graph."""
177177 check_ops = {
178178 torch .ops .aten ._assert_scalar ,
@@ -183,11 +183,26 @@ def _clean_up_assertions(gm: fx.GraphModule):
183183 # torch.ops.aten._functional_sym_constrain_range_for_size
184184 }
185185 graph : fx .Graph = gm .graph
186+ removed = False
186187 for node in reversed (graph .nodes ):
187188 if len (node .users ) > 0 or not is_op (node , check_ops ):
188189 continue
189190 graph .erase_node (node )
190- canonicalize_graph (gm )
191+ removed = True
192+ for node in reversed (graph .nodes ):
193+ if node .op == "call_module" and (
194+ str (node .target ) == "_guards_fn" or str (node .target ).startswith ("_guards" )
195+ ):
196+ # there's typically no users of the guards, but if there are, we route through the first arg
197+ if len (node .users ) > 0 and len (node .args ) >= 1 :
198+ node .replace_all_uses_with (node .args [0 ])
199+ graph .erase_node (node )
200+ removed = True
201+
202+ if removed and hasattr (gm , "_guards_fn" ):
203+ delattr (gm , "_guards_fn" )
204+ if removed :
205+ canonicalize_graph (gm )
191206
192207
193208def run_forward_for_capture (
@@ -308,7 +323,7 @@ def _capture_fn(model, args, kwargs):
308323 _clean_up_device_info (egm )
309324
310325 # clean up checks --> generally the sanity checks are overly conservative and we can remove them
311- _clean_up_assertions (egm )
326+ _clean_up_assertions_and_guards (egm )
312327
313328 # show exported graph
314329 ad_logger .debug ("exported graph: " + str (egm ))
0 commit comments