Skip to content

Commit 7891b80

Browse files
authored
Do not approximate erf on rocm. (#19969)
On ROCm, we want to use the device library functions, which we link as bitcode and inline. In this PR, we start with `math.erf` because that's the immediate use case, but this will likely be generalized to other functions in a subsequent PR. Signed-off-by: Benoit Jacob <[email protected]>
1 parent c8ba691 commit 7891b80

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ static bool predicateF32Cast(StringRef name,
9999

100100
static bool predicateApprox(StringRef name,
101101
IREE::HAL::ExecutableTargetAttr target) {
102-
(void)target; // Currently unused.
103102
if (clNativeMathPrecision) { // Legacy.
104103
if (name == math::ErfOp::getOperationName()) {
105104
// The legacy implementation had a bug: it always applied polynomial
@@ -124,6 +123,9 @@ static bool predicateApprox(StringRef name,
124123
StringRef expm1 = math::ExpM1Op::getOperationName();
125124
StringRef cbrt = math::CbrtOp::getOperationName();
126125
StringRef erf = math::ErfOp::getOperationName();
126+
if (isROCMBackend(target) && name == erf) {
127+
return false;
128+
}
127129
return llvm::is_contained({atan, atan2, tanh, log, log2, log1p, erf, asin,
128130
acos, exp, expm1, cbrt, sin, cos},
129131
name);

compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,20 @@ func.func @rewrite_erf(%arg0: f16) -> f16 attributes {
5555
%0 = math.erf %arg0 : f16
5656
return %0 : f16
5757
}
58+
59+
// -----
60+
61+
// CHECK-LABEL: @no_approx_erf_on_rocm
62+
func.func @no_approx_erf_on_rocm(%arg0: f16) -> f16 attributes {
63+
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {}>
64+
} {
65+
// On ROCm, we want to use the native device library function, so math.erf
66+
// should not get rewritten. It's OK for f16 to still get casted to f32, as
67+
// the device library function for f16 is casting to f32 anyway.
68+
// CHECK: math.erf
69+
// CHECK-NOT: math.exp
70+
// CHECK-NOT: math.log
71+
// CHECK-NOT: math.fma
72+
%0 = math.erf %arg0 : f16
73+
return %0 : f16
74+
}

0 commit comments

Comments
 (0)