@@ -191,7 +191,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
191191 Value max = rewriter.create <arith::ConstantIntOp>(
192192 loc, APInt::getSignedMaxValue (inputBitWidth).getSExtValue (),
193193 intermediateType);
194- auto clamp = clampIntHelper (loc, sub, min, max, rewriter);
194+ auto clamp =
195+ clampIntHelper (loc, sub, min, max, rewriter, /* isUnsigned=*/ false );
195196
196197 // Truncate to the final value.
197198 return rewriter.create <arith::TruncIOp>(loc, elementTy, clamp);
@@ -402,24 +403,26 @@ static Value createLinalgBodyCalculationForElementwiseOp(
402403 int64_t max =
403404 cast<IntegerAttr>(op->getAttr (" max_int" )).getValue ().getSExtValue ();
404405
406+ int64_t minRepresentable = std::numeric_limits<int64_t >::min ();
407+ int64_t maxRepresentable = std::numeric_limits<int64_t >::max ();
405408 if (intTy.isUnsignedInteger ()) {
406- if (intTy. getIntOrFloatBitWidth () > 63 ) {
407- ( void )rewriter. notifyMatchFailure (
408- op, " support for 64-bit or larger integers is not implemented " );
409- return {} ;
409+ minRepresentable = 0 ;
410+ if (intTy. getIntOrFloatBitWidth () <= 63 ) {
411+ maxRepresentable = ( int64_t ) APInt::getMaxValue (intTy. getIntOrFloatBitWidth ())
412+ . getZExtValue () ;
410413 }
411- min = std::max (min, (int64_t )0 );
412- max = std::min (max,
413- (int64_t )APInt::getMaxValue (intTy.getIntOrFloatBitWidth ())
414- .getZExtValue ());
415- } else {
416- min =
417- std::max (min, APInt::getSignedMinValue (intTy.getIntOrFloatBitWidth ())
418- .getSExtValue ());
419- max =
420- std::min (max, APInt::getSignedMaxValue (intTy.getIntOrFloatBitWidth ())
421- .getSExtValue ());
414+ } else if (intTy.getIntOrFloatBitWidth () <= 64 ) {
415+ // Ensure that min & max fit into signed n-bit constants.
416+ minRepresentable = APInt::getSignedMinValue (intTy.getIntOrFloatBitWidth ())
417+ .getSExtValue ();
418+ maxRepresentable = APInt::getSignedMaxValue (intTy.getIntOrFloatBitWidth ())
419+ .getSExtValue ();
422420 }
421+ // Ensure that the bounds are representable as n-bit signed/unsigned integers.
422+ min = std::max (min, minRepresentable);
423+ max = std::max (max, minRepresentable);
424+ min = std::min (min, maxRepresentable);
425+ max = std::min (max, maxRepresentable);
423426
424427 auto minVal = rewriter.create <arith::ConstantIntOp>(
425428 loc, min, intTy.getIntOrFloatBitWidth ());
@@ -666,10 +669,8 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
666669}
667670
668671static SmallVector<Value> expandInputRanks (PatternRewriter &rewriter,
669- Location loc, Operation *operation,
670- ValueRange operands) {
671- auto rank =
672- cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
672+ Location loc, ValueRange operands,
673+ int64_t rank) {
673674 return llvm::map_to_vector (operands, [&](Value operand) {
674675 return expandRank (rewriter, loc, operand, rank);
675676 });
@@ -898,10 +899,10 @@ static LogicalResult
898899emitElementwiseComputation (ConversionPatternRewriter &rewriter, Location loc,
899900 Operation *operation, ValueRange operands,
900901 ArrayRef<OpFoldResult> targetShape,
901- const TypeConverter * converter) {
902+ const TypeConverter & converter) {
902903 // Generate output tensor
903- auto resultType = cast_or_null<RankedTensorType>(converter-> convertType (
904- cast<RankedTensorType> (operation->getResultTypes ().front () )));
904+ auto resultType = cast_or_null<RankedTensorType>(
905+ converter. convertType (operation->getResultTypes ().front ()));
905906 if (!resultType) {
906907 return rewriter.notifyMatchFailure (operation, " failed to convert type" );
907908 }
@@ -953,7 +954,7 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
953954static LogicalResult
954955elementwiseMatchAndRewriteHelper (Operation *operation, ValueRange operands,
955956 ConversionPatternRewriter &rewriter,
956- const TypeConverter * converter) {
957+ const TypeConverter & converter) {
957958
958959 // Collect op properties
959960 assert (operation->getNumResults () == 1 && " elementwise op expects 1 result" );
@@ -966,7 +967,9 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
966967 // Lower operation
967968 IndexPool indexPool;
968969 auto loc = operation->getLoc ();
969- auto expandedOperands = expandInputRanks (rewriter, loc, operation, operands);
970+ auto rank =
971+ cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
972+ auto expandedOperands = expandInputRanks (rewriter, loc, operands, rank);
970973 auto [targetShape, masterOperands] =
971974 computeTargetShape (rewriter, loc, indexPool, expandedOperands);
972975 auto broadcastOperands = broadcastDynamicDimensions (
@@ -1173,8 +1176,8 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
11731176 LogicalResult
11741177 matchAndRewrite (SrcOp op, OpAdaptor operands,
11751178 ConversionPatternRewriter &rewriter) const final {
1176- return elementwiseMatchAndRewriteHelper (op, operands. getOperands (),
1177- rewriter, this ->getTypeConverter ());
1179+ return elementwiseMatchAndRewriteHelper (
1180+ op, operands. getOperands (), rewriter, * this ->getTypeConverter ());
11781181 }
11791182};
11801183
@@ -1398,7 +1401,7 @@ class RescaleConverter : public OpConversionPattern<tosa::RescaleOp> {
13981401 loc, nestedBuilder.getI32IntegerAttr (intMax));
13991402
14001403 value = clampIntHelper (nestedLoc, value, intMinVal, intMaxVal,
1401- nestedBuilder);
1404+ nestedBuilder, /* isUnsigned= */ false );
14021405
14031406 if (outIntType.getWidth () < 32 ) {
14041407 value = nestedBuilder.create <arith::TruncIOp>(
@@ -1772,7 +1775,7 @@ class GenericResizeConverter : public OpConversionPattern<tosa::ResizeOp> {
17721775
17731776 auto offset = b.create <arith::SelectOp>(pred, one, zeroI32);
17741777 val = b.create <arith::AddIOp>(val, offset);
1775- val = clampIntHelper (loc, val, zeroI32, max, b);
1778+ val = clampIntHelper (loc, val, zeroI32, max, b, /* isUnsigned= */ false );
17761779 return b.create <arith::IndexCastOp>(b.getIndexType (), val);
17771780 };
17781781
@@ -1793,8 +1796,10 @@ class GenericResizeConverter : public OpConversionPattern<tosa::ResizeOp> {
17931796 Value max, ImplicitLocOpBuilder &b) {
17941797 val0 = in;
17951798 val1 = b.create <arith::AddIOp>(val0, oneVal);
1796- val0 = clampIntHelper (loc, val0, zeroI32, max, b);
1797- val1 = clampIntHelper (loc, val1, zeroI32, max, b);
1799+ val0 =
1800+ clampIntHelper (loc, val0, zeroI32, max, b, /* isUnsigned=*/ false );
1801+ val1 =
1802+ clampIntHelper (loc, val1, zeroI32, max, b, /* isUnsigned=*/ false );
17981803 val0 = b.create <arith::IndexCastOp>(b.getIndexType (), val0);
17991804 val1 = b.create <arith::IndexCastOp>(b.getIndexType (), val1);
18001805 };
@@ -2760,7 +2765,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
27602765 PointwiseConverter<tosa::CeilOp>,
27612766 PointwiseConverter<tosa::FloorOp>,
27622767 PointwiseConverter<tosa::ClampOp>,
2763- PointwiseConverter<tosa::SigmoidOp>,
2768+ PointwiseConverter<tosa::SigmoidOp>
2769+ >(converter, patterns->getContext ());
2770+
2771+ patterns->add <
27642772 IdentityNConverter<tosa::IdentityOp>,
27652773 ReduceConverter<tosa::ReduceAllOp>,
27662774 ReduceConverter<tosa::ReduceAnyOp>,
0 commit comments