@@ -47,8 +47,8 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
4747}
4848
4949static Value createLinalgBodyCalculationForElementwiseOp (
50- Operation *op, ValueRange args, ArrayRef<Type> resultTypes ,
51- ConversionPatternRewriter &rewriter) {
50+ Operation *op, const TypeConverter &converter, ValueRange args ,
51+ ArrayRef<Type> resultTypes, ConversionPatternRewriter &rewriter) {
5252 Location loc = op->getLoc ();
5353 auto elementTy =
5454 cast<ShapedType>(op->getOperand (0 ).getType ()).getElementType ();
@@ -61,7 +61,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
6161
6262 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
6363 auto zero = rewriter.create <arith::ConstantOp>(
64- loc, rewriter.getZeroAttr (elementTy));
64+ loc, rewriter.getZeroAttr (converter. convertType ( elementTy) ));
6565 auto neg = rewriter.create <arith::SubIOp>(loc, zero, args[0 ]);
6666 return rewriter.create <arith::MaxSIOp>(loc, args[0 ], neg);
6767 }
@@ -416,17 +416,19 @@ static Value createLinalgBodyCalculationForElementwiseOp(
416416 if (intTy.isUnsignedInteger ()) {
417417 minRepresentable = 0 ;
418418 if (intTy.getIntOrFloatBitWidth () <= 63 ) {
419- maxRepresentable = (int64_t )APInt::getMaxValue (intTy.getIntOrFloatBitWidth ())
420- .getZExtValue ();
419+ maxRepresentable =
420+ (int64_t )APInt::getMaxValue (intTy.getIntOrFloatBitWidth ())
421+ .getZExtValue ();
421422 }
422- } else if (intTy.getIntOrFloatBitWidth () <= 64 ) {
423+ } else if (intTy.getIntOrFloatBitWidth () <= 64 ) {
423424 // Ensure that min & max fit into signed n-bit constants.
424425 minRepresentable = APInt::getSignedMinValue (intTy.getIntOrFloatBitWidth ())
425- .getSExtValue ();
426+ .getSExtValue ();
426427 maxRepresentable = APInt::getSignedMaxValue (intTy.getIntOrFloatBitWidth ())
427- .getSExtValue ();
428+ .getSExtValue ();
428429 }
429- // Ensure that the bounds are representable as n-bit signed/unsigned integers.
430+ // Ensure that the bounds are representable as n-bit signed/unsigned
431+ // integers.
430432 min = std::max (min, minRepresentable);
431433 max = std::max (max, minRepresentable);
432434 min = std::min (min, maxRepresentable);
@@ -946,7 +948,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
946948 getNParallelLoopsAttrs (rank),
947949 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
948950 Value opResult = createLinalgBodyCalculationForElementwiseOp (
949- operation, blockArgs.take_front (operation->getNumOperands ()),
951+ operation, converter,
952+ blockArgs.take_front (operation->getNumOperands ()),
950953 {resultType.getElementType ()}, rewriter);
951954 if (!opResult) {
952955 encounteredError = true ;
0 commit comments