Skip to content

Commit c328c30

Browse files
authored
Reconcile math.rsqrt expansion with upstream (#222)
1 parent f8d597a commit c328c30

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,8 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
485485

486486
Location loc = op->getLoc();
487487
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
488-
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0));
489-
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, operandTy,
490-
ValueRange{constOneFloat, sqrtOp});
488+
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, operand);
489+
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
491490
return success();
492491
}
493492

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,20 +515,48 @@ func.func @roundeven16(%arg: f16) -> f16 {
515515

516516
// -----
517517

518+
// CHECK-LABEL: func.func @rsqrt
519+
// CHECK-SAME: (%[[ARG:.*]]: f16)
520+
// CHECK-SAME: -> f16
521+
// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f16
522+
// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : f16
523+
// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f16
524+
// CHECK: return %[[DIV]] : f16
525+
func.func @rsqrt16(%float: f16) -> (f16) {
526+
%float_result = math.rsqrt %float : f16
527+
return %float_result : f16
528+
}
529+
530+
// -----
531+
518532
// CHECK-LABEL: func.func @rsqrt
519533
// CHECK-SAME: (%[[ARG:.*]]: f32)
520534
// CHECK-SAME: -> f32
521535
// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
522536
// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : f32
523537
// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f32
524538
// CHECK: return %[[DIV]] : f32
525-
func.func @rsqrt(%float: f32) -> (f32) {
539+
func.func @rsqrt32(%float: f32) -> (f32) {
526540
%float_result = math.rsqrt %float : f32
527541
return %float_result : f32
528542
}
529543

530544
// -----
531545

546+
// CHECK-LABEL: func.func @rsqrt
547+
// CHECK-SAME: (%[[ARG:.*]]: f64)
548+
// CHECK-SAME: -> f64
549+
// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f64
550+
// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : f64
551+
// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f64
552+
// CHECK: return %[[DIV]] : f64
553+
func.func @rsqrt64(%float: f64) -> (f64) {
554+
%float_result = math.rsqrt %float : f64
555+
return %float_result : f64
556+
}
557+
558+
// -----
559+
532560
// CHECK-LABEL: func.func @rsqrt_vec
533561
// CHECK-SAME: (%[[ARG:.*]]: vector<5xf32>)
534562
// CHECK-SAME: -> vector<5xf32>

0 commit comments

Comments
 (0)