diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 3dadf9474cf4f..30bcdfc45837a 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -311,7 +311,8 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, return success(); } -// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) +// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) +// Restricting a >= 0 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operandA = op.getOperand(0); @@ -319,21 +320,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { Type opType = operandA.getType(); Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); - Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter); - Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); - Value opASquared = b.create(opType, operandA, operandA); - Value opBHalf = b.create(opType, operandB, two); - Value logA = b.create(opType, opASquared); - Value mult = b.create(opType, opBHalf, logA); + Value logA = b.create(opType, operandA); + Value mult = b.create(opType, operandB, logA); Value expResult = b.create(opType, mult); - Value negExpResult = b.create(opType, expResult, negOne); - Value remainder = b.create(opType, operandB, two); - Value negCheck = - b.create(arith::CmpFPredicate::OLT, operandA, zero); - Value oddPower = - b.create(arith::CmpFPredicate::ONE, remainder, zero); - Value oddAndNeg = b.create(op->getLoc(), oddPower, negCheck); // First, we select between the exp value and the adjusted value for odd // powers of negatives. Then, we ensure that one is produced if `b` is zero. @@ -341,10 +331,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`. Value zeroCheck = b.create(arith::CmpFPredicate::OEQ, operandB, zero); - Value res = b.create(op->getLoc(), oddAndNeg, negExpResult, - expResult); - res = b.create(op->getLoc(), zeroCheck, one, res); - rewriter.replaceOp(op, res); + Value finalResult = + b.create(op->getLoc(), zeroCheck, one, expResult); + rewriter.replaceOp(op, finalResult); return success(); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 6055ed0504c84..5b443e9e8d4e7 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -202,25 +202,15 @@ func.func @roundf_func(%a: f32) -> f32 { // CHECK-LABEL: func @powf_func // CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64) -func.func @powf_func(%a: f64, %b: f64) ->f64 { +func.func @powf_func(%a: f64, %b: f64) -> f64 { // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00 // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0 - // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00 - // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00 - // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]] - // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]] - // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]] - // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]] - // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]] - // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]] - // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]] - // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]] - // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]] - // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]] - // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]] - // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]] - // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]] - // CHECK: return [[SEL1]] + // CHECK: [[LOGA:%.+]] = math.log [[ARG0]] + // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]] + // CHECK: [[EXP:%.+]] = math.exp [[MULB]] + // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]] + // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]] + // CHECK: return [[SEL]] %ret = math.powf %a, %b : f64 return %ret : f64 } @@ -602,26 +592,15 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t return %2 : tensor<8xf32> } // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> { -// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32> -// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32> -// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> // CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32> -// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32> -// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32> -// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32> -// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32> -// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32> -// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32> -// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32> -// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32> -// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32> -// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32> -// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1> -// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] -// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32> -// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]] -// CHECK: return %[[SEL1]] : tensor<8xf32> - +// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> +// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32> +// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32> +// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32> +// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32> +// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32> +// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32> +// CHECK: return %[[SEL]] // ----- // CHECK-LABEL: func.func @math_fpowi_to_powf_scalar @@ -630,25 +609,15 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 { return %2 : f32 } // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 { -// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32 -// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 // CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32 -// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32 -// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32 -// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32 -// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32 +// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32 +// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32 // CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32 -// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32 -// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32 -// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32 -// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32 -// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1 -// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] -// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32 -// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]] -// CHECK: return %[[SEL1]] : f32 +// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32 +// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32 +// CHECK: return %[[SEL]] : f32 // ----- diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir index 106b48a2daea2..d1916c28878b9 100644 --- a/mlir/test/mlir-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir @@ -202,11 +202,6 @@ func.func @powf() { %a_p = arith.constant 2.0 : f64 call @func_powff64(%a, %a_p) : (f64, f64) -> () - // CHECK-NEXT: -27 - %b = arith.constant -3.0 : f64 - %b_p = arith.constant 3.0 : f64 - call @func_powff64(%b, %b_p) : (f64, f64) -> () - // CHECK-NEXT: 2.343 %c = arith.constant 2.343 : f64 %c_p = arith.constant 1.000 : f64