Skip to content

Commit 649ceda

Browse files
Raman-RHangelayi
authored andcommitted
[export] handling NamedTuple inputs (pytorch#162959)
Fixes pytorch#160547 ### Summary: bug ``` def test_namedtuple(self): from collections import namedtuple Point = namedtuple('Point', 'x y') class M(torch.nn.Module): def forward(self, x, y): return x + y inp = Point(torch.ones(3), torch.ones(3)) print(M()(*inp)) # errors ep = torch.export.export(M(), inp, strict=False) print(ep) # succeeds ep = torch.export.export(M(), inp, strict=True) print(ep) # workaround could be to convert namedtuple to a kwarg inp_kwargs = {field: getattr(inp, field) for field in inp._fields} ep = torch.export.export(M(), (), inp_kwargs) print(ep) ``` FIx : namedtuple is subclass of tuple but namedtuple is not expected So, this change handles named tuple case I have added 🧪 test case for this as well Pull Request resolved: pytorch#162959 Approved by: https://github.com/angelayi Co-authored-by: Angela Yi <[email protected]>
1 parent 2aadcea commit 649ceda

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

test/export/test_export.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16017,6 +16017,26 @@ def forward(self, q, k, v):
1601716017
):
1601816018
export(Foo(), (torch.randn(1, 33, 256, 128), k, v))
1601916019

16020+
def test_namedtuple_input_export(self):
16021+
# test for NamedTuple inputs with both strict and non-strict export modes
16022+
from collections import namedtuple
16023+
16024+
PointNT = namedtuple("PointNT", ["x", "y"])
16025+
16026+
class M(torch.nn.Module):
16027+
def forward(self, x, y):
16028+
return x + y
16029+
16030+
inp = PointNT(torch.ones(3), torch.ones(3))
16031+
16032+
ep_non_strict = export(M(), inp)
16033+
result_non_strict = ep_non_strict.module()(*inp)
16034+
16035+
ep_strict = export(M(), inp, strict=True)
16036+
result_strict = ep_strict.module()(*inp)
16037+
16038+
self.assertEqual(result_non_strict, result_strict)
16039+
1602016040

1602116041
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
1602216042
class TestOneOffModelExportResult(TestCase):

torch/export/_trace.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,6 +1265,9 @@ def _process_export_inputs(
12651265
f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
12661266
)
12671267
kwargs = kwargs if kwargs is not None else {}
1268+
if pytree.is_namedtuple_instance(args):
1269+
args = tuple(args)
1270+
12681271
_, original_in_spec = pytree.tree_flatten((args, kwargs))
12691272

12701273
verify_additional_inputs: Callable[[ExportedProgram], None]

0 commit comments

Comments
 (0)