3636using namespace mlir ;
3737using namespace mlir ::tosa;
3838
39+ // Helper function to materialize the semantically correct compare and select
40+ // operations a reduction operation with a specific NaN propagation mode.
41+ //
42+ // In the case of "PROPAGATE" semantics no compare and selection is required and
43+ // this function does nothing.
44+ //
45+ // In the case of "IGNORE" semantics this function materializes a comparison of
46+ // the current operand to the reduction which will return true for a NaN
47+ // argument and then selects between the initial reduction value and the
48+ // calculated result based on whether the argument is NaN or not. In pseudo
49+ // code:
50+ //
51+ // reduce<op>(x, init):
52+ // result = op(init, x)
53+ // return init if x == NaN else result
54+ static Value materializeReductionNanCheckIfRequired (Operation *op,
55+ PatternRewriter &rewriter,
56+ Value in, Value init,
57+ Value result) {
58+ const auto nanMode = getNanMode (op, rewriter);
59+ if (!nanMode)
60+ return {};
61+
62+ if (*nanMode == " PROPAGATE" )
63+ return result;
64+
65+ assert (*nanMode == " IGNORE" && " Unhandled nan-propagation mode" );
66+
67+ // Unordered comparison of NaN against itself will always return true.
68+ Value isNaN = rewriter.create <arith::CmpFOp>(
69+ op->getLoc (), arith::CmpFPredicate::UNO, in, in);
70+ return rewriter.create <arith::SelectOp>(op->getLoc (), isNaN, init, result);
71+ }
72+
73+ // Helper function to materialize the semantically correct compare and select
74+ // operations a binary operation with a specific NaN propagation mode.
75+ //
76+ // In the case of "PROPAGATE" semantics no compare and selection is required and
77+ // this function does nothing.
78+ //
79+ // In the case of "IGNORE" semantics this function materializes a comparison of
80+ // the current operands to the op which will return true for any NaN
81+ // argument and then selects between the non-NaN operation argument and the
82+ // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
83+ // code:
84+ //
85+ // binary<op>(lhs, rhs):
86+ // result = op(lhs, rhs)
87+ // if lhs == NaN return rhs
88+ // if rhs == NaN return lhs
89+ // return result
90+ static Value materializeBinaryNanCheckIfRequired (Operation *op,
91+ PatternRewriter &rewriter,
92+ Value lhs, Value rhs,
93+ Value result) {
94+ const auto nanMode = getNanMode (op, rewriter);
95+ if (!nanMode)
96+ return {};
97+
98+ if (*nanMode == " PROPAGATE" )
99+ return result;
100+
101+ assert (*nanMode == " IGNORE" && " Unhandled nan-propagation mode" );
102+
103+ // Unordered comparison of NaN against itself will always return true.
104+ Value lhsIsNaN = rewriter.create <arith::CmpFOp>(
105+ op->getLoc (), arith::CmpFPredicate::UNO, lhs, lhs);
106+ Value rhsIsNaN = rewriter.create <arith::CmpFOp>(
107+ op->getLoc (), arith::CmpFPredicate::UNO, rhs, rhs);
108+ Value rhsOrResult =
109+ rewriter.create <arith::SelectOp>(op->getLoc (), lhsIsNaN, rhs, result);
110+ return rewriter.create <arith::SelectOp>(op->getLoc (), rhsIsNaN, lhs,
111+ rhsOrResult);
112+ }
113+
39114template <typename T>
40115static arith::ConstantOp
41116createConstFromIntAttribute (Operation *op, const std::string &attrName,
@@ -358,7 +433,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
358433
359434 // tosa::MaximumOp
360435 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
361- return rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
436+ auto max = rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
437+ return materializeBinaryNanCheckIfRequired (op, rewriter, args[0 ], args[1 ],
438+ max);
362439 }
363440
364441 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -367,7 +444,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
367444
368445 // tosa::MinimumOp
369446 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
370- return rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
447+ auto min = rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
448+ return materializeBinaryNanCheckIfRequired (op, rewriter, args[0 ], args[1 ],
449+ min);
371450 }
372451
373452 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -395,7 +474,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
395474 loc, elementTy, rewriter.getFloatAttr (elementTy, minApf));
396475 auto max = rewriter.create <arith::ConstantOp>(
397476 loc, elementTy, rewriter.getFloatAttr (elementTy, maxApf));
398- return clampFloatHelper (loc, args[0 ], min, max, rewriter);
477+ auto result = clampFloatHelper (loc, args[0 ], min, max, rewriter);
478+ // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
479+ // is NaN.
480+ return materializeReductionNanCheckIfRequired (op, rewriter, args[0 ], min,
481+ result);
399482 }
400483
401484 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1042,15 +1125,19 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
10421125 }
10431126
10441127 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1045- return rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
1128+ auto min = rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
1129+ return materializeReductionNanCheckIfRequired (op, rewriter, args[0 ],
1130+ args[1 ], min);
10461131 }
10471132
10481133 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
10491134 return rewriter.create <arith::MinSIOp>(loc, args[0 ], args[1 ]);
10501135 }
10511136
10521137 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1053- return rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
1138+ auto max = rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
1139+ return materializeReductionNanCheckIfRequired (op, rewriter, args[0 ],
1140+ args[1 ], max);
10541141 }
10551142
10561143 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
@@ -2078,6 +2165,32 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
20782165 nestedLoc, predicate, newValue, oldValue);
20792166 auto resultIndex = rewriter.create <arith::SelectOp>(
20802167 nestedLoc, predicate, newIndex, oldIndex);
2168+
2169+ // Check if we need to materialize compare and select for the given
2170+ // NaN propagation mode.
2171+ const auto nanMode = getNanMode (argmaxOp, rewriter);
2172+ if (!nanMode) {
2173+ didEncounterError = true ;
2174+ return ;
2175+ }
2176+
2177+ // "PROPAGATE" matches the default NaN propagation mode of the arith
2178+ // dialect so no compare and select is required.
2179+ //
2180+ // In the case "IGNORE" we check if the current argument is NaN and
2181+ // select the old index and value otherwise take the updated index and
2182+ // value.
2183+ if (*nanMode == " IGNORE" ) {
2184+ // Unordered comparison of NaN against itself will always return
2185+ // true.
2186+ Value isNaN = rewriter.create <arith::CmpFOp>(
2187+ argmaxOp.getLoc (), arith::CmpFPredicate::UNO, newValue,
2188+ newValue);
2189+ resultMax = rewriter.create <arith::SelectOp>(nestedLoc, isNaN,
2190+ oldValue, resultMax);
2191+ resultIndex = rewriter.create <arith::SelectOp>(
2192+ nestedLoc, isNaN, oldIndex, resultIndex);
2193+ }
20812194 nestedBuilder.create <linalg::YieldOp>(
20822195 nestedLoc, ValueRange ({resultIndex, resultMax}));
20832196 });
0 commit comments