Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 97 additions & 18 deletions mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/LogicalResult.h"
#include <cmath>

using namespace mlir;

Expand Down Expand Up @@ -311,40 +316,113 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}

// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
// Convert Powf(float a, float b) for some special cases
// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0
static LogicalResult convertSpecialPowfOp(math::PowFOp op,
PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
Value operandB = op.getOperand(1);
auto baseType = operandB.getType();

auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, since math.powf requires float arguments (I just checked MathOps.td), the cast really shouldn't ever fail, so I think you can simply use cast instead of dyn_cast. You weren't checking for a null return value from dyn_cast anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

.getFloatSemantics();

auto valueB = APFloat(sem);
if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
// Not a constant, return failure
return failure();
}
float floatValueB = valueB.convertToFloat();
Copy link
Contributor

@bjacob bjacob Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid converting to C++ float, as the actual type could have higher precision, so this conversion would be rounding to lesser precision and could end up enabling an incorrect rewrite, e.g. if the type is f64, then the rewrite from powf(a, 1.0 + 1.0e-10) into a would be incorrect as, in f64, 1.0 + 1.0e-10 != 1.0, but the rounded floatValueB is exactly 1.0.

Can you keep the whole logic in this function in APFloat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but this is might be bit verbose so could you check it please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hanhanW , i am not familiar myself with APFloat, can you review that aspect?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with it either. After skimming through the doc, I think we can use isExactlyValue method. I think we do not have precision issue for 0, +-1, +-0.5, +-2 numbers.

/// We don't rely on operator== working on double values, as
/// it returns true for things that are clearly not equal, like -0.0 and 0.0.
/// As such, this method can be used to do an exact bit-for-bit comparison of
/// two floating point values.
///
/// We leave the version with the double argument here because it's just so
/// convenient to write "2.0" and the like. Without this function we'd
/// have to duplicate its logic everywhere it's called.
bool isExactlyValue(double V) const {
bool ignored;
APFloat Tmp(V);
Tmp.convert(getSemantics(), APFloat::rmNearestTiesToEven, &ignored);
return bitwiseIsEqual(Tmp);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!


if (floatValueB == 1.0f) {
// a^1 -> a
rewriter.replaceOp(op, operandA);
return success();
}

if (floatValueB == 0.0) {
// a^0 -> 1
Value one =
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
rewriter.replaceOp(op, one);
return success();
}

if (floatValueB == 0.5f) {
// a^(1/2) -> sqrt(a)
Value sqrt = b.create<math::SqrtOp>(operandA);
rewriter.replaceOp(op, sqrt);
return success();
}

if (floatValueB == -0.5f) {
// a^(-1/2) -> 1 / sqrt(a)
Value rsqrt = b.create<math::RsqrtOp>(operandA);
rewriter.replaceOp(op, rsqrt);
return success();
}

if (floatValueB == -1.0f) {
// a^(-1) -> 1 / a
Value one =
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
Value div = b.create<arith::DivFOp>(one, operandA);
rewriter.replaceOp(op, div);
return success();
}

// Check if the power is an integer
if (floatValueB != std::floor(floatValueB)) {
Copy link
Contributor

@bjacob bjacob Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't try to handle arbitrary integer x. Just handle the special value 2.0, and maybe also 3.0 and 4.0 if you want, but that's it. If someone really needs a larger integral exponent to be match, we can always expand this pattern later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, let's handle few cases for now and document it in the function comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept only |b|=2.0 case and removed other integer cases.

// We don't handle non-integer powers here, return failure
return failure();
}

auto sign = std::signbit(floatValueB) ? -1 : 1;
auto absIntValueB = std::abs(static_cast<int>(floatValueB));

auto cstOne =
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
auto base = operandA;
if (sign == -1) {
base = b.create<arith::DivFOp>(cstOne, base);
}
auto current = base;
auto result = cstOne;
while (absIntValueB > 0) {
if (absIntValueB & 1) {
result = b.create<arith::MulFOp>(result, current);
}
current = b.create<arith::MulFOp>(current, current);
absIntValueB >>= 1;
}
rewriter.replaceOp(op, result);
return success();
}

// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
// Restricting a >= 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this actually checked? This seems to be expanding under this assumption, but always does it. Is this a new assumption on the op that should be documented?

Copy link
Contributor Author

@ita9naiwa ita9naiwa Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry!. I forgot to describe PR in more detail.

  1. I believe it should be documented, in general, powf(a, b) where a < 0 generally yields NaN and we (as far as I know) aren't able to check it runtime.

  1. This transform should be applied to some small number of 'b' (e.g., when 'abs(b) < 16')
  while (absIntValueB > 0) {
    if (absIntValueB & 1) {
      result = b.create<arith::MulFOp>(result, current);
    }
    current = b.create<arith::MulFOp>(current, current);
    absIntValueB >>= 1;
  }
  rewriter.replaceOp(op, result);
  return success();

The heuristic number for b is not determined yet. This last case can be dropped if it's not necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One problem is that there are some special use-cases where var a < 0 but const b == some multiple of 2 cc @hanhanW

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just drop the comment // Restricting a >= 0 here.

Mathematically, the power operation a^b, is well-defined in two separate (though overlapping) cases:

  1. When a > 0. In that case, a^b is defined as exp(b * ln(a)).
  2. When b is an integer. In that case, a^b is defined as a * ... * a, (b times), or the reciprocal of that if b is negative.

These two definitions agree in the intersection of these two cases.

Because "power" has inherently that two-mode definition, the MLIR op powf should have been specified from the start to implement one of these two modes only. Obviously it should have been a > 0.

I believe that it is still time to clarify that. We have observed recently that some rewrite patterns for powf ops have been broken outside of the case a > 0, suggesting that no one was relying on that.

But that discussion doesn't need to be conflated into this PR, because this PR implements rewrites that are either agnostic as to which case we are in (e.g. the case of pow(a, 2.0)) or that are explicitly not applying to the other case anyway (e.g. the case of pow(a, 0.5)).

static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
Value operandB = op.getOperand(1);
Type opType = operandA.getType();
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);

Value logA = b.create<math::LogOp>(opType, opASquared);
Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
Value logA = b.create<math::LogOp>(opType, operandA);
Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
Value expResult = b.create<math::ExpOp>(opType, mult);
Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
Value negCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
Value oddPower =
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);

// First, we select between the exp value and the adjusted value for odd
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
Value zeroCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
expResult);
res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
rewriter.replaceOp(op, res);
Value finalResult =
b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
rewriter.replaceOp(op, finalResult);
return success();
}

Expand Down Expand Up @@ -660,6 +738,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
}

void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
patterns.add(convertSpecialPowfOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question for MLIR experts: Here, we want the convertSpecialPowfOp to have precedence over the convertPowfOp pattern. Is that ensured by it being added first here? If not, do we need to merge these two patterns to ensure ordering?

Copy link
Contributor Author

@ita9naiwa ita9naiwa Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

patterns.add(convertSpecialPowfOp, /*benefit=*/ 2);
This would explicitly give the order we want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very sure since I didn't check the code, but adding patterns in this order make convertSpecialPowfOp run first.

patterns.add(convertSpecialPowfOp);
patterns.add(convertPowfOp);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@hanhanW hanhanW Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very sure since I didn't check the code, but adding patterns in this order make convertSpecialPowfOp run first.

This is correct, but I think it is not documented. IMO, we prefer using benefit to prioritize the patterns.

https://mlir.llvm.org/docs/PatternRewriter/#benefit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to give explicit benefit=2!

patterns.add(convertPowfOp);
}

Expand Down
71 changes: 20 additions & 51 deletions mlir/test/Dialect/Math/expand-math.mlir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least we need all the special cases are tested in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my apology for that tihs PR got verbose, I made appropriate tests and runs well!!

Original file line number Diff line number Diff line change
Expand Up @@ -202,25 +202,15 @@ func.func @roundf_func(%a: f32) -> f32 {

// CHECK-LABEL: func @powf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
func.func @powf_func(%a: f64, %b: f64) ->f64 {
func.func @powf_func(%a: f64, %b: f64) -> f64 {
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
// CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
// CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
// CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
// CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
// CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
// CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
// CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
// CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
// CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
// CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
// CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
// CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
// CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
// CHECK: return [[SEL1]]
// CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
// CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
// CHECK: [[EXP:%.+]] = math.exp [[MULB]]
// CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
// CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
// CHECK: return [[SEL]]
%ret = math.powf %a, %b : f64
return %ret : f64
}
Expand Down Expand Up @@ -602,26 +592,15 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
// CHECK: return %[[SEL1]] : tensor<8xf32>

// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
// CHECK: return %[[SEL]]
// -----

// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
Expand All @@ -630,25 +609,15 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32
// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32
// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
// CHECK: return %[[SEL1]] : f32
// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
// CHECK: return %[[SEL]] : f32

// -----

Expand Down
5 changes: 0 additions & 5 deletions mlir/test/mlir-runner/test-expand-math-approx.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,6 @@ func.func @powf() {
%a_p = arith.constant 2.0 : f64
call @func_powff64(%a, %a_p) : (f64, f64) -> ()

// CHECK-NEXT: -27
%b = arith.constant -3.0 : f64
%b_p = arith.constant 3.0 : f64
call @func_powff64(%b, %b_p) : (f64, f64) -> ()

// CHECK-NEXT: 2.343
%c = arith.constant 2.343 : f64
%c_p = arith.constant 1.000 : f64
Expand Down