Skip to content

Commit 8e2b616

Browse files
committed
Direct lowering of torch.aten.convolution_backward from torch to linalg
1 parent c180509 commit 8e2b616

File tree

7 files changed

+1015
-436
lines changed

7 files changed

+1015
-436
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 627 additions & 0 deletions
Large diffs are not rendered by default.

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 0 additions & 401 deletions
Large diffs are not rendered by default.

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
444444
target.addIllegalOp<AtenNativeGroupNormOp>();
445445
target.addIllegalOp<AtenNativeBatchNormOp>();
446446
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
447-
target.addIllegalOp<AtenConvolutionBackwardOp>();
448447
target.addIllegalOp<AtenConvTbcOp>();
449448
target.addIllegalOp<AtenConv1dOp>();
450449
target.addIllegalOp<AtenConv2dOp>();

projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"aten.flatten.using_ints",
3535
"aten.adaptive_avg_pool1d",
3636
"aten.adaptive_avg_pool2d",
37+
"aten.convolution_backward",
3738
"aten.unflatten.int",
3839
],
3940
OutputType.STABLEHLO: [

projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,75 @@ def ConvolutionBackwardModule2DStrided_basic(module, tu: TestUtils):
228228
module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8), tu.rand(2, 2, 3, 3))
229229

230230

231+
class ConvolutionBackwardModule2DDilated(torch.nn.Module):
232+
def __init__(self):
233+
super().__init__()
234+
235+
@export
236+
@annotate_args(
237+
[
238+
None,
239+
([1, 2, 6, 6], torch.float32, True),
240+
([1, 4, 8, 8], torch.float32, True),
241+
([2, 4, 3, 3], torch.float32, True),
242+
]
243+
)
244+
def forward(self, grad_out, input_vec, weight):
245+
return torch.ops.aten.convolution_backward(
246+
grad_out,
247+
input_vec,
248+
weight,
249+
bias_sizes=[4],
250+
stride=[1, 1],
251+
padding=[1, 1],
252+
dilation=[2, 2],
253+
transposed=False,
254+
output_padding=[0, 0],
255+
groups=1,
256+
output_mask=[True, True, True],
257+
)
258+
259+
260+
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DDilated())
261+
def ConvolutionBackwardModule2DDilated_basic(module, tu: TestUtils):
262+
with torch.backends.mkldnn.flags(enabled=False):
263+
module.forward(tu.rand(1, 2, 6, 6), tu.rand(1, 4, 8, 8), tu.rand(2, 4, 3, 3))
264+
265+
266+
class ConvolutionBackwardModule2DStridedPaddedDilatedGrouped(torch.nn.Module):
267+
def __init__(self):
268+
super().__init__()
269+
270+
@export
271+
@annotate_args(
272+
[
273+
None,
274+
([2, 16, 32, 32], torch.float32, True),
275+
([2, 128, 64, 64], torch.float32, True),
276+
([16, 32, 2, 2], torch.float32, True),
277+
]
278+
)
279+
def forward(self, grad_out, input_vec, weight):
280+
return torch.ops.aten.convolution_backward(
281+
grad_out,
282+
input_vec,
283+
weight,
284+
bias_sizes=[4],
285+
stride=[2, 2],
286+
padding=[2, 2],
287+
dilation=[4, 4],
288+
transposed=False,
289+
output_padding=[0, 0],
290+
groups=4,
291+
output_mask=[True, True, True],
292+
)
293+
294+
295+
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStridedPaddedDilatedGrouped())
296+
def ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic(module, tu: TestUtils):
297+
with torch.backends.mkldnn.flags(enabled=False):
298+
module.forward(tu.rand(2, 16, 32, 32), tu.rand(2, 128, 64, 64), tu.rand(16, 32, 2, 2))
299+
231300
# ==============================================================================
232301

233302

test/Conversion/TorchToLinalg/convolution_bwd.mlir

Lines changed: 318 additions & 0 deletions
Large diffs are not rendered by default.

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -273,40 +273,6 @@ func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int {
273273
return %arg0 : !torch.int
274274
}
275275

276-
// -----
277-
278-
// CHECK-LABEL: func.func @convolution_backward_none_result(
279-
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,3,3],f32>, %[[VAL_1:.*]]: !torch.vtensor<[1,1,5,5],f32>,
280-
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1,1,3,3],f32>,
281-
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>) {
282-
func.func @convolution_backward_none_result(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,3,3],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>) {
283-
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
284-
// CHECK: %[[VAL_5:.*]] = torch.constant.int 2
285-
// CHECK: %[[VAL_6:.*]] = torch.constant.none
286-
// CHECK: %[[VAL_7:.*]] = torch.constant.int 0
287-
// CHECK: %[[VAL_8:.*]] = torch.constant.bool false
288-
// CHECK: %[[VAL_9:.*]] = torch.constant.int 1
289-
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_9]], %[[VAL_9]] : (!torch.int, !torch.int) -> !torch.list<int>
290-
// CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list<int>
291-
// CHECK: %[[VAL_12:.*]] = torch.aten.transpose.int %[[VAL_1]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,5,5],f32>
292-
// CHECK: %[[VAL_13:.*]] = torch.aten.transpose.int %[[VAL_0]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3,3],f32>
293-
// CHECK: %[[VAL_14:.*]] = torch.aten.convolution %[[VAL_12]], %[[VAL_13]], %[[VAL_6]], %[[VAL_10]], %[[VAL_11]], %[[VAL_10]], %[[VAL_8]], %[[VAL_11]], %[[VAL_9]] : !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,3,3],f32>
294-
// CHECK: %[[VAL_15:.*]] = torch.aten.transpose.int %[[VAL_14]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3,3],f32>
295-
// CHECK: %[[VAL_16:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_5]], %[[VAL_4]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
296-
// CHECK: %[[VAL_17:.*]] = torch.aten.sum.dim_IntList %[[VAL_0]], %[[VAL_16]], %[[VAL_8]], %[[VAL_6]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1],f32>
297-
// CHECK: return %[[VAL_15]], %[[VAL_17]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>
298-
%true = torch.constant.bool true
299-
%int0 = torch.constant.int 0
300-
%false = torch.constant.bool false
301-
%int1 = torch.constant.int 1
302-
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
303-
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
304-
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
305-
%3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
306-
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.none, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>
307-
return %result1, %result2 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>
308-
}
309-
310276
// -----
311277
// CHECK-LABEL: func.func @emptyLikeNoneDtype(
312278
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {

0 commit comments

Comments
 (0)