Skip to content

Commit 1e53fb5

Browse files
committed
Fixed type conversion for tosa.abs when lowering to linalg
1 parent 37f5d68 commit 1e53fb5

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
4747
}
4848

4949
static Value createLinalgBodyCalculationForElementwiseOp(
50-
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51-
ConversionPatternRewriter &rewriter) {
50+
Operation *op, const TypeConverter &converter, ValueRange args,
51+
ArrayRef<Type> resultTypes, ConversionPatternRewriter &rewriter) {
5252
Location loc = op->getLoc();
5353
auto elementTy =
5454
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
@@ -61,7 +61,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
6161

6262
if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
6363
auto zero = rewriter.create<arith::ConstantOp>(
64-
loc, rewriter.getZeroAttr(elementTy));
64+
loc, rewriter.getZeroAttr(converter.convertType(elementTy)));
6565
auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
6666
return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
6767
}
@@ -416,17 +416,19 @@ static Value createLinalgBodyCalculationForElementwiseOp(
416416
if (intTy.isUnsignedInteger()) {
417417
minRepresentable = 0;
418418
if (intTy.getIntOrFloatBitWidth() <= 63) {
419-
maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
420-
.getZExtValue();
419+
maxRepresentable =
420+
(int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
421+
.getZExtValue();
421422
}
422-
} else if(intTy.getIntOrFloatBitWidth() <= 64) {
423+
} else if (intTy.getIntOrFloatBitWidth() <= 64) {
423424
// Ensure that min & max fit into signed n-bit constants.
424425
minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
425-
.getSExtValue();
426+
.getSExtValue();
426427
maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
427-
.getSExtValue();
428+
.getSExtValue();
428429
}
429-
// Ensure that the bounds are representable as n-bit signed/unsigned integers.
430+
// Ensure that the bounds are representable as n-bit signed/unsigned
431+
// integers.
430432
min = std::max(min, minRepresentable);
431433
max = std::max(max, minRepresentable);
432434
min = std::min(min, maxRepresentable);
@@ -946,7 +948,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
946948
getNParallelLoopsAttrs(rank),
947949
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
948950
Value opResult = createLinalgBodyCalculationForElementwiseOp(
949-
operation, blockArgs.take_front(operation->getNumOperands()),
951+
operation, converter,
952+
blockArgs.take_front(operation->getNumOperands()),
950953
{resultType.getElementType()}, rewriter);
951954
if (!opResult) {
952955
encounteredError = true;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,3 +2010,12 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
20102010
%output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = true} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
20112011
return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
20122012
}
2013+
2014+
// -----
2015+
// CHECK-LABEL: @test_abs_conversion
2016+
// CHECK: linalg.generic
2017+
// CHECK: arith.constant 0 : i64
2018+
func.func @test_abs_conversion(%arg0: tensor<9xui64> {func.orig_type = tensor<9xui64>, onnx.name = "in0"}) -> (tensor<9xui64> {func.orig_type = tensor<9xui64>, onnx.name = "out0"}) {
2019+
%0 = tosa.abs %arg0 : (tensor<9xui64>) -> tensor<9xui64>
2020+
return %0 : tensor<9xui64>
2021+
}

0 commit comments

Comments
 (0)