Skip to content

Commit 88a4619

Browse files
authored
Generate 2^n tests, not 3^n tests for vmap (#937)
Previously, our vmap tests were generating 3^n tests per OpInfo sample. For each tensor argument, we would generate all permutations of bdim = (0, -1, None). This is pretty redundant and also performance intensive. The original purpose of this was to make sure functorch's batching rules work with bdim other than 0 (it's really easy to forget that the bdim is not always at the front of the tensor). The new strategy is to generate all permutations of bdim = (-1, None) and also include the case where all bdims are 0 as a sanity check. This leads to 2^n tests. On my machine test_vmap goes from 3m25s to 2m45s, which is promising. However the biggest wins are going to be in test_ops where n can be as high as 10.
1 parent 590e861 commit 88a4619

File tree

2 files changed

+87
-15
lines changed

2 files changed

+87
-15
lines changed

test/common_utils.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,90 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
3838
return loop_out
3939

4040

41-
def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=3, bdims=(0, -1), for_batch_norm=False):
41+
# This is kind of dangerous, please think carefully before using it.
42+
# Known risks:
43+
# - the return better not be mutated so it's best to return immutable types
44+
# (e.g. prefer tuples to list)
45+
# - Don't hash tensors in a global context, that'll keep them around forever
46+
def memoize(fn):
47+
memo = {}
48+
def wrapped(*args):
49+
if args not in memo:
50+
memo[args] = fn(*args)
51+
return memo[args]
52+
return wrapped
53+
54+
55+
# NB: This is O(2 ** num_tensors).
56+
# num_tensors ranges from 1 to 10, with 2-4 being most common.
57+
# Try not to extravagate it if you're modifying it.
58+
@memoize
59+
def get_bdim_choices(num_tensors):
60+
choices = []
61+
62+
# full of zeros
63+
choices.append((0,) * num_tensors)
64+
65+
# All permutations of (-1, None)
66+
options = (-1, None)
67+
for choice in itertools.product(options, repeat=num_tensors):
68+
choices.append(choice)
69+
70+
assert choices[-1] == (None,) * num_tensors
71+
return tuple(choices[:-1])
72+
73+
74+
def add_batch_dim(arg, bdim, batch_size=3):
75+
assert bdim == 0 or bdim == -1
76+
assert isinstance(arg, torch.Tensor)
77+
if bdim == 0:
78+
shape = [1] * len(arg.shape)
79+
shape.insert(bdim, batch_size)
80+
return (arg.repeat(shape), bdim)
81+
if bdim == -1:
82+
arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous()
83+
return (arg, bdim)
84+
85+
86+
def construct_in_dims(bdim_choice_for_tensors, is_tensors):
87+
result = []
88+
bdim = iter(bdim_choice_for_tensors)
89+
for is_tensor in is_tensors:
90+
if not is_tensor:
91+
result.append(None)
92+
continue
93+
result.append(next(bdim))
94+
return tuple(result)
95+
96+
97+
def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=2, *, for_batch_norm=False):
98+
if for_batch_norm:
99+
# TODO: delete this path
100+
return get_exhaustive_batched_inputs_batch_norm(arg_values, kwarg_values, batch_size)
101+
102+
flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
103+
is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
104+
bdim_choices = get_bdim_choices(sum(is_tensors))
105+
106+
@memoize
107+
def get_batched_arg(arg, bdim):
108+
assert isinstance(arg, torch.Tensor)
109+
assert bdim is not None
110+
result, _ = add_batch_dim(arg, bdim, batch_size)
111+
return result
112+
113+
for bdim_choice in bdim_choices:
114+
flat_in_dims = construct_in_dims(bdim_choice, is_tensors)
115+
116+
flat_batched_args = tuple(arg if in_dim is None else get_batched_arg(arg, in_dim)
117+
for arg, in_dim in zip(flat_args, flat_in_dims))
118+
batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec)
119+
in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec)
120+
yield batched_args, in_dims, kwarg_values
121+
122+
123+
def get_exhaustive_batched_inputs_batch_norm(arg_values, kwarg_values, batch_size=3, bdims=(0, -1)):
124+
for_batch_norm = True
42125
assert bdims == (0,) or bdims == (0, -1)
43126

44127
def add_batch_dim(arg, bdim, batch_size=3):
@@ -112,18 +195,12 @@ def add_batch_choices(a):
112195
yield batched_args_tuple, in_dims_tuple, kwarg_values
113196

114197

115-
def get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch_size=3, bdims=(0, -1)):
116-
return get_exhaustive_batched_inputs(arg_values, kwarg_values,
117-
batch_size=batch_size, bdims=bdims, for_batch_norm=True)
118-
119-
120-
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, opinfo=None, compute_loop_out=True, bdims=(0, -1)):
198+
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, opinfo=None, compute_loop_out=True):
121199
out_dim = 0
122200
batch_size = 2
123-
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, bdims=bdims)
124201
batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm
125-
if opinfo is not None and opinfo.name in batch_norm_fns:
126-
generator = get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch_size, bdims=bdims)
202+
for_batch_norm = opinfo is not None and opinfo.name in batch_norm_fns
203+
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, for_batch_norm=for_batch_norm)
127204
for batched_args, in_dims, kwarg_values in generator:
128205
if compute_loop_out:
129206
loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values)

test/test_ops.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -653,14 +653,9 @@ def test_vmapvjp(self, device, dtype, op):
653653

654654
# The following are bugs that we should fix
655655
skip('nn.functional.max_pool1d'), # fails on cpu, runs on cuda
656-
xfail('nn.functional.batch_norm', device_type='cuda'),
657-
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
658656
xfail('_masked.mean'),
659657
xfail('_masked.prod'),
660658

661-
# Causing issues with multiple cpu levels of forward mode AD
662-
xfail('nn.functional.batch_norm', device_type='cpu'),
663-
664659
# Not actually a problem: embedding with max_norm mutates the weight
665660
# and causes different runs to produce different results.
666661
# skip because this is flaky depending on what the max_norm is!

0 commit comments

Comments
 (0)