Skip to content

Commit 7c0259c

Browse files
authored
Refactor PolynomialApproximationPass into MathTransformPass. (#19922)
The name `PolynomialApproximationPass` was a misnomer since that pass did more than polynomial approximation. It also does other non-approximative rewrites, and casts to f32. This PR renames it and refactors it to explicitly adjust the rewrites to the target. This also reverse-engineers, reimplements and deprecates the `clNativeMathPrecision` flag which had unwitting semantics. Signed-off-by: Benoit Jacob <[email protected]>
1 parent 3fce185 commit 7c0259c

File tree

12 files changed

+234
-116
lines changed

12 files changed

+234
-116
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,14 @@ iree_compiler_cc_library(
131131
"MaterializeEncodingIntoPadding.cpp",
132132
"MaterializeEncodingPatterns.cpp",
133133
"MaterializeTuningSpecsPass.cpp",
134+
"MathTransformPass.cpp",
134135
"MemrefCopyToLinalg.cpp",
135136
"NormalizeLoopBounds.cpp",
136137
"OptimizeTensorInsertExtractSlices.cpp",
137138
"OptimizeVectorTransferPass.cpp",
138139
"PadDynamicAlloc.cpp",
139140
"PassUtils.cpp",
140141
"Passes.cpp",
141-
"PolynomialApproximationPass.cpp",
142142
"PropagateDispatchSizeBounds.cpp",
143143
"PropagateReshapesByExpansion.cpp",
144144
"ReconcileTranslationInfo.cpp",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,14 @@ iree_cc_library(
124124
"MaterializeEncodingIntoPadding.cpp"
125125
"MaterializeEncodingPatterns.cpp"
126126
"MaterializeTuningSpecsPass.cpp"
127+
"MathTransformPass.cpp"
127128
"MemrefCopyToLinalg.cpp"
128129
"NormalizeLoopBounds.cpp"
129130
"OptimizeTensorInsertExtractSlices.cpp"
130131
"OptimizeVectorTransferPass.cpp"
131132
"PadDynamicAlloc.cpp"
132133
"PassUtils.cpp"
133134
"Passes.cpp"
134-
"PolynomialApproximationPass.cpp"
135135
"PropagateDispatchSizeBounds.cpp"
136136
"PropagateReshapesByExpansion.cpp"
137137
"ReconcileTranslationInfo.cpp"
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
// Copyright 2022 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "mlir/Dialect/Math/Transforms/Approximation.h"
9+
#include "mlir/Dialect/Math/Transforms/Passes.h"
10+
#include "mlir/IR/PatternMatch.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
14+
namespace mlir::iree_compiler {
15+
16+
/// Deprecated! This flag had buggy/unintentional semantics.
17+
/// Its original comment said:
18+
/// ""use native hardware operations instead of polynomial approximation".
19+
static llvm::cl::opt<bool> clNativeMathPrecision(
20+
"iree-codegen-gpu-native-math-precision",
21+
llvm::cl::desc("Deprecated! This flag had buggy/unintentional semantics. "
22+
"Its original description said: \"Skip polynomial lowering "
23+
"for math op natively available on GPU.\""),
24+
llvm::cl::init(false));
25+
26+
#define GEN_PASS_DEF_MATHTRANSFORMPASS
27+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
28+
29+
static void populateMathFunctionsRewritePatterns(
30+
RewritePatternSet &patterns,
31+
const std::function<bool(StringRef)> &predicate) {
32+
if (predicate(math::TanOp::getOperationName())) {
33+
populateExpandTanPattern(patterns);
34+
}
35+
if (predicate(math::SinhOp::getOperationName())) {
36+
populateExpandSinhPattern(patterns);
37+
}
38+
if (predicate(math::CoshOp::getOperationName())) {
39+
populateExpandCoshPattern(patterns);
40+
}
41+
if (predicate(math::AsinhOp::getOperationName())) {
42+
populateExpandAsinhPattern(patterns);
43+
}
44+
if (predicate(math::AcoshOp::getOperationName())) {
45+
populateExpandAcoshPattern(patterns);
46+
}
47+
if (predicate(math::AtanhOp::getOperationName())) {
48+
populateExpandAtanhPattern(patterns);
49+
}
50+
if (predicate(math::PowFOp::getOperationName())) {
51+
populateExpandPowFPattern(patterns);
52+
}
53+
if (predicate(math::FPowIOp::getOperationName())) {
54+
populateExpandFPowIPattern(patterns);
55+
}
56+
if (predicate(math::Exp2Op::getOperationName())) {
57+
populateExpandExp2FPattern(patterns);
58+
}
59+
if (predicate(math::RoundEvenOp::getOperationName())) {
60+
populateExpandRoundEvenPattern(patterns);
61+
}
62+
}
63+
64+
static bool predicateRewrite(StringRef name,
65+
IREE::HAL::ExecutableTargetAttr target) {
66+
(void)target; // Currently unused.
67+
if (clNativeMathPrecision) { // Legacy.
68+
if (name == math::Exp2Op::getOperationName() ||
69+
name == math::RoundEvenOp::getOperationName()) {
70+
return false;
71+
}
72+
}
73+
// Currently enable all non-approximative rewrites.
74+
return true;
75+
}
76+
77+
static bool predicateF32Cast(StringRef name,
78+
IREE::HAL::ExecutableTargetAttr target) {
79+
(void)target; // Currently unused.
80+
if (clNativeMathPrecision) { // Legacy.
81+
return false;
82+
}
83+
StringRef atan = math::AtanOp::getOperationName();
84+
StringRef atan2 = math::Atan2Op::getOperationName();
85+
StringRef cos = math::CosOp::getOperationName();
86+
StringRef sin = math::SinOp::getOperationName();
87+
StringRef tanh = math::TanhOp::getOperationName();
88+
StringRef log = math::LogOp::getOperationName();
89+
StringRef log2 = math::Log2Op::getOperationName();
90+
StringRef log1p = math::Log1pOp::getOperationName();
91+
StringRef exp = math::ExpOp::getOperationName();
92+
StringRef expm1 = math::ExpM1Op::getOperationName();
93+
StringRef cbrt = math::CbrtOp::getOperationName();
94+
StringRef erf = math::ErfOp::getOperationName();
95+
return llvm::is_contained(
96+
{atan, atan2, tanh, log, log2, log1p, erf, exp, expm1, cbrt, sin, cos},
97+
name);
98+
}
99+
100+
static bool predicateApprox(StringRef name,
101+
IREE::HAL::ExecutableTargetAttr target) {
102+
(void)target; // Currently unused.
103+
if (clNativeMathPrecision) { // Legacy.
104+
if (name == math::ErfOp::getOperationName()) {
105+
// The legacy implementation had a bug: it always applied polynomial
106+
// approximation of math.erf, even when clNativeMathPrecision was passed.
107+
// We actually have CI tests that rely on that bug: they pass
108+
// clNativeMathPrecision but fail unless math.erf is approximated.
109+
return true;
110+
}
111+
return false;
112+
}
113+
StringRef acos = math::AcosOp::getOperationName();
114+
StringRef asin = math::AsinOp::getOperationName();
115+
StringRef atan = math::AtanOp::getOperationName();
116+
StringRef atan2 = math::Atan2Op::getOperationName();
117+
StringRef cos = math::CosOp::getOperationName();
118+
StringRef sin = math::SinOp::getOperationName();
119+
StringRef tanh = math::TanhOp::getOperationName();
120+
StringRef log = math::LogOp::getOperationName();
121+
StringRef log2 = math::Log2Op::getOperationName();
122+
StringRef log1p = math::Log1pOp::getOperationName();
123+
StringRef exp = math::ExpOp::getOperationName();
124+
StringRef expm1 = math::ExpM1Op::getOperationName();
125+
StringRef cbrt = math::CbrtOp::getOperationName();
126+
StringRef erf = math::ErfOp::getOperationName();
127+
return llvm::is_contained({atan, atan2, tanh, log, log2, log1p, erf, asin,
128+
acos, exp, expm1, cbrt, sin, cos},
129+
name);
130+
}
131+
132+
namespace {
133+
134+
class MathTransformPass final
135+
: public impl::MathTransformPassBase<MathTransformPass> {
136+
public:
137+
using Base::Base;
138+
139+
void runOnOperation() override {
140+
RewritePatternSet patterns(&getContext());
141+
auto target = IREE::HAL::ExecutableTargetAttr::lookup(getOperation());
142+
if (!target) {
143+
return signalPassFailure();
144+
}
145+
populateMathFunctionsRewritePatterns(patterns, [target](StringRef name) {
146+
return predicateRewrite(name, target);
147+
});
148+
149+
populateMathF32ExpansionPatterns(patterns, [target](StringRef name) {
150+
return predicateF32Cast(name, target);
151+
});
152+
153+
populateMathPolynomialApproximationPatterns(
154+
patterns,
155+
[target](StringRef name) { return predicateApprox(name, target); });
156+
157+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
158+
return signalPassFailure();
159+
}
160+
}
161+
};
162+
163+
} // namespace
164+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -537,15 +537,9 @@ def PadDynamicAllocPass :
537537
let summary = "Pass to pad dynamic alloc into static one.";
538538
}
539539

540-
def PolynomialApproximationPass :
541-
Pass<"iree-codegen-polynomial-approximation", ""> {
542-
let summary = "Convert math operations to their polynomial approximation";
543-
let options = [
544-
ListOption<"noApproxOps", "no-approx-ops", "std::string",
545-
[{List of operations that should not be approximated.\n"
546-
"As of now, possible options are:\n"
547-
"\ttan, sinh, cosh, asinh, acosh, atanh, powf, fpowf, erf\n}]>,
548-
];
540+
def MathTransformPass :
541+
Pass<"iree-codegen-math-transform", ""> {
542+
let summary = "Apply math ops transformations: approximations, rewrites to other math ops, operand casts.";
549543
}
550544

551545
def PropagateDispatchSizeBoundsPass :

compiler/src/iree/compiler/Codegen/Common/PolynomialApproximationPass.cpp

Lines changed: 0 additions & 78 deletions
This file was deleted.

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ iree_lit_test_suite(
6767
"materialize_tuning_specs_invalid_spec.mlir",
6868
"materialize_user_config_from_tuning_spec.mlir",
6969
"materialize_user_configs.mlir",
70+
"math_transform.mlir",
7071
"normalize_loop_bounds.mlir",
7172
"optimize_tensor_insert_extract_slices.mlir",
7273
"pad_dynamic_alloc.mlir",
73-
"polynomial_approximation.mlir",
7474
"propagate_dispatch_size_bounds.mlir",
7575
"propagate_reshapes_by_expansion.mlir",
7676
"reconcile_translation_info.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ iree_lit_test_suite(
6363
"materialize_tuning_specs_invalid_spec.mlir"
6464
"materialize_user_config_from_tuning_spec.mlir"
6565
"materialize_user_configs.mlir"
66+
"math_transform.mlir"
6667
"normalize_loop_bounds.mlir"
6768
"optimize_tensor_insert_extract_slices.mlir"
6869
"pad_dynamic_alloc.mlir"
69-
"polynomial_approximation.mlir"
7070
"propagate_dispatch_size_bounds.mlir"
7171
"propagate_reshapes_by_expansion.mlir"
7272
"reconcile_translation_info.mlir"
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-codegen-math-transform))' --split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: @rewrite_tan
4+
func.func @rewrite_tan(%arg0: f16) -> f16 attributes {
5+
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>
6+
} {
7+
// Tan should be directly approximated by a rational function. It's also possible
8+
// (though not good) that it gets rewritten as sin/cos and those get approximated by
9+
// rational functions. Either way, we expect to see rational arithmetic here, on f32
10+
// as the operands get casted to f32.
11+
// CHECK-NOT: math.tan
12+
// CHECK-NOT: math.sin
13+
// CHECK-NOT: math.cos
14+
// CHECK: math.fma {{.*}} : f32
15+
// Final division after cast to f16.
16+
// CHECK: arith.divf {{.*}} : f16
17+
%0 = math.tan %arg0 : f16
18+
return %0 : f16
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: @rewrite_pow
24+
func.func @rewrite_pow(%arg0: f16, %arg1: f16) -> f16 attributes {
25+
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>
26+
} {
27+
28+
// Powf should be either directly approximated, or first rewritten into log and
29+
// exp and then those get approximated. Some targets with fast exponentials might
30+
// prefer to keep the exponential form, but this is not the case with the current
31+
// lowering for CPU, so we expect to see rational arithmetic here, on f32 as the
32+
// operands get casted to f32.
33+
// CHECK-NOT: math.powf
34+
// CHECK-NOT: math.exp
35+
// CHECK-NOT: math.log
36+
// CHECK: math.fma {{.*}} : f32
37+
%0 = math.powf %arg0, %arg1 : f16
38+
return %0 : f16
39+
}
40+
41+
// -----
42+
43+
// CHECK-LABEL: @rewrite_erf
44+
func.func @rewrite_erf(%arg0: f16) -> f16 attributes {
45+
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>
46+
} {
47+
// Erf should be directly approximated by a rational function. Some targets
48+
// with fast exponentials might prefer an exponential approximation, but this
49+
// is not the case with the current lowering for CPU, so we expect to see rational
50+
// arithmetic here, on f32 as the operands get casted to f32.
51+
// CHECK-NOT: math.erf
52+
// CHECK-NOT: math.exp
53+
// CHECK-NOT: math.log
54+
// CHECK: math.fma {{.*}} : f32
55+
%0 = math.erf %arg0 : f16
56+
return %0 : f16
57+
}

compiler/src/iree/compiler/Codegen/Common/test/polynomial_approximation.mlir

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)