Skip to content

Commit 5497d75

Browse files
authored
Make it so that vmap tests generate with bdim=-1 as well as 0 (#204)
Fixes #62. There are a lot of new xfails.
1 parent 66a05dc commit 5497d75

File tree

7 files changed

+85
-29
lines changed

7 files changed

+85
-29
lines changed

functorch/csrc/BatchRulesLinearAlgebra.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
175175
VARIADIC_BDIMS(logdet);
176176
VARIADIC_BDIMS(matrix_exp);
177177
VARIADIC_BDIMS(pinverse);
178+
VARIADIC_BDIMS(inverse);
178179
VARIADIC_BDIMS_BOXED(slogdet);
179180
VARIADIC_BDIMS_BOXED(_svd_helper);
180181
VARIADIC_BDIMS_BOXED(solve);

functorch/csrc/BatchRulesReduceOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> aminmax_batchin
271271
dim = maybe_wrap_dim(dim.value(), logical_rank) + 1;
272272
} else {
273273
// flatten the input except for batch-dim
274-
auto bsize = self.size(0);
275-
self_ = self.view({bsize, -1});
274+
auto bsize = self_.size(0);
275+
self_ = self_.view({bsize, -1});
276276
dim = 1;
277277
}
278278

functorch/csrc/BatchRulesUnaryOps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
3939
UNARY_POINTWISE_ALL(floor);
4040
UNARY_POINTWISE_ALL(frac);
4141
UNARY_POINTWISE(glu);
42-
UNARY_POINTWISE(inverse);
4342
UNARY_POINTWISE(isfinite);
4443
UNARY_POINTWISE(isnan);
4544
UNARY_POINTWISE(isposinf);

functorch/csrc/BatchRulesViews.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include <functorch/csrc/BatchedFallback.h>
1212
#include <ATen/core/dispatch/Dispatcher.h>
1313
#include <c10/util/SmallBuffer.h>
14-
14+
#include <ATen/InferSize.h>
1515

1616
namespace at { namespace functorch {
1717

@@ -134,10 +134,21 @@ std::tuple<Tensor,optional<int64_t>> _unsafe_view_batch_rule(
134134
const Tensor& self,
135135
optional<int64_t> self_bdim,
136136
IntArrayRef size) {
137+
auto self_ = moveBatchDimToFront(self, self_bdim);
137138
VmapDimVector view_size(size);
138-
view_size.insert(view_size.begin() + *self_bdim, self.size(*self_bdim));
139-
140-
return std::make_tuple(at::_unsafe_view(self, view_size), self_bdim);
139+
view_size.insert(view_size.begin(), self_.size(0));
140+
141+
// See if the view is valid. If it's not, then we copy.
142+
// It's OK to copy, because _unsafe_view(x) guarantees that x isn't used
143+
// anymore.
144+
const at::DimVector inferred_size = at::infer_size_dv(view_size, self_.numel());
145+
const auto stride = at::detail::computeStride(self_.sizes(),
146+
self_.strides(),
147+
inferred_size);
148+
if (!stride.has_value()) {
149+
self_ = self_.contiguous();
150+
}
151+
return std::make_tuple(at::_unsafe_view(self_, view_size), 0);
141152
}
142153

143154
Tensor trace_decomp(const Tensor& self) {
@@ -276,11 +287,11 @@ std::tuple<Tensor, optional<int64_t>> _reshape_alias_batch_rule(const Tensor& se
276287
(void) strides;
277288
TORCH_INTERNAL_ASSERT(bdim.has_value());
278289

290+
auto self_ = moveBatchDimToFront(self, bdim);
279291
c10::SmallBuffer<int64_t, 5> new_shape(shape.size() + 1);
280-
new_shape[*bdim] = self.size(*bdim);
281-
std::copy(shape.begin(), shape.begin() + *bdim, new_shape.begin());
282-
std::copy(shape.begin() + *bdim, shape.end(), new_shape.begin() + *bdim + 1);
283-
return std::make_tuple(at::reshape(self, new_shape), bdim);
292+
new_shape[0] = self_.size(0);
293+
std::copy(shape.begin(), shape.end(), new_shape.begin() + 1);
294+
return std::make_tuple(at::reshape(self_, new_shape), 0);
284295
}
285296

286297
std::tuple<Tensor, optional<int64_t>> roll_batch_rule(const Tensor& self, optional<int64_t> bdim, IntArrayRef shifts, IntArrayRef dims) {

test/common_utils.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,32 +37,39 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
3737

3838
def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=3):
3939
def add_batch_dim(arg, bdim, batch_size=3):
40+
assert bdim == 0 or bdim == -1
4041
if isinstance(arg, torch.Tensor):
41-
shape = [1] * len(arg.shape)
42-
shape.insert(bdim, batch_size)
43-
return (arg.repeat(shape), bdim)
42+
if bdim == 0:
43+
shape = [1] * len(arg.shape)
44+
shape.insert(bdim, batch_size)
45+
return (arg.repeat(shape), bdim)
46+
if bdim == -1:
47+
arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous()
48+
return (arg, bdim)
49+
assert False
4450
else:
4551
return (arg, None)
4652

47-
batch_choices = []
48-
def add_batch_choices(a):
49-
if isinstance(a, torch.Tensor):
50-
batched_val = add_batch_dim(a, 0, batch_size)
51-
batch_choices.append((batched_val, (a, None)))
52-
else:
53-
batch_choices.append(((a, None),))
53+
for bdim in [0, -1]:
54+
batch_choices = []
55+
def add_batch_choices(a):
56+
if isinstance(a, torch.Tensor):
57+
batched_val = add_batch_dim(a, bdim, batch_size)
58+
batch_choices.append((batched_val, (a, None)))
59+
else:
60+
batch_choices.append(((a, None),))
5461

55-
flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
56-
for arg in flat_args:
57-
add_batch_choices(arg)
62+
flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
63+
for arg in flat_args:
64+
add_batch_choices(arg)
5865

59-
for batched_values in itertools.product(*batch_choices):
60-
batched_args, in_dims = zip(*batched_values)
66+
for batched_values in itertools.product(*batch_choices):
67+
batched_args, in_dims = zip(*batched_values)
6168

62-
if all([i is None for i in in_dims]):
63-
continue
69+
if all([i is None for i in in_dims]):
70+
continue
6471

65-
yield pytree.tree_unflatten(batched_args, arg_spec), pytree.tree_unflatten(in_dims, arg_spec), kwarg_values
72+
yield pytree.tree_unflatten(batched_args, arg_spec), pytree.tree_unflatten(in_dims, arg_spec), kwarg_values
6673

6774

6875
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, compute_loop_out=True):

test/test_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,24 @@ def vjp_of_vjp(*args_and_cotangents):
307307

308308
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
309309
@skipOps('TestOperators', 'test_vmapvjp', vjp_fail.union({
310+
# All of the following are bugs and need to be fixed
310311
xfail('clamp', ''),
311312
xfail('diag_embed'),
312313
xfail('eig'),
314+
xfail('matrix_exp'),
315+
xfail('nn.functional.conv_transpose2d'),
316+
xfail('nn.functional.pad', 'constant'),
317+
xfail('view_as_complex'),
318+
xfail('fft.fft'),
319+
xfail('fft.ifft'),
320+
xfail('fft.ihfft'),
313321
xfail('fft.ihfft'),
314322
xfail('fft.rfft'),
323+
xfail('fft.rfft'),
324+
xfail('fft.fftn'),
315325
xfail('fft.rfftn'),
326+
xfail('fft.ifftn'),
327+
xfail('cdist'),
316328
xfail('fmax'),
317329
xfail('fmin'),
318330
xfail('index_add'),
@@ -373,6 +385,8 @@ def test_vmapvjp(self, device, dtype, op):
373385

374386
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
375387
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', {
388+
xfail('nn.functional.pad', 'constant'),
389+
xfail('view_as_complex'),
376390
xfail('__getitem__'),
377391
xfail('__rpow__'),
378392
xfail('cdist'),
@@ -388,9 +402,14 @@ def test_vmapvjp(self, device, dtype, op):
388402
xfail('diag'),
389403
xfail('diag_embed'),
390404
xfail('eig'),
405+
xfail('fft.fft'),
406+
xfail('fft.fftn'),
407+
xfail('fft.ifft'),
408+
xfail('fft.ifftn'),
391409
xfail('fft.ihfft'),
392410
xfail('fft.rfft'),
393411
xfail('fft.rfftn'),
412+
xfail('cdist'),
394413
xfail('fill_'),
395414
xfail('float_power'),
396415
xfail('fmax'),
@@ -500,6 +519,7 @@ def test():
500519

501520
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
502521
@skipOps('TestOperators', 'test_vjpvmap', vjp_fail.union({
522+
# All of the following are bugs and need to be fixed
503523
xfail('__getitem__'),
504524
xfail('clamp', ''),
505525
xfail('dsplit'),
@@ -518,6 +538,11 @@ def test():
518538
xfail('block_diag'),
519539
xfail('nn.functional.batch_norm'),
520540
xfail('nn.functional.nll_loss'),
541+
xfail('cdist'),
542+
xfail('lu_solve'),
543+
xfail('lu_unpack'),
544+
xfail('matrix_exp'),
545+
xfail('view_as_complex'),
521546
}))
522547
def test_vjpvmap(self, device, dtype, op):
523548
# NB: there is no vjpvmap_has_batch_rule test because that is almost

test/test_vmap.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2989,6 +2989,7 @@ class TestVmapOperatorsOpInfo(TestCase):
29892989
xfail('gradient'),
29902990
xfail('hsplit'),
29912991
xfail('nn.functional.pad', 'circular'),
2992+
xfail('resize_'),
29922993
xfail('resize_as_'),
29932994
xfail('tensor_split'),
29942995
xfail('to_sparse'),
@@ -3000,15 +3001,24 @@ class TestVmapOperatorsOpInfo(TestCase):
30003001
xfail('nanmean'),
30013002
xfail('block_diag'),
30023003
xfail('nn.functional.dropout'),
3004+
xfail('view_as_complex'),
30033005
30043006
# entries in here don't work and need to be fixed.
30053007
# Each one of these is a bug
30063008
xfail('unfold'),
30073009
xfail('svd', device_type='cuda'),
30083010
xfail('linalg.svd', device_type='cuda'),
30093011
xfail('index_put'),
3012+
xfail('matrix_exp'),
3013+
xfail('fft.fft'),
3014+
xfail('fft.ifft'),
3015+
xfail('fft.ihfft'),
3016+
xfail('fft.rfft'),
3017+
xfail('fft.rfftn'),
30103018
xfail('nn.functional.batch_norm'),
30113019
xfail('nn.functional.nll_loss'),
3020+
xfail('lu_unpack'),
3021+
xfail('nn.functional.pad', 'constant'),
30123022
})
30133023
def test_vmap_exhaustive(self, device, dtype, op):
30143024
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
@@ -3105,6 +3115,9 @@ def test_vmap_exhaustive(self, device, dtype, op):
31053115
xfail('nn.functional.dropout'),
31063116
xfail('nn.functional.conv2d', ''),
31073117
xfail('nn.functional.batch_norm'),
3118+
xfail('resize_'),
3119+
xfail('view_as_complex'),
3120+
xfail('matrix_exp'),
31083121
})
31093122
def test_op_has_batch_rule(self, device, dtype, op):
31103123
def test():

0 commit comments

Comments
 (0)