@@ -2995,7 +2995,8 @@ def f(x: torch.Tensor) -> torch.Tensor:
2995
2995
return x
2996
2996
# There's a copy_ in the graph, because the input (x) was mutated.
2997
2997
# To preserve semantics, functionalize() needs to propagate the mutation.
2998
- out = make_fx (functionalize (f , remove = 'mutations_and_views' ))(torch .zeros (4 , 2 , device = device ))
2998
+ fn = make_fx (functionalize (f , remove = 'mutations_and_views' ), trace_factory_functions = False )
2999
+ out = fn (torch .zeros (4 , 2 , device = device ))
2999
3000
self .assertExpectedInline ((out .code ), """\
3000
3001
3001
3002
@@ -3013,7 +3014,8 @@ def test_functionalize_fx_transpose_simple(self, device):
3013
3014
3014
3015
def f (x : torch .Tensor ) -> torch .Tensor :
3015
3016
return x .transpose (1 , 0 )
3016
- out = make_fx (functionalize (f , remove = 'mutations_and_views' ))(torch .zeros (4 , 2 , device = device ))
3017
+ fn = make_fx (functionalize (f , remove = 'mutations_and_views' ), trace_factory_functions = False )
3018
+ out = fn (torch .zeros (4 , 2 , device = device ))
3017
3019
self .assertExpectedInline (out .code , """\
3018
3020
3019
3021
@@ -3032,7 +3034,7 @@ def f(inpt: torch.Tensor) -> torch.Tensor:
3032
3034
out_view .add_ (1 )
3033
3035
return out
3034
3036
3035
- fn = make_fx (functionalize (f , remove = 'mutations_and_views' ))
3037
+ fn = make_fx (functionalize (f , remove = 'mutations_and_views' ), trace_factory_functions = False )
3036
3038
out = fn (torch .arange (4 , device = device , dtype = torch .float32 ))
3037
3039
self .assertExpectedInline (out .code , """\
3038
3040
@@ -3057,7 +3059,7 @@ def f(inpt: torch.Tensor) -> torch.Tensor:
3057
3059
torch .aminmax (inpt_view , dim = 0 , out = (mins , maxs_view ))
3058
3060
return (maxs , mins )
3059
3061
3060
- fn = make_fx (functionalize (f , remove = 'mutations_and_views' ))
3062
+ fn = make_fx (functionalize (f , remove = 'mutations_and_views' ), trace_factory_functions = False )
3061
3063
out = fn (torch .arange (8 , device = device , dtype = torch .float32 ))
3062
3064
self .assertExpectedInline (out .code , """\
3063
3065
@@ -3080,7 +3082,7 @@ def f(x: torch.Tensor) -> torch.Tensor:
3080
3082
y .add_ (tmp )
3081
3083
return x
3082
3084
3083
- out = make_fx (functionalize (f ))(torch .zeros (4 , 2 , device = device ))
3085
+ out = make_fx (functionalize (f ), trace_factory_functions = False )(torch .zeros (4 , 2 , device = device ))
3084
3086
self .assertExpectedInline (out .code , """\
3085
3087
3086
3088
0 commit comments