Skip to content

Commit 0b15a12

Browse files
committed
Convolution refactor
1 parent 46bb2eb commit 0b15a12

File tree

2 files changed

+190
-213
lines changed

2 files changed

+190
-213
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <functorch/csrc/BatchRulesHelper.h>
8+
#include <functorch/csrc/PlumbingHelper.h>
9+
#include <ATen/core/dispatch/Dispatcher.h>
10+
11+
namespace at { namespace functorch {
12+
13+
// convolution_batch_rule translated from jax with modifications:
14+
// https://github.com/google/jax/blob/master/jax/_src/lax/lax.py#L3143
15+
16+
// PyTorch's convolution is different from JAX's conv_general_dilated:
17+
// we do not support batch_group_count (which is needed for convolution backwards).
18+
// Instead, there's a convolution_backward op that needs a batching rule.
19+
std::tuple<Tensor,optional<int64_t>>
20+
convolution_batch_rule(const Tensor& lhs, optional<int64_t> lhs_bdim, const Tensor& rhs, optional<int64_t> rhs_bdim, const optional<Tensor>& bias, optional<int64_t> bias_bdim, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups) {
21+
DimVector lhs_spec(stride.size() + 2);
22+
std::iota(lhs_spec.begin(), lhs_spec.end(), 0);
23+
DimVector rhs_spec = lhs_spec;
24+
DimVector out_spec = lhs_spec;
25+
if (transposed) {
26+
rhs_spec[0] = 1;
27+
rhs_spec[1] = 0;
28+
}
29+
30+
// If we have a batched bias or weight, we need to perform the computation separately.
31+
optional<Tensor> unbatched_bias;
32+
bool separate_bias;
33+
if ((rhs_bdim && bias && bias->defined()) || bias_bdim) {
34+
TORCH_INTERNAL_ASSERT(bias.has_value());
35+
TORCH_INTERNAL_ASSERT(bias->defined());
36+
unbatched_bias = nullopt;
37+
separate_bias = true;
38+
} else {
39+
unbatched_bias = bias;
40+
separate_bias = false;
41+
}
42+
std::tuple<Tensor, optional<int64_t>> result;
43+
if (lhs_bdim && !rhs_bdim) {
44+
auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[0], lhs);
45+
auto out = at::convolution(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
46+
out = reshape_dim_outof(out_spec[0], lhs.sizes()[*lhs_bdim], out);
47+
result = std::make_tuple(out, out_spec[0]);
48+
} else if (!lhs_bdim && rhs_bdim) {
49+
if (groups == 1) {
50+
auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[0], rhs);
51+
auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
52+
out = reshape_dim_outof(out_spec[1], rhs.sizes()[*rhs_bdim], out);
53+
result = std::make_tuple(out, out_spec[1]);
54+
} else {
55+
auto dim_with_groups = transposed ? 1 : 0;
56+
auto new_w = reshape_dim_outof(rhs_spec[dim_with_groups] + (*rhs_bdim <= rhs_spec[0]), groups, rhs);
57+
new_w = reshape_dim_into(*rhs_bdim + (rhs_spec[0] < rhs_bdim), rhs_spec[0] + 1, new_w);
58+
new_w = reshape_dim_into(rhs_spec[0], rhs_spec[0], new_w);
59+
auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
60+
out = reshape_dim_outof(out_spec[1], groups, out);
61+
out = reshape_dim_outof(out_spec[1] + 1, rhs.sizes()[*rhs_bdim], out);
62+
out = reshape_dim_into(out_spec[1], out_spec[1] + 1, out);
63+
result = std::make_tuple(out, out_spec[1]);
64+
}
65+
} else if (lhs_bdim && rhs_bdim) {
66+
auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[1], lhs);
67+
groups *= lhs.sizes()[*lhs_bdim];
68+
auto dim_with_groups = transposed ? 1 : 0;
69+
auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[dim_with_groups], rhs);
70+
auto out = at::convolution(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
71+
out = reshape_dim_outof(out_spec[1], lhs.sizes()[*lhs_bdim], out);
72+
result = std::make_tuple(out, out_spec[1]);
73+
} else {
74+
result = std::make_tuple(at::convolution(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), nullopt);
75+
}
76+
if (separate_bias) {
77+
auto A = std::get<0>(result);
78+
auto A_batch_dim = std::get<1>(result);
79+
auto B = *bias;
80+
auto B_batch_dim = bias_bdim;
81+
A = moveBatchDimToFront(A, A_batch_dim);
82+
B = moveBatchDimToFront(B, B_batch_dim);
83+
for (size_t i = 0; i < out_spec.size() - 2; i++) {
84+
B = B.unsqueeze(-1);
85+
}
86+
B = maybePadToLogicalRank(B, B_batch_dim, rankWithoutBatchDim(A, A_batch_dim));
87+
88+
return std::make_tuple(at::add(A, B), 0);
89+
} else {
90+
return result;
91+
}
92+
}
93+
94+
Tensor _convolution_decomp(
95+
const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_r_opt,
96+
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
97+
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
98+
bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
99+
// Ignore everything. If the user called this in the normal way,
100+
// then they should be fine.
101+
(void) benchmark;
102+
(void) deterministic;
103+
(void) cudnn_enabled;
104+
(void) allow_tf32;
105+
return at::convolution(
106+
input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_);
107+
}
108+
109+
// TODO: delete the following after confirming performance
110+
// bool first_dim_has_size_1(const Tensor& value, int64_t bdim) {
111+
// if (bdim == 0) {
112+
// return value.size(1) == 1;
113+
// }
114+
// return value.size(0) == 1;
115+
// }
116+
//
117+
// std::tuple<Tensor,int64_t,Tensor,int64_t> cudnn_conv_per_sample_grad_rule(
118+
// const Tensor& self, optional<int64_t> self_bdim,
119+
// const Tensor& grad_output, optional<int64_t> grad_output_bdim,
120+
// const Tensor& weight, optional<int64_t> weight_bdim,
121+
// IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark,
122+
// bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
123+
// TORCH_INTERNAL_ASSERT(self_bdim && grad_output_bdim && !weight_bdim);
124+
// // TODO: No clue if this works if the first non-batch dim isn't size 1
125+
// TORCH_INTERNAL_ASSERT(first_dim_has_size_1(self, *self_bdim));
126+
// TORCH_INTERNAL_ASSERT(self.dim() == 5);
127+
//
128+
// auto bdim_size = self.size(*self_bdim);
129+
// auto self_ = reshape_dim_into(*self_bdim, 0, self);
130+
// auto in_channels = self_.size(1);
131+
// auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
132+
//
133+
// auto grad_self = at::cudnn_convolution_backward_input(
134+
// self_.sizes(), grad_output_, weight,
135+
// padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
136+
// grad_self = reshape_dim_outof(0, bdim_size, grad_self);
137+
//
138+
// // Copied from https://github.com/pytorch/opacus/blob/master/opacus/grad_sample/conv.py
139+
// auto A = at::im2col(self_, {weight.size(2), weight.size(3)}, dilation, padding, stride);
140+
// auto B = grad_output_.reshape({bdim_size, -1, A.size(-1)});
141+
// auto grad_sample = at::einsum("noq,npq->nop", {B, A});
142+
// grad_sample = grad_sample.view({
143+
// bdim_size, groups, -1, groups, in_channels / groups,
144+
// weight.size(2) * weight.size(3) });
145+
// grad_sample = at::einsum("ngrg...->ngr...", {grad_sample});
146+
// grad_sample = grad_sample.reshape(
147+
// {bdim_size, weight.size(0), weight.size(1), weight.size(2), weight.size(3)});
148+
//
149+
// return std::make_tuple(grad_self, 0, grad_sample, 0);
150+
// }
151+
//
152+
// std::tuple<Tensor,Tensor> cudnn_convolution_backward_plumbing(const Tensor & self, const Tensor & grad_output, const Tensor & weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
153+
// auto maybe_layer = maybeCurrentDynamicLayer();
154+
// TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
155+
// int64_t cur_level = maybe_layer->layerId();
156+
//
157+
// Tensor self_value;
158+
// optional<int64_t> self_bdim;
159+
// std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
160+
// Tensor grad_output_value;
161+
// optional<int64_t> grad_output_bdim;
162+
// std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output, cur_level);
163+
// Tensor weight_value;
164+
// optional<int64_t> weight_bdim;
165+
// std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
166+
//
167+
// if (self_bdim.has_value() && self_value.dim() == 5 && first_dim_has_size_1(self_value, *self_bdim) && grad_output_bdim.has_value() && !weight_bdim.has_value()) {
168+
// c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
169+
// auto result = cudnn_conv_per_sample_grad_rule(
170+
// self_value, self_bdim,
171+
// grad_output_value, grad_output_bdim,
172+
// weight_value, weight_bdim,
173+
// padding, stride, dilation, groups,
174+
// benchmark, deterministic, allow_tf32, output_mask);
175+
// return std::make_tuple(
176+
// makeBatched(std::get<0>(result), std::get<1>(result), cur_level),
177+
// makeBatched(std::get<2>(result), std::get<3>(result), cur_level));
178+
// }
179+
//
180+
// static auto op = c10::Dispatcher::singleton()
181+
// .findSchemaOrThrow("aten::cudnn_convolution_backward", "");
182+
// return slow_fallback<Tensor,Tensor>(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask });
183+
// }
184+
185+
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
186+
VMAP_SUPPORT("convolution", convolution_batch_rule);
187+
m.impl("_convolution", _convolution_decomp);
188+
}
189+
190+
}} // namespace at;:functorch

0 commit comments

Comments
 (0)