3636using namespace mlir ;
3737using namespace mlir ::tosa;
3838
39+ // Helper function to materialize the semantically correct compare and select
40+ // operations a binary operation with a specific NaN propagation mode.
41+ //
42+ // In the case of "PROPAGATE" semantics no compare and selection is required and
43+ // this function does nothing.
44+ //
45+ // In the case of "IGNORE" semantics this function materializes a comparison of
46+ // the current operands to the op which will return true for any NaN
47+ // argument and then selects between the non-NaN operation argument and the
48+ // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
49+ // code:
50+ //
51+ // binary<op>(lhs, rhs):
52+ // result = op(lhs, rhs)
53+ // if lhs == NaN return rhs
54+ // if rhs == NaN return lhs
55+ // return result
56+ static Value materializeBinaryNanCheckIfRequired (Operation *op,
57+ PatternRewriter &rewriter,
58+ Value lhs, Value rhs,
59+ Value result) {
60+ const auto nanMode = getNanMode (op, rewriter);
61+ if (!nanMode)
62+ return {};
63+
64+ if (*nanMode == " PROPAGATE" )
65+ return result;
66+
67+ assert (*nanMode == " IGNORE" && " Unhandled nan-propagation mode" );
68+
69+ // Unordered comparison of NaN against itself will always return true.
70+ Value lhsIsNaN = rewriter.create <arith::CmpFOp>(
71+ op->getLoc (), arith::CmpFPredicate::UNO, lhs, lhs);
72+ Value rhsIsNaN = rewriter.create <arith::CmpFOp>(
73+ op->getLoc (), arith::CmpFPredicate::UNO, rhs, rhs);
74+ Value rhsOrResult =
75+ rewriter.create <arith::SelectOp>(op->getLoc (), lhsIsNaN, rhs, result);
76+ return rewriter.create <arith::SelectOp>(op->getLoc (), rhsIsNaN, lhs,
77+ rhsOrResult);
78+ }
79+
3980template <typename T>
4081static arith::ConstantOp
4182createConstFromIntAttribute (Operation *op, const std::string &attrName,
@@ -367,7 +408,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
367408
368409 // tosa::MaximumOp
369410 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
370- return rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
411+ auto max = rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
412+ return materializeBinaryNanCheckIfRequired (op, rewriter, args[0 ], args[1 ],
413+ max);
371414 }
372415
373416 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -376,7 +419,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
376419
377420 // tosa::MinimumOp
378421 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
379- return rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
422+ auto min = rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
423+ return materializeBinaryNanCheckIfRequired (op, rewriter, args[0 ], args[1 ],
424+ min);
380425 }
381426
382427 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -404,7 +449,34 @@ static Value createLinalgBodyCalculationForElementwiseOp(
404449 loc, elementTy, rewriter.getFloatAttr (elementTy, minApf));
405450 auto max = rewriter.create <arith::ConstantOp>(
406451 loc, elementTy, rewriter.getFloatAttr (elementTy, maxApf));
407- return clampFloatHelper (loc, args[0 ], min, max, rewriter);
452+ auto result = clampFloatHelper (loc, args[0 ], min, max, rewriter);
453+
454+ const auto nanMode = getNanMode (op, rewriter);
455+ if (!nanMode)
456+ return {};
457+
458+ // In the case of "PROPAGATE" semantics no compare and selection is
459+ // required.
460+ if (*nanMode == " PROPAGATE" )
461+ return result;
462+
463+ // In the case of "IGNORE" semantics materialize a comparison
464+ // of the current operand to the reduction which will return true for a NaN
465+ // argument and then selects between the initial reduction value and the
466+ // calculated result based on whether the argument is NaN or not. In pseudo
467+ // code:
468+ //
469+ // reduce<op>(x, init):
470+ // result = op(init, x)
471+ // return init if x == NaN else result
472+ assert (*nanMode == " IGNORE" && " Unhandled nan-propagation mode" );
473+
474+ // Unordered comparison of NaN against itself will always return true.
475+ Value isNaN = rewriter.create <arith::CmpFOp>(
476+ op->getLoc (), arith::CmpFPredicate::UNO, args[0 ], args[0 ]);
477+ // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
478+ // is NaN.
479+ return rewriter.create <arith::SelectOp>(op->getLoc (), isNaN, min, result);
408480 }
409481
410482 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1096,6 +1168,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
10961168 }
10971169 }
10981170
1171+ SmallVector<Value> inputs, outputs;
1172+ inputs.push_back (input);
1173+
10991174 // First fill the output buffer with the init value.
11001175 auto emptyTensor =
11011176 rewriter
@@ -1113,26 +1188,124 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11131188 .create <linalg::FillOp>(loc, ValueRange{fillValue},
11141189 ValueRange{emptyTensor})
11151190 .result ();
1191+ outputs.push_back (filledTensor);
1192+
1193+ const auto nanMode = getNanMode (op, rewriter);
1194+ const bool isNanIgnoreMode = (nanMode && *nanMode == " IGNORE" );
1195+ if (isNanIgnoreMode) {
1196+ // Because the TOSA spec requires the result be NaN iff all elements in the
1197+ // reduction are NaN we can't simply perform a compare and select.
1198+ // Additionally we have to keep track of whether we've seen any non-NaN
1199+ // values and then do a final select based on this predicate.
1200+ auto trueAttr = rewriter.getBoolAttr (true );
1201+ auto trueValue = rewriter.create <arith::ConstantOp>(loc, trueAttr);
1202+ auto emptyBoolTensor =
1203+ rewriter
1204+ .create <tensor::EmptyOp>(loc, reduceShape, trueValue.getType (),
1205+ dynDims)
1206+ .getResult ();
1207+ auto allResultsNaNTensor =
1208+ rewriter
1209+ .create <linalg::FillOp>(loc, ValueRange{trueValue},
1210+ ValueRange{emptyBoolTensor})
1211+ .result ();
1212+ // Note that because the linalg::ReduceOp has two variadic arguments (inputs
1213+ // and outputs) and it has the SameVariadicOperandSize trait we need to have
1214+ // the same number of inputs and outputs.
1215+ //
1216+ // The second input isn't actully used anywhere since the value used to
1217+ // update the NaN flag is calculated inside the body of the reduction and
1218+ // then used to update an out value.
1219+ // In order to satisfy type constraints we just pass another copy of the
1220+ // input here.
1221+ inputs.push_back (input);
1222+ outputs.push_back (allResultsNaNTensor);
1223+ }
11161224
11171225 bool didEncounterError = false ;
1118- auto linalgOp = rewriter.create <linalg::ReduceOp>(
1119- loc, input, filledTensor , axis,
1226+ linalg::LinalgOp linalgOp = rewriter.create <linalg::ReduceOp>(
1227+ loc, inputs, outputs , axis,
11201228 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1229+ std::array<Value, 2 > binaryArgs{
1230+ blockArgs[0 ], isNanIgnoreMode ? blockArgs[2 ] : blockArgs[1 ]};
11211231 auto result = createLinalgBodyCalculationForReduceOp (
1122- op, blockArgs , elementTy, rewriter);
1232+ op, binaryArgs , elementTy, rewriter);
11231233 if (result)
11241234 didEncounterError = true ;
11251235
1126- nestedBuilder.create <linalg::YieldOp>(loc, result);
1236+ SmallVector<Value> resultsToYield;
1237+ if (isNanIgnoreMode) {
1238+ auto inputValue = blockArgs[0 ];
1239+ auto initialValue = blockArgs[2 ];
1240+ auto oldAllResultsNanFlagValue = blockArgs[3 ];
1241+
1242+ // Unordered comparison of NaN against itself will always return true.
1243+ Value isNaN = nestedBuilder.create <arith::CmpFOp>(
1244+ op->getLoc (), arith::CmpFPredicate::UNO, inputValue, inputValue);
1245+ // If we've encountered a NaN, take the non-NaN value.
1246+ auto selectOp = nestedBuilder.create <arith::SelectOp>(
1247+ op->getLoc (), isNaN, initialValue, result);
1248+ // Update the flag which keeps track of whether we have seen a non-NaN
1249+ // value.
1250+ auto newAllResultsNanFlagValue = nestedBuilder.create <arith::AndIOp>(
1251+ op->getLoc (), oldAllResultsNanFlagValue, isNaN);
1252+ resultsToYield.push_back (selectOp);
1253+ resultsToYield.push_back (newAllResultsNanFlagValue);
1254+ } else {
1255+ resultsToYield.push_back (result);
1256+ }
1257+ nestedBuilder.create <linalg::YieldOp>(loc, resultsToYield);
11271258 });
11281259
11291260 if (!didEncounterError)
11301261 return rewriter.notifyMatchFailure (
11311262 op, " unable to create linalg.generic body for reduce op" );
11321263
1264+ if (isNanIgnoreMode) {
1265+ // Materialize a check to see whether we encountered any non-NaN values, if
1266+ // we didn't we need to select a tensor of NaNs since the result will just
1267+ // be the initial identity value propagated through all the compares and
1268+ // selects inside the reduction.
1269+
1270+ // Create a tensor full of NaNs.
1271+ auto nanValueAttr = rewriter.getFloatAttr (
1272+ elementTy,
1273+ APFloat::getNaN (cast<FloatType>(elementTy).getFloatSemantics (), false ));
1274+ auto nanValue = rewriter.create <arith::ConstantOp>(loc, nanValueAttr);
1275+ auto emptyNanTensor =
1276+ rewriter
1277+ .create <tensor::EmptyOp>(loc, reduceShape,
1278+ resultTy.getElementType (), dynDims)
1279+ .getResult ();
1280+ auto nanFilledTensor =
1281+ rewriter
1282+ .create <linalg::FillOp>(loc, ValueRange{nanValue},
1283+ ValueRange{emptyNanTensor})
1284+ .result ();
1285+
1286+ // Create an empty tensor, non need to fill this since it will be
1287+ // overwritten by the select.
1288+ auto finalEmptyTensor =
1289+ rewriter
1290+ .create <tensor::EmptyOp>(loc, reduceShape,
1291+ resultTy.getElementType (), dynDims)
1292+ .getResult ();
1293+
1294+ // Do a selection between the tensors akin to:
1295+ // result = NaN if "all results NaN" else result.
1296+ SmallVector<Value> ins, outs;
1297+ ins.push_back (linalgOp->getOpResult (1 ));
1298+ ins.push_back (nanFilledTensor);
1299+ ins.push_back (linalgOp->getResult (0 ));
1300+ outs.push_back (finalEmptyTensor);
1301+ auto linalgSelect =
1302+ rewriter.create <linalg::SelectOp>(op->getLoc (), ins, outs);
1303+ linalgOp = linalgSelect;
1304+ }
1305+
11331306 SmallVector<ReassociationExprs, 4 > reassociationMap;
11341307 uint64_t expandInputRank =
1135- cast<ShapedType>(linalgOp. getResults ()[0 ].getType ()).getRank ();
1308+ cast<ShapedType>(linalgOp-> getResults ()[0 ].getType ()).getRank ();
11361309 reassociationMap.resize (expandInputRank);
11371310
11381311 for (uint64_t i = 0 ; i < expandInputRank; i++) {
@@ -1151,7 +1324,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11511324 // not have access to such information. This matters when handling dynamically
11521325 // sized tensors.
11531326 rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
1154- op, resultTy, linalgOp. getResults ()[0 ], reassociationMap);
1327+ op, resultTy, linalgOp-> getResults ()[0 ], reassociationMap);
11551328 return success ();
11561329}
11571330
@@ -2088,6 +2261,32 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
20882261 nestedLoc, predicate, newValue, oldValue);
20892262 auto resultIndex = rewriter.create <arith::SelectOp>(
20902263 nestedLoc, predicate, newIndex, oldIndex);
2264+
2265+ // Check if we need to materialize compare and select for the given
2266+ // NaN propagation mode.
2267+ const auto nanMode = getNanMode (argmaxOp, rewriter);
2268+ if (!nanMode) {
2269+ didEncounterError = true ;
2270+ return ;
2271+ }
2272+
2273+ // "PROPAGATE" matches the default NaN propagation mode of the arith
2274+ // dialect so no compare and select is required.
2275+ //
2276+ // In the case "IGNORE" we check if the current argument is NaN and
2277+ // select the old index and value otherwise take the updated index and
2278+ // value.
2279+ if (*nanMode == " IGNORE" ) {
2280+ // Unordered comparison of NaN against itself will always return
2281+ // true.
2282+ Value isNaN = rewriter.create <arith::CmpFOp>(
2283+ argmaxOp.getLoc (), arith::CmpFPredicate::UNO, newValue,
2284+ newValue);
2285+ resultMax = rewriter.create <arith::SelectOp>(nestedLoc, isNaN,
2286+ oldValue, resultMax);
2287+ resultIndex = rewriter.create <arith::SelectOp>(
2288+ nestedLoc, isNaN, oldIndex, resultIndex);
2289+ }
20912290 nestedBuilder.create <linalg::YieldOp>(
20922291 nestedLoc, ValueRange ({resultIndex, resultMax}));
20932292 });
0 commit comments