diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index a99cf293b9eac..17ebc7dc32372 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -618,12 +618,8 @@ static Value createLinalgBodyCalculationForElementwiseOp( loc, rewriter.getIntegerAttr( getElementTypeOrSelf(dstTy), APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()))); - auto intMax = rewriter.create( - loc, rewriter.getIntegerAttr( - getElementTypeOrSelf(dstTy), - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); auto maxClamped = - rewriter.create(loc, overflow, intMax, conv); + rewriter.create(loc, overflow, intMin, conv); return rewriter.create(loc, underflow, intMin, maxClamped); } @@ -647,8 +643,13 @@ static Value createLinalgBodyCalculationForElementwiseOp( APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); + auto overflow = rewriter.create( + loc, arith::CmpFPredicate::UGT, rounded, intMaxFP); + Value maxClampedFP = + rewriter.create(loc, overflow, intMinFP, rounded); + Value clamped = - clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter); + clampFloatHelper(loc, maxClampedFP, intMinFP, intMaxFP, rewriter); return rewriter.create(loc, dstTy, clamped); } @@ -664,17 +665,17 @@ static Value createLinalgBodyCalculationForElementwiseOp( .getSExtValue()) + 1.0f)); - auto intMax = rewriter.create( + auto intMin = rewriter.create( loc, rewriter.getIntegerAttr( getElementTypeOrSelf(dstTy), - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); + APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()))); auto minClampedFP = rewriter.create(loc, rounded, intMinFP); auto minClamped = rewriter.create(loc, dstTy, minClampedFP); auto overflow = rewriter.create( loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP); - return rewriter.create(loc, overflow, intMax, + return rewriter.create(loc, overflow, intMin, minClamped); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 9ba9965315fd3..180db212b5448 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -541,13 +541,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () { // CHECK: linalg.generic // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32 - // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32 + // CHECK: [[CSTMINF:%.+]] = arith.constant -2.14748365E+9 : f32 // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32 - // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32 - // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32 + // CHECK: [[CSTMIN:%.+]] = arith.constant -2147483648 : i32 + // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMINF]] : f32 // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32 // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32 - // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32 + // CHECK: arith.select [[CMP]], [[CSTMIN]], [[CONV]] : i32 %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32> // CHECK: linalg.generic @@ -591,7 +591,9 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () { // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16 // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16 // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16 - // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16 + // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ugt, [[ROUND]], [[CSTMAX]] : f16 + // CHECK: [[CLAMPMAX:%.+]] = arith.select [[OVERFLOW]], [[CSTMIN]], [[ROUND]] : f16 + // CHECK: [[MIN:%.+]] = arith.minimumf [[CLAMPMAX]], [[CSTMAX]] : f16 // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16 // CHECK: arith.fptosi [[CLAMP]] : f16 to i8 %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8> @@ -604,8 +606,7 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () { // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16 // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16 // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32 - // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32 - // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32 + // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MININT]], [[CONV]] : i32 // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32 %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32> return @@ -1980,11 +1981,11 @@ func.func @test_dynamic_fft2d(%arg0: tensor, %arg1: tensor // CHECK: %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32 // CHECK: %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32 // CHECK: %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32 -// CHECK: %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64 +// CHECK: %[[INT_MIN:.*]] = arith.constant -9223372036854775808 : i64 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32 // CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64 // CHECK: %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32 -// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64 +// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MIN]], %[[FPTOSI]] : i64 // CHECK: linalg.yield %[[SELECT]] : i64 // CHECK: } -> tensor<1xi64> // CHECK: return %[[RESULT]] : tensor<1xi64> @@ -1995,6 +1996,34 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) { // ----- +// CHECK-LABEL: @test_simple_f64 +func.func @test_simple_f64(%arg0: tensor<1xf64>) -> () { + // CHECK: linalg.generic + // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f64 + // CHECK: [[CSTMINF:%.+]] = arith.constant -9.2233720368547758E+18 : f64 + // CHECK: [[CSTMAXP1:%.+]] = arith.constant 9.2233720368547758E+18 : f64 + // CHECK: [[CSTMIN:%.+]] = arith.constant -9223372036854775808 : i64 + // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMINF]] : f64 + // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f64 to i64 + // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f64 + // CHECK: arith.select [[CMP]], [[CSTMIN]], [[CONV]] : i64 + %0 = tosa.cast %arg0 : (tensor<1xf64>) -> tensor<1xi64> + + // CHECK: linalg.generic + // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f64 + // CHECK: [[CSTMIN:%.+]] = arith.constant 0xC1E0000000000000 : f64 + // CHECK: [[CSTMAX:%.+]] = arith.constant 0x41DFFFFFFFC00000 : f64 + // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ugt, [[ROUND]], [[CSTMAX]] : f64 + // CHECK: [[CLAMPMAX:%.+]] = arith.select [[OVERFLOW]], [[CSTMIN]], [[ROUND]] : f64 + // CHECK: [[MIN:%.+]] = arith.minimumf [[CLAMPMAX]], [[CSTMAX]] : f64 + // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f64 + // CHECK: arith.fptosi [[CLAMP]] : f64 to i32 + %1 = tosa.cast %arg0 : (tensor<1xf64>) -> tensor<1xi32> + return +} + +// ----- + // CHECK-LABEL: @reduce_min_nan_propagate func.func @reduce_min_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { // CHECK: linalg.reduce