Skip to content

Commit f8585ea

Browse files
author
Alaa Ali
committed
tosa.cast: fix answer mismatch to cast f64/f32 max value to i64/i32
1 parent 2c8b2dc commit f8585ea

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -618,12 +618,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
618618
loc, rewriter.getIntegerAttr(
619619
getElementTypeOrSelf(dstTy),
620620
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
621-
auto intMax = rewriter.create<arith::ConstantOp>(
622-
loc, rewriter.getIntegerAttr(
623-
getElementTypeOrSelf(dstTy),
624-
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
625621
auto maxClamped =
626-
rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
622+
rewriter.create<arith::SelectOp>(loc, overflow, intMin, conv);
627623
return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
628624
maxClamped);
629625
}
@@ -647,8 +643,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
647643
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
648644
.getSExtValue()));
649645

646+
auto overflow = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, rounded, intMaxFP);
647+
Value maxClampedFP = rewriter.create<arith::SelectOp>(loc, overflow, intMinFP, rounded);
648+
650649
Value clamped =
651-
clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
650+
clampFloatHelper(loc, maxClampedFP, intMinFP, intMaxFP, rewriter);
652651
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
653652
}
654653

@@ -664,17 +663,17 @@ static Value createLinalgBodyCalculationForElementwiseOp(
664663
.getSExtValue()) +
665664
1.0f));
666665

667-
auto intMax = rewriter.create<arith::ConstantOp>(
666+
auto intMin = rewriter.create<arith::ConstantOp>(
668667
loc, rewriter.getIntegerAttr(
669668
getElementTypeOrSelf(dstTy),
670-
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
669+
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
671670
auto minClampedFP =
672671
rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
673672
auto minClamped =
674673
rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
675674
auto overflow = rewriter.create<arith::CmpFOp>(
676675
loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
677-
return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
676+
return rewriter.create<arith::SelectOp>(loc, overflow, intMin,
678677
minClamped);
679678
}
680679

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -541,13 +541,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
541541

542542
// CHECK: linalg.generic
543543
// CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
544-
// CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
544+
// CHECK: [[CSTMINF:%.+]] = arith.constant -2.14748365E+9 : f32
545545
// CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
546-
// CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
547-
// CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
546+
// CHECK: [[CSTMIN:%.+]] = arith.constant -2147483648 : i32
547+
// CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMINF]] : f32
548548
// CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
549549
// CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
550-
// CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
550+
// CHECK: arith.select [[CMP]], [[CSTMIN]], [[CONV]] : i32
551551
%20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
552552

553553
// CHECK: linalg.generic
@@ -591,7 +591,9 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
591591
// CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
592592
// CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
593593
// CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
594-
// CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
594+
// CHECK: [[OVERFLOW:%.+]] = arith.cmpf ugt, [[ROUND]], [[CSTMAX]] : f16
595+
// CHECK: [[CLAMPMAX:%.+]] = arith.select [[OVERFLOW]], [[CSTMIN]], [[ROUND]] : f16
596+
// CHECK: [[MIN:%.+]] = arith.minimumf [[CLAMPMAX]], [[CSTMAX]] : f16
595597
// CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
596598
// CHECK: arith.fptosi [[CLAMP]] : f16 to i8
597599
%1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
@@ -604,8 +606,7 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
604606
// CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
605607
// CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
606608
// CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
607-
// CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
608-
// CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
609+
// CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MININT]], [[CONV]] : i32
609610
// CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
610611
%2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
611612
return
@@ -1980,11 +1981,11 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
19801981
// CHECK: %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32
19811982
// CHECK: %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32
19821983
// CHECK: %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32
1983-
// CHECK: %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64
1984+
// CHECK: %[[INT_MIN:.*]] = arith.constant -9223372036854775808 : i64
19841985
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32
19851986
// CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64
19861987
// CHECK: %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32
1987-
// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64
1988+
// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MIN]], %[[FPTOSI]] : i64
19881989
// CHECK: linalg.yield %[[SELECT]] : i64
19891990
// CHECK: } -> tensor<1xi64>
19901991
// CHECK: return %[[RESULT]] : tensor<1xi64>

0 commit comments

Comments
 (0)