@@ -46,10 +46,9 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
4646 op->getLoc (), IntegerAttr::get (requiredAttrType, castedN));
4747}
4848
49- static Value
50- createLinalgBodyCalculationForElementwiseOp (Operation *op, ValueRange args,
51- ArrayRef<Type> resultTypes,
52- PatternRewriter &rewriter) {
49+ static Value createLinalgBodyCalculationForElementwiseOp (
50+ Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51+ ConversionPatternRewriter &rewriter) {
5352 Location loc = op->getLoc ();
5453 auto elementTy =
5554 cast<ShapedType>(op->getOperand (0 ).getType ()).getElementType ();
@@ -186,7 +185,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
186185 Value max = rewriter.create <arith::ConstantIntOp>(
187186 loc, APInt::getSignedMaxValue (inputBitWidth).getSExtValue (),
188187 intermediateType);
189- auto clamp = clampIntHelper (loc, sub, min, max, rewriter);
188+ auto clamp =
189+ clampIntHelper (loc, sub, min, max, rewriter, /* isUnsigned=*/ false );
190190
191191 // Truncate to the final value.
192192 return rewriter.create <arith::TruncIOp>(loc, elementTy, clamp);
@@ -389,25 +389,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
389389 int64_t max =
390390 cast<IntegerAttr>(op->getAttr (" max_int" )).getValue ().getSExtValue ();
391391
392+ int64_t minRepresentable = std::numeric_limits<int64_t >::min ();
393+ int64_t maxRepresentable = std::numeric_limits<int64_t >::max ();
392394 if (intTy.isUnsignedInteger ()) {
393- min = std::max (min, ( int64_t ) 0 ) ;
394- max = std::min (
395- max,
396- APInt::getMaxValue (intTy. getIntOrFloatBitWidth ()). getSExtValue () );
397- } else {
398- min =
399- std:: max(min, APInt::getSignedMinValue (intTy. getIntOrFloatBitWidth ())
400- . getSExtValue ());
401- max =
402- std::min (max, APInt::getSignedMaxValue (intTy.getIntOrFloatBitWidth ())
403- .getSExtValue ()) ;
395+ minRepresentable = 0 ;
396+ if (intTy. getIntOrFloatBitWidth () <= 63 ) {
397+ maxRepresentable = ( int64_t ) APInt::getMaxValue (intTy. getIntOrFloatBitWidth ())
398+ . getZExtValue ( );
399+ }
400+ } else if (intTy. getIntOrFloatBitWidth () <= 64 ) {
401+ // Ensure that min & max fit into signed n-bit constants.
402+ minRepresentable = APInt::getSignedMinValue (intTy. getIntOrFloatBitWidth ())
403+ . getSExtValue ();
404+ maxRepresentable = APInt::getSignedMaxValue (intTy.getIntOrFloatBitWidth ())
405+ .getSExtValue ();
404406 }
407+ // Ensure that the bounds are representable as n-bit signed/unsigned integers.
408+ min = std::max (min, minRepresentable);
409+ max = std::max (max, minRepresentable);
410+ min = std::min (min, maxRepresentable);
411+ max = std::min (max, maxRepresentable);
405412
406413 auto minVal = rewriter.create <arith::ConstantIntOp>(
407414 loc, min, intTy.getIntOrFloatBitWidth ());
408415 auto maxVal = rewriter.create <arith::ConstantIntOp>(
409416 loc, max, intTy.getIntOrFloatBitWidth ());
410- return clampIntHelper (loc, args[0 ], minVal, maxVal, rewriter);
417+ return clampIntHelper (loc, args[0 ], minVal, maxVal, rewriter,
418+ intTy.isUnsignedInteger ());
411419 }
412420
413421 // tosa::SigmoidOp
@@ -615,10 +623,9 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
615623}
616624
617625static SmallVector<Value> expandInputRanks (PatternRewriter &rewriter,
618- Location loc, Operation *operation) {
619- auto rank =
620- cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
621- return llvm::map_to_vector (operation->getOperands (), [&](Value operand) {
626+ Location loc, ValueRange operands,
627+ int64_t rank) {
628+ return llvm::map_to_vector (operands, [&](Value operand) {
622629 return expandRank (rewriter, loc, operand, rank);
623630 });
624631}
@@ -843,11 +850,16 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
843850}
844851
845852static LogicalResult
846- emitElementwiseComputation (PatternRewriter &rewriter, Location loc,
853+ emitElementwiseComputation (ConversionPatternRewriter &rewriter, Location loc,
847854 Operation *operation, ValueRange operands,
848- ArrayRef<OpFoldResult> targetShape) {
855+ ArrayRef<OpFoldResult> targetShape,
856+ const TypeConverter &converter) {
849857 // Generate output tensor
850- auto resultType = cast<RankedTensorType>(operation->getResultTypes ().front ());
858+ auto resultType = cast_or_null<RankedTensorType>(
859+ converter.convertType (operation->getResultTypes ().front ()));
860+ if (!resultType) {
861+ return rewriter.notifyMatchFailure (operation, " failed to convert type" );
862+ }
851863 Value outputTensor = rewriter.create <tensor::EmptyOp>(
852864 loc, targetShape, resultType.getElementType ());
853865
@@ -894,8 +906,9 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
894906}
895907
896908static LogicalResult
897- elementwiseMatchAndRewriteHelper (Operation *operation,
898- PatternRewriter &rewriter) {
909+ elementwiseMatchAndRewriteHelper (Operation *operation, ValueRange operands,
910+ ConversionPatternRewriter &rewriter,
911+ const TypeConverter &converter) {
899912
900913 // Collect op properties
901914 assert (operation->getNumResults () == 1 && " elementwise op expects 1 result" );
@@ -908,13 +921,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
908921 // Lower operation
909922 IndexPool indexPool;
910923 auto loc = operation->getLoc ();
911- auto expandedOperands = expandInputRanks (rewriter, loc, operation);
924+ auto rank =
925+ cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
926+ auto expandedOperands = expandInputRanks (rewriter, loc, operands, rank);
912927 auto [targetShape, masterOperands] =
913928 computeTargetShape (rewriter, loc, indexPool, expandedOperands);
914929 auto broadcastOperands = broadcastDynamicDimensions (
915930 rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
916931 return emitElementwiseComputation (rewriter, loc, operation, broadcastOperands,
917- targetShape);
932+ targetShape, converter );
918933}
919934
920935// Returns the constant initial value for a given reduction operation. The
@@ -1100,13 +1115,16 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11001115namespace {
11011116
11021117template <typename SrcOp>
1103- class PointwiseConverter : public OpRewritePattern <SrcOp> {
1118+ class PointwiseConverter : public OpConversionPattern <SrcOp> {
11041119public:
1105- using OpRewritePattern<SrcOp>::OpRewritePattern;
1120+ using OpConversionPattern<SrcOp>::OpConversionPattern;
1121+ using typename OpConversionPattern<SrcOp>::OpAdaptor;
11061122
1107- LogicalResult matchAndRewrite (SrcOp op,
1108- PatternRewriter &rewriter) const final {
1109- return elementwiseMatchAndRewriteHelper (op, rewriter);
1123+ LogicalResult
1124+ matchAndRewrite (SrcOp op, OpAdaptor operands,
1125+ ConversionPatternRewriter &rewriter) const final {
1126+ return elementwiseMatchAndRewriteHelper (
1127+ op, operands.getOperands (), rewriter, *this ->getTypeConverter ());
11101128 }
11111129};
11121130
@@ -1279,7 +1297,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
12791297 loc, nestedBuilder.getI32IntegerAttr (intMax));
12801298
12811299 value = clampIntHelper (nestedLoc, value, intMinVal, intMaxVal,
1282- nestedBuilder);
1300+ nestedBuilder, /* isUnsigned= */ false );
12831301
12841302 if (outIntType.getWidth () < 32 ) {
12851303 value = nestedBuilder.create <arith::TruncIOp>(
@@ -1643,7 +1661,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
16431661
16441662 auto offset = b.create <arith::SelectOp>(pred, one, zeroI32);
16451663 val = b.create <arith::AddIOp>(val, offset);
1646- val = clampIntHelper (loc, val, zeroI32, max, b);
1664+ val = clampIntHelper (loc, val, zeroI32, max, b, /* isUnsigned= */ false );
16471665 return b.create <arith::IndexCastOp>(b.getIndexType (), val);
16481666 };
16491667
@@ -1664,8 +1682,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
16641682 Value max, ImplicitLocOpBuilder &b) {
16651683 val0 = in;
16661684 val1 = b.create <arith::AddIOp>(val0, oneVal);
1667- val0 = clampIntHelper (loc, val0, zeroI32, max, b);
1668- val1 = clampIntHelper (loc, val1, zeroI32, max, b);
1685+ val0 =
1686+ clampIntHelper (loc, val0, zeroI32, max, b, /* isUnsigned=*/ false );
1687+ val1 =
1688+ clampIntHelper (loc, val1, zeroI32, max, b, /* isUnsigned=*/ false );
16691689 val0 = b.create <arith::IndexCastOp>(b.getIndexType (), val0);
16701690 val1 = b.create <arith::IndexCastOp>(b.getIndexType (), val1);
16711691 };
@@ -2555,7 +2575,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
25552575} // namespace
25562576
25572577void mlir::tosa::populateTosaToLinalgConversionPatterns (
2558- RewritePatternSet *patterns) {
2578+ TypeConverter &converter, RewritePatternSet *patterns) {
25592579
25602580 // We have multiple resize coverters to handle degenerate cases.
25612581 patterns->add <GenericResizeConverter>(patterns->getContext (),
@@ -2602,7 +2622,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
26022622 PointwiseConverter<tosa::CeilOp>,
26032623 PointwiseConverter<tosa::FloorOp>,
26042624 PointwiseConverter<tosa::ClampOp>,
2605- PointwiseConverter<tosa::SigmoidOp>,
2625+ PointwiseConverter<tosa::SigmoidOp>
2626+ >(converter, patterns->getContext ());
2627+
2628+ patterns->add <
26062629 IdentityNConverter<tosa::IdentityOp>,
26072630 ReduceConverter<tosa::ReduceAllOp>,
26082631 ReduceConverter<tosa::ReduceAnyOp>,
0 commit comments