Skip to content

Commit 51f92bc

Browse files
committed
[AutoBump] Merge with fixes of c1d01b2 (Jan 08)
2 parents 0676d76 + c1d01b2 commit 51f92bc

File tree

17 files changed

+124
-139
lines changed

17 files changed

+124
-139
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,21 +1576,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
15761576
Example:
15771577

15781578
```mlir
1579-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
1580-
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
1579+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
1580+
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
15811581
```
15821582

15831583
Example 2:
15841584

15851585
```mlir
1586-
%0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
1587-
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
1586+
%0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
1587+
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
15881588
```
15891589
}];
15901590

15911591
let arguments = (ins
15921592
Tosa_RankedTensor:$input1,
1593-
Tosa_Int32Or64Tensor:$padding,
1593+
TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
15941594
Optional<Tosa_ScalarTensor>:$pad_const,
15951595
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
15961596
);

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,17 @@ def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
6565
// int8 : symmetric per tensor/per channel, signed
6666
// int16 : symmetric per tensor, signed
6767
//===----------------------------------------------------------------------===//
68-
def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
69-
Tosa_QuantizedType<"int4", [4, 0], 1>,
70-
Tosa_QuantizedType<"int8", [8, 0], 1>,
71-
Tosa_QuantizedType<"int16", [16, 0], 1>,
72-
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
68+
def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
69+
Tosa_QuantizedType<"int4", [4, 0], 1>,
70+
Tosa_QuantizedType<"int8", [8, 0], 1>,
71+
Tosa_QuantizedType<"int16", [16, 0], 1>,
72+
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
7373

7474
//===----------------------------------------------------------------------===//
7575
// Multi-category types.
7676
//===----------------------------------------------------------------------===//
7777
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
78-
"number">;
78+
"number">;
7979

8080
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
8181
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
@@ -115,7 +115,7 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
115115
def Tosa_I1Tensor : TosaTensorOf<[I1]>;
116116
def Tosa_IntTensor : TensorOf<[Tosa_Int]>;
117117
def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
118-
def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
118+
def Tosa_Int32Or64Tensor : TosaTensorOf<[Tosa_Int32Or64]>;
119119

120120
def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
121121

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -338,23 +338,19 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
338338
padOp, "tosa.pad was unable to determine the pad constant value.");
339339
}
340340

341-
Value lowIndex =
342-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
343-
Value highIndex =
344-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
345-
346341
SmallVector<OpFoldResult, 3> lowValues;
347342
SmallVector<OpFoldResult, 3> highValues;
348343

349344
lowValues.reserve(rank);
350345
highValues.reserve(rank);
351346

352347
for (int i = 0; i < rank; i++) {
353-
Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i);
348+
Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
349+
Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
354350
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
355-
loc, padding, ValueRange({inputIndex, lowIndex}));
351+
loc, padding, ValueRange({lowIndex}));
356352
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
357-
loc, padding, ValueRange({inputIndex, highIndex}));
353+
loc, padding, ValueRange({highIndex}));
358354

359355
lowVal = rewriter.createOrFold<arith::IndexCastOp>(
360356
loc, rewriter.getIndexType(), lowVal);

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
822822
return success();
823823
}
824824

825-
outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
825+
outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
826826
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
827827
return success();
828828
}
@@ -858,13 +858,17 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
858858
LogicalResult tosa::PadOp::verify() {
859859
RankedTensorType inputType = getInput1().getType();
860860
RankedTensorType outputType = getOutput().getType();
861-
TensorType paddingType = getPadding().getType();
861+
RankedTensorType paddingType = getPadding().getType();
862862

863863
if (inputType.getRank() != outputType.getRank())
864864
return emitOpError() << "expect same input and output tensor rank.";
865865

866-
if (paddingType.hasRank() && paddingType.getRank() != 2)
867-
return emitOpError() << "expect 'padding' tensor rank equal to 2.";
866+
if (!paddingType.isDynamicDim(0) &&
867+
paddingType.getDimSize(0) != inputType.getRank() * 2)
868+
return emitOpError() << "expected padding tensor dim 0 to have size "
869+
<< inputType.getRank() * 2
870+
<< " (2*rank(shape1)) but got size "
871+
<< paddingType.getDimSize(0);
868872

869873
return success();
870874
}

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
8181
}
8282
}
8383

84-
auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type());
84+
auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
8585
auto padSize =
8686
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
8787
Value padSizeVal =

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
108108
}
109109
}
110110

111-
auto padSizeTy = RankedTensorType::get({5, 2}, rewriter.getI64Type());
111+
auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
112112
auto padSize =
113113
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
114114
Value padSizeVal =

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class TransposeConvStridedConverter
139139
weightPadding[5] =
140140
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
141141
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
142-
RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
142+
RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
143143
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
144144
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
145145

@@ -202,7 +202,7 @@ class TransposeConvStridedConverter
202202
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
203203

204204
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
205-
RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
205+
RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
206206

207207
Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
208208
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
@@ -314,7 +314,7 @@ class TransposeConvStridedConverter
314314
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
315315

316316
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
317-
RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
317+
RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
318318

319319
Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
320320
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);

mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -484,85 +484,65 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
484484
// CHECK-LABEL: @pad_float
485485
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
486486
func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
487-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
488-
// TODO: Output contains multiple "arith.constant 1 : index".
489-
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
490-
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
491-
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
492-
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
487+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
493488
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
494-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
489+
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
495490
// CHECK: tensor.yield [[CST]]
496491
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
497-
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
492+
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
498493
return %1 : tensor<4x9xf32>
499494
}
500495

501496
func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
502-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
497+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
503498
// CHECK: [[CST:%.+]] = arith.constant 0 : i32
504499
// CHECK: tensor.pad
505500
// CHECK: tensor.yield [[CST]]
506-
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
501+
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>)
507502
return %1 : tensor<4x9xi32>
508503
}
509504

510505
func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
511-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
506+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
512507
// CHECK: [[CST:%.+]] = arith.constant 42 : i32
513508
// CHECK: tensor.pad
514509
// CHECK: tensor.yield [[CST]]
515-
%1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
510+
%1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>)
516511
return %1 : tensor<4x9xi32>
517512
}
518513

519514
// -----
520515

521516
func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
522-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
523-
// TODO: Output contains multiple "arith.constant 1 : index".
524-
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
525-
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
526-
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
527-
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
517+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
528518
// CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
529-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
519+
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
530520
// CHECK: tensor.yield [[CST]]
531521
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
532522
%1 = arith.constant dense<42.0> : tensor<f32>
533-
%2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>) -> (tensor<4x9xf32>)
523+
%2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<4xi32>, tensor<f32>) -> (tensor<4x9xf32>)
534524
return %2 : tensor<4x9xf32>
535525
}
536526

537527
// -----
538528

539529
func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
540-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
541-
// TODO: Output contains multiple "arith.constant 1 : index".
542-
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
543-
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
544-
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
545-
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
530+
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
546531
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
547-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
532+
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
548533
// CHECK: tensor.yield [[CST]]
549534
// CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
550-
%1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
535+
%1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
551536
return %1 : tensor<?x9xf32>
552537
}
553538

554539
func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
555-
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
556-
// TODO: Output contains multiple "arith.constant 1 : index".
557-
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
558-
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
559-
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
560-
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
540+
%0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
561541
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
562-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
542+
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
563543
// CHECK: tensor.yield [[CST]]
564544
// CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
565-
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
545+
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
566546
return %1 : tensor<?x9xf32>
567547
}
568548

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
305305
// CHECK-LABEL: @pad_noop
306306
func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
307307
// CHECK: return %arg0
308-
%0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
309-
%1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
308+
%0 = "tosa.const"() { value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
309+
%1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
310310
return %1 : tensor<?x?xf32>
311311
}
312312

@@ -316,8 +316,8 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
316316
func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
317317
// CHECK: %[[PAD:.+]] = tosa.pad
318318
// CHECK: return %[[PAD]]
319-
%0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
320-
%1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
319+
%0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<4xi32>} : () -> tensor<4xi32>
320+
%1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
321321
return %1 : tensor<?x?xf32>
322322
}
323323

@@ -329,42 +329,39 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
329329
// CHECK: return %[[PAD]]
330330

331331
%c0_i32 = arith.constant 0 : i32
332-
%shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32>
332+
%shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<2xi32>
333333

334-
%0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32>
334+
%0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<2xi32>) -> tensor<?xf32>
335335
return %0 : tensor<?xf32>
336336
}
337337

338338
// -----
339339

340340
// CHECK-LABEL: @pad_determine_val_i32
341-
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
341+
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
342342
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
343343
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
344-
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
345-
%1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
344+
%1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
346345
return %1 : tensor<?x?xi32>
347346
}
348347

349348
// -----
350349

351350
// CHECK-LABEL: @pad_determine_val_f32
352-
func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
351+
func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<4xi32>) -> tensor<?x?xf32> {
353352
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
354353
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
355-
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
356-
%1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
354+
%1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
357355
return %1 : tensor<?x?xf32>
358356
}
359357

360358
// -----
361359

362360
// CHECK-LABEL: @pad_determine_val_quant
363-
func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
361+
func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
364362
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
365363
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
366-
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
367-
%1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
364+
%1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
368365
return %1 : tensor<?x?xi32>
369366
}
370367

mlir/test/Dialect/Tosa/constant-pad-multi-user.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ func.func @pad_int32_multi_user() -> (tensor<2x2xi32>, tensor<5x5xi32>) {
66
// CHECK-DAG: "tosa.const"() <{value = dense<2> : tensor<2x2xi32>}>
77
// CHECK-NOT: "tosa.pad"
88
%0 = "tosa.const"() {value = dense<2> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
9-
%5 = "tosa.const"() {value = dense<[[1, 2], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64>
9+
%5 = "tosa.const"() {value = dense<[1, 2, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
1010
%6 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
11-
%1 = "tosa.pad"(%0, %5, %6) : (tensor<2x2xi32>, tensor<2x2xi64>, tensor<i32>) -> tensor<5x5xi32>
11+
%1 = "tosa.pad"(%0, %5, %6) : (tensor<2x2xi32>, tensor<4xi64>, tensor<i32>) -> tensor<5x5xi32>
1212
return %0, %1 : tensor<2x2xi32>, tensor<5x5xi32>
1313
}

0 commit comments

Comments
 (0)