@@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
6565 return result;
6666
6767 auto nanMode = op.getNanMode ();
68- if (nanMode == " PROPAGATE" )
68+ if (nanMode == NanPropagationMode:: PROPAGATE)
6969 return result;
7070
7171 // Unordered comparison of NaN against itself will always return true.
@@ -160,9 +160,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
160160 b = arith::ExtSIOp::create (rewriter, loc, rewriter.getI32Type (), b);
161161
162162 auto shiftAmount = shiftIsConstant ? shiftConst : args[2 ];
163- auto result = tosa::ApplyScaleOp::create (
164- rewriter, loc, rewriter.getI32Type (), a, b, shiftAmount,
165- rewriter.getStringAttr (" SINGLE_ROUND" ));
163+ auto roundingAttr = RoundingModeAttr::get (rewriter.getContext (),
164+ RoundingMode::SINGLE_ROUND);
165+ auto result =
166+ tosa::ApplyScaleOp::create (rewriter, loc, rewriter.getI32Type (), a,
167+ b, shiftAmount, roundingAttr);
166168
167169 return result;
168170 }
@@ -465,7 +467,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
465467
466468 // In the case of "PROPAGATE" semantics no compare and selection is
467469 // required.
468- if (nanMode == " PROPAGATE" )
470+ if (nanMode == NanPropagationMode:: PROPAGATE)
469471 return result;
470472
471473 // In the case of "IGNORE" semantics materialize a comparison
@@ -1192,7 +1194,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
11921194 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
11931195 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
11941196 // NaN propagation has no meaning for non floating point types.
1195- if (isa<FloatType>(elementTy) && op.getNanMode () == " IGNORE" ) {
1197+ if (isa<FloatType>(elementTy) &&
1198+ op.getNanMode () == NanPropagationMode::IGNORE) {
11961199 isNanIgnoreMode = true ;
11971200 // Because the TOSA spec requires the result be NaN iff all elements in
11981201 // the reduction are NaN we can't simply perform a compare and select.
@@ -1355,11 +1358,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13551358 unsigned rank = inputTy.getRank ();
13561359
13571360 // This is an illegal configuration. terminate and log an error
1358- if (op.getRoundingMode () == " INEXACT_ROUND" )
1361+ if (op.getRoundingMode () == RoundingMode:: INEXACT_ROUND)
13591362 return rewriter.notifyMatchFailure (
13601363 op, " tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
13611364 " currently supported" );
1362- if (op.getRoundingMode () == " DOUBLE_ROUND" && !op.getScale32 ())
1365+ if (op.getRoundingMode () == RoundingMode:: DOUBLE_ROUND && !op.getScale32 ())
13631366 return rewriter.notifyMatchFailure (
13641367 op, " tosa.rescale requires scale32 for double_round to be true" );
13651368
@@ -1405,11 +1408,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14051408 // is ever true.
14061409
14071410 bool doubleRound =
1408- op.getRoundingMode () == " DOUBLE_ROUND" &&
1411+ op.getRoundingMode () == RoundingMode:: DOUBLE_ROUND &&
14091412 llvm::any_of (shiftValues, [](int32_t v) { return v > 31 ; });
1410- StringAttr roundingMode = doubleRound
1411- ? rewriter.getStringAttr (" DOUBLE_ROUND" )
1412- : rewriter.getStringAttr (" SINGLE_ROUND" );
1413+ RoundingMode roundingMode =
1414+ doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
14131415
14141416 SmallVector<AffineMap> indexingMaps = {
14151417 rewriter.getMultiDimIdentityMap (rank)};
@@ -1592,7 +1594,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
15921594 auto input = op.getInput ();
15931595 auto inputTy = cast<RankedTensorType>(input.getType ());
15941596 auto resultTy = cast<RankedTensorType>(op.getType ());
1595- const bool isBilinear = op.getMode () == " BILINEAR" ;
1597+ const bool isBilinear = op.getMode () == ResizeMode:: BILINEAR;
15961598
15971599 auto inputH = inputTy.getDimSize (1 );
15981600 auto inputW = inputTy.getDimSize (2 );
@@ -1603,8 +1605,8 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
16031605 return rewriter.notifyMatchFailure (
16041606 op, " tosa.resize is not a pure 1x1->1x1 image operation" );
16051607
1606- // TODO(suderman): These string values should be declared the TOSA dialect.
1607- if (op. getMode () != " NEAREST_NEIGHBOR " && op.getMode () != " BILINEAR" )
1608+ if (op. getMode () != ResizeMode::NEAREST_NEIGHBOR &&
1609+ op.getMode () != ResizeMode:: BILINEAR)
16081610 return rewriter.notifyMatchFailure (
16091611 op, " tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR" );
16101612
@@ -1804,7 +1806,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
18041806 return rewriter.notifyMatchFailure (
18051807 op, " unable to get dynamic dimensions of tosa.resize" );
18061808
1807- if (op.getMode () != " NEAREST_NEIGHBOR" && op.getMode () != " BILINEAR" )
1809+ if (op.getMode () != ResizeMode::NEAREST_NEIGHBOR &&
1810+ op.getMode () != ResizeMode::BILINEAR)
18081811 return rewriter.notifyMatchFailure (
18091812 op, " tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR" );
18101813
@@ -1909,7 +1912,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
19091912 getIndexAndDeltaInt (ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
19101913 }
19111914
1912- if (op.getMode () == " NEAREST_NEIGHBOR" ) {
1915+ if (op.getMode () == ResizeMode:: NEAREST_NEIGHBOR) {
19131916 auto one = arith::ConstantOp::create (b, b.getI32IntegerAttr (1 ));
19141917
19151918 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
@@ -1945,7 +1948,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
19451948 linalg::YieldOp::create (b, result);
19461949 } else {
19471950 // The mode here must be BILINEAR.
1948- assert (op.getMode () == " BILINEAR" );
1951+ assert (op.getMode () == ResizeMode:: BILINEAR);
19491952
19501953 auto oneVal = arith::ConstantOp::create (b, b.getI32IntegerAttr (1 ));
19511954
@@ -2310,7 +2313,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
23102313
23112314 Value predicate;
23122315 if (isa<FloatType>(inElementTy)) {
2313- if (argmaxOp.getNanMode () == " IGNORE" ) {
2316+ if (argmaxOp.getNanMode () == NanPropagationMode:: IGNORE) {
23142317 // Only update index & max value for non NaN values. If all
23152318 // values are NaNs, the initial index will be return which is 0.
23162319 predicate = arith::CmpFOp::create (rewriter, nestedLoc,
0 commit comments