diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index d7953719d44b5..23356d752146d 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -325,6 +325,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { auto &sem = cast(getElementTypeOrSelf(typeB)).getFloatSemantics(); APFloat valueB(sem); + auto mulf = [&](Value x, Value y) -> Value { + return b.create(x, y); + }; if (matchPattern(operandB, m_ConstantFloat(&valueB))) { if (valueB.isZero()) { // a^0 -> 1 @@ -358,19 +361,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { } if (valueB.isExactlyValue(2.0)) { // a^2 -> a * a - Value mul = b.create(operandA, operandA); - rewriter.replaceOp(op, mul); + rewriter.replaceOp(op, mulf(operandA, operandA)); return success(); } if (valueB.isExactlyValue(-2.0)) { // a^(-2) -> 1 / (a * a) - Value mul = b.create(operandA, operandA); Value one = createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); - Value div = b.create(one, mul); + Value div = b.create(one, mulf(operandA, operandA)); rewriter.replaceOp(op, div); return success(); } + if (valueB.isExactlyValue(3.0)) { + rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA)); + return success(); + } } Value logA = b.create(operandA); diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index f39d1a7a6dc50..1fdfb854325b4 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -285,6 +285,17 @@ func.func @powf_func_negtwo(%a: f64) -> f64{ return %ret : f64 } +// CHECK-LABEL: func @powf_func_three +// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64 +func.func @powf_func_three(%a: f64) -> f64{ + // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64 + // CHECK: %[[MUL2:.+]] = arith.mulf %[[MUL]], %[[ARG0]] : f64 + // CHECK: return %[[MUL2]] : f64 + %b = arith.constant 3.0 : f64 + %ret = math.powf %a, %b : f64 + return %ret : f64 +} + // ----- // CHECK-LABEL: func.func @roundeven64