Skip to content

Commit 52f7106

Browse files
jbschlosserpytorchmergebot
authored andcommitted
[Dynamo] Guard serialization for TUPLE_ITERATOR_LEN (pytorch#152865)
Tests serialization for TUPLE_ITERATOR_LEN; includes no non-test changes. Passing a tuple iterator as input results in the iterator being exhausted during testing. I threw together a super janky workaround via accepting a func for kwarg generation and replacing `frame_state.f_locals` with newly-generated kwargs to get fresh iterators, but insights into a better approach are welcome! Pull Request resolved: pytorch#152865 Approved by: https://github.com/jansel ghstack dependencies: pytorch#152725, pytorch#152727, pytorch#152728, pytorch#152730
1 parent fb500d0 commit 52f7106

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

test/dynamo/test_guard_serialization.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pickle
66
import sys
77
import types
8+
from collections.abc import Iterator
89
from unittest.mock import patch
910

1011
import torch
@@ -239,6 +240,11 @@ def _tracefunc(self, frame, event, arg):
239240
)
240241

241242
def _test_serialization(self, guard_type, fn, *args, **kwargs):
243+
# kwargs might contain a callable that generates kwargs
244+
kwarg_gen_fn = kwargs.get("_gen_fn", None)
245+
if kwarg_gen_fn is not None:
246+
kwargs = kwarg_gen_fn()
247+
242248
self._frame_state = None
243249
sys.settrace(self._tracefunc)
244250
if isinstance(fn, torch.nn.Module):
@@ -250,6 +256,14 @@ def _test_serialization(self, guard_type, fn, *args, **kwargs):
250256

251257
assert self._frame_state is not None
252258

259+
# Set f_locals from regenerated kwargs to handle exhausted input iterators
260+
# NB: This is super janky and might cause unforeseen problems
261+
if kwarg_gen_fn is not None:
262+
kwargs = kwarg_gen_fn()
263+
for key in self._frame_state.f_locals.keys():
264+
if key in kwargs and isinstance(kwargs[key], Iterator):
265+
self._frame_state.f_locals[key] = kwargs[key]
266+
253267
def guard_filter_fn(guards):
254268
ret = [
255269
g.guard_type == guard_type or guard_type in g.derived_guard_types
@@ -688,6 +702,41 @@ def fn(t, x):
688702
False,
689703
)
690704

705+
def test_tuple_iterator_len(self):
706+
def fn(t, x):
707+
if len(list(t)) > 2:
708+
return x * 2
709+
return x + 1
710+
711+
tup = (1, 2, 3)
712+
x = torch.randn(3)
713+
714+
# func to generate kwargs; useful for avoiding iterator exhaustion issues
715+
def _gen_kwargs(tup=tup, x=x):
716+
return {"t": iter(tup), "x": x}
717+
718+
ref, loaded = self._test_serialization(
719+
"TUPLE_ITERATOR_LEN", fn, _gen_fn=_gen_kwargs
720+
)
721+
722+
# same tuple
723+
self._test_check_fn(ref, loaded, {"t": iter(tup), "x": x}, True)
724+
self._test_check_fn(ref, loaded, {"t": iter(tup), "x": torch.randn(4)}, True)
725+
# same length tuple, different contents
726+
self._test_check_fn(ref, loaded, {"t": iter((3, 2, 1)), "x": x}, True)
727+
self._test_check_fn(
728+
ref, loaded, {"t": iter((3, 2, 1)), "x": torch.randn(4)}, True
729+
)
730+
# different tuple lengths
731+
self._test_check_fn(ref, loaded, {"t": iter((1, 2)), "x": x}, False)
732+
self._test_check_fn(
733+
ref, loaded, {"t": iter((1, 2)), "x": torch.randn(4)}, False
734+
)
735+
self._test_check_fn(ref, loaded, {"t": iter((1, 2, 3, 4)), "x": x}, False)
736+
self._test_check_fn(
737+
ref, loaded, {"t": iter((1, 2, 3, 4)), "x": torch.randn(4)}, False
738+
)
739+
691740
def test_dict_version(self):
692741
def fn(x):
693742
return pytree.tree_leaves(x)[0] + 1

0 commit comments

Comments
 (0)