Skip to content

Commit 57e1943

Browse files
committed
[mlir] Add support for non-f32 polynomial approximation
Polynomial approximations assume F32 values. We can convert all non-f32 cases to operate on f32s with intermediate casts. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D146677
1 parent c2a1ab3 commit 57e1943

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
331331
SmallVector<Value> operands;
332332
for (auto operand : op->getOperands())
333333
operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
334-
auto result = rewriter.create<math::Atan2Op>(loc, newType, operands);
334+
auto result =
335+
rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
335336
rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
336337
return success();
337338
}
@@ -1381,13 +1382,24 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
13811382
void mlir::populateMathPolynomialApproximationPatterns(
13821383
RewritePatternSet &patterns,
13831384
const MathPolynomialApproximationOptions &options) {
1385+
// Patterns for leveraging existing f32 lowerings on other data types.
1386+
patterns
1387+
.add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1388+
ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1389+
ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1390+
ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1391+
ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1392+
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1393+
patterns.getContext());
1394+
13841395
patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
13851396
LogApproximation, Log2Approximation, Log1pApproximation,
13861397
ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1387-
CbrtApproximation, ReuseF32Expansion<math::Atan2Op>,
1388-
SinAndCosApproximation<true, math::SinOp>,
1398+
CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
13891399
SinAndCosApproximation<false, math::CosOp>>(
13901400
patterns.getContext());
1391-
if (options.enableAvx2)
1392-
patterns.add<RsqrtApproximation>(patterns.getContext());
1401+
if (options.enableAvx2) {
1402+
patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1403+
patterns.getContext());
1404+
}
13931405
}

mlir/test/Dialect/Math/polynomial-approximation.mlir

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,4 +642,55 @@ func.func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 {
642642
func.func @cbrt_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
643643
%0 = "math.cbrt"(%arg0) : (vector<4xf32>) -> vector<4xf32>
644644
func.return %0 : vector<4xf32>
645-
}
645+
}
646+
647+
648+
// CHECK-LABEL: @math_f16
649+
func.func @math_f16(%arg0 : vector<4xf16>) -> vector<4xf16> {
650+
651+
// CHECK-NOT: math.atan
652+
%0 = "math.atan"(%arg0) : (vector<4xf16>) -> vector<4xf16>
653+
654+
// CHECK-NOT: math.atan2
655+
%1 = "math.atan2"(%0, %arg0) : (vector<4xf16>, vector<4xf16>) -> vector<4xf16>
656+
657+
// CHECK-NOT: math.tanh
658+
%2 = "math.tanh"(%1) : (vector<4xf16>) -> vector<4xf16>
659+
660+
// CHECK-NOT: math.log
661+
%3 = "math.log"(%2) : (vector<4xf16>) -> vector<4xf16>
662+
663+
// CHECK-NOT: math.log2
664+
%4 = "math.log2"(%3) : (vector<4xf16>) -> vector<4xf16>
665+
666+
// CHECK-NOT: math.log1p
667+
%5 = "math.log1p"(%4) : (vector<4xf16>) -> vector<4xf16>
668+
669+
// CHECK-NOT: math.erf
670+
%6 = "math.erf"(%5) : (vector<4xf16>) -> vector<4xf16>
671+
672+
// CHECK-NOT: math.exp
673+
%7 = "math.exp"(%6) : (vector<4xf16>) -> vector<4xf16>
674+
675+
// CHECK-NOT: math.expm1
676+
%8 = "math.expm1"(%7) : (vector<4xf16>) -> vector<4xf16>
677+
678+
// CHECK-NOT: math.cbrt
679+
%9 = "math.cbrt"(%8) : (vector<4xf16>) -> vector<4xf16>
680+
681+
// CHECK-NOT: math.sin
682+
%10 = "math.sin"(%9) : (vector<4xf16>) -> vector<4xf16>
683+
684+
// CHECK-NOT: math.cos
685+
%11 = "math.cos"(%10) : (vector<4xf16>) -> vector<4xf16>
686+
687+
return %11 : vector<4xf16>
688+
}
689+
690+
691+
// AVX2-LABEL: @rsqrt_f16
692+
func.func @rsqrt_f16(%arg0 : vector<2x8xf16>) -> vector<2x8xf16> {
693+
// AVX2-NOT: math.rsqrt
694+
%0 = "math.rsqrt"(%arg0) : (vector<2x8xf16>) -> vector<2x8xf16>
695+
return %0 : vector<2x8xf16>
696+
}

0 commit comments

Comments
 (0)