Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 5304c81

Browse files
authored
Fix convolution batch rule in the transpose case (#345)
We were making wrong assumptions about where the input_channels / output_channels were in the weight tensor and where the groups dimension gets included. Test Plan: - run tests
1 parent 67075bd commit 5304c81

File tree

4 files changed

+14
-16
lines changed

4 files changed

+14
-16
lines changed

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
224224
OP_DECOMPOSE(arctan2);
225225
OP_DECOMPOSE(layer_norm);
226226
OP_DECOMPOSE(diag_backward);
227+
OP_DECOMPOSE(conv_transpose1d);
228+
OP_DECOMPOSE2(conv_transpose2d, input);
229+
OP_DECOMPOSE2(conv_transpose3d, input);
227230
DECOMPOSE_FUNCTIONAL(diag_embed);
228231
DECOMPOSE_FUNCTIONAL(block_diag);
229232
}

functorch/csrc/BatchRulesModules.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@ namespace at { namespace functorch {
1515
// Does not support batch_group_count (needed for convolution backwards)
1616
std::tuple<Tensor,optional<int64_t>>
1717
convolution_batch_rule(const Tensor& lhs, optional<int64_t> lhs_bdim, const Tensor& rhs, optional<int64_t> rhs_bdim, const optional<Tensor>& bias, optional<int64_t> bias_bdim, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups) {
18-
std::vector<int64_t> lhs_spec(stride.size() + 2);
18+
DimVector lhs_spec(stride.size() + 2);
1919
std::iota(lhs_spec.begin(), lhs_spec.end(), 0);
20-
std::vector<int64_t> rhs_spec = lhs_spec;
21-
std::vector<int64_t> out_spec = lhs_spec;
20+
DimVector rhs_spec = lhs_spec;
21+
DimVector out_spec = lhs_spec;
22+
if (transposed) {
23+
rhs_spec[0] = 1;
24+
rhs_spec[1] = 0;
25+
}
2226

2327
// If we have a batched bias or weight, we need to perform the computation separately.
2428
optional<Tensor> unbatched_bias;
@@ -45,7 +49,8 @@ convolution_batch_rule(const Tensor& lhs, optional<int64_t> lhs_bdim, const Tens
4549
out = reshape_dim_outof(out_spec[1], rhs.sizes()[*rhs_bdim], out);
4650
result = std::make_tuple(out, out_spec[1]);
4751
} else {
48-
auto new_w = reshape_dim_outof(rhs_spec[0] + (*rhs_bdim <= rhs_spec[0]), groups, rhs);
52+
auto dim_with_groups = transposed ? 1 : 0;
53+
auto new_w = reshape_dim_outof(rhs_spec[dim_with_groups] + (*rhs_bdim <= rhs_spec[0]), groups, rhs);
4954
new_w = reshape_dim_into(*rhs_bdim + (rhs_spec[0] < rhs_bdim), rhs_spec[0] + 1, new_w);
5055
new_w = reshape_dim_into(rhs_spec[0], rhs_spec[0], new_w);
5156
auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
@@ -57,7 +62,8 @@ convolution_batch_rule(const Tensor& lhs, optional<int64_t> lhs_bdim, const Tens
5762
} else if (lhs_bdim && rhs_bdim) {
5863
auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[1], lhs);
5964
groups *= lhs.sizes()[*lhs_bdim];
60-
auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[0], rhs);
65+
auto dim_with_groups = transposed ? 1 : 0;
66+
auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[dim_with_groups], rhs);
6167
auto out = at::convolution(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
6268
out = reshape_dim_outof(out_spec[1], lhs.sizes()[*lhs_bdim], out);
6369
result = std::make_tuple(out, out_spec[1]);

test/test_ops.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,6 @@ def vjp_of_vjp(*args_and_cotangents):
467467
xfail('symeig'),
468468
xfail('take'),
469469
xfail('linalg.tensorinv'),
470-
xfail('nn.functional.conv_transpose2d', device_type='cuda'),
471470
xfail('nanmean'),
472471
xfail('block_diag'),
473472
xfail('nn.functional.dropout'),
@@ -487,13 +486,6 @@ def vjp_of_vjp(*args_and_cotangents):
487486
xfail('nn.functional.fractional_max_pool3d'),
488487
xfail('as_strided'),
489488
xfail('nn.functional.fractional_max_pool2d'),
490-
491-
# PyTorch changed its convolution recently.
492-
# Maybe it is responsible for all of the following changes.
493-
xfail('nn.functional.conv_transpose1d'),
494-
xfail('nn.functional.conv_transpose2d'),
495-
xfail('nn.functional.conv_transpose3d'),
496-
497489
})
498490
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
499491
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail)

test/test_vmap.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3197,7 +3197,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
31973197
xfail('masked_scatter'),
31983198
xfail('masked_select'),
31993199
xfail('nanquantile'),
3200-
xfail('nn.functional.conv_transpose2d'),
32013200
xfail('norm', 'fro'),
32023201
xfail('norm', 'nuc'),
32033202
xfail('ormqr'),
@@ -3265,7 +3264,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
32653264
xfail('nn.functional.poisson_nll_loss'),
32663265
xfail('nn.functional.max_pool3d'),
32673266
xfail('histc'),
3268-
xfail('nn.functional.conv_transpose1d'),
32693267
xfail('as_strided'),
32703268
xfail('istft'),
32713269
xfail('nonzero'),
@@ -3282,7 +3280,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
32823280
xfail('isclose'),
32833281
xfail('cartesian_prod'),
32843282
xfail('nn.functional.fractional_max_pool3d'),
3285-
xfail('nn.functional.conv_transpose3d'),
32863283
xfail('nn.functional.rrelu'),
32873284
xfail('nn.functional.bilinear'),
32883285
xfail('nn.functional.embedding_bag'),

0 commit comments

Comments
 (0)