Skip to content

Commit 3684be0

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo] Fix source for lru_cache method (pytorch#157292)
Fixes - pytorch#157273 Pull Request resolved: pytorch#157292 Approved by: https://github.com/zou3519, https://github.com/malfet, https://github.com/jansel
1 parent 2349151 commit 3684be0

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

test/dynamo/test_functions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4132,6 +4132,21 @@ def fn(x):
41324132
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
41334133
self.assertEqual(fn(x), opt_fn(x))
41344134

4135+
def test_functools_cache_guard(self):
4136+
class Foo:
4137+
@functools.lru_cache # noqa: B019
4138+
def run(self, val, c=1.0):
4139+
return val * c * 2
4140+
4141+
f = Foo()
4142+
4143+
def fn(x):
4144+
return f.run(x)
4145+
4146+
x = torch.randn(2)
4147+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4148+
self.assertEqual(fn(x), opt_fn(x))
4149+
41354150

41364151
def udf_mul(x, y):
41374152
return x * y

torch/_dynamo/variables/user_defined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,7 @@ def var_getattr(self, tx: "InstructionTranslator", name):
12661266
# extract the underlying method from the wrapped function. To handle
12671267
# it, manually create a wrapped user method vt.
12681268
return variables.WrapperUserMethodVariable(
1269-
subobj, "__wrapped__", self, source=self.source
1269+
subobj, "__wrapped__", self, source=source
12701270
)
12711271
elif inspect.getattr_static(
12721272
type(subobj), "__get__", NO_SUCH_SUBOBJ

0 commit comments

Comments
 (0)