Skip to content

Commit 42954ab

Browse files
jbschlosserpytorchmergebot
authored andcommitted
[Dynamo] Guard serialization for CLOSURE_MATCH (pytorch#152728)
Unsupported because it uses unsupported FUNCTION_MATCH. Pull Request resolved: pytorch#152728 Approved by: https://github.com/jansel ghstack dependencies: pytorch#152725, pytorch#152727
1 parent a9186ec commit 42954ab

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

test/dynamo/test_guard_serialization.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def forward(self, x):
3838
return x + 1
3939

4040

41+
def global_func(x):
42+
return x + 1
43+
44+
4145
class SubclassWithMeta(torch.Tensor):
4246
@staticmethod
4347
def __new__(cls, a, extra, outer_size=None, outer_stride=None):
@@ -637,6 +641,20 @@ def fn(x):
637641
):
638642
self._test_serialization("FUNCTION_MATCH", fn, x)
639643

644+
def test_closure_match(self):
645+
def fn(x):
646+
# usage of this global function installs a CLOSURE_MATCH guard
647+
return global_func(x)
648+
649+
x = torch.randn(3)
650+
651+
# we don't support CLOSURE_MATCH because it adds a FUNCTION_MATCH guard, and we don't
652+
# support that in serialization
653+
with self.assertRaisesRegex(
654+
RuntimeError, "CLOSURE_MATCH guard cannot be serialized."
655+
):
656+
self._test_serialization("CLOSURE_MATCH", fn, x)
657+
640658
def test_dict_version(self):
641659
def fn(x):
642660
return pytree.tree_leaves(x)[0] + 1

torch/_dynamo/guards.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,6 +1768,9 @@ def FUNCTION_MATCH(self, guard: Guard):
17681768

17691769
def CLOSURE_MATCH(self, guard: Guard):
17701770
"""matches a closure by __code__ id."""
1771+
# don't support this in serialization because it uses unsupported FUNCTION_MATCH
1772+
if self.serialization_mode == "save":
1773+
raise RuntimeError("CLOSURE_MATCH guard cannot be serialized.")
17711774
val = self.get(guard.name)
17721775
# Strictly only want user-defined functions
17731776
if type(val) == types.FunctionType and hasattr(val, "__code__"):

0 commit comments

Comments
 (0)