Skip to content

Commit 7dcf261

Browse files
authored
convolution_backward batch rule (#359)
This is the most ridiculous batching rule we have. Featuring a guest appearance of efficient zeros tensors. We should really consider upstreaming einops.rearrange (https://einops.rocks/api/rearrange/). Many batching rules are straight up dimension manipulation and if we could specify that with strings we'd be done 10x faster. Test Plan: - run tests
1 parent acb46ce commit 7dcf261

File tree

2 files changed

+317
-8
lines changed

2 files changed

+317
-8
lines changed

functorch/csrc/BatchRulesConvolution.cpp

Lines changed: 317 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,327 @@ Tensor _convolution_decomp(
180180
// static auto op = c10::Dispatcher::singleton()
181181
// .findSchemaOrThrow("aten::cudnn_convolution_backward", "");
182182
// return slow_fallback<Tensor,Tensor>(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask });
183-
// }
183+
184+
static Tensor compute_grad_bias(
185+
const Tensor& grad_output_, std::array<bool, 3> output_mask) {
186+
if (!output_mask[2]) {
187+
return Tensor();
188+
}
189+
DimVector reduce_dims;
190+
reduce_dims.resize(grad_output_.dim() - 1);
191+
reduce_dims[0] = 0;
192+
std::iota(reduce_dims.begin() + 1, reduce_dims.end(), 2);
193+
return grad_output_.sum(reduce_dims);
194+
}
195+
196+
// reshapes the batch_size into dim
197+
Tensor make_dummy(
198+
const Tensor& tensor, optional<int64_t> tensor_bdim,
199+
int64_t dim, int64_t batch_size) {
200+
auto tensor_ = tensor_bdim ? tensor.select(*tensor_bdim, 0) : tensor;
201+
auto orig_size = tensor_.size(dim);
202+
tensor_ = tensor_.slice(dim, 0, 1);
203+
204+
DimVector expand_shape(tensor_.sizes().begin(), tensor_.sizes().end());
205+
expand_shape[dim] = batch_size * orig_size;
206+
207+
// return tensor_.new_zeros(expand_shape);
208+
return at::_efficientzerotensor(expand_shape, tensor.options());
209+
}
210+
211+
std::tuple<Tensor,optional<int64_t>>
212+
convolution_backward_input_batch_rule(
213+
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
214+
const Tensor& input, optional<int64_t> input_bdim,
215+
const Tensor& weight, optional<int64_t> weight_bdim,
216+
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
217+
IntArrayRef output_padding, int64_t groups) {
218+
const std::array<bool, 3> mask = {true, false, false};
219+
if (grad_output_bdim && weight_bdim) {
220+
// regular: BNO, BOI -> N(BO), (BO)I -> N(BI)
221+
// transposed: BNO, BIO -> N(BO), (BI)O -> N(BI)
222+
const auto batch_size = weight.size(*weight_bdim);
223+
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
224+
const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
225+
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
226+
const auto result = at::convolution_backward(
227+
grad_output_, dummy_input, weight_, nullopt, stride, padding,
228+
dilation, transposed, output_padding, groups * batch_size, mask);
229+
const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
230+
return std::make_tuple(grad_input, 1);
231+
} else if (grad_output_bdim && !weight_bdim) {
232+
// BNO, OI -> (BN)O, OI -> (BN)I
233+
// transposed is the same.
234+
const auto batch_size = grad_output.size(*grad_output_bdim);
235+
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
236+
auto dummy_input = make_dummy(input, input_bdim, 0, batch_size);
237+
const auto result = at::convolution_backward(
238+
grad_output_, dummy_input, weight, nullopt, stride, padding,
239+
dilation, transposed, output_padding, groups, mask);
240+
const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
241+
return std::make_tuple(grad_input, 0);
242+
} else if (!grad_output_bdim && weight_bdim) {
243+
const auto batch_size = weight.size(*weight_bdim);
244+
if (groups == 1) {
245+
// regular: NO, BOI -> NO, O(BI) -> N(BI)
246+
// transposed: NO, BIO -> NO, (BI)O -> N(BI)
247+
const auto in_ch_dim = transposed ? 0 : 1;
248+
const auto weight_ = reshape_dim_into(*weight_bdim, in_ch_dim, weight);
249+
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
250+
const auto result = at::convolution_backward(
251+
grad_output, dummy_input, weight_, nullopt, stride, padding,
252+
dilation, transposed, output_padding, groups, mask);
253+
const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
254+
return std::make_tuple(grad_input, 1);
255+
}
256+
Tensor grad_input;
257+
if (!transposed) {
258+
// N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI)
259+
const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight);
260+
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
261+
const auto result = at::convolution_backward(
262+
grad_output, dummy_input, weight_, nullopt, stride, padding,
263+
dilation, transposed, output_padding, groups, mask);
264+
grad_input = std::get<0>(result); // N(GBI)
265+
} else {
266+
// N(GO), B(GI)O -> N(GO), (GBI)O -> N(GBI)
267+
auto weight_ = moveBatchDimToFront(weight, weight_bdim); // B(GI)O
268+
weight_ = reshape_dim_outof(1, groups, weight_); // BGIO
269+
weight_ = weight_.transpose(0, 1); // GBIO
270+
weight_ = weight_.flatten(0, 2); // (GBI)O
271+
const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
272+
const auto result = at::convolution_backward(
273+
grad_output, dummy_input, weight_, nullopt, stride, padding,
274+
dilation, transposed, output_padding, groups, mask);
275+
grad_input = std::get<0>(result); // N(GBI)
276+
}
277+
// N(GBI) -> NG(BI) -> NGBI -> NBGI -> NB(GI)
278+
grad_input = reshape_dim_outof(1, groups, grad_input);
279+
grad_input = reshape_dim_outof(2, batch_size, grad_input);
280+
grad_input = grad_input.transpose(1, 2);
281+
grad_input = reshape_dim_into(2, 2, grad_input);
282+
return std::make_tuple(grad_input, 1);
283+
} else {
284+
TORCH_INTERNAL_ASSERT(input_bdim);
285+
const auto dummy_input = make_dummy(input, input_bdim, 0, 1);
286+
const auto result = at::convolution_backward(
287+
grad_output, dummy_input, weight, nullopt, stride, padding,
288+
dilation, transposed, output_padding, groups, mask);
289+
return std::make_tuple(std::get<0>(result), nullopt);
290+
}
291+
}
292+
std::tuple<Tensor,optional<int64_t>>
293+
convolution_backward_weight_batch_rule(
294+
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
295+
const Tensor& input, optional<int64_t> input_bdim,
296+
const Tensor& weight, optional<int64_t> weight_bdim,
297+
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
298+
IntArrayRef output_padding, int64_t groups) {
299+
const std::array<bool, 3> mask = {false, true, false};
300+
if (grad_output_bdim && input_bdim) {
301+
// BNO, BNI -> N(BO), N(BI) -> (BO)I (regular) (BI)O (transposed)
302+
const auto batch_size = input.size(*input_bdim);
303+
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
304+
const auto input_ = reshape_dim_into(*input_bdim, 1, input);
305+
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
306+
const auto result = at::convolution_backward(
307+
grad_output_, input_, dummy_weight, nullopt, stride, padding,
308+
dilation, transposed, output_padding, groups * batch_size, mask);
309+
auto grad_weight = std::get<1>(result);
310+
grad_weight = reshape_dim_outof(0, batch_size, grad_weight);
311+
return std::make_tuple(grad_weight, 0);
312+
} else if (grad_output_bdim && !input_bdim) {
313+
const auto batch_size = grad_output.size(*grad_output_bdim);
314+
if (groups == 1) {
315+
// regular: BNO, NI -> N(BO), NI -> (BO)I
316+
// transposed: BNO, NI -> N(BO), NI -> I(BO)
317+
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
318+
const auto out_ch_dim = transposed ? 1 : 0;
319+
const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size);
320+
const auto result = at::convolution_backward(
321+
grad_output_, input, dummy_weight, nullopt, stride, padding,
322+
dilation, transposed, output_padding, groups, mask);
323+
auto grad_weight = std::get<1>(result);
324+
grad_weight = reshape_dim_outof(out_ch_dim, batch_size, grad_weight);
325+
return std::make_tuple(grad_weight, out_ch_dim);
326+
} else {
327+
auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim); // BN(GO)
328+
grad_output_ = reshape_dim_outof(2, groups, grad_output_); // BNGO
329+
grad_output_ = grad_output_.movedim(0, 2); // NGBO
330+
grad_output_ = grad_output_.flatten(1, 3); // N(GBO)
331+
if (!transposed) {
332+
// BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I
333+
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
334+
const auto result = at::convolution_backward(
335+
grad_output_, input, dummy_weight, nullopt, stride, padding,
336+
dilation, transposed, output_padding, groups, mask);
337+
auto grad_weight = std::get<1>(result);
338+
grad_weight = grad_weight.unflatten(0, { groups, batch_size, -1 }); // GBOI
339+
grad_weight = grad_weight.transpose(0, 1); // BGOI
340+
grad_weight = grad_weight.flatten(1, 2); // B(GO)I
341+
return std::make_tuple(grad_weight, 0);
342+
} else {
343+
// BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO)
344+
const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
345+
const auto result = at::convolution_backward(
346+
grad_output_, input, dummy_weight, nullopt, stride, padding,
347+
dilation, transposed, output_padding, groups, mask);
348+
auto grad_weight = std::get<1>(result);
349+
grad_weight = reshape_dim_outof(1, batch_size, grad_weight);
350+
return std::make_tuple(grad_weight, 1);
351+
}
352+
}
353+
} else if (!grad_output_bdim && input_bdim) {
354+
const auto batch_size = input.size(*input_bdim);
355+
if (groups == 1) {
356+
// regular: NO, BNI -> NO, N(BI) -> O(BI)
357+
// transposed: NO, BNI -> NO, N(BI) -> (BI)O
358+
const auto input_ = reshape_dim_into(*input_bdim, 1, input);
359+
const auto in_ch_dim = transposed ? 0 : 1;
360+
const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size);
361+
const auto result = at::convolution_backward(
362+
grad_output, input_, dummy_weight, nullopt, stride, padding,
363+
dilation, transposed, output_padding, groups, mask);
364+
auto grad_weight = std::get<1>(result);
365+
grad_weight = reshape_dim_outof(in_ch_dim, batch_size, grad_weight);
366+
return std::make_tuple(grad_weight, in_ch_dim);
367+
} else {
368+
auto input_ = moveBatchDimToFront(input, input_bdim); // BN(GI)
369+
input_ = reshape_dim_outof(2, groups, input_); // BNGI
370+
input_ = input_.movedim(0, 2); // NGBI
371+
input_ = input_.flatten(1, 3); // N(GBI)
372+
if (!transposed) {
373+
// regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI)
374+
const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
375+
const auto result = at::convolution_backward(
376+
grad_output, input_, dummy_weight, nullopt, stride, padding,
377+
dilation, transposed, output_padding, groups, mask);
378+
auto grad_weight = std::get<1>(result);
379+
grad_weight = reshape_dim_outof(1, batch_size, grad_weight);
380+
return std::make_tuple(grad_weight, 1);
381+
} else {
382+
// transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O
383+
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
384+
const auto result = at::convolution_backward(
385+
grad_output, input_, dummy_weight, nullopt, stride, padding,
386+
dilation, transposed, output_padding, groups, mask);
387+
auto grad_weight = std::get<1>(result);
388+
grad_weight = grad_weight.unflatten(0, { groups, batch_size, -1 }); // GBIO
389+
grad_weight = grad_weight.transpose(0, 1); // BGIO
390+
grad_weight = grad_weight.flatten(1, 2); // B(GI)O
391+
return std::make_tuple(grad_weight, 0);
392+
}
393+
}
394+
} else {
395+
TORCH_INTERNAL_ASSERT(weight_bdim);
396+
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1);
397+
const auto result = at::convolution_backward(
398+
grad_output, input, dummy_weight, nullopt, stride, padding,
399+
dilation, transposed, output_padding, groups, mask);
400+
return std::make_tuple(std::get<1>(result), nullopt);
401+
402+
}
403+
}
404+
405+
std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
406+
const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
407+
const c10::optional<IntArrayRef> bias_sizes_opt,
408+
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
409+
IntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) {
410+
const auto maybe_layer = maybeCurrentDynamicLayer();
411+
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
412+
int64_t cur_level = maybe_layer->layerId();
413+
Tensor grad_output;
414+
optional<int64_t> grad_output_bdim;
415+
std::tie(grad_output, grad_output_bdim) = unwrapTensorAtLevel(grad_output_, cur_level);
416+
Tensor input;
417+
optional<int64_t> input_bdim;
418+
std::tie(input, input_bdim) = unwrapTensorAtLevel(input_, cur_level);
419+
Tensor weight;
420+
optional<int64_t> weight_bdim;
421+
std::tie(weight, weight_bdim) = unwrapTensorAtLevel(weight_, cur_level);
422+
423+
const auto grad_bias = compute_grad_bias(grad_output_, output_mask);
424+
output_mask[2] = false;
425+
426+
// TODO: A little bird says that unfold + matmul is actually faster than
427+
// group convolution in many cases. We should benchmark some of
428+
// the common cases and replace things with unfold + matmul as necessary.
429+
430+
// Notation:
431+
// B - a batch dimension
432+
// G - groups (sometimes omitted because it doesn't matter)
433+
// NO - grad_output
434+
// NI - input
435+
// OI - weight
436+
// "(BO)I" - we don't actually care about the values of this Tensor,
437+
// we just need to create a tensor on the same device with the
438+
// correct shape and pray that the implementation is smart enough
439+
// to not do anything with it.
440+
441+
// BNO, BNI, BOI
442+
// AKA one of the model ensembling case
443+
if (grad_output_bdim && input_bdim && weight_bdim) {
444+
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
445+
grad_output = reshape_dim_into(*grad_output_bdim, 1, grad_output);
446+
447+
// BNO, BNI, BOI -> N(BO), N(BI), (BO)I
448+
const auto batch_size = weight.size(*weight_bdim);
449+
input = reshape_dim_into(*input_bdim, 1, input);
450+
weight = reshape_dim_into(*weight_bdim, 0, weight);
451+
const auto result = at::convolution_backward(
452+
grad_output, input, weight, nullopt, stride, padding, dilation,
453+
transposed, output_padding, batch_size * groups, output_mask);
454+
// N(BI), (BO)I -> NBI, BOI
455+
const auto grad_input = output_mask[0] ?
456+
reshape_dim_outof(1, batch_size, std::get<0>(result)) : Tensor();
457+
const auto grad_weight = output_mask[1] ?
458+
reshape_dim_outof(0, batch_size, std::get<1>(result)) : Tensor();
459+
return std::make_tuple(
460+
output_mask[0] ? makeBatched(grad_input, 1, cur_level) : grad_input,
461+
output_mask[1] ? makeBatched(grad_weight, 0, cur_level) : grad_weight,
462+
grad_bias);
463+
}
464+
465+
Tensor grad_input;
466+
if (output_mask[0]) {
467+
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
468+
const auto result = convolution_backward_input_batch_rule(
469+
grad_output, grad_output_bdim,
470+
input, input_bdim,
471+
weight, weight_bdim,
472+
stride, padding, dilation, transposed, output_padding, groups);
473+
grad_input = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
474+
}
475+
476+
Tensor grad_weight;
477+
if (output_mask[1]) {
478+
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
479+
const auto result = convolution_backward_weight_batch_rule(
480+
grad_output, grad_output_bdim,
481+
input, input_bdim,
482+
weight, weight_bdim,
483+
stride, padding, dilation, transposed, output_padding, groups);
484+
grad_weight = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
485+
}
486+
return std::make_tuple(grad_input, grad_weight, grad_bias);
487+
488+
// Someone's definitely going to find a problem with this batching rule so
489+
// I'm leaving the following fallback if we need it back.
490+
// static auto op = c10::Dispatcher::singleton()
491+
// .findSchemaOrThrow("aten::convolution_backward", "");
492+
// auto result = slow_fallback<Tensor,Tensor,Tensor>(op, {
493+
// grad_output_, input_, weight_, bias_sizes_opt,
494+
// stride, padding, dilation, transposed, output_padding, groups, output_mask
495+
// });
496+
// return std::make_tuple(grad_input, std::get<1>(result), grad_bias);
497+
}
498+
184499

185500
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
186501
VMAP_SUPPORT("convolution", convolution_batch_rule);
187502
m.impl("_convolution", _convolution_decomp);
503+
m.impl("convolution_backward", convolution_backward_plumbing);
188504
}
189505

190506
}} // namespace at;:functorch

test/test_ops.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,6 @@ def test_vmapjvpall(self, device, dtype, op):
790790
xfail('masked_select'),
791791
xfail('matrix_exp'),
792792
xfail('nanquantile'),
793-
xfail('nn.functional.conv_transpose2d'),
794793
xfail('nn.functional.gelu'),
795794
xfail('norm', 'nuc'),
796795
xfail('pinverse'),
@@ -816,13 +815,11 @@ def test_vmapjvpall(self, device, dtype, op):
816815
xfail('cross'),
817816
xfail('double', 'channels_last'),
818817
xfail('linalg.cross'),
819-
skip('nn.functional.conv1d'),
820818
xfail('nn.functional.gaussian_nll_loss'),
821819
xfail('nn.functional.hardsigmoid'),
822820
xfail('nn.functional.huber_loss'),
823821
xfail('nn.functional.instance_norm'),
824822
xfail('nn.functional.poisson_nll_loss'),
825-
xfail('nn.functional.conv_transpose3d'),
826823
xfail('nn.functional.bilinear'),
827824
xfail('nn.functional.prelu'),
828825
xfail('nn.functional.glu'),
@@ -833,16 +830,12 @@ def test_vmapjvpall(self, device, dtype, op):
833830
xfail('nn.functional.rrelu'),
834831
xfail('nn.functional.embedding_bag'),
835832
xfail('nn.functional.softshrink'),
836-
xfail('nn.functional.conv_transpose1d'),
837833
xfail('nn.functional.max_pool3d'),
838834
xfail('istft'),
839835
xfail('nn.functional.fractional_max_pool2d'),
840836
xfail('linalg.tensorsolve'),
841837
}))
842838
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
843-
# These are too annoying to put into the list above
844-
if op.name in {'nn.functional.conv2d'}:
845-
self.skipTest("Skipped! ExpectedF failures")
846839
if not op.supports_autograd:
847840
self.skipTest("Skipped! Autograd not supported.")
848841
return

0 commit comments

Comments
 (0)