Skip to content

Commit bcfea48

Browse files
bdhirshpytorchmergebot
authored andcommitted
add and fix OpInfo tests for the default partitioner (pytorch#165372)
I noticed the default partitioner was breaking in some dynamic shape tests, so prior to turning off functionalization I want to tweak it to pass all of our OpInfo tests Pull Request resolved: pytorch#165372 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#165327
1 parent d2e1dbc commit bcfea48

File tree

3 files changed

+50
-11
lines changed

3 files changed

+50
-11
lines changed

test/functorch/test_aotdispatch.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8059,7 +8059,7 @@ def fn(x):
80598059
}
80608060

80618061

8062-
def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
8062+
def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False, use_min_cut=True):
80638063
if not op.supports_autograd:
80648064
self.skipTest("Op does not support autograd")
80658065

@@ -8090,6 +8090,7 @@ def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
80908090
check_gradients=True,
80918091
try_check_data_specialization=try_check_data_specialization,
80928092
skip_correctness_check=op.skip_correctness_check_compile_vs_eager,
8093+
use_min_cut=use_min_cut,
80938094
)
80948095
except DynamicOutputShapeException:
80958096
self.skipTest("Dynamic output shape operation in trace")
@@ -8190,6 +8191,29 @@ def test_aot_autograd_exhaustive(self, device, dtype, op):
81908191
def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
81918192
_test_aot_autograd_helper(self, device, dtype, op, dynamic=True)
81928193

8194+
@ops(op_db + hop_db, allowed_dtypes=(torch.float,))
8195+
@skipOps(
8196+
"TestEagerFusionOpInfo",
8197+
"test_aot_autograd_default_partition_exhaustive",
8198+
aot_autograd_failures,
8199+
)
8200+
def test_aot_autograd_default_partition_exhaustive(self, device, dtype, op):
8201+
_test_aot_autograd_helper(self, device, dtype, op, use_min_cut=False)
8202+
8203+
@ops(op_db + hop_db, allowed_dtypes=(torch.float,))
8204+
@patch("functorch.compile.config.debug_assert", True)
8205+
@skipOps(
8206+
"TestEagerFusionOpInfo",
8207+
"test_aot_autograd_symbolic_default_partition_exhaustive",
8208+
aot_autograd_failures | symbolic_aot_autograd_failures,
8209+
)
8210+
def test_aot_autograd_symbolic_default_partition_exhaustive(
8211+
self, device, dtype, op
8212+
):
8213+
_test_aot_autograd_helper(
8214+
self, device, dtype, op, dynamic=True, use_min_cut=False
8215+
)
8216+
81938217

81948218
aot_autograd_module_failures = set(
81958219
{

torch/_functorch/partitioners.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,11 @@ def default_partition(
10251025
# Symints must be kept separate from tensors so that PythonFunction only calls
10261026
# save_for_backward on tensors and stashes symints in autograd .ctx
10271027
saved_sym_nodes.append(node)
1028-
elif "tensor_meta" not in node.meta and node.op == "call_function":
1028+
elif (
1029+
"tensor_meta" not in node.meta
1030+
and node.op == "call_function"
1031+
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
1032+
):
10291033
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
10301034
users = node.users
10311035
assert all(user.target == operator.getitem for user in users)

torch/testing/_internal/optests/aot_autograd.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torch.utils._pytree as pytree
55
from torch.testing._utils import wrapper_set_seed
6-
from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop
6+
from functorch.compile import compiled_function, min_cut_rematerialization_partition, default_partition, nop
77
from .make_fx import randomize
88
import re
99

@@ -38,6 +38,7 @@ def aot_autograd_check(
3838
assert_equals_fn=torch.testing.assert_close,
3939
check_gradients=True,
4040
try_check_data_specialization=False,
41+
use_min_cut=True,
4142
skip_correctness_check=False):
4243
"""Compares func(*args, **kwargs) in eager-mode to under AOTAutograd.
4344
@@ -63,14 +64,24 @@ def func_no_tensors(args):
6364
c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec)
6465
return func(*c_args, **c_kwargs)
6566

66-
compiled_f = compiled_function(
67-
func_no_tensors,
68-
nop,
69-
nop,
70-
dynamic=dynamic,
71-
partition_fn=min_cut_rematerialization_partition,
72-
keep_inference_input_mutations=True
73-
)
67+
if use_min_cut:
68+
compiled_f = compiled_function(
69+
func_no_tensors,
70+
nop,
71+
nop,
72+
dynamic=dynamic,
73+
partition_fn=min_cut_rematerialization_partition,
74+
keep_inference_input_mutations=True
75+
)
76+
else:
77+
compiled_f = compiled_function(
78+
func_no_tensors,
79+
nop,
80+
nop,
81+
dynamic=dynamic,
82+
partition_fn=default_partition,
83+
keep_inference_input_mutations=True
84+
)
7485

7586
out = wrapper_set_seed(func_no_tensors, args)
7687
if check_gradients == "auto":

0 commit comments

Comments
 (0)