Skip to content

Commit 1682ce2

Browse files
authored
[Tosa] : Equalize ranks for all operands for tosa.select + Slice conv inputs for dynamic batch as long as spatial dims are static. (#4162)
Fixes two bugs 1. equalize ranks for all operands of `tosa.select` 2. slices conv inputs for dynamic batch if spatial dims are static
1 parent 82ca7f3 commit 1682ce2

File tree

2 files changed

+120
-8
lines changed

2 files changed

+120
-8
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2453,9 +2453,13 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24532453
}
24542454

24552455
int64_t outputHDim, outputWDim;
2456-
if (inputTy.hasStaticShape()) {
2457-
int64_t inputHDim = inputShape[2];
2458-
int64_t inputWDim = inputShape[3];
2456+
int64_t inputHDim = inputShape[2];
2457+
int64_t inputWDim = inputShape[3];
2458+
2459+
bool isStaticSpatialDims =
2460+
!ShapedType::isDynamic(inputHDim) && !ShapedType::isDynamic(inputWDim);
2461+
if (isStaticSpatialDims) {
2462+
24592463
int64_t weightHDim = weightShape[2];
24602464
int64_t weightWDim = weightShape[3];
24612465

@@ -2473,8 +2477,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24732477
SmallVector<int64_t> sizeHSlice(transposedInputShape);
24742478
// TOSA uses NHWC, so we will slice dim 1 for Height value
24752479
sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]);
2476-
transposedInput = rewriter.create<tosa::SliceOp>(
2477-
op->getLoc(), RankedTensorType::get(sizeHSlice, inputElemTy),
2480+
transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
2481+
rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy),
24782482
transposedInput,
24792483
tosa::getTosaConstShape(rewriter, op->getLoc(), startHSlice),
24802484
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeHSlice));
@@ -2498,8 +2502,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24982502
dyn_cast<RankedTensorType>(transposedInput.getType()).getShape());
24992503
// TOSA uses NHWC, so we will slice dim 2 for Width value
25002504
sizeWSlice[2] = inputWDim - (remainderWDim - padding[3]);
2501-
transposedInput = rewriter.create<tosa::SliceOp>(
2502-
op->getLoc(), RankedTensorType::get(sizeWSlice, inputElemTy),
2505+
transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
2506+
rewriter, op->getLoc(), UnrankedTensorType::get(inputElemTy),
25032507
transposedInput,
25042508
tosa::getTosaConstShape(rewriter, op->getLoc(), startWSlice),
25052509
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeWSlice));
@@ -5004,7 +5008,8 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
50045008
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
50055009

50065010
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, self).failed() ||
5007-
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, other).failed())
5011+
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, other).failed() ||
5012+
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, other).failed())
50085013
return rewriter.notifyMatchFailure(
50095014
op, "Failed to equalize ranks among operands and result");
50105015

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,28 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to
14211421
return %0 : !torch.vtensor<[1,12,5,5],f32>
14221422
}
14231423

1424+
// -----
1425+
// CHECK-LABEL: func.func @torch.aten.where.self_differing_rank_inputs(
1426+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,4],i1>,
1427+
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>,
1428+
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1,3,1,1,5,4],f32>) -> !torch.vtensor<[1,3,1,1,5,4],f32> {
1429+
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1,3,1,1,5,4],f32> -> tensor<1x3x1x1x5x4xf32>
1430+
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[],f32> -> tensor<f32>
1431+
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,4],i1> -> tensor<5x4xi1>
1432+
// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
1433+
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_6]] : (tensor<f32>, !tosa.shape<2>) -> tensor<1x1xf32>
1434+
// CHECK: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 1, 5, 4]> : tensor<6xindex>} : () -> !tosa.shape<6>
1435+
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_8]] : (tensor<5x4xi1>, !tosa.shape<6>) -> tensor<1x1x1x1x5x4xi1>
1436+
// CHECK: %[[VAL_10:.*]] = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
1437+
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_10]] : (tensor<1x1xf32>, !tosa.shape<6>) -> tensor<1x1x1x1x1x1xf32>
1438+
// CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_9]], %[[VAL_11]], %[[VAL_3]] : (tensor<1x1x1x1x5x4xi1>, tensor<1x1x1x1x1x1xf32>, tensor<1x3x1x1x5x4xf32>) -> tensor<1x3x1x1x5x4xf32>
1439+
// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x3x1x1x5x4xf32> -> !torch.vtensor<[1,3,1,1,5,4],f32>
1440+
// CHECK: return %[[VAL_13]]
1441+
func.func @torch.aten.where.self_differing_rank_inputs(%40: !torch.vtensor<[5,4],i1>, %41: !torch.vtensor<[],f32>, %38 : !torch.vtensor<[1,3,1,1,5,4],f32>) -> (!torch.vtensor<[1,3,1,1,5,4],f32>) {
1442+
%42 = torch.aten.where.self %40, %41, %38 : !torch.vtensor<[5,4],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[1,3,1,1,5,4],f32> -> !torch.vtensor<[1,3,1,1,5,4],f32>
1443+
return %42: !torch.vtensor<[1,3,1,1,5,4],f32>
1444+
}
1445+
14241446
// -----
14251447
// CHECK-LABEL: func.func @torch.aten.remainder.Scalar(
14261448
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
@@ -3780,6 +3802,91 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp
37803802
return %5 : !torch.vtensor<[1,32,75,75],f32>
37813803
}
37823804

3805+
3806+
// -----
3807+
3808+
// CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(
3809+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> {
3810+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,224,224],f32> -> tensor<?x3x224x224xf32>
3811+
// CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3812+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3813+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3814+
// CHECK: %[[VAL_5:.*]] = torch.constant.none
3815+
// CHECK: %[[VAL_6:.*]] = torch.constant.int 2
3816+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3817+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3818+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3819+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3820+
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3821+
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x224x224xf32>) -> tensor<?x224x224x3xf32>
3822+
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3823+
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3824+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3825+
// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 2, 2>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x112x112x32xf32>
3826+
// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x112x112x32xf32>) -> tensor<?x32x112x112xf32>
3827+
// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<?x32x112x112xf32> to tensor<?x32x112x112xf32>
3828+
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<?x32x112x112xf32> -> !torch.vtensor<[?,32,112,112],f32>
3829+
// CHECK: return %[[VAL_19]]
3830+
3831+
func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(%arg0: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> {
3832+
%false = torch.constant.bool false
3833+
%int1 = torch.constant.int 1
3834+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>) : !torch.vtensor<[32,3,3,3],f32>
3835+
%none = torch.constant.none
3836+
%int2 = torch.constant.int 2
3837+
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
3838+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3839+
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3840+
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
3841+
%5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[?,3,224,224],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,32,112,112],f32>
3842+
return %5 : !torch.vtensor<[?,32,112,112],f32>
3843+
}
3844+
3845+
3846+
// -----
3847+
3848+
// CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(
3849+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> {
3850+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,225,225],f32> -> tensor<?x3x225x225xf32>
3851+
// CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3852+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3853+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3854+
// CHECK: %[[VAL_5:.*]] = torch.constant.none
3855+
// CHECK: %[[VAL_6:.*]] = torch.constant.int 3
3856+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3857+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3858+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3859+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3860+
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3861+
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x225x225xf32>) -> tensor<?x225x225x3xf32>
3862+
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3863+
// CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3864+
// CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3865+
// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<?x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x225x3xf32>
3866+
// CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3867+
// CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3868+
// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<?x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x224x3xf32>
3869+
// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3870+
// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3871+
// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 3, 3>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x75x75x32xf32>
3872+
// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x75x75x32xf32>) -> tensor<?x32x75x75xf32>
3873+
// CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor<?x32x75x75xf32> to tensor<?x32x75x75xf32>
3874+
// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<?x32x75x75xf32> -> !torch.vtensor<[?,32,75,75],f32>
3875+
// CHECK: return %[[VAL_25]]
3876+
func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(%arg0: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> {
3877+
%false = torch.constant.bool false
3878+
%int1 = torch.constant.int 1
3879+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>) : !torch.vtensor<[32,3,3,3],f32>
3880+
%none = torch.constant.none
3881+
%int3 = torch.constant.int 3
3882+
%1 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
3883+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3884+
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3885+
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
3886+
%5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[?,3,225,225],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,32,75,75],f32>
3887+
return %5 : !torch.vtensor<[?,32,75,75],f32>
3888+
}
3889+
37833890
// -----
37843891

37853892
// CHECK-LABEL: func.func @torch.aten.max_pool2d$zero_pad_with_sliced_input(

0 commit comments

Comments
 (0)