Skip to content

Commit 058782c

Browse files
malaybagpytorchmergebot
authored andcommitted
[torch.export] Rmoving unused constants - add support for corner case (pytorch#165205)
Summary: In some cases unused constant had only one level of child node, no second level of child node. Those constants should be removed too. The added test case has the scenario where this scenario will happen. Test Plan: ``` buck test mode/opt caffe2/test:test_export -- 'test_unused_constant' ``` https://www.internalfb.com/intern/testinfra/testrun/15481123837456594 Differential Revision: D84398413 Pull Request resolved: pytorch#165205 Approved by: https://github.com/angelayi
1 parent 2b4ef6b commit 058782c

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

test/export/test_export.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,6 +1628,24 @@ def forward(self, x):
16281628
ep = export(M(), (torch.ones(3),))
16291629
self.assertEqual(len(ep.constants), 0)
16301630

1631+
class M(torch.nn.Module):
1632+
def __init__(self, num_features: int = 1) -> None:
1633+
super().__init__()
1634+
self.num_features = num_features
1635+
1636+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
1637+
res = [torch.Tensor([])] * self.num_features
1638+
for i in range(self.num_features):
1639+
res[i] = x * (i + 1)
1640+
return res
1641+
1642+
inp = torch.ones(3)
1643+
ep = export(M(), (inp,))
1644+
self.assertEqual(len(ep.constants), 0)
1645+
1646+
unf = unflatten(ep)
1647+
self.assertTrue(torch.allclose(M()(inp)[0], unf(inp)[0]))
1648+
16311649
def test_unbacked_bincount(self):
16321650
class Foo(torch.nn.Module):
16331651
def forward(self, xs):

torch/_export/passes/lift_constants_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]:
142142
if len(lift_fresh_node.users) > 1:
143143
return None
144144

145+
# Case 1: lift node is not used anywhere
146+
if len(lift_fresh_node.users) == 0:
147+
return [lift_fresh_node, node]
148+
145149
detach_node = next(iter(lift_fresh_node.users.keys()))
146150
if not (
147151
detach_node.op == "call_function"
@@ -156,6 +160,7 @@ def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]:
156160
if len(detach_node.users) > 0:
157161
return None
158162
else:
163+
# Case 2: Lift node's child is not used anywhere
159164
return [detach_node, lift_fresh_node, node]
160165

161166

0 commit comments

Comments
 (0)