@@ -37,7 +37,7 @@ def __call__(self):
3737 self .model_path , module , example_inputs
3838 )
3939 bw_gm , backward_inputs = self .capture_backward_graph (module , example_inputs )
40- print (bw_gm .graph )
40+ # print(bw_gm.graph)
4141 self .builtin_extractor (bw_gm , backward_inputs )
4242
4343 def capture_backward_graph (self , module , example_inputs ):
@@ -76,7 +76,28 @@ def wrapped_forward(*args):
7676
7777 outs_grad = [torch .ones_like (out ) for out in outs ]
7878 torch .autograd .backward (outs , outs_grad )
79- return backward_gm_holder ["gm" ], backward_inputs
79+ bw_gm = self ._remove_none_from_output (backward_gm_holder ["gm" ])
80+ return bw_gm , backward_inputs
81+
82+ def _remove_none_from_output (self , gm ):
83+ output_node = next (
84+ (n for n in gm .graph .nodes if n .op == "output" ),
85+ None ,
86+ )
87+ outs = (
88+ output_node .args [0 ]
89+ if output_node and isinstance (output_node .args , (tuple , list ))
90+ else output_node .args
91+ )
92+ if isinstance (outs , (tuple , list )):
93+ new_outs = tuple (out for out in outs if out is not None )
94+ if new_outs != outs :
95+ output_node .args = (new_outs ,)
96+
97+ gm .graph .eliminate_dead_code ()
98+ gm .graph .lint ()
99+ gm .recompile ()
100+ return gm
80101
81102 def _requires_grad (self , name , tensor ):
82103 if not tensor .is_floating_point ():
0 commit comments