Skip to content

Commit d9b8473

Browse files
jbschlosserpytorchmergebot
authored andcommitted
[Dynamo] Guard serialization for RANGE_ITERATOR_MATCH (pytorch#152872)
Tests serialization for RANGE_ITERATOR_MATCH; includes no non-test changes. This PR handles iterator exhaustion issues by utilizing the janky solution from pytorch#152865; it passes a function to generate kwargs and `frame_state.f_locals` is updated with fresh iterators through a second kwarg generation pass after initial tracing. Pull Request resolved: pytorch#152872 Approved by: https://github.com/jansel ghstack dependencies: pytorch#152725, pytorch#152727, pytorch#152728, pytorch#152730, pytorch#152865
1 parent 52f7106 commit d9b8473

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

test/dynamo/test_guard_serialization.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,48 @@ def _gen_kwargs(tup=tup, x=x):
737737
ref, loaded, {"t": iter((1, 2, 3, 4)), "x": torch.randn(4)}, False
738738
)
739739

740+
def test_range_iterator_match(self):
741+
def fn(x, r):
742+
y = x
743+
for val in r:
744+
y = x + val
745+
return y
746+
747+
x = torch.randn(3)
748+
749+
def _gen_kwargs(x=x):
750+
return {"x": x, "r": iter(range(2, 15, 3))}
751+
752+
ref, loaded = self._test_serialization(
753+
"RANGE_ITERATOR_MATCH", fn, _gen_fn=_gen_kwargs
754+
)
755+
756+
# same range
757+
self._test_check_fn(ref, loaded, {"x": x, "r": iter(range(2, 15, 3))}, True)
758+
self._test_check_fn(
759+
ref, loaded, {"x": torch.randn(4), "r": iter(range(2, 15, 3))}, True
760+
)
761+
# equivalent even with different end
762+
self._test_check_fn(ref, loaded, {"x": x, "r": iter(range(2, 16, 3))}, True)
763+
self._test_check_fn(
764+
ref, loaded, {"x": torch.randn(4), "r": iter(range(2, 16, 3))}, True
765+
)
766+
# different start
767+
self._test_check_fn(ref, loaded, {"x": x, "r": iter(range(1, 15, 3))}, False)
768+
self._test_check_fn(
769+
ref, loaded, {"x": torch.randn(4), "r": iter(range(1, 15, 3))}, False
770+
)
771+
# different end resulting in different values
772+
self._test_check_fn(ref, loaded, {"x": x, "r": iter(range(2, 18, 3))}, False)
773+
self._test_check_fn(
774+
ref, loaded, {"x": torch.randn(4), "r": iter(range(2, 18, 3))}, False
775+
)
776+
# different step
777+
self._test_check_fn(ref, loaded, {"x": x, "r": iter(range(2, 15, 4))}, False)
778+
self._test_check_fn(
779+
ref, loaded, {"x": torch.randn(4), "r": iter(range(2, 15, 4))}, False
780+
)
781+
740782
def test_dict_version(self):
741783
def fn(x):
742784
return pytree.tree_leaves(x)[0] + 1

0 commit comments

Comments
 (0)