Skip to content

Commit f474f76

Browse files
authored
[mlir][tosa] Convert TOSA enumerations from StringBasedAttr to Tosa_I32EnumAttr (#152856)
Fixes #152129 Use `Tosa_I32EnumAttr` instead of `StringBasedAttr` to represent Tosa enumerations. This PR replaces `StringBasedAttr` with `Tosa_I32EnumAttr` to represent Tosa enumerations as per the specification. The intent is to make the IR and C++ APIs more type-safe and prevent fragile string comparisons in passes. Enumerations rewritten are: - `Tosa_ResizeTypeAttr` - `Tosa_NanPropagationAttr` - `Tosa_RoundingTypeAttr` **BREAKING CHANGE**: This commit changes attribute assembly and the C++ API surface for the listed attributes. Code that previously used `StringAttr` for these fields must now be updated to use the new enum representation. In `.mlir` files, replace string literals with the enum assembly (e.g. `mode = #tosa.resize_type<BILINEAR>`). In C++, update call sites to either pass the generated enum (e.g. `::mlir::tosa::RoundingType::SINGLE_ROUND`) into builder overloads or construct the typed attribute with `tosa::RoundingTypeAttr::get(context, /*enum*/)` and pass that.
1 parent c118215 commit f474f76

30 files changed

+264
-244
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,34 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
381381
let instance = "ref";
382382
}
383383

384+
//===----------------------------------------------------------------------===//
385+
// Iterable attributes.
386+
//===----------------------------------------------------------------------===//
387+
// Defined in `section 3. Enumerations` of the TOSA specification.
388+
389+
def Tosa_RESIZE_NEAREST_NEIGHBOR : I32EnumAttrCase<"NEAREST_NEIGHBOR", 1>;
390+
def Tosa_RESIZE_BILINEAR : I32EnumAttrCase<"BILINEAR", 2>;
391+
392+
def Tosa_ResizeModeAttr
393+
: Tosa_I32EnumAttr<"ResizeMode", "Supported resize/upsampling strategies", "resize_mode",
394+
[Tosa_RESIZE_NEAREST_NEIGHBOR, Tosa_RESIZE_BILINEAR]>;
395+
396+
def Tosa_NANPROPAGATION_PROPAGATE : I32EnumAttrCase<"PROPAGATE", 1>;
397+
def Tosa_NANPROPAGATION_IGNORE : I32EnumAttrCase<"IGNORE", 2>;
398+
399+
def Tosa_NanPropagationModeAttr
400+
: Tosa_I32EnumAttr<"NanPropagationMode", "Supported NaN propagation strategies", "nan_mode",
401+
[Tosa_NANPROPAGATION_PROPAGATE, Tosa_NANPROPAGATION_IGNORE]>;
402+
403+
def Tosa_ROUNDING_SINGLE_ROUND : I32EnumAttrCase<"SINGLE_ROUND", 1>;
404+
def Tosa_ROUNDING_INEXACT_ROUND : I32EnumAttrCase<"INEXACT_ROUND", 2>;
405+
def Tosa_ROUNDING_DOUBLE_ROUND : I32EnumAttrCase<"DOUBLE_ROUND", 3>;
406+
407+
def Tosa_RoundingModeAttr
408+
: Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
409+
[Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
410+
411+
384412
//===----------------------------------------------------------------------===//
385413
// TOSA Interfaces.
386414
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
4343
let arguments = (ins
4444
Tosa_TensorAtLeast1D: $input,
4545
I32Attr: $axis,
46-
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
46+
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
4747
);
4848

4949
let results = (outs
@@ -357,7 +357,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
357357
Tosa_IntArrayAttr2:$kernel,
358358
Tosa_IntArrayAttr2:$stride,
359359
Tosa_IntArrayAttr4:$pad,
360-
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
360+
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
361361
);
362362

363363
let results = (outs
@@ -487,7 +487,7 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
487487
Tosa_Tensor:$input,
488488
Tosa_IntOrFloatAttr:$min_val,
489489
Tosa_IntOrFloatAttr:$max_val,
490-
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
490+
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
491491
);
492492

493493
let results = (outs
@@ -935,7 +935,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
935935
let arguments = (ins
936936
Tosa_Tensor:$input1,
937937
Tosa_Tensor:$input2,
938-
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
938+
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
939939
);
940940

941941
let results = (outs
@@ -964,7 +964,7 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
964964
let arguments = (ins
965965
Tosa_Tensor:$input1,
966966
Tosa_Tensor:$input2,
967-
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
967+
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
968968
);
969969

970970
let results = (outs
@@ -1711,7 +1711,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
17111711
let arguments = (ins
17121712
Tosa_TensorAtLeast1D:$input,
17131713
I32Attr:$axis,
1714-
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
1714+
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
17151715
);
17161716

17171717
let results = (outs
@@ -1751,7 +1751,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
17511751
let arguments = (ins
17521752
Tosa_TensorAtLeast1D:$input,
17531753
I32Attr:$axis,
1754-
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
1754+
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
17551755
);
17561756

17571757
let results = (outs
@@ -2224,7 +2224,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
22242224
Rank4TosaShape:$scale,
22252225
Rank2TosaShape:$offset,
22262226
Rank2TosaShape:$border,
2227-
Tosa_ResizeTypeAttr:$mode
2227+
Tosa_ResizeModeAttr:$mode
22282228
);
22292229

22302230
let results = (outs
@@ -2374,7 +2374,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
23742374
Tosa_ScalarIntOrFloatTensor:$input_zp,
23752375
Tosa_ScalarIntOrFloatTensor:$output_zp,
23762376
BoolAttr:$scale32,
2377-
Tosa_RoundingTypeAttr:$rounding_mode,
2377+
Tosa_RoundingModeAttr:$rounding_mode,
23782378
BoolAttr:$per_channel,
23792379
BoolAttr: $input_unsigned,
23802380
BoolAttr: $output_unsigned

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -234,29 +234,6 @@ def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
234234

235235
def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
236236

237-
//===----------------------------------------------------------------------===//
238-
// Iterable attributes.
239-
//===----------------------------------------------------------------------===//
240-
// Defined in `section 3. Enumerations` of the TOSA specification.
241-
242-
// Supported regimes for tosa.resize.
243-
def Tosa_ResizeTypeAttr : StringBasedAttr<
244-
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
245-
"::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
246-
"Supported resize/upsampling strategies">;
247-
248-
// Supported NaN propagation strategies.
249-
def Tosa_NanPropagationAttr : StringBasedAttr<
250-
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
251-
"::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
252-
"Supported NaN propagation strategies">;
253-
254-
// Rounding mode for tosa.rescale
255-
def Tosa_RoundingTypeAttr : StringBasedAttr<
256-
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\" || " #
257-
"::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
258-
"::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
259-
"Supported rounding modes">;
260237

261238
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
262239

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def Tosa_ApplyScaleOp :
4444
Tosa_IntLike:$value,
4545
Tosa_IntLike:$multiplier,
4646
Tosa_Int8Like:$shift,
47-
Tosa_RoundingTypeAttr:$rounding_mode
47+
Tosa_RoundingModeAttr:$rounding_mode
4848
);
4949

5050
let results = (outs

mlir/lib/Conversion/TosaToArith/TosaToArith.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ class ApplyScaleGenericOpConverter
6464

6565
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
6666
PatternRewriter &rewriter) const final {
67-
StringRef roundingMode = op.getRoundingMode();
68-
if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
67+
RoundingMode roundingMode = op.getRoundingMode();
68+
if (roundingMode != RoundingMode::DOUBLE_ROUND &&
69+
roundingMode != RoundingMode::SINGLE_ROUND) {
6970
return failure();
7071
}
7172

@@ -100,7 +101,7 @@ class ApplyScaleGenericOpConverter
100101
multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
101102

102103
// Apply double rounding if necessary.
103-
if (op.getRoundingMode() == "DOUBLE_ROUND") {
104+
if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
104105
int64_t roundInt = 1 << 30;
105106
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
106107
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -129,8 +130,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
129130

130131
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
131132
PatternRewriter &rewriter) const final {
132-
StringRef roundingMode = op.getRoundingMode();
133-
if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
133+
RoundingMode roundingMode = op.getRoundingMode();
134+
if (roundingMode != RoundingMode::DOUBLE_ROUND &&
135+
roundingMode != RoundingMode::SINGLE_ROUND) {
134136
return failure();
135137
}
136138

@@ -179,7 +181,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
179181
arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
180182

181183
// Conditionally perform our double round.
182-
if (op.getRoundingMode() == "DOUBLE_ROUND") {
184+
if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
183185
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
184186
Value valuePositive = arith::CmpIOp::create(
185187
rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
6565
return result;
6666

6767
auto nanMode = op.getNanMode();
68-
if (nanMode == "PROPAGATE")
68+
if (nanMode == NanPropagationMode::PROPAGATE)
6969
return result;
7070

7171
// Unordered comparison of NaN against itself will always return true.
@@ -160,9 +160,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
160160
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
161161

162162
auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
163-
auto result = tosa::ApplyScaleOp::create(
164-
rewriter, loc, rewriter.getI32Type(), a, b, shiftAmount,
165-
rewriter.getStringAttr("SINGLE_ROUND"));
163+
auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
164+
RoundingMode::SINGLE_ROUND);
165+
auto result =
166+
tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
167+
b, shiftAmount, roundingAttr);
166168

167169
return result;
168170
}
@@ -465,7 +467,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
465467

466468
// In the case of "PROPAGATE" semantics no compare and selection is
467469
// required.
468-
if (nanMode == "PROPAGATE")
470+
if (nanMode == NanPropagationMode::PROPAGATE)
469471
return result;
470472

471473
// In the case of "IGNORE" semantics materialize a comparison
@@ -1192,7 +1194,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
11921194
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
11931195
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
11941196
// NaN propagation has no meaning for non floating point types.
1195-
if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
1197+
if (isa<FloatType>(elementTy) &&
1198+
op.getNanMode() == NanPropagationMode::IGNORE) {
11961199
isNanIgnoreMode = true;
11971200
// Because the TOSA spec requires the result be NaN iff all elements in
11981201
// the reduction are NaN we can't simply perform a compare and select.
@@ -1355,11 +1358,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13551358
unsigned rank = inputTy.getRank();
13561359

13571360
// This is an illegal configuration. terminate and log an error
1358-
if (op.getRoundingMode() == "INEXACT_ROUND")
1361+
if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
13591362
return rewriter.notifyMatchFailure(
13601363
op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
13611364
"currently supported");
1362-
if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
1365+
if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
13631366
return rewriter.notifyMatchFailure(
13641367
op, "tosa.rescale requires scale32 for double_round to be true");
13651368

@@ -1405,11 +1408,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14051408
// is ever true.
14061409

14071410
bool doubleRound =
1408-
op.getRoundingMode() == "DOUBLE_ROUND" &&
1411+
op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
14091412
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1410-
StringAttr roundingMode = doubleRound
1411-
? rewriter.getStringAttr("DOUBLE_ROUND")
1412-
: rewriter.getStringAttr("SINGLE_ROUND");
1413+
RoundingMode roundingMode =
1414+
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
14131415

14141416
SmallVector<AffineMap> indexingMaps = {
14151417
rewriter.getMultiDimIdentityMap(rank)};
@@ -1592,7 +1594,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
15921594
auto input = op.getInput();
15931595
auto inputTy = cast<RankedTensorType>(input.getType());
15941596
auto resultTy = cast<RankedTensorType>(op.getType());
1595-
const bool isBilinear = op.getMode() == "BILINEAR";
1597+
const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
15961598

15971599
auto inputH = inputTy.getDimSize(1);
15981600
auto inputW = inputTy.getDimSize(2);
@@ -1603,8 +1605,8 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
16031605
return rewriter.notifyMatchFailure(
16041606
op, "tosa.resize is not a pure 1x1->1x1 image operation");
16051607

1606-
// TODO(suderman): These string values should be declared the TOSA dialect.
1607-
if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1608+
if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1609+
op.getMode() != ResizeMode::BILINEAR)
16081610
return rewriter.notifyMatchFailure(
16091611
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
16101612

@@ -1804,7 +1806,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
18041806
return rewriter.notifyMatchFailure(
18051807
op, "unable to get dynamic dimensions of tosa.resize");
18061808

1807-
if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1809+
if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1810+
op.getMode() != ResizeMode::BILINEAR)
18081811
return rewriter.notifyMatchFailure(
18091812
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
18101813

@@ -1909,7 +1912,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
19091912
getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
19101913
}
19111914

1912-
if (op.getMode() == "NEAREST_NEIGHBOR") {
1915+
if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
19131916
auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
19141917

19151918
auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
@@ -1945,7 +1948,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
19451948
linalg::YieldOp::create(b, result);
19461949
} else {
19471950
// The mode here must be BILINEAR.
1948-
assert(op.getMode() == "BILINEAR");
1951+
assert(op.getMode() == ResizeMode::BILINEAR);
19491952

19501953
auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
19511954

@@ -2310,7 +2313,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
23102313

23112314
Value predicate;
23122315
if (isa<FloatType>(inElementTy)) {
2313-
if (argmaxOp.getNanMode() == "IGNORE") {
2316+
if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
23142317
// Only update index & max value for non NaN values. If all
23152318
// values are NaNs, the initial index will be return which is 0.
23162319
predicate = arith::CmpFOp::create(rewriter, nestedLoc,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
803803
dilationAttr);
804804

805805
rewriter.setInsertionPointAfter(op);
806-
StringRef nanMode = op.getNanMode();
806+
NanPropagationMode nanMode = op.getNanMode();
807807
rewriter.replaceOp(op, resultOp);
808808

809809
// NaN propagation has no meaning for non floating point types.
@@ -817,7 +817,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
817817
// we've already produced a named op we will just take its body and modify
818818
// it to include the appropriate checks. If the current value is NaN the
819819
// old value of pool will be taken otherwise we use the result.
820-
if (nanMode == "IGNORE") {
820+
if (nanMode == NanPropagationMode::IGNORE) {
821821
auto genericOp = linalg::GenericOp::create(
822822
rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
823823
resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
@@ -1040,11 +1040,13 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
10401040
rewriter, loc, rewriter.getI8IntegerAttr(30));
10411041
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
10421042

1043-
auto scaled =
1044-
tosa::ApplyScaleOp::create(
1045-
rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
1046-
shift, rewriter.getStringAttr("SINGLE_ROUND"))
1047-
.getResult();
1043+
auto roundingAttr = RoundingModeAttr::get(
1044+
rewriter.getContext(), RoundingMode::SINGLE_ROUND);
1045+
1046+
auto scaled = tosa::ApplyScaleOp::create(
1047+
rewriter, loc, rewriter.getI32Type(), poolVal,
1048+
multiplier, shift, roundingAttr)
1049+
.getResult();
10481050

10491051
// If we have quantization information we need to apply output
10501052
// zeropoint.

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
556556
// Check we have a valid NaN propagation combination.
557557
const auto opNanMode = op.getNanMode();
558558
const auto clampNanMode = clampOp.getNanMode();
559-
if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
559+
if (opNanMode == NanPropagationMode::IGNORE &&
560+
clampNanMode == NanPropagationMode::PROPAGATE)
560561
return failure();
561562

562563
auto maxValAttr = op.getMaxValAttr();
@@ -637,10 +638,16 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
637638
}
638639
}
639640

641+
auto newMode = (opNanMode != clampNanMode)
642+
? tosa::NanPropagationMode::IGNORE
643+
: opNanMode;
644+
645+
auto newModeAttr =
646+
NanPropagationModeAttr::get(rewriter.getContext(), newMode);
647+
640648
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
641649
op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
642-
rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
643-
: opNanMode));
650+
newModeAttr);
644651
return success();
645652
}
646653
};

0 commit comments

Comments
 (0)