Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 0022325

Browse files
authored
Update eager_transform_test (#891)
Pass trace_factory_functions=False to match the behavior before pytorch/pytorch#79638
1 parent 88eae98 commit 0022325

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

test/test_eager_transforms.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2995,7 +2995,8 @@ def f(x: torch.Tensor) -> torch.Tensor:
29952995
return x
29962996
# There's a copy_ in the graph, because the input (x) was mutated.
29972997
# To preserve semantics, functionalize() needs to propagate the mutation.
2998-
out = make_fx(functionalize(f, remove='mutations_and_views'))(torch.zeros(4, 2, device=device))
2998+
fn = make_fx(functionalize(f, remove='mutations_and_views'), trace_factory_functions=False)
2999+
out = fn(torch.zeros(4, 2, device=device))
29993000
self.assertExpectedInline((out.code), """\
30003001
30013002
@@ -3013,7 +3014,8 @@ def test_functionalize_fx_transpose_simple(self, device):
30133014

30143015
def f(x: torch.Tensor) -> torch.Tensor:
30153016
return x.transpose(1, 0)
3016-
out = make_fx(functionalize(f, remove='mutations_and_views'))(torch.zeros(4, 2, device=device))
3017+
fn = make_fx(functionalize(f, remove='mutations_and_views'), trace_factory_functions=False)
3018+
out = fn(torch.zeros(4, 2, device=device))
30173019
self.assertExpectedInline(out.code, """\
30183020
30193021
@@ -3032,7 +3034,7 @@ def f(inpt: torch.Tensor) -> torch.Tensor:
30323034
out_view.add_(1)
30333035
return out
30343036

3035-
fn = make_fx(functionalize(f, remove='mutations_and_views'))
3037+
fn = make_fx(functionalize(f, remove='mutations_and_views'), trace_factory_functions=False)
30363038
out = fn(torch.arange(4, device=device, dtype=torch.float32))
30373039
self.assertExpectedInline(out.code, """\
30383040
@@ -3057,7 +3059,7 @@ def f(inpt: torch.Tensor) -> torch.Tensor:
30573059
torch.aminmax(inpt_view, dim=0, out=(mins, maxs_view))
30583060
return (maxs, mins)
30593061

3060-
fn = make_fx(functionalize(f, remove='mutations_and_views'))
3062+
fn = make_fx(functionalize(f, remove='mutations_and_views'), trace_factory_functions=False)
30613063
out = fn(torch.arange(8, device=device, dtype=torch.float32))
30623064
self.assertExpectedInline(out.code, """\
30633065
@@ -3080,7 +3082,7 @@ def f(x: torch.Tensor) -> torch.Tensor:
30803082
y.add_(tmp)
30813083
return x
30823084

3083-
out = make_fx(functionalize(f))(torch.zeros(4, 2, device=device))
3085+
out = make_fx(functionalize(f), trace_factory_functions=False)(torch.zeros(4, 2, device=device))
30843086
self.assertExpectedInline(out.code, """\
30853087
30863088

0 commit comments

Comments
 (0)