Skip to content

Commit fb500d0

Browse files
jbschlosserpytorchmergebot
authored andcommitted
[Dynamo] Guard serialization for SEQUENCE_LENGTH (pytorch#152730)
Tests only; no other changes needed. Test logic uses a tuple function input to trigger installation of a SEQUENCE_LENGTH guard. Pull Request resolved: pytorch#152730 Approved by: https://github.com/jansel ghstack dependencies: pytorch#152725, pytorch#152727, pytorch#152728
1 parent 42954ab commit fb500d0

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

test/dynamo/test_guard_serialization.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,39 @@ def fn(x):
655655
):
656656
self._test_serialization("CLOSURE_MATCH", fn, x)
657657

658+
def test_sequence_length(self):
659+
# tuple input installs a SEQUENCE_LENGTH guard
660+
def fn(t, x):
661+
return t[1] + x
662+
663+
t = tuple(torch.randn(3) for _ in range(3))
664+
x = torch.randn(3)
665+
666+
ref, loaded = self._test_serialization("SEQUENCE_LENGTH", fn, t, x)
667+
self._test_check_fn(ref, loaded, {"x": x, "t": t}, True)
668+
self._test_check_fn(
669+
ref,
670+
loaded,
671+
{
672+
"x": torch.randn(3),
673+
"t": tuple(torch.randn(3) for _ in range(3)),
674+
},
675+
True,
676+
)
677+
# different types in tuple of same length shouldn't fail SEQUENCE_LENGTH guard
678+
# (it should fail the separate TYPE_MATCH guard but that isn't tested here)
679+
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "t": (0, 1, 2)}, True)
680+
# different length tuple
681+
self._test_check_fn(
682+
ref,
683+
loaded,
684+
{
685+
"x": torch.randn(3),
686+
"t": tuple(torch.randn(3) for _ in range(4)),
687+
},
688+
False,
689+
)
690+
658691
def test_dict_version(self):
659692
def fn(x):
660693
return pytree.tree_leaves(x)[0] + 1

0 commit comments

Comments
 (0)