@@ -287,10 +287,12 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
287287
288288 if (isa<FloatType>(inputElementType)) {
289289 // Unlike integer types, floating point types can represent infinity.
290- auto minClamp = op.getMinFp ();
291- auto maxClamp = op.getMaxFp ();
292- bool isMin = minClamp.isInfinity () && minClamp.isNegative ();
293- bool isMax = maxClamp.isInfinity () && !maxClamp.isNegative ();
290+ auto minClamp =
291+ llvm::cast<mlir::FloatAttr>(op.getMinValAttr ()).getValue ();
292+ auto maxClamp =
293+ llvm::cast<mlir::FloatAttr>(op.getMaxValAttr ()).getValue ();
294+ bool isMin = minClamp.isNegInfinity ();
295+ bool isMax = maxClamp.isInfinity ();
294296
295297 if (isMin && isMax) {
296298 rewriter.replaceOp (op, input);
@@ -300,8 +302,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
300302 }
301303
302304 if (inputElementType.isUnsignedInteger ()) {
303- int64_t minClamp = op.getMinInt ();
304- int64_t maxClamp = op.getMaxInt ();
305+ int64_t minClamp =
306+ llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ()).getUInt ();
307+ int64_t maxClamp =
308+ llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ()).getUInt ();
305309
306310 int64_t intMin =
307311 APInt::getMinValue (inputElementType.getIntOrFloatBitWidth ())
@@ -318,8 +322,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
318322 }
319323
320324 if (llvm::isa<IntegerType>(inputElementType)) {
321- int64_t minClamp = op.getMinInt ();
322- int64_t maxClamp = op.getMaxInt ();
325+ int64_t minClamp =
326+ llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ()).getInt ();
327+ int64_t maxClamp =
328+ llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ()).getInt ();
323329
324330 int64_t intMin =
325331 APInt::getSignedMinValue (inputElementType.getIntOrFloatBitWidth ())
@@ -374,9 +380,10 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
374380
375381 LogicalResult matchAndRewrite (tosa::ClampOp op,
376382 PatternRewriter &rewriter) const override {
383+ Value input = op.getInput ();
384+
377385 // Check the input to the CLAMP op is itself a CLAMP.
378- auto clampOp =
379- dyn_cast_if_present<tosa::ClampOp>(op.getInput ().getDefiningOp ());
386+ auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp ());
380387 if (!clampOp)
381388 return failure ();
382389
@@ -386,34 +393,86 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
386393 if (opNanMode == " IGNORE" && clampNanMode == " PROPAGATE" )
387394 return failure ();
388395
389- // Check we have intersecting ranges.
390- const auto opMinInt = op.getMinInt ();
391- const auto opMaxInt = op.getMaxInt ();
392- const auto clampOpMinInt = clampOp.getMinInt ();
393- const auto clampOpMaxInt = clampOp.getMaxInt ();
394- ClampRange<std::int64_t > opRangeIntRange (opMinInt, opMaxInt);
395- ClampRange<std::int64_t > clampRangeIntRange (clampOpMinInt, clampOpMaxInt);
396- if (!opRangeIntRange.intersects (clampRangeIntRange))
397- return failure ();
396+ auto maxValAttr = op.getMaxValAttr ();
397+ auto minValAttr = op.getMinValAttr ();
398+ auto clampOpMaxValAttr = clampOp.getMaxValAttr ();
399+ auto clampOpMinValAttr = clampOp.getMinValAttr ();
398400
399- const auto opMinFloat = op.getMinFp ();
400- const auto opMaxFloat = op.getMaxFp ();
401- const auto clampOpMinFloat = clampOp.getMinFp ();
402- const auto clampOpMaxFloat = clampOp.getMaxFp ();
403- ClampRange<APFloat> opRangeFloatRange (opMinFloat, opMaxFloat);
404- ClampRange<APFloat> clampRangeFloatRange (clampOpMinFloat, clampOpMaxFloat);
405- if (!opRangeFloatRange.intersects (clampRangeFloatRange))
406- return failure ();
401+ auto inputEType = llvm::cast<ShapedType>(input.getType ()).getElementType ();
402+ if (auto quantType =
403+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
404+ inputEType = quantType.getStorageType ();
405+ }
406+
407+ Attribute newMinValAttr, newMaxValAttr;
408+ if (mlir::isa<FloatType>(inputEType)) {
409+ auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
410+ auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
411+ auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
412+ auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
413+
414+ // Check we have intersecting ranges.
415+ const auto opMinFloat = floatMinValAttr.getValue ();
416+ const auto opMaxFloat = floatMaxValAttr.getValue ();
417+ const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue ();
418+ const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue ();
419+ ClampRange<APFloat> opRangeFloatRange (opMinFloat, opMaxFloat);
420+ ClampRange<APFloat> clampRangeFloatRange (clampOpMinFloat,
421+ clampOpMaxFloat);
422+ if (!opRangeFloatRange.intersects (clampRangeFloatRange))
423+ return failure ();
424+
425+ // Run the transformation.
426+ auto newMinVal = std::max (opMinFloat, clampOpMinFloat);
427+ auto newMaxVal = std::min (opMaxFloat, clampOpMaxFloat);
428+ newMinValAttr = rewriter.getFloatAttr (inputEType, newMinVal);
429+ newMaxValAttr = rewriter.getFloatAttr (inputEType, newMaxVal);
430+ } else {
431+ assert (mlir::isa<IntegerType>(inputEType));
432+ auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
433+ auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
434+ auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
435+ auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
436+
437+ if (inputEType.isUnsignedInteger ()) {
438+ // Check we have intersecting ranges.
439+ const auto opMinInt = intMinValAttr.getUInt ();
440+ const auto opMaxInt = intMaxValAttr.getUInt ();
441+ const auto clampOpMinInt = clampOpIntMinValAttr.getUInt ();
442+ const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt ();
443+ ClampRange<std::uint64_t > opRangeIntRange (opMinInt, opMaxInt);
444+ ClampRange<std::uint64_t > clampRangeIntRange (clampOpMinInt,
445+ clampOpMaxInt);
446+ if (!opRangeIntRange.intersects (clampRangeIntRange))
447+ return failure ();
448+
449+ // Run the transformation.
450+ auto newMinVal = std::max (opMinInt, clampOpMinInt);
451+ auto newMaxVal = std::min (opMaxInt, clampOpMaxInt);
452+ newMinValAttr = rewriter.getIntegerAttr (inputEType, newMinVal);
453+ newMaxValAttr = rewriter.getIntegerAttr (inputEType, newMaxVal);
454+ } else {
455+ // Check we have intersecting ranges.
456+ const auto opMinInt = intMinValAttr.getInt ();
457+ const auto opMaxInt = intMaxValAttr.getInt ();
458+ const auto clampOpMinInt = clampOpIntMinValAttr.getInt ();
459+ const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt ();
460+ ClampRange<std::int64_t > opRangeIntRange (opMinInt, opMaxInt);
461+ ClampRange<std::int64_t > clampRangeIntRange (clampOpMinInt,
462+ clampOpMaxInt);
463+ if (!opRangeIntRange.intersects (clampRangeIntRange))
464+ return failure ();
465+
466+ // Run the transformation.
467+ auto newMinVal = std::max (opMinInt, clampOpMinInt);
468+ auto newMaxVal = std::min (opMaxInt, clampOpMaxInt);
469+ newMinValAttr = rewriter.getIntegerAttr (inputEType, newMinVal);
470+ newMaxValAttr = rewriter.getIntegerAttr (inputEType, newMaxVal);
471+ }
472+ }
407473
408- // Run the transformation.
409- const auto minFp = std::max (opMinFloat, clampOpMinFloat).convertToFloat ();
410- const auto maxFp = std::min (opMaxFloat, clampOpMaxFloat).convertToFloat ();
411- const auto minInt = std::max (opMinInt, clampOpMinInt);
412- const auto maxInt = std::min (opMaxInt, clampOpMaxInt);
413474 rewriter.replaceOpWithNewOp <tosa::ClampOp>(
414- op, op.getType (), clampOp.getInput (),
415- rewriter.getI64IntegerAttr (minInt), rewriter.getI64IntegerAttr (maxInt),
416- rewriter.getF32FloatAttr (minFp), rewriter.getF32FloatAttr (maxFp),
475+ op, op.getType (), clampOp.getInput (), newMinValAttr, newMaxValAttr,
417476 rewriter.getStringAttr ((opNanMode != clampNanMode) ? " IGNORE"
418477 : opNanMode));
419478 return success ();
0 commit comments