Skip to content

Commit 53fb23a

Browse files
committed
[mlir][tosa] Always generated pad_const and remove input_zp attr for PadOp
Co-authored-by: Udaya Ranga <[email protected]> Co-authored-by: Tai Ly <[email protected]> Signed-off-by: Jerry Ge <[email protected]> Change-Id: I2b7a0169b7ec1158d28779713ad125c061e04592
1 parent 299be61 commit 53fb23a

File tree

14 files changed

+118
-176
lines changed

14 files changed

+118
-176
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,6 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
197197
input, paddings);
198198
}]>;
199199

200-
def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
201-
(ins "Type":$outputType, "Value":$input, "Value":$paddings,
202-
"Value":$pad_value),
203-
[{
204-
buildExplicitValuePadOpWithQuantInfo($_builder, $_state, outputType,
205-
input, paddings, pad_value);
206-
}]>;
207-
208200
// Wrapper over base I32EnumAttr to set common fields.
209201
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
210202
: I32EnumAttr<name, description, cases> {

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ namespace tosa {
168168
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
169169
Type srcElemType, int64_t zp = 0);
170170

171+
// Create a pad-const const tensor with value of `val` of required data-type
172+
std::optional<Value> createPadConstTensor(OpBuilder &builder, Location loc,
173+
Value src, int32_t val = 0);
174+
171175
} // namespace tosa
172176
} // namespace mlir
173177

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,8 +1882,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
18821882
let arguments = (ins
18831883
Tosa_RankedTensor:$input1,
18841884
Tosa_Shape:$padding,
1885-
Optional<Tosa_ScalarTensor>:$pad_const,
1886-
OptionalAttr<I32Attr>:$input_zp
1885+
Tosa_ScalarTensor:$pad_const
18871886
);
18881887

18891888
let results = (outs
@@ -1895,10 +1894,8 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
18951894
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
18961895
];
18971896

1898-
let builders = [Tosa_PadOpQuantInfoBuilder,
1899-
Tosa_ExplicitValuePadOpQuantInfoBuilder];
1897+
let builders = [Tosa_PadOpQuantInfoBuilder];
19001898

1901-
let hasCanonicalizer = 1;
19021899
let hasFolder = 1;
19031900
let hasVerifier = 1;
19041901
}

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -350,29 +350,12 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
350350
}
351351

352352
ShapedType inputTy = cast<ShapedType>(input.getType());
353-
Type elementTy = inputTy.getElementType();
354353
int64_t rank = inputTy.getRank();
355354

356355
// Setup the default constantAttr.
357356

358-
Value padConstant;
359-
360-
if (padOp.getPadConst()) {
361-
padConstant = rewriter.createOrFold<tensor::ExtractOp>(
362-
loc, padOp.getPadConst(), ValueRange({}));
363-
} else {
364-
TypedAttr constantAttr;
365-
if (isa<FloatType>(elementTy)) {
366-
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
367-
} else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
368-
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
369-
} else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
370-
int64_t value = padOp.getInputZpAttr().getInt();
371-
constantAttr = rewriter.getIntegerAttr(elementTy, value);
372-
}
373-
if (constantAttr)
374-
padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
375-
}
357+
Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
358+
loc, padOp.getPadConst(), ValueRange({}));
376359

377360
if (!padConstant) {
378361
return rewriter.notifyMatchFailure(

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -175,53 +175,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175175
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176176
}
177177

178-
struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
179-
using OpRewritePattern::OpRewritePattern;
180-
181-
LogicalResult matchAndRewrite(tosa::PadOp op,
182-
PatternRewriter &rewriter) const override {
183-
if (op.getPadConst())
184-
return failure();
185-
186-
auto input = op.getInput1();
187-
auto padding = op.getPadding();
188-
189-
ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
190-
Type elementTy = inputTy.getElementType();
191-
192-
Attribute constantAttr;
193-
if (llvm::isa<FloatType>(elementTy)) {
194-
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
195-
} else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
196-
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
197-
} else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
198-
int64_t value = op.getInputZpAttr().getInt();
199-
constantAttr = rewriter.getIntegerAttr(elementTy, value);
200-
}
201-
202-
if (!constantAttr) {
203-
return rewriter.notifyMatchFailure(
204-
op,
205-
"tosa.pad to linalg lowering encountered an unknown element type");
206-
}
207-
208-
auto denseAttr = DenseElementsAttr::get(
209-
RankedTensorType::get({1}, elementTy), constantAttr);
210-
auto constantVal = rewriter.create<tosa::ConstOp>(
211-
op.getLoc(), denseAttr.getType(), denseAttr);
212-
213-
rewriter.replaceOpWithNewOp<tosa::PadOp>(
214-
op, op.getType(), ValueRange{input, padding, constantVal},
215-
op->getAttrs());
216-
return success();
217-
}
218-
};
219-
220-
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
221-
MLIRContext *context) {
222-
results.add<MaterializePadValue>(context);
223-
}
224-
225178
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
226179
using OpRewritePattern::OpRewritePattern;
227180

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,23 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
214214
}
215215
}
216216

217+
// Create a pad-const const tensor with value of `val` of required data-type
218+
std::optional<Value> mlir::tosa::createPadConstTensor(OpBuilder &builder,
219+
Location loc, Value src,
220+
int32_t val) {
221+
const auto srcType = getElementTypeOrSelf(src);
222+
const auto srcElemType = getElementTypeOrSelf(src);
223+
const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
224+
const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
225+
const auto padConstAttr{
226+
llvm::isa<FloatType>(srcElemType)
227+
? DenseElementsAttr::get(padConstEType,
228+
builder.getFloatAttr(srcElemType, val))
229+
: DenseElementsAttr::get(padConstEType,
230+
builder.getIntegerAttr(srcElemType, val))};
231+
return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
232+
}
233+
217234
//===----------------------------------------------------------------------===//
218235
// Tosa utilities.
219236
//===----------------------------------------------------------------------===//
@@ -679,30 +696,14 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
679696
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
680697
Type outputType, Value input,
681698
Value paddings) {
682-
result.addOperands({input, paddings});
683-
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
699+
const Location loc{result.location};
700+
int32_t zp{0};
701+
const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
684702
if (quantAttr) {
685-
result.addAttribute("input_zp",
686-
builder.getI32IntegerAttr(
687-
static_cast<int32_t>(quantAttr.getInputZp())));
688-
}
689-
result.types.push_back(outputType);
690-
}
691-
692-
/// This builder is called on TOSA pad operator when an explicit pad_const
693-
/// value is passed in. It also optionally constructs quantization_attr.
694-
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
695-
OperationState &result,
696-
Type outputType, Value input,
697-
Value paddings,
698-
Value padConst) {
699-
result.addOperands({input, paddings, padConst});
700-
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
701-
if (quantAttr) {
702-
result.addAttribute("input_zp",
703-
builder.getI32IntegerAttr(
704-
static_cast<int32_t>(quantAttr.getInputZp())));
703+
zp = static_cast<int32_t>(quantAttr.getInputZp());
705704
}
705+
const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
706+
result.addOperands({input, paddings, padConstOp.value()});
706707
result.types.push_back(outputType);
707708
}
708709

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -148,16 +148,16 @@ class TransposeConvStridedConverter
148148
return rewriter.notifyMatchFailure(
149149
op, "zero point must be zero for non-int8 integer types");
150150

151-
if (weightZpVal != 0) {
152-
weight = CreateOpAndInferShape<tosa::PadOp>(
153-
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
154-
weightPaddingVal, nullptr, rewriter.getI32IntegerAttr(weightZpVal));
155-
156-
} else {
157-
weight = CreateOpAndInferShape<tosa::PadOp>(
158-
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
159-
weightPaddingVal);
160-
}
151+
// construct pad_const values from zp values
152+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
153+
const Value inputPadConst =
154+
createPadConstTensor(builder, op->getLoc(), input, inputZpVal).value();
155+
const Value weightPadConst =
156+
createPadConstTensor(builder, op->getLoc(), input, weightZpVal).value();
157+
158+
weight = CreateOpAndInferShape<tosa::PadOp>(
159+
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
160+
weightPaddingVal, weightPadConst);
161161

162162
weightTy = cast<ShapedType>(weight.getType());
163163
weightHeight = weightTy.getDimSize(1);
@@ -169,7 +169,6 @@ class TransposeConvStridedConverter
169169
stride[0], weightWidth / stride[1],
170170
stride[1], inputChannels};
171171

172-
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
173172
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
174173
builder, UnrankedTensorType::get(weightETy), weight,
175174
getTosaConstShape(rewriter, loc, weightReshapeDims0));
@@ -206,15 +205,9 @@ class TransposeConvStridedConverter
206205
Value inputPaddingVal =
207206
getTosaConstShape(rewriter, op->getLoc(), inputPadding);
208207

209-
if (inputZpVal != 0) {
210-
input = CreateOpAndInferShape<tosa::PadOp>(
211-
rewriter, loc, UnrankedTensorType::get(inputETy), input,
212-
inputPaddingVal, nullptr, rewriter.getI32IntegerAttr(inputZpVal));
213-
} else {
214-
input = CreateOpAndInferShape<tosa::PadOp>(
215-
rewriter, loc, UnrankedTensorType::get(inputETy), input,
216-
inputPaddingVal);
217-
}
208+
input = CreateOpAndInferShape<tosa::PadOp>(
209+
rewriter, loc, UnrankedTensorType::get(inputETy), input,
210+
inputPaddingVal, inputPadConst);
218211

219212
// We use a zero bias as we need to broadcast the bias.
220213
auto zeroBias = rewriter.create<tosa::ConstOp>(

mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -498,35 +498,38 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
498498
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
499499
func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
500500
%0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
501+
%pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
501502
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
502503
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
503504
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
504505
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
505-
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
506+
// CHECK-DAG: [[CST:%.+]] = arith.constant 3.140000e+00 : f32
506507
// CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
507508
// CHECK: tensor.yield [[CST]]
508509
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
509-
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<4x9xf32>)
510+
%1 = "tosa.pad"(%arg0, %0, %pad_const) : (tensor<1x2xf32>, !tosa.shape<4>, tensor<1xf32>) -> (tensor<4x9xf32>)
510511
return %1 : tensor<4x9xf32>
511512
}
512513
// -----
513514

514515
func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
515516
%0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
516-
// CHECK: [[CST:%.+]] = arith.constant 0 : i32
517+
%pad_const = "tosa.const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
518+
// CHECK: [[CST:%.+]] = arith.constant 3 : i32
517519
// CHECK: tensor.pad
518520
// CHECK: tensor.yield [[CST]]
519-
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
521+
%1 = "tosa.pad"(%arg0, %0, %pad_const) : (tensor<1x2xi32>, !tosa.shape<4>, tensor<1xi32>) -> (tensor<4x9xi32>)
520522
return %1 : tensor<4x9xi32>
521523
}
522524
// -----
523525

524526
func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
525527
%0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
526-
// CHECK: [[CST:%.+]] = arith.constant 42 : i32
528+
%pad_const = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
529+
// CHECK: [[CST:%.+]] = arith.constant 0 : i32
527530
// CHECK: tensor.pad
528531
// CHECK: tensor.yield [[CST]]
529-
%1 = "tosa.pad"(%arg0, %0) {input_zp = 42 : i32} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
532+
%1 = "tosa.pad"(%arg0, %0, %pad_const) {input_zp = 42 : i32} : (tensor<1x2xi32>, !tosa.shape<4>, tensor<1xi32>) -> (tensor<4x9xi32>)
530533
return %1 : tensor<4x9xi32>
531534
}
532535

@@ -551,30 +554,32 @@ func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
551554

552555
func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
553556
%0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
557+
%pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
554558
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
555559
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
556560
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
557561
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
558-
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
562+
// CHECK-DAG: [[CST:%.+]] = arith.constant 3.140000e+00 : f32
559563
// CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
560564
// CHECK: tensor.yield [[CST]]
561565
// CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
562-
%1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
566+
%1 = "tosa.pad"(%arg0, %0, %pad_const) : (tensor<?x2xf32>, !tosa.shape<4>, tensor<1xf32>) -> (tensor<?x9xf32>)
563567
return %1 : tensor<?x9xf32>
564568
}
565569
// -----
566570

567571
func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
568572
%0 = tosa.const_shape {value = dense<[-1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
573+
%pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
569574
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant -1 : index
570575
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
571576
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
572577
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
573-
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
578+
// CHECK-DAG: [[CST:%.+]] = arith.constant 3.140000e+00 : f32
574579
// CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
575580
// CHECK: tensor.yield [[CST]]
576581
// CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
577-
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
582+
%1 = "tosa.pad"(%arg0, %0, %pad_const) : (tensor<1x2xf32>, !tosa.shape<4>, tensor<1xf32>) -> (tensor<?x9xf32>)
578583
return %1 : tensor<?x9xf32>
579584
}
580585

mlir/test/Dialect/Tosa/availability.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,9 +512,10 @@ func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -
512512
// CHECK-LABEL: pad
513513
func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
514514
%padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
515+
%pad_const = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
515516
// CHECK: profiles: [ [pro_int, pro_fp] ]
516517
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
517-
%0 = tosa.pad %arg0, %padding : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32>
518+
%0 = tosa.pad %arg0, %padding, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
518519
return %0 : tensor<13x21x3xf32>
519520
}
520521

0 commit comments

Comments
 (0)