Skip to content

Commit c1d01b2

Browse files
authored
[mlir][tosa] Add missing verifier for tosa.pad (#120934)
This PR adds a missing verifier for `tosa.pad`, ensuring that the padding shape matches [2*rank(shape1)] according to V1.0.0 Specification. Fixes #119840.
1 parent e7244d8 commit c1d01b2

File tree

15 files changed

+100
-115
lines changed

15 files changed

+100
-115
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,21 +1552,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
15521552
Example:
15531553

15541554
```mlir
1555-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
1556-
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
1555+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
1556+
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
15571557
```
15581558

15591559
Example 2:
15601560

15611561
```mlir
1562-
%0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
1563-
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
1562+
%0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
1563+
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
15641564
```
15651565
}];
15661566

15671567
let arguments = (ins
15681568
Tosa_RankedTensor:$input1,
1569-
Tosa_Int32Or64Tensor:$padding,
1569+
TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
15701570
Optional<Tosa_ScalarTensor>:$pad_const,
15711571
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
15721572
);

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,17 @@ def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
6565
// int8 : symmetric per tensor/per channel, signed
6666
// int16 : symmetric per tensor, signed
6767
//===----------------------------------------------------------------------===//
68-
def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
69-
Tosa_QuantizedType<"int4", [4, 0], 1>,
70-
Tosa_QuantizedType<"int8", [8, 0], 1>,
71-
Tosa_QuantizedType<"int16", [16, 0], 1>,
72-
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
68+
def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
69+
Tosa_QuantizedType<"int4", [4, 0], 1>,
70+
Tosa_QuantizedType<"int8", [8, 0], 1>,
71+
Tosa_QuantizedType<"int16", [16, 0], 1>,
72+
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
7373

7474
//===----------------------------------------------------------------------===//
7575
// Multi-category types.
7676
//===----------------------------------------------------------------------===//
7777
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
78-
"number">;
78+
"number">;
7979

8080
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
8181
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
@@ -112,7 +112,7 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
112112

113113
def Tosa_I1Tensor : TosaTensorOf<[I1]>;
114114
def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
115-
def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
115+
def Tosa_Int32Or64Tensor : TosaTensorOf<[Tosa_Int32Or64]>;
116116

117117
def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
118118

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -338,23 +338,19 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
338338
padOp, "tosa.pad was unable to determine the pad constant value.");
339339
}
340340

341-
Value lowIndex =
342-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
343-
Value highIndex =
344-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
345-
346341
SmallVector<OpFoldResult, 3> lowValues;
347342
SmallVector<OpFoldResult, 3> highValues;
348343

349344
lowValues.reserve(rank);
350345
highValues.reserve(rank);
351346

352347
for (int i = 0; i < rank; i++) {
353-
Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i);
348+
Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
349+
Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
354350
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
355-
loc, padding, ValueRange({inputIndex, lowIndex}));
351+
loc, padding, ValueRange({lowIndex}));
356352
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
357-
loc, padding, ValueRange({inputIndex, highIndex}));
353+
loc, padding, ValueRange({highIndex}));
358354

359355
lowVal = rewriter.createOrFold<arith::IndexCastOp>(
360356
loc, rewriter.getIndexType(), lowVal);

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
787787
return success();
788788
}
789789

790-
outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
790+
outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
791791
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
792792
return success();
793793
}
@@ -823,13 +823,17 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
823823
LogicalResult tosa::PadOp::verify() {
824824
RankedTensorType inputType = getInput1().getType();
825825
RankedTensorType outputType = getOutput().getType();
826-
TensorType paddingType = getPadding().getType();
826+
RankedTensorType paddingType = getPadding().getType();
827827

828828
if (inputType.getRank() != outputType.getRank())
829829
return emitOpError() << "expect same input and output tensor rank.";
830830

831-
if (paddingType.hasRank() && paddingType.getRank() != 2)
832-
return emitOpError() << "expect 'padding' tensor rank equal to 2.";
831+
if (!paddingType.isDynamicDim(0) &&
832+
paddingType.getDimSize(0) != inputType.getRank() * 2)
833+
return emitOpError() << "expected padding tensor dim 0 to have size "
834+
<< inputType.getRank() * 2
835+
<< " (2*rank(shape1)) but got size "
836+
<< paddingType.getDimSize(0);
833837

834838
return success();
835839
}

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
8181
}
8282
}
8383

84-
auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type());
84+
auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
8585
auto padSize =
8686
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
8787
Value padSizeVal =

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
108108
}
109109
}
110110

111-
auto padSizeTy = RankedTensorType::get({5, 2}, rewriter.getI64Type());
111+
auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
112112
auto padSize =
113113
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
114114
Value padSizeVal =

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class TransposeConvStridedConverter
139139
weightPadding[5] =
140140
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
141141
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
142-
RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
142+
RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
143143
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
144144
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
145145

@@ -202,7 +202,7 @@ class TransposeConvStridedConverter
202202
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
203203

204204
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
205-
RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
205+
RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
206206

207207
Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
208208
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
@@ -314,7 +314,7 @@ class TransposeConvStridedConverter
314314
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
315315

316316
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
317-
RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
317+
RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
318318

319319
Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
320320
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);

mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -459,85 +459,65 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
459459
// CHECK-LABEL: @pad_float
460460
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
461461
func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
462-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
463-
// TODO: Output contains multiple "arith.constant 1 : index".
464-
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
465-
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
466-
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
467-
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
462+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
468463
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
469-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
464+
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
470465
// CHECK: tensor.yield [[CST]]
471466
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
472-
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
467+
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
473468
return %1 : tensor<4x9xf32>
474469
}
475470

476471
func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
477-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
472+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
478473
// CHECK: [[CST:%.+]] = arith.constant 0 : i32
479474
// CHECK: tensor.pad
480475
// CHECK: tensor.yield [[CST]]
481-
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
476+
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>)
482477
return %1 : tensor<4x9xi32>
483478
}
484479

485480
func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
486-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
481+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
487482
// CHECK: [[CST:%.+]] = arith.constant 42 : i32
488483
// CHECK: tensor.pad
489484
// CHECK: tensor.yield [[CST]]
490-
%1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
485+
%1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>)
491486
return %1 : tensor<4x9xi32>
492487
}
493488

494489
// -----
495490

496491
func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
497-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
498-
// TODO: Output contains multiple "arith.constant 1 : index".
499-
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
500-
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
501-
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
502-
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
492+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
503493
// CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
504-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
494+
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
505495
// CHECK: tensor.yield [[CST]]
506496
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
507497
%1 = arith.constant dense<42.0> : tensor<f32>
508-
%2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>) -> (tensor<4x9xf32>)
498+
%2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<4xi32>, tensor<f32>) -> (tensor<4x9xf32>)
509499
return %2 : tensor<4x9xf32>
510500
}
511501

512502
// -----
513503

514504
func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
515-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
516-
// TODO: Output contains multiple "arith.constant 1 : index".
517-
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
518-
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
519-
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
520-
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
505+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
521506
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
522-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
507+
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
523508
// CHECK: tensor.yield [[CST]]
524509
// CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
525-
%1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
510+
%1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
526511
return %1 : tensor<?x9xf32>
527512
}
528513

529514
func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
530-
%0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
531-
// TODO: Output contains multiple "arith.constant 1 : index".
532-
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
533-
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
534-
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
535-
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
515+
%0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
536516
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
537-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
517+
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
538518
// CHECK: tensor.yield [[CST]]
539519
// CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
540-
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
520+
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
541521
return %1 : tensor<?x9xf32>
542522
}
543523

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
210210
// CHECK-LABEL: @pad_noop
211211
func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
212212
// CHECK: return %arg0
213-
%0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
214-
%1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
213+
%0 = "tosa.const"() { value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
214+
%1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
215215
return %1 : tensor<?x?xf32>
216216
}
217217

@@ -221,8 +221,8 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
221221
func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
222222
// CHECK: %[[PAD:.+]] = tosa.pad
223223
// CHECK: return %[[PAD]]
224-
%0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
225-
%1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
224+
%0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<4xi32>} : () -> tensor<4xi32>
225+
%1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
226226
return %1 : tensor<?x?xf32>
227227
}
228228

@@ -234,42 +234,39 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
234234
// CHECK: return %[[PAD]]
235235

236236
%c0_i32 = arith.constant 0 : i32
237-
%shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32>
237+
%shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<2xi32>
238238

239-
%0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32>
239+
%0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<2xi32>) -> tensor<?xf32>
240240
return %0 : tensor<?xf32>
241241
}
242242

243243
// -----
244244

245245
// CHECK-LABEL: @pad_determine_val_i32
246-
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
246+
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
247247
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
248248
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
249-
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
250-
%1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
249+
%1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
251250
return %1 : tensor<?x?xi32>
252251
}
253252

254253
// -----
255254

256255
// CHECK-LABEL: @pad_determine_val_f32
257-
func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
256+
func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<4xi32>) -> tensor<?x?xf32> {
258257
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
259258
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
260-
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
261-
%1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
259+
%1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
262260
return %1 : tensor<?x?xf32>
263261
}
264262

265263
// -----
266264

267265
// CHECK-LABEL: @pad_determine_val_quant
268-
func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
266+
func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
269267
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
270268
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
271-
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
272-
%1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
269+
%1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
273270
return %1 : tensor<?x?xi32>
274271
}
275272

0 commit comments

Comments
 (0)