Skip to content

Commit 3cd6935

Browse files
angelayipytorchmergebot
authored andcommitted
[export] Unflatten None (pytorch#153000)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#153000 Approved by: https://github.com/pianpwk
1 parent 7b806a8 commit 3cd6935

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

test/export/test_unflatten.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,27 @@ def forward(self, x):
954954
unflattened.foo = torch.compile(unflattened.foo, fullgraph=True)
955955
self.compare_outputs(orig_eager, unflattened, inputs)
956956

957+
def test_unflatten_none(self):
958+
class M2(torch.nn.Module):
959+
def forward(self, x, y):
960+
return x + x, None
961+
962+
class M(torch.nn.Module):
963+
def __init__(self) -> None:
964+
super().__init__()
965+
self.m2 = M2()
966+
967+
def forward(self, x, y):
968+
x = x + x
969+
return self.m2(x, y)
970+
971+
ep = export(
972+
M(), (torch.rand(2, 3), None), preserve_module_call_signature=("m2",)
973+
)
974+
unflattened = unflatten(ep)
975+
inp = (torch.randn(2, 3), None)
976+
self.assertTrue(torch.allclose(M()(*inp)[0], unflattened(*inp)[0]))
977+
957978

958979
if __name__ == "__main__":
959980
run_tests()

torch/export/unflatten.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,13 @@ def finalize_outputs(self):
11711171
for output in signature.outputs:
11721172
if isinstance(
11731173
output,
1174-
(TensorArgument, SymIntArgument, SymBoolArgument, SymFloatArgument),
1174+
(
1175+
TensorArgument,
1176+
SymIntArgument,
1177+
SymBoolArgument,
1178+
SymFloatArgument,
1179+
ConstantArgument,
1180+
),
11751181
):
11761182
if output.name in self.seen_nodes:
11771183
orig_outputs.append(self.seen_nodes[output.name])

0 commit comments

Comments
 (0)