Skip to content

Commit de2a081

Browse files
committed
Remove None from returns.
1 parent 40bafbc commit de2a081

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

graph_net/torch/sample_passes/backward_graph_extractor.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __call__(self):
3737
self.model_path, module, example_inputs
3838
)
3939
bw_gm, backward_inputs = self.capture_backward_graph(module, example_inputs)
40-
print(bw_gm.graph)
40+
# print(bw_gm.graph)
4141
self.builtin_extractor(bw_gm, backward_inputs)
4242

4343
def capture_backward_graph(self, module, example_inputs):
@@ -76,7 +76,28 @@ def wrapped_forward(*args):
7676

7777
outs_grad = [torch.ones_like(out) for out in outs]
7878
torch.autograd.backward(outs, outs_grad)
79-
return backward_gm_holder["gm"], backward_inputs
79+
bw_gm = self._remove_none_from_output(backward_gm_holder["gm"])
80+
return bw_gm, backward_inputs
81+
82+
def _remove_none_from_output(self, gm):
83+
output_node = next(
84+
(n for n in gm.graph.nodes if n.op == "output"),
85+
None,
86+
)
87+
outs = (
88+
output_node.args[0]
89+
if output_node and isinstance(output_node.args, (tuple, list))
90+
else output_node.args
91+
)
92+
if isinstance(outs, (tuple, list)):
93+
new_outs = tuple(out for out in outs if out is not None)
94+
if new_outs != outs:
95+
output_node.args = (new_outs,)
96+
97+
gm.graph.eliminate_dead_code()
98+
gm.graph.lint()
99+
gm.recompile()
100+
return gm
80101

81102
def _requires_grad(self, name, tensor):
82103
if not tensor.is_floating_point():

0 commit comments

Comments
 (0)