From 0e790b6f4a51ba7ab3e7a805e6141108036bab0a Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Wed, 29 Jan 2025 12:56:43 +0900 Subject: [PATCH 01/14] [mlir][math]Update `convertPowfOp` `ExpandPatterns.cpp` (#124402) The current implementation of `convertPowfOp` requires a calculation of `a * a` but, max\ ~= 65,504, and if `a` is about 16, it will overflow so get INF in fp8 or fp16 easily. Remove support when `a < 0`. Overhead of handling negative value of `a` is large and easy to overflow; - related issue in iree: https://github.com/iree-org/iree/issues/15936 --- .../Math/Transforms/ExpandPatterns.cpp | 25 ++----- mlir/test/Dialect/Math/expand-math.mlir | 71 ++++++------------- .../mlir-runner/test-expand-math-approx.mlir | 5 -- 3 files changed, 27 insertions(+), 74 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 3dadf9474cf4f..30bcdfc45837a 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -311,7 +311,8 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, return success(); } -// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) +// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) +// Restricting a >= 0 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operandA = op.getOperand(0); @@ -319,21 +320,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { Type opType = operandA.getType(); Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); - Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter); - Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); - Value opASquared = b.create(opType, operandA, operandA); - Value opBHalf = b.create(opType, operandB, two); - Value logA = b.create(opType, opASquared); - Value mult = b.create(opType, opBHalf, logA); + Value logA = b.create(opType, operandA); + Value mult = b.create(opType, operandB, logA); Value expResult = b.create(opType, mult); - Value negExpResult = b.create(opType, expResult, negOne); - Value remainder = b.create(opType, operandB, two); - Value negCheck = - b.create(arith::CmpFPredicate::OLT, operandA, zero); - Value oddPower = - b.create(arith::CmpFPredicate::ONE, remainder, zero); - Value oddAndNeg = b.create(op->getLoc(), oddPower, negCheck); // First, we select between the exp value and the adjusted value for odd // powers of negatives. Then, we ensure that one is produced if `b` is zero. @@ -341,10 +331,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`. Value zeroCheck = b.create(arith::CmpFPredicate::OEQ, operandB, zero); - Value res = b.create(op->getLoc(), oddAndNeg, negExpResult, - expResult); - res = b.create(op->getLoc(), zeroCheck, one, res); - rewriter.replaceOp(op, res); + Value finalResult = + b.create(op->getLoc(), zeroCheck, one, expResult); + rewriter.replaceOp(op, finalResult); return success(); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 6055ed0504c84..5b443e9e8d4e7 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -202,25 +202,15 @@ func.func @roundf_func(%a: f32) -> f32 { // CHECK-LABEL: func @powf_func // CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64) -func.func @powf_func(%a: f64, %b: f64) ->f64 { +func.func @powf_func(%a: f64, %b: f64) -> f64 { // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00 // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0 - // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00 - // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00 - // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]] - // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]] - // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]] - // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]] - // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]] - // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]] - // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]] - // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]] - // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]] - // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]] - // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]] - // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]] - // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]] - // CHECK: return [[SEL1]] + // CHECK: [[LOGA:%.+]] = math.log [[ARG0]] + // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]] + // CHECK: [[EXP:%.+]] = math.exp [[MULB]] + // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]] + // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]] + // CHECK: return [[SEL]] %ret = math.powf %a, %b : f64 return %ret : f64 } @@ -602,26 +592,15 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t return %2 : tensor<8xf32> } // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> { -// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32> -// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32> -// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> // CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32> -// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32> -// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32> -// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32> -// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32> -// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32> -// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32> -// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32> -// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32> -// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32> -// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32> -// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1> -// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] -// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32> -// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]] -// CHECK: return %[[SEL1]] : tensor<8xf32> - +// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> +// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32> +// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32> +// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32> +// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32> +// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32> +// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32> +// CHECK: return %[[SEL]] // ----- // CHECK-LABEL: func.func @math_fpowi_to_powf_scalar @@ -630,25 +609,15 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 { return %2 : f32 } // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 { -// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32 -// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 // CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32 -// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32 -// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32 -// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32 -// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32 +// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32 +// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32 // CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32 -// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32 -// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32 -// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32 -// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32 -// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1 -// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] -// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32 -// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]] -// CHECK: return %[[SEL1]] : f32 +// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32 +// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32 +// CHECK: return %[[SEL]] : f32 // ----- diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir index 106b48a2daea2..d1916c28878b9 100644 --- a/mlir/test/mlir-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir @@ -202,11 +202,6 @@ func.func @powf() { %a_p = arith.constant 2.0 : f64 call @func_powff64(%a, %a_p) : (f64, f64) -> () - // CHECK-NEXT: -27 - %b = arith.constant -3.0 : f64 - %b_p = arith.constant 3.0 : f64 - call @func_powff64(%b, %b_p) : (f64, f64) -> () - // CHECK-NEXT: 2.343 %c = arith.constant 2.343 : f64 %c_p = arith.constant 1.000 : f64 From b248c2275c9d499695b3d63a96e65fcce88e9689 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Sat, 8 Feb 2025 12:51:18 +0900 Subject: [PATCH 02/14] add special cases for handling powf --- .../Math/Transforms/ExpandPatterns.cpp | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 30bcdfc45837a..235ea38dd87d1 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -17,8 +17,13 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/Support/LogicalResult.h" +#include using namespace mlir; @@ -311,6 +316,90 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, return success(); } +// Convert Powf(float a, float b) for some special cases +// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0 +static LogicalResult convertSpecialPowfOp(math::PowFOp op, + PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operandA = op.getOperand(0); + Value operandB = op.getOperand(1); + auto baseType = operandB.getType(); + + auto &sem = dyn_cast(getElementTypeOrSelf(baseType)) + .getFloatSemantics(); + + auto valueB = APFloat(sem); + if (!matchPattern(operandB, m_ConstantFloat(&valueB))) { + // Not a constant, return failure + return failure(); + } + float floatValueB = valueB.convertToFloat(); + + if (floatValueB == 1.0f) { + // a^1 -> a + rewriter.replaceOp(op, operandA); + return success(); + } + + if (floatValueB == 0.0) { + // a^0 -> 1 + Value one = + createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); + rewriter.replaceOp(op, one); + return success(); + } + + if (floatValueB == 0.5f) { + // a^(1/2) -> sqrt(a) + Value sqrt = b.create(operandA); + rewriter.replaceOp(op, sqrt); + return success(); + } + + if (floatValueB == -0.5f) { + // a^(-1/2) -> 1 / sqrt(a) + Value rsqrt = b.create(operandA); + rewriter.replaceOp(op, rsqrt); + return success(); + } + + if (floatValueB == -1.0f) { + // a^(-1) -> 1 / a + Value one = + createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); + Value div = b.create(one, operandA); + rewriter.replaceOp(op, div); + return success(); + } + + // Check if the power is an integer + if (floatValueB != std::floor(floatValueB)) { + // We don't handle non-integer powers here, return failure + return failure(); + } + + auto sign = std::signbit(floatValueB) ? -1 : 1; + auto absIntValueB = std::abs(static_cast(floatValueB)); + + auto cstOne = + createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); + auto base = operandA; + if (sign == -1) { + base = b.create(cstOne, base); + } + auto current = base; + auto result = cstOne; + while (absIntValueB > 0) { + if (absIntValueB & 1) { + result = b.create(result, current); + } + current = b.create(current, current); + absIntValueB >>= 1; + } + rewriter.replaceOp(op, result); + return success(); +} + // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) // Restricting a >= 0 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { @@ -649,6 +738,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { } void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { + patterns.add(convertSpecialPowfOp); patterns.add(convertPowfOp); } From 0e7dc199d7ee765ded899c753c12724ae21db96e Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Tue, 11 Feb 2025 11:57:58 +0900 Subject: [PATCH 03/14] add test --- .../Math/Transforms/ExpandPatterns.cpp | 89 ++++++------------- mlir/test/Dialect/Math/expand-math.mlir | 24 ++--- .../mlir-runner/test-expand-math-approx.mlir | 64 +++++++------ 3 files changed, 73 insertions(+), 104 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 235ea38dd87d1..9ad1ac2308838 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -9,7 +9,6 @@ // This file implements expansion of various math operations. // //===----------------------------------------------------------------------===// - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" @@ -316,13 +315,14 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, return success(); } -// Convert Powf(float a, float b) for some special cases -// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0 +// Convert Powf(float a, float b) for special cases when b is constant: +// when b == 0, or |b| == 0.5, 1.0, or 2.0. static LogicalResult convertSpecialPowfOp(math::PowFOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operandA = op.getOperand(0); Value operandB = op.getOperand(1); + auto opType = operandA.getType(); auto baseType = operandB.getType(); auto &sem = dyn_cast(getElementTypeOrSelf(baseType)) @@ -334,95 +334,64 @@ static LogicalResult convertSpecialPowfOp(math::PowFOp op, return failure(); } float floatValueB = valueB.convertToFloat(); - + if (floatValueB == 0.0f) { + // a^0 -> 1 + Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); + rewriter.replaceOp(op, one); + return success(); + } if (floatValueB == 1.0f) { // a^1 -> a rewriter.replaceOp(op, operandA); return success(); } - - if (floatValueB == 0.0) { - // a^0 -> 1 - Value one = - createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); - rewriter.replaceOp(op, one); + if (floatValueB == -1.0f) { + // a^(-1) -> 1 / a + Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); + Value div = b.create(one, operandA); + rewriter.replaceOp(op, div); return success(); } - if (floatValueB == 0.5f) { // a^(1/2) -> sqrt(a) Value sqrt = b.create(operandA); rewriter.replaceOp(op, sqrt); return success(); } - if (floatValueB == -0.5f) { // a^(-1/2) -> 1 / sqrt(a) Value rsqrt = b.create(operandA); rewriter.replaceOp(op, rsqrt); return success(); } - - if (floatValueB == -1.0f) { - // a^(-1) -> 1 / a + if (floatValueB == 2.0f) { + // a^2 -> a * a + Value mul = b.create(operandA, operandA); + rewriter.replaceOp(op, mul); + return success(); + } + if (floatValueB == -2.0f) { + // a^(-2) -> 1 / (a * a) + Value mul = b.create(operandA, operandA); Value one = createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); - Value div = b.create(one, operandA); + Value div = b.create(one, mul); rewriter.replaceOp(op, div); return success(); } - // Check if the power is an integer - if (floatValueB != std::floor(floatValueB)) { - // We don't handle non-integer powers here, return failure - return failure(); - } - - auto sign = std::signbit(floatValueB) ? -1 : 1; - auto absIntValueB = std::abs(static_cast(floatValueB)); - - auto cstOne = - createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); - auto base = operandA; - if (sign == -1) { - base = b.create(cstOne, base); - } - auto current = base; - auto result = cstOne; - while (absIntValueB > 0) { - if (absIntValueB & 1) { - result = b.create(result, current); - } - current = b.create(current, current); - absIntValueB >>= 1; - } - rewriter.replaceOp(op, result); - return success(); + return failure(); } // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) -// Restricting a >= 0 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operandA = op.getOperand(0); Value operandB = op.getOperand(1); - Type opType = operandA.getType(); - Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); - Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); - - Value logA = b.create(opType, operandA); - Value mult = b.create(opType, operandB, logA); - Value expResult = b.create(opType, mult); - - // First, we select between the exp value and the adjusted value for odd - // powers of negatives. Then, we ensure that one is produced if `b` is zero. - // This corresponds to `libm` behavior, even for `0^0`. Without this check, - // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`. - Value zeroCheck = - b.create(arith::CmpFPredicate::OEQ, operandB, zero); - Value finalResult = - b.create(op->getLoc(), zeroCheck, one, expResult); - rewriter.replaceOp(op, finalResult); + Value logA = b.create(operandA); + Value mult = b.create(operandB, logA); + Value expResult = b.create(mult); + rewriter.replaceOp(op, expResult); return success(); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 5b443e9e8d4e7..3cf372ea0cf50 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -203,14 +203,10 @@ func.func @roundf_func(%a: f32) -> f32 { // CHECK-LABEL: func @powf_func // CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64) func.func @powf_func(%a: f64, %b: f64) -> f64 { - // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00 - // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0 - // CHECK: [[LOGA:%.+]] = math.log [[ARG0]] - // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]] - // CHECK: [[EXP:%.+]] = math.exp [[MULB]] - // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]] - // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]] - // CHECK: return [[SEL]] + // CHECK: [[LOGA:%.+]] = math.log [[ARG0]] : f64 + // CHECK: [[MUL:%.+]] = arith.mulf [[ARG1]], [[LOGA]] : f64 + // CHECK: [[EXP:%.+]] = math.exp [[MUL]] : f64 + // CHECK: return [[EXP]] : f64 %ret = math.powf %a, %b : f64 return %ret : f64 } @@ -592,15 +588,11 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t return %2 : tensor<8xf32> } // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> { -// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32> -// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> // CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32> // CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32> // CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32> // CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32> -// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32> -// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32> -// CHECK: return %[[SEL]] +// CHECK: return %[[EXP]] : tensor<8xf32> // ----- // CHECK-LABEL: func.func @math_fpowi_to_powf_scalar @@ -609,15 +601,11 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 { return %2 : f32 } // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 { -// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32 // CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32 // CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32 // CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32 -// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32 -// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32 -// CHECK: return %[[SEL]] : f32 +// CHECK: return %[[EXP]] : f32 // ----- diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir index d1916c28878b9..b599c9d8435d4 100644 --- a/mlir/test/mlir-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir @@ -203,49 +203,61 @@ func.func @powf() { call @func_powff64(%a, %a_p) : (f64, f64) -> () // CHECK-NEXT: 2.343 - %c = arith.constant 2.343 : f64 - %c_p = arith.constant 1.000 : f64 - call @func_powff64(%c, %c_p) : (f64, f64) -> () + %b = arith.constant 2.343 : f64 + %b_p = arith.constant 1.000 : f64 + call @func_powff64(%b, %b_p) : (f64, f64) -> () // CHECK-NEXT: 0.176171 - %d = arith.constant 4.25 : f64 - %d_p = arith.constant -1.2 : f64 - call @func_powff64(%d, %d_p) : (f64, f64) -> () + %c = arith.constant 4.25 : f64 + %c_p = arith.constant -1.2 : f64 + call @func_powff64(%c, %c_p) : (f64, f64) -> () // CHECK-NEXT: 1 - %e = arith.constant 4.385 : f64 - %e_p = arith.constant 0.00 : f64 - call @func_powff64(%e, %e_p) : (f64, f64) -> () + %d = arith.constant 4.385 : f64 + %d_p = arith.constant 0.00 : f64 + call @func_powff64(%d, %d_p) : (f64, f64) -> () // CHECK-NEXT: 6.62637 - %f = arith.constant 4.835 : f64 - %f_p = arith.constant 1.2 : f64 - call @func_powff64(%f, %f_p) : (f64, f64) -> () + %e = arith.constant 4.835 : f64 + %e_p = arith.constant 1.2 : f64 + call @func_powff64(%e, %e_p) : (f64, f64) -> () // CHECK-NEXT: nan - %i = arith.constant 1.0 : f64 - %h = arith.constant 0x7fffffffffffffff : f64 - call @func_powff64(%i, %h) : (f64, f64) -> () + %f = arith.constant 1.0 : f64 + %f_p = arith.constant 0x7fffffffffffffff : f64 + call @func_powff64(%f, %f_p) : (f64, f64) -> () // CHECK-NEXT: inf - %j = arith.constant 29385.0 : f64 - %j_p = arith.constant 23598.0 : f64 - call @func_powff64(%j, %j_p) : (f64, f64) -> () + %g = arith.constant 29385.0 : f64 + %g_p = arith.constant 23598.0 : f64 + call @func_powff64(%g, %g_p) : (f64, f64) -> () // CHECK-NEXT: -nan - %k = arith.constant 1.0 : f64 - %k_p = arith.constant 0xfff0000001000000 : f64 - call @func_powff64(%k, %k_p) : (f64, f64) -> () + %h = arith.constant 1.0 : f64 + %h_p = arith.constant 0xfff0000001000000 : f64 + call @func_powff64(%h, %h_p) : (f64, f64) -> () // CHECK-NEXT: -nan - %l = arith.constant 1.0 : f32 - %l_p = arith.constant 0xffffffff : f32 - call @func_powff32(%l, %l_p) : (f32, f32) -> () + %i = arith.constant 1.0 : f32 + %i_p = arith.constant 0xffffffff : f32 + call @func_powff32(%i, %i_p) : (f32, f32) -> () // CHECK-NEXT: 1 - %zero = arith.constant 0.0 : f32 - call @func_powff32(%zero, %zero) : (f32, f32) -> () + %j = arith.constant 0.000 : f32 + %j_r = math.powf %j, %j : f32 + vector.print %j_r : f32 + // CHECK-NEXT: 4 + %k = arith.constant -2.0 : f32 + %k_p = arith.constant 2.0 : f32 + %k_r = math.powf %k, %k_p : f32 + vector.print %k_r : f32 + + // CHECK-NEXT: 0.25 + %l = arith.constant -2.0 : f32 + %l_p = arith.constant -2.0 : f32 + %l_r = math.powf %k, %l_p : f32 + vector.print %l_r : f32 return } From c52ba9fc6aa11743140fe0a6402ec93d68e3aed9 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Tue, 11 Feb 2025 11:59:43 +0900 Subject: [PATCH 04/14] formatting --- mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 9ad1ac2308838..5d9b264fd0d51 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -9,6 +9,7 @@ // This file implements expansion of various math operations. // //===----------------------------------------------------------------------===// + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" @@ -16,13 +17,8 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/Support/LogicalResult.h" -#include using namespace mlir; From e1e06ec257ded2ae2441556cf72b7d3732cba6c9 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Wed, 12 Feb 2025 11:26:33 +0900 Subject: [PATCH 05/14] Give explicit benefit to convertSpecialPowfOp --- mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 5d9b264fd0d51..34ba98ca16a3e 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -703,7 +703,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { } void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { - patterns.add(convertSpecialPowfOp); + patterns.add(convertSpecialPowfOp, /*benefit=*/ 2); patterns.add(convertPowfOp); } From 9cf3d3bb2e4c387d53779460a935af76c4f3d94f Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Wed, 12 Feb 2025 11:33:30 +0900 Subject: [PATCH 06/14] formatting --- mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 34ba98ca16a3e..704001870601b 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -703,7 +703,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { } void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { - patterns.add(convertSpecialPowfOp, /*benefit=*/ 2); + patterns.add(convertSpecialPowfOp, /*benefit=*/2); patterns.add(convertPowfOp); } From ec9f1287c4f466a342797bc84c455299877efa49 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Wed, 12 Feb 2025 13:01:27 +0900 Subject: [PATCH 07/14] avoid using native float --- .../Math/Transforms/ExpandPatterns.cpp | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 704001870601b..1d27686b3acc1 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APFloat.h" using namespace mlir; @@ -318,55 +319,54 @@ static LogicalResult convertSpecialPowfOp(math::PowFOp op, ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operandA = op.getOperand(0); Value operandB = op.getOperand(1); - auto opType = operandA.getType(); - auto baseType = operandB.getType(); - - auto &sem = dyn_cast(getElementTypeOrSelf(baseType)) - .getFloatSemantics(); + auto typeA = operandA.getType(); + auto typeB = operandB.getType(); + auto &sem = + cast(getElementTypeOrSelf(typeB)).getFloatSemantics(); auto valueB = APFloat(sem); if (!matchPattern(operandB, m_ConstantFloat(&valueB))) { // Not a constant, return failure return failure(); } - float floatValueB = valueB.convertToFloat(); - if (floatValueB == 0.0f) { + + if (valueB.compare(APFloat(0.0f)) == APFloat::cmpEqual) { // a^0 -> 1 - Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); + Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); rewriter.replaceOp(op, one); return success(); } - if (floatValueB == 1.0f) { + if (valueB.compare(APFloat(1.0f)) == APFloat::cmpEqual) { // a^1 -> a rewriter.replaceOp(op, operandA); return success(); } - if (floatValueB == -1.0f) { + if (valueB.compare(APFloat(-1.0f)) == APFloat::cmpEqual) { // a^(-1) -> 1 / a - Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); + Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); Value div = b.create(one, operandA); rewriter.replaceOp(op, div); return success(); } - if (floatValueB == 0.5f) { + if (valueB.compare(APFloat(0.5f)) == APFloat::cmpEqual) { // a^(1/2) -> sqrt(a) Value sqrt = b.create(operandA); rewriter.replaceOp(op, sqrt); return success(); } - if (floatValueB == -0.5f) { + if (valueB.compare(APFloat(-0.5f)) == APFloat::cmpEqual) { // a^(-1/2) -> 1 / sqrt(a) Value rsqrt = b.create(operandA); rewriter.replaceOp(op, rsqrt); return success(); } - if (floatValueB == 2.0f) { + if (valueB.compare(APFloat(2.0f)) == APFloat::cmpEqual) { // a^2 -> a * a Value mul = b.create(operandA, operandA); rewriter.replaceOp(op, mul); return success(); } - if (floatValueB == -2.0f) { + if (valueB.compare(APFloat(-2.0f)) == APFloat::cmpEqual) { // a^(-2) -> 1 / (a * a) Value mul = b.create(operandA, operandA); Value one = From 95c3d55b1803ab057deb88266ef06be67ce1a496 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Wed, 12 Feb 2025 13:20:19 +0900 Subject: [PATCH 08/14] match APFloat Types --- .../Math/Transforms/ExpandPatterns.cpp | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 1d27686b3acc1..3718800437bf2 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -324,49 +324,52 @@ static LogicalResult convertSpecialPowfOp(math::PowFOp op, auto &sem = cast(getElementTypeOrSelf(typeB)).getFloatSemantics(); - auto valueB = APFloat(sem); + APFloat valueB(sem); if (!matchPattern(operandB, m_ConstantFloat(&valueB))) { // Not a constant, return failure return failure(); } - - if (valueB.compare(APFloat(0.0f)) == APFloat::cmpEqual) { + if (valueB.compare(APFloat::getZero(sem)) == APFloat::cmpEqual) { // a^0 -> 1 Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); rewriter.replaceOp(op, one); return success(); } - if (valueB.compare(APFloat(1.0f)) == APFloat::cmpEqual) { + if (valueB.compare(APFloat::getOne(sem)) == APFloat::cmpEqual) { // a^1 -> a rewriter.replaceOp(op, operandA); return success(); } - if (valueB.compare(APFloat(-1.0f)) == APFloat::cmpEqual) { + if (valueB.compare(-APFloat::getOne(sem)) == APFloat::cmpEqual) { // a^(-1) -> 1 / a Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); Value div = b.create(one, operandA); rewriter.replaceOp(op, div); return success(); } - if (valueB.compare(APFloat(0.5f)) == APFloat::cmpEqual) { + APFloat halfVal(0.5); + halfVal.convert(sem, APFloat::rmNearestTiesToEven, /*losesInfo=*/nullptr); + if (valueB.compare(halfVal) == APFloat::cmpEqual) { // a^(1/2) -> sqrt(a) Value sqrt = b.create(operandA); rewriter.replaceOp(op, sqrt); return success(); } - if (valueB.compare(APFloat(-0.5f)) == APFloat::cmpEqual) { + if (valueB.compare(-halfVal) == APFloat::cmpEqual) { // a^(-1/2) -> 1 / sqrt(a) Value rsqrt = b.create(operandA); rewriter.replaceOp(op, rsqrt); return success(); } - if (valueB.compare(APFloat(2.0f)) == APFloat::cmpEqual) { + APFloat twoVal(2.0); + twoVal.convert(sem, APFloat::rmNearestTiesToEven, /*losesInfo=*/nullptr); + if (valueB.compare(twoVal) == APFloat::cmpEqual) { // a^2 -> a * a Value mul = b.create(operandA, operandA); rewriter.replaceOp(op, mul); return success(); } - if (valueB.compare(APFloat(-2.0f)) == APFloat::cmpEqual) { + if (valueB.compare(-twoVal) == APFloat::cmpEqual) { // a^(-2) -> 1 / (a * a) Value mul = b.create(operandA, operandA); Value one = From 79c4ef4e5f18efa74edf6aeb89a056ff6af07a0a Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Wed, 12 Feb 2025 13:22:56 +0900 Subject: [PATCH 09/14] match APFloat Types --- mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 3718800437bf2..8837c283f46be 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -329,7 +329,7 @@ static LogicalResult convertSpecialPowfOp(math::PowFOp op, // Not a constant, return failure return failure(); } - if (valueB.compare(APFloat::getZero(sem)) == APFloat::cmpEqual) { + if (valueB.isZero()) { // a^0 -> 1 Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); rewriter.replaceOp(op, one); From 393aaa8305c5c0b86df6bfafad73549b92a1679e Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Thu, 13 Feb 2025 07:52:43 +0900 Subject: [PATCH 10/14] APFloat Fix, merge two functions --- .../Math/Transforms/ExpandPatterns.cpp | 113 ++++++++---------- 1 file changed, 49 insertions(+), 64 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 8837c283f46be..c243ada80c8f3 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -312,10 +312,10 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, return success(); } -// Convert Powf(float a, float b) for special cases when b is constant: +// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) +// Some special cases where b is constant are handled separately: // when b == 0, or |b| == 0.5, 1.0, or 2.0. -static LogicalResult convertSpecialPowfOp(math::PowFOp op, - PatternRewriter &rewriter) { +static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operandA = op.getOperand(0); Value operandB = op.getOperand(1); @@ -325,68 +325,54 @@ static LogicalResult convertSpecialPowfOp(math::PowFOp op, auto &sem = cast(getElementTypeOrSelf(typeB)).getFloatSemantics(); APFloat valueB(sem); - if (!matchPattern(operandB, m_ConstantFloat(&valueB))) { - // Not a constant, return failure - return failure(); - } - if (valueB.isZero()) { - // a^0 -> 1 - Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); - rewriter.replaceOp(op, one); - return success(); - } - if (valueB.compare(APFloat::getOne(sem)) == APFloat::cmpEqual) { - // a^1 -> a - rewriter.replaceOp(op, operandA); - return success(); - } - if (valueB.compare(-APFloat::getOne(sem)) == APFloat::cmpEqual) { - // a^(-1) -> 1 / a - Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); - Value div = b.create(one, operandA); - rewriter.replaceOp(op, div); - return success(); - } - APFloat halfVal(0.5); - halfVal.convert(sem, APFloat::rmNearestTiesToEven, /*losesInfo=*/nullptr); - if (valueB.compare(halfVal) == APFloat::cmpEqual) { - // a^(1/2) -> sqrt(a) - Value sqrt = b.create(operandA); - rewriter.replaceOp(op, sqrt); - return success(); - } - if (valueB.compare(-halfVal) == APFloat::cmpEqual) { - // a^(-1/2) -> 1 / sqrt(a) - Value rsqrt = b.create(operandA); - rewriter.replaceOp(op, rsqrt); - return success(); - } - APFloat twoVal(2.0); - twoVal.convert(sem, APFloat::rmNearestTiesToEven, /*losesInfo=*/nullptr); - if (valueB.compare(twoVal) == APFloat::cmpEqual) { - // a^2 -> a * a - Value mul = b.create(operandA, operandA); - rewriter.replaceOp(op, mul); - return success(); + if (matchPattern(operandB, m_ConstantFloat(&valueB))) { + if (valueB.isZero()) { + // a^0 -> 1 + Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); + rewriter.replaceOp(op, one); + return success(); + } + if (valueB.compare(APFloat::getOne(sem)) == APFloat::cmpEqual) { + // a^1 -> a + rewriter.replaceOp(op, operandA); + return success(); + } + if (valueB.compare(-APFloat::getOne(sem)) == APFloat::cmpEqual) { + // a^(-1) -> 1 / a + Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); + Value div = b.create(one, operandA); + rewriter.replaceOp(op, div); + return success(); + } + if (valueB.isExactlyValue(0.5)) { + // a^(1/2) -> sqrt(a) + Value sqrt = b.create(operandA); + rewriter.replaceOp(op, sqrt); + return success(); + } + if (valueB.isExactlyValue(-0.5)) { + // a^(-1/2) -> 1 / sqrt(a) + Value rsqrt = b.create(operandA); + rewriter.replaceOp(op, rsqrt); + return success(); + } + if (valueB.isExactlyValue(2.0)) { + // a^2 -> a * a + Value mul = b.create(operandA, operandA); + rewriter.replaceOp(op, mul); + return success(); + } + if (valueB.isExactlyValue(-2.0)) { + // a^(-2) -> 1 / (a * a) + Value mul = b.create(operandA, operandA); + Value one = + createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); + Value div = b.create(one, mul); + rewriter.replaceOp(op, div); + return success(); + } } - if (valueB.compare(-twoVal) == APFloat::cmpEqual) { - // a^(-2) -> 1 / (a * a) - Value mul = b.create(operandA, operandA); - Value one = - createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); - Value div = b.create(one, mul); - rewriter.replaceOp(op, div); - return success(); - } - - return failure(); -} -// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) -static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { - ImplicitLocOpBuilder b(op->getLoc(), rewriter); - Value operandA = op.getOperand(0); - Value operandB = op.getOperand(1); Value logA = b.create(operandA); Value mult = b.create(operandB, logA); Value expResult = b.create(mult); @@ -706,7 +692,6 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { } void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { - patterns.add(convertSpecialPowfOp, /*benefit=*/2); patterns.add(convertPowfOp); } From f5205a6be34865eb93705724966d8ab11d7a0587 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Thu, 13 Feb 2025 09:18:00 +0900 Subject: [PATCH 11/14] fix tests --- .../Math/Transforms/ExpandPatterns.cpp | 4 +- mlir/test/Dialect/Math/expand-math.mlir | 77 ++++++++++++++++++- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index c243ada80c8f3..d7953719d44b5 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -332,12 +332,12 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { rewriter.replaceOp(op, one); return success(); } - if (valueB.compare(APFloat::getOne(sem)) == APFloat::cmpEqual) { + if (valueB.isExactlyValue(1.0)) { // a^1 -> a rewriter.replaceOp(op, operandA); return success(); } - if (valueB.compare(-APFloat::getOne(sem)) == APFloat::cmpEqual) { + if (valueB.isExactlyValue(-1.0)) { // a^(-1) -> 1 / a Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); Value div = b.create(one, operandA); diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 3cf372ea0cf50..280b133926a0c 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -211,6 +211,81 @@ func.func @powf_func(%a: f64, %b: f64) -> f64 { return %ret : f64 } +// CHECK-LABEL: func @powf_func_zero +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @powf_func_zero(%a: f64) -> f64{ + // CHECK: [[ONE:%.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: return [[ONE]] : f64 + %b = arith.constant 0.0 : f64 + %ret = math.powf %a, %b : f64 + return %ret : f64 +} + +// CHECK-LABEL: func @powf_func_one +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @powf_func_one(%a: f64) -> f64{ + // CHECK: return [[ARG0]] : f64 + %b = arith.constant 1.0 : f64 + %ret = math.powf %a, %b : f64 + return %ret : f64 +} + + +// CHECK-LABEL: func @powf_func_negone +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @powf_func_negone(%a: f64) -> f64{ + // CHECK: [[CSTONE:%.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[ARG0]] : f64 + // CHECK: return [[DIV]] : f64 + %b = arith.constant -1.0 : f64 + %ret = math.powf %a, %b : f64 + return %ret : f64 +} + +// CHECK-LABEL: func @powf_func_half +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @powf_func_half(%a: f64) -> f64{ + // CHECK: [[SQRT:%.+]] = math.sqrt [[ARG0]] : f64 + // CHECK: return [[SQRT]] : f64 + %b = arith.constant 0.5 : f64 + %ret = math.powf %a, %b : f64 + return %ret : f64 +} + +// CHECK-LABEL: func @powf_func_neghalf +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @powf_func_neghalf(%a: f64) -> f64{ + // CHECK: [[CSTONE:%.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: [[SQRT:%.+]] = math.sqrt [[ARG0]] : f64 + // CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[SQRT]] : f64 + // CHECK: return [[DIV]] : f64 + %b = arith.constant -0.5 : f64 + %ret = math.powf %a, %b : f64 + return %ret : f64 +} + +// CHECK-LABEL: func @powf_func_two +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @powf_func_two(%a: f64) -> f64{ + // CHECK: [[MUL:%.+]] = arith.mulf [[ARG0]], [[ARG0]] : f64 + // CHECK: return [[MUL]] : f64 + %b = arith.constant 2.0 : f64 + %ret = math.powf %a, %b : f64 + return %ret : f64 +} + +// CHECK-LABEL: func @powf_func_negtwo +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @powf_func_negtwo(%a: f64) -> f64{ + // CHECK-DAG: [[MUL:%.+]] = arith.mulf [[ARG0]], [[ARG0]] : f64 + // CHECK-DAG: [[CSTONE:%.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[MUL]] : f64 + // CHECK: return [[DIV]] : f64 + %b = arith.constant -2.0 : f64 + %ret = math.powf %a, %b : f64 + return %ret : f64 +} + // ----- // CHECK-LABEL: func.func @roundeven64 @@ -592,7 +667,7 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t // CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32> // CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32> // CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32> -// CHECK: return %[[EXP]] : tensor<8xf32> +// CHECK: return %[[EXP]] // ----- // CHECK-LABEL: func.func @math_fpowi_to_powf_scalar From d5ee522a217359264d35516a5b506149b5d1ab68 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Thu, 13 Feb 2025 09:18:17 +0900 Subject: [PATCH 12/14] fix tests --- mlir/test/Dialect/Math/expand-math.mlir | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 280b133926a0c..36a69ec83ce6e 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -230,7 +230,6 @@ func.func @powf_func_one(%a: f64) -> f64{ return %ret : f64 } - // CHECK-LABEL: func @powf_func_negone // CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 func.func @powf_func_negone(%a: f64) -> f64{ From 48b94056c8d9674462e63d68481f3b57fa0d613f Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Thu, 13 Feb 2025 10:37:08 +0900 Subject: [PATCH 13/14] lit fix --- mlir/test/Dialect/Math/expand-math.mlir | 44 ++++++++++++------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 36a69ec83ce6e..7491315300a99 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -201,20 +201,20 @@ func.func @roundf_func(%a: f32) -> f32 { // ----- // CHECK-LABEL: func @powf_func -// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64) +// CHECK-SAME: ([[ARG0:.+]]: f64, [[ARG1:.+]]: f64) func.func @powf_func(%a: f64, %b: f64) -> f64 { - // CHECK: [[LOGA:%.+]] = math.log [[ARG0]] : f64 - // CHECK: [[MUL:%.+]] = arith.mulf [[ARG1]], [[LOGA]] : f64 - // CHECK: [[EXP:%.+]] = math.exp [[MUL]] : f64 + // CHECK: [[LOGA:.+]] = math.log [[ARG0]] : f64 + // CHECK: [[MUL:.+]] = arith.mulf [[ARG1]], [[LOGA]] : f64 + // CHECK: [[EXP:.+]] = math.exp [[MUL]] : f64 // CHECK: return [[EXP]] : f64 %ret = math.powf %a, %b : f64 return %ret : f64 } // CHECK-LABEL: func @powf_func_zero -// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 func.func @powf_func_zero(%a: f64) -> f64{ - // CHECK: [[ONE:%.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: [[ONE:.+]] = arith.constant 1.000000e+00 : f64 // CHECK: return [[ONE]] : f64 %b = arith.constant 0.0 : f64 %ret = math.powf %a, %b : f64 @@ -222,7 +222,7 @@ func.func @powf_func_zero(%a: f64) -> f64{ } // CHECK-LABEL: func @powf_func_one -// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 func.func @powf_func_one(%a: f64) -> f64{ // CHECK: return [[ARG0]] : f64 %b = arith.constant 1.0 : f64 @@ -231,10 +231,10 @@ func.func @powf_func_one(%a: f64) -> f64{ } // CHECK-LABEL: func @powf_func_negone -// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 func.func @powf_func_negone(%a: f64) -> f64{ - // CHECK: [[CSTONE:%.+]] = arith.constant 1.000000e+00 : f64 - // CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[ARG0]] : f64 + // CHECK: [[CSTONE:.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: [[DIV:.+]] = arith.divf [[CSTONE]], [[ARG0]] : f64 // CHECK: return [[DIV]] : f64 %b = arith.constant -1.0 : f64 %ret = math.powf %a, %b : f64 @@ -242,9 +242,9 @@ func.func @powf_func_negone(%a: f64) -> f64{ } // CHECK-LABEL: func @powf_func_half -// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 func.func @powf_func_half(%a: f64) -> f64{ - // CHECK: [[SQRT:%.+]] = math.sqrt [[ARG0]] : f64 + // CHECK: [[SQRT:.+]] = math.sqrt [[ARG0]] : f64 // CHECK: return [[SQRT]] : f64 %b = arith.constant 0.5 : f64 %ret = math.powf %a, %b : f64 @@ -252,11 +252,11 @@ func.func @powf_func_half(%a: f64) -> f64{ } // CHECK-LABEL: func @powf_func_neghalf -// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 func.func @powf_func_neghalf(%a: f64) -> f64{ - // CHECK: [[CSTONE:%.+]] = arith.constant 1.000000e+00 : f64 - // CHECK: [[SQRT:%.+]] = math.sqrt [[ARG0]] : f64 - // CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[SQRT]] : f64 + // CHECK: [[CSTONE:.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: [[SQRT:.+]] = math.sqrt [[ARG0]] : f64 + // CHECK: [[DIV:.+]] = arith.divf [[CSTONE]], [[SQRT]] : f64 // CHECK: return [[DIV]] : f64 %b = arith.constant -0.5 : f64 %ret = math.powf %a, %b : f64 @@ -264,9 +264,9 @@ func.func @powf_func_neghalf(%a: f64) -> f64{ } // CHECK-LABEL: func @powf_func_two -// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 func.func @powf_func_two(%a: f64) -> f64{ - // CHECK: [[MUL:%.+]] = arith.mulf [[ARG0]], [[ARG0]] : f64 + // CHECK: [[MUL:.+]] = arith.mulf [[ARG0]], [[ARG0]] : f64 // CHECK: return [[MUL]] : f64 %b = arith.constant 2.0 : f64 %ret = math.powf %a, %b : f64 @@ -274,11 +274,11 @@ func.func @powf_func_two(%a: f64) -> f64{ } // CHECK-LABEL: func @powf_func_negtwo -// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 func.func @powf_func_negtwo(%a: f64) -> f64{ - // CHECK-DAG: [[MUL:%.+]] = arith.mulf [[ARG0]], [[ARG0]] : f64 - // CHECK-DAG: [[CSTONE:%.+]] = arith.constant 1.000000e+00 : f64 - // CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[MUL]] : f64 + // CHECK-DAG: [[MUL:.+]] = arith.mulf [[ARG0]], [[ARG0]] : f64 + // CHECK-DAG: [[CSTONE:.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: [[DIV:.+]] = arith.divf [[CSTONE]], [[MUL]] : f64 // CHECK: return [[DIV]] : f64 %b = arith.constant -2.0 : f64 %ret = math.powf %a, %b : f64 From 640ec459c6ce717999f1625e9e3ca7a0be83c96b Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Thu, 13 Feb 2025 11:50:32 +0900 Subject: [PATCH 14/14] fix tests --- mlir/test/Dialect/Math/expand-math.mlir | 60 ++++++++++++------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 7491315300a99..f39d1a7a6dc50 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -201,85 +201,85 @@ func.func @roundf_func(%a: f32) -> f32 { // ----- // CHECK-LABEL: func @powf_func -// CHECK-SAME: ([[ARG0:.+]]: f64, [[ARG1:.+]]: f64) +// CHECK-SAME: (%[[ARG0:.+]]: f64, %[[ARG1:.+]]: f64) -> f64 func.func @powf_func(%a: f64, %b: f64) -> f64 { - // CHECK: [[LOGA:.+]] = math.log [[ARG0]] : f64 - // CHECK: [[MUL:.+]] = arith.mulf [[ARG1]], [[LOGA]] : f64 - // CHECK: [[EXP:.+]] = math.exp [[MUL]] : f64 - // CHECK: return [[EXP]] : f64 + // CHECK: %[[LOGA:.+]] = math.log %[[ARG0]] : f64 + // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG1]], %[[LOGA]] : f64 + // CHECK: %[[EXP:.+]] = math.exp %[[MUL]] : f64 + // CHECK: return %[[EXP]] : f64 %ret = math.powf %a, %b : f64 return %ret : f64 } // CHECK-LABEL: func @powf_func_zero -// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 +// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64 func.func @powf_func_zero(%a: f64) -> f64{ - // CHECK: [[ONE:.+]] = arith.constant 1.000000e+00 : f64 - // CHECK: return [[ONE]] : f64 + // CHECK: %[[ONE:.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: return %[[ONE]] : f64 %b = arith.constant 0.0 : f64 %ret = math.powf %a, %b : f64 return %ret : f64 } // CHECK-LABEL: func @powf_func_one -// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 +// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64 func.func @powf_func_one(%a: f64) -> f64{ - // CHECK: return [[ARG0]] : f64 + // CHECK: return %[[ARG0]] : f64 %b = arith.constant 1.0 : f64 %ret = math.powf %a, %b : f64 return %ret : f64 } // CHECK-LABEL: func @powf_func_negone -// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 +// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64 func.func @powf_func_negone(%a: f64) -> f64{ - // CHECK: [[CSTONE:.+]] = arith.constant 1.000000e+00 : f64 - // CHECK: [[DIV:.+]] = arith.divf [[CSTONE]], [[ARG0]] : f64 - // CHECK: return [[DIV]] : f64 + // CHECK: %[[CSTONE:.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: %[[DIV:.+]] = arith.divf %[[CSTONE]], %[[ARG0]] : f64 + // CHECK: return %[[DIV]] : f64 %b = arith.constant -1.0 : f64 %ret = math.powf %a, %b : f64 return %ret : f64 } // CHECK-LABEL: func @powf_func_half -// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 +// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64 func.func @powf_func_half(%a: f64) -> f64{ - // CHECK: [[SQRT:.+]] = math.sqrt [[ARG0]] : f64 - // CHECK: return [[SQRT]] : f64 + // CHECK: %[[SQRT:.+]] = math.sqrt %[[ARG0]] : f64 + // CHECK: return %[[SQRT]] : f64 %b = arith.constant 0.5 : f64 %ret = math.powf %a, %b : f64 return %ret : f64 } // CHECK-LABEL: func @powf_func_neghalf -// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 +// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64 func.func @powf_func_neghalf(%a: f64) -> f64{ - // CHECK: [[CSTONE:.+]] = arith.constant 1.000000e+00 : f64 - // CHECK: [[SQRT:.+]] = math.sqrt [[ARG0]] : f64 - // CHECK: [[DIV:.+]] = arith.divf [[CSTONE]], [[SQRT]] : f64 - // CHECK: return [[DIV]] : f64 + // CHECK: %[[CSTONE:.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: %[[SQRT:.+]] = math.sqrt %[[ARG0]] : f64 + // CHECK: %[[DIV:.+]] = arith.divf %[[CSTONE]], %[[SQRT]] : f64 + // CHECK: return %[[DIV]] : f64 %b = arith.constant -0.5 : f64 %ret = math.powf %a, %b : f64 return %ret : f64 } // CHECK-LABEL: func @powf_func_two -// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 +// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64 func.func @powf_func_two(%a: f64) -> f64{ - // CHECK: [[MUL:.+]] = arith.mulf [[ARG0]], [[ARG0]] : f64 - // CHECK: return [[MUL]] : f64 + // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64 + // CHECK: return %[[MUL]] : f64 %b = arith.constant 2.0 : f64 %ret = math.powf %a, %b : f64 return %ret : f64 } // CHECK-LABEL: func @powf_func_negtwo -// CHECK-SAME: ([[ARG0:.+]]: f64) -> f64 +// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64 func.func @powf_func_negtwo(%a: f64) -> f64{ - // CHECK-DAG: [[MUL:.+]] = arith.mulf [[ARG0]], [[ARG0]] : f64 - // CHECK-DAG: [[CSTONE:.+]] = arith.constant 1.000000e+00 : f64 - // CHECK: [[DIV:.+]] = arith.divf [[CSTONE]], [[MUL]] : f64 - // CHECK: return [[DIV]] : f64 + // CHECK-DAG: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64 + // CHECK-DAG: %[[CSTONE:.+]] = arith.constant 1.000000e+00 : f64 + // CHECK: %[[DIV:.+]] = arith.divf %[[CSTONE]], %[[MUL]] : f64 + // CHECK: return %[[DIV]] : f64 %b = arith.constant -2.0 : f64 %ret = math.powf %a, %b : f64 return %ret : f64