Skip to content

Commit a6f51be

Browse files
jbschlosserpytorchmergebot
authored andcommitted
[Dynamo] Guard serialization for NN_MODULE (pytorch#152725)
Throws an error when attempting to serialize an NN_MODULE guard. It is not supported because it uses the unsupported ID_MATCH guard (pytorch#152330): https://github.com/pytorch/pytorch/blob/a6dd1c2208f29a3169c1fe96bf4e79a10aa5647d/torch/_dynamo/guards.py#L1738-L1739 Pull Request resolved: pytorch#152725 Approved by: https://github.com/jansel
1 parent 2cf7fd0 commit a6f51be

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

test/dynamo/test_guard_serialization.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pickle
66
import sys
77
import types
8+
from unittest.mock import patch
89

910
import torch
1011
import torch._dynamo.testing
@@ -604,6 +605,22 @@ def fn(x, y):
604605
# guard should fail for different y value
605606
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": 6}, False)
606607

608+
def test_nn_module(self):
609+
def fn(m, x):
610+
return m(x)
611+
612+
m = GlobalModule()
613+
x = torch.randn(3)
614+
615+
# config setting controls whether the NN_MODULE guard is installed
616+
with patch("torch._dynamo.config.inline_inbuilt_nn_modules", False):
617+
# we don't support NN_MODULE because it adds an ID_MATCH guard, and we don't
618+
# support that in serialization
619+
with self.assertRaisesRegex(
620+
RuntimeError, "NN_MODULE guard cannot be serialized."
621+
):
622+
self._test_serialization("NN_MODULE", fn, m, x)
623+
607624
def test_dict_version(self):
608625
def fn(x):
609626
return pytree.tree_leaves(x)[0] + 1

torch/_dynamo/guards.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,6 +1740,9 @@ def CONSTANT_MATCH(self, guard: Guard):
17401740
self.EQUALS_MATCH(guard)
17411741

17421742
def NN_MODULE(self, guard: Guard):
1743+
# don't support this in serialization because it uses unsupported ID_MATCH
1744+
if self.serialization_mode == "save":
1745+
raise RuntimeError("NN_MODULE guard cannot be serialized.")
17431746
self.ID_MATCH(guard)
17441747
val = self.get(guard.name)
17451748
if hasattr(val, "training"):

0 commit comments

Comments
 (0)