Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 74 additions & 30 deletions mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,40 +311,83 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}

// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
// Convert Powf(float a, float b) for special cases when b is constant:
// when b == 0, or |b| == 0.5, 1.0, or 2.0.
static LogicalResult convertSpecialPowfOp(math::PowFOp op,
PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
Value operandB = op.getOperand(1);
auto opType = operandA.getType();
auto baseType = operandB.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename opType to typeA, and baseType to typeB ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, since math.powf requires float arguments (I just checked MathOps.td), the cast really shouldn't ever fail, so I think you can simply use cast instead of dyn_cast. You weren't checking for a null return value from dyn_cast anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

.getFloatSemantics();

auto valueB = APFloat(sem);
if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
// Not a constant, return failure
return failure();
}
float floatValueB = valueB.convertToFloat();
Copy link
Contributor

@bjacob bjacob Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid converting to C++ float, as the actual type could have higher precision, so this conversion would be rounding to lesser precision and could end up enabling an incorrect rewrite, e.g. if the type is f64, then the rewrite from powf(a, 1.0 + 1.0e-10) into a would be incorrect as, in f64, 1.0 + 1.0e-10 != 1.0, but the rounded floatValueB is exactly 1.0.

Can you keep the whole logic in this function in APFloat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but this is might be bit verbose so could you check it please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hanhanW , i am not familiar myself with APFloat, can you review that aspect?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with it either. After skimming through the doc, I think we can use isExactlyValue method. I think we do not have precision issue for 0, +-1, +-0.5, +-2 numbers.

/// We don't rely on operator== working on double values, as
/// it returns true for things that are clearly not equal, like -0.0 and 0.0.
/// As such, this method can be used to do an exact bit-for-bit comparison of
/// two floating point values.
///
/// We leave the version with the double argument here because it's just so
/// convenient to write "2.0" and the like. Without this function we'd
/// have to duplicate its logic everywhere it's called.
bool isExactlyValue(double V) const {
bool ignored;
APFloat Tmp(V);
Tmp.convert(getSemantics(), APFloat::rmNearestTiesToEven, &ignored);
return bitwiseIsEqual(Tmp);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

if (floatValueB == 0.0f) {
// a^0 -> 1
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
rewriter.replaceOp(op, one);
return success();
}
if (floatValueB == 1.0f) {
// a^1 -> a
rewriter.replaceOp(op, operandA);
return success();
}
if (floatValueB == -1.0f) {
// a^(-1) -> 1 / a
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
Value div = b.create<arith::DivFOp>(one, operandA);
rewriter.replaceOp(op, div);
return success();
}
if (floatValueB == 0.5f) {
// a^(1/2) -> sqrt(a)
Value sqrt = b.create<math::SqrtOp>(operandA);
rewriter.replaceOp(op, sqrt);
return success();
}
if (floatValueB == -0.5f) {
// a^(-1/2) -> 1 / sqrt(a)
Value rsqrt = b.create<math::RsqrtOp>(operandA);
rewriter.replaceOp(op, rsqrt);
return success();
}
if (floatValueB == 2.0f) {
// a^2 -> a * a
Value mul = b.create<arith::MulFOp>(operandA, operandA);
rewriter.replaceOp(op, mul);
return success();
}
if (floatValueB == -2.0f) {
// a^(-2) -> 1 / (a * a)
Value mul = b.create<arith::MulFOp>(operandA, operandA);
Value one =
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
Value div = b.create<arith::DivFOp>(one, mul);
rewriter.replaceOp(op, div);
return success();
}

return failure();
}

// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
Value operandB = op.getOperand(1);
Type opType = operandA.getType();
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);

Value logA = b.create<math::LogOp>(opType, opASquared);
Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
Value expResult = b.create<math::ExpOp>(opType, mult);
Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
Value negCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
Value oddPower =
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);

// First, we select between the exp value and the adjusted value for odd
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
Value zeroCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
expResult);
res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
rewriter.replaceOp(op, res);
Value logA = b.create<math::LogOp>(operandA);
Value mult = b.create<arith::MulFOp>(operandB, logA);
Value expResult = b.create<math::ExpOp>(mult);
rewriter.replaceOp(op, expResult);
return success();
}

Expand Down Expand Up @@ -660,6 +703,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
}

void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
patterns.add(convertSpecialPowfOp, /*benefit=*/2);
patterns.add(convertPowfOp);
}

Expand Down
69 changes: 13 additions & 56 deletions mlir/test/Dialect/Math/expand-math.mlir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least we need all the special cases are tested in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my apology for that tihs PR got verbose, I made appropriate tests and runs well!!

Original file line number Diff line number Diff line change
Expand Up @@ -202,25 +202,11 @@ func.func @roundf_func(%a: f32) -> f32 {

// CHECK-LABEL: func @powf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
func.func @powf_func(%a: f64, %b: f64) ->f64 {
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
// CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
// CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
// CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
// CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
// CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
// CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
// CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
// CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
// CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
// CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
// CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
// CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
// CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
// CHECK: return [[SEL1]]
func.func @powf_func(%a: f64, %b: f64) -> f64 {
// CHECK: [[LOGA:%.+]] = math.log [[ARG0]] : f64
// CHECK: [[MUL:%.+]] = arith.mulf [[ARG1]], [[LOGA]] : f64
// CHECK: [[EXP:%.+]] = math.exp [[MUL]] : f64
// CHECK: return [[EXP]] : f64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's follow the %[[XXX:.+]] style because it is more common in MLIR codebase. One of the benefits is that we do not need to escape [ in the [%[[XXX]]] case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed other tests which are using %[[VAR:%.+]], I would fix.

// CHECK-LABEL:     func @ceilf_func
// CHECK-SAME:      ([[ARG0:%.+]]: f64) -> f64
func.func @ceilf_func(%a: f64) -> f64 {
  // CHECK-DAG:   [[CST:%.+]] = arith.constant 0.000
  // CHECK-DAG:   [[CST_0:%.+]] = arith.constant 1.000
  // CHECK-NEXT:   [[CVTI:%.+]] = arith.fptosi [[ARG0]]
  // CHECK-NEXT:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
  // CHECK-NEXT:   [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
  // CHECK-NEXT:   [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
  // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
  // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
  // CHECK-NEXT:   return [[ADDF]]
  %ret = math.ceil %a : f64
  return %ret : f64
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed, thanks now I see how file-check test works.

%ret = math.powf %a, %b : f64
return %ret : f64
}
Expand Down Expand Up @@ -602,26 +588,11 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
// CHECK: return %[[SEL1]] : tensor<8xf32>

// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
// CHECK: return %[[EXP]] : tensor<8xf32>
// -----

// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
Expand All @@ -630,25 +601,11 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32
// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32
// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
// CHECK: return %[[SEL1]] : f32
// CHECK: return %[[EXP]] : f32

// -----

Expand Down
69 changes: 38 additions & 31 deletions mlir/test/mlir-runner/test-expand-math-approx.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -202,55 +202,62 @@ func.func @powf() {
%a_p = arith.constant 2.0 : f64
call @func_powff64(%a, %a_p) : (f64, f64) -> ()

// CHECK-NEXT: -27
%b = arith.constant -3.0 : f64
%b_p = arith.constant 3.0 : f64
call @func_powff64(%b, %b_p) : (f64, f64) -> ()

// CHECK-NEXT: 2.343
%c = arith.constant 2.343 : f64
%c_p = arith.constant 1.000 : f64
call @func_powff64(%c, %c_p) : (f64, f64) -> ()
%b = arith.constant 2.343 : f64
%b_p = arith.constant 1.000 : f64
call @func_powff64(%b, %b_p) : (f64, f64) -> ()

// CHECK-NEXT: 0.176171
%d = arith.constant 4.25 : f64
%d_p = arith.constant -1.2 : f64
call @func_powff64(%d, %d_p) : (f64, f64) -> ()
%c = arith.constant 4.25 : f64
%c_p = arith.constant -1.2 : f64
call @func_powff64(%c, %c_p) : (f64, f64) -> ()

// CHECK-NEXT: 1
%e = arith.constant 4.385 : f64
%e_p = arith.constant 0.00 : f64
call @func_powff64(%e, %e_p) : (f64, f64) -> ()
%d = arith.constant 4.385 : f64
%d_p = arith.constant 0.00 : f64
call @func_powff64(%d, %d_p) : (f64, f64) -> ()

// CHECK-NEXT: 6.62637
%f = arith.constant 4.835 : f64
%f_p = arith.constant 1.2 : f64
call @func_powff64(%f, %f_p) : (f64, f64) -> ()
%e = arith.constant 4.835 : f64
%e_p = arith.constant 1.2 : f64
call @func_powff64(%e, %e_p) : (f64, f64) -> ()

// CHECK-NEXT: nan
%i = arith.constant 1.0 : f64
%h = arith.constant 0x7fffffffffffffff : f64
call @func_powff64(%i, %h) : (f64, f64) -> ()
%f = arith.constant 1.0 : f64
%f_p = arith.constant 0x7fffffffffffffff : f64
call @func_powff64(%f, %f_p) : (f64, f64) -> ()

// CHECK-NEXT: inf
%j = arith.constant 29385.0 : f64
%j_p = arith.constant 23598.0 : f64
call @func_powff64(%j, %j_p) : (f64, f64) -> ()
%g = arith.constant 29385.0 : f64
%g_p = arith.constant 23598.0 : f64
call @func_powff64(%g, %g_p) : (f64, f64) -> ()

// CHECK-NEXT: -nan
%k = arith.constant 1.0 : f64
%k_p = arith.constant 0xfff0000001000000 : f64
call @func_powff64(%k, %k_p) : (f64, f64) -> ()
%h = arith.constant 1.0 : f64
%h_p = arith.constant 0xfff0000001000000 : f64
call @func_powff64(%h, %h_p) : (f64, f64) -> ()

// CHECK-NEXT: -nan
%l = arith.constant 1.0 : f32
%l_p = arith.constant 0xffffffff : f32
call @func_powff32(%l, %l_p) : (f32, f32) -> ()
%i = arith.constant 1.0 : f32
%i_p = arith.constant 0xffffffff : f32
call @func_powff32(%i, %i_p) : (f32, f32) -> ()

// CHECK-NEXT: 1
%zero = arith.constant 0.0 : f32
call @func_powff32(%zero, %zero) : (f32, f32) -> ()
%j = arith.constant 0.000 : f32
%j_r = math.powf %j, %j : f32
vector.print %j_r : f32

// CHECK-NEXT: 4
%k = arith.constant -2.0 : f32
%k_p = arith.constant 2.0 : f32
%k_r = math.powf %k, %k_p : f32
vector.print %k_r : f32

// CHECK-NEXT: 0.25
%l = arith.constant -2.0 : f32
%l_p = arith.constant -2.0 : f32
%l_r = math.powf %k, %l_p : f32
vector.print %l_r : f32
return
}

Expand Down