Skip to content

Commit db5b69a

Browse files
authored
[AMDGPU] Do not rewrite or approximate math functions on ROCm (#20222)
This is the re-landing of #19970 which was rolled back due to a ONNX test failure which we want to accept as its root cause is a torch-mlir bug: llvm/torch-mlir#4091 This reverts commit 00e8873. Signed-off-by: Benoit Jacob <[email protected]>
1 parent 9d693cb commit db5b69a

File tree

5 files changed

+172
-152
lines changed

5 files changed

+172
-152
lines changed

compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,16 @@ static void populateMathFunctionsRewritePatterns(
6363

6464
static bool predicateRewrite(StringRef name,
6565
IREE::HAL::ExecutableTargetAttr target) {
66-
(void)target; // Currently unused.
6766
if (clNativeMathPrecision) { // Legacy.
6867
if (name == math::Exp2Op::getOperationName() ||
6968
name == math::RoundEvenOp::getOperationName()) {
7069
return false;
7170
}
7271
}
72+
if (isROCMBackend(target)) {
73+
// On ROCm, we want to use device library functions.
74+
return false;
75+
}
7376
// Currently enable all non-approximative rewrites.
7477
return true;
7578
}
@@ -109,6 +112,10 @@ static bool predicateApprox(StringRef name,
109112
}
110113
return false;
111114
}
115+
if (isROCMBackend(target)) {
116+
// On ROCm, we want to use device library functions.
117+
return false;
118+
}
112119
StringRef acos = math::AcosOp::getOperationName();
113120
StringRef asin = math::AsinOp::getOperationName();
114121
StringRef atan = math::AtanOp::getOperationName();
@@ -123,9 +130,6 @@ static bool predicateApprox(StringRef name,
123130
StringRef expm1 = math::ExpM1Op::getOperationName();
124131
StringRef cbrt = math::CbrtOp::getOperationName();
125132
StringRef erf = math::ErfOp::getOperationName();
126-
if (isROCMBackend(target) && name == erf) {
127-
return false;
128-
}
129133
return llvm::is_contained({atan, atan2, tanh, log, log2, log1p, erf, asin,
130134
acos, exp, expm1, cbrt, sin, cos},
131135
name);

compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,36 @@ func.func @rewrite_erf(%arg0: f16) -> f16 attributes {
5858

5959
// -----
6060

61-
// CHECK-LABEL: @no_approx_erf_on_rocm
62-
func.func @no_approx_erf_on_rocm(%arg0: f16) -> f16 attributes {
61+
// CHECK-LABEL: @no_approx_on_rocm
62+
func.func @no_approx_on_rocm(%arg0: f16) -> f16 attributes {
6363
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {}>
6464
} {
65-
// On ROCm, we want to use the native device library function, so math.erf
66-
// should not get rewritten. It's OK for f16 to still get casted to f32, as
67-
// the device library function for f16 is casting to f32 anyway.
65+
// On ROCm, we want to use the native device library functions.
66+
// It's OK for f16 to still get casted to f32, as
67+
// the device library functions for f16 are casting to f32 anyway.
68+
// CHECK: math.acos
69+
// CHECK: math.atan
70+
// CHECK: math.sin
71+
// CHECK: math.tanh
72+
// CHECK: math.log
73+
// CHECK: math.log2
74+
// CHECK: math.log1p
75+
// CHECK: math.exp
76+
// CHECK: math.exp2
77+
// CHECK: math.expm1
78+
// CHECK: math.cbrt
6879
// CHECK: math.erf
69-
// CHECK-NOT: math.exp
70-
// CHECK-NOT: math.log
71-
// CHECK-NOT: math.fma
72-
%0 = math.erf %arg0 : f16
73-
return %0 : f16
80+
%0 = math.acos %arg0 : f16
81+
%1 = math.atan %0 : f16
82+
%2 = math.sin %1 : f16
83+
%3 = math.tanh %2 : f16
84+
%4 = math.log %3 : f16
85+
%5 = math.log2 %4 : f16
86+
%6 = math.log1p %5 : f16
87+
%7 = math.exp %6 : f16
88+
%8 = math.exp2 %7 : f16
89+
%9 = math.expm1 %8 : f16
90+
%10 = math.cbrt %9 : f16
91+
%11 = math.erf %10 : f16
92+
return %11 : f16
7493
}

tests/e2e/math/math_ops_llvm-cpu.json

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
{
33
"op": "acos",
44
"type": "f32",
5-
"atol": 1.0e-06,
6-
"rtol": 1.0e-06
5+
"atol": 1.2e-07,
6+
"rtol": 1.2e-07
77
},
88
{
99
"op": "acos",
@@ -14,8 +14,8 @@
1414
{
1515
"op": "acosh",
1616
"type": "f32",
17-
"atol": 1.0e-06,
18-
"rtol": 1.0e-06
17+
"atol": 1.2e-07,
18+
"rtol": 1.2e-07
1919
},
2020
{
2121
"op": "acosh",
@@ -26,8 +26,8 @@
2626
{
2727
"op": "asin",
2828
"type": "f32",
29-
"atol": 1.0e-06,
30-
"rtol": 1.0e-06
29+
"atol": 1.2e-07,
30+
"rtol": 1.2e-07
3131
},
3232
{
3333
"op": "asin",
@@ -50,20 +50,20 @@
5050
{
5151
"op": "atan",
5252
"type": "f32",
53-
"atol": 1.0e-06,
54-
"rtol": 1.0e-06
53+
"atol": 1.2e-07,
54+
"rtol": 1.2e-07
5555
},
5656
{
5757
"op": "atan",
5858
"type": "f16",
59-
"atol": 1.0e-03,
60-
"rtol": 1.0e-03
59+
"atol": 1.0e-04,
60+
"rtol": 1.0e-04
6161
},
6262
{
6363
"op": "atanh",
6464
"type": "f32",
65-
"atol": 1.0e-06,
66-
"rtol": 1.0e-06
65+
"atol": 1.2e-07,
66+
"rtol": 1.2e-07
6767
},
6868
{
6969
"op": "atanh",
@@ -80,8 +80,8 @@
8080
{
8181
"op": "cbrt",
8282
"type": "f16",
83-
"atol": 1.0e-03,
84-
"rtol": 1.0e-03
83+
"atol": 1.0e-04,
84+
"rtol": 1.0e-04
8585
},
8686
{
8787
"op": "ceil",
@@ -104,8 +104,8 @@
104104
{
105105
"op": "cos",
106106
"type": "f16",
107-
"atol": 1.0e-03,
108-
"rtol": 1.0e-03
107+
"atol": 1.0e-04,
108+
"rtol": 1.0e-04
109109
},
110110
{
111111
"op": "cosh",
@@ -125,14 +125,14 @@
125125
{
126126
"op": "erf",
127127
"type": "f32",
128-
"atol": 1.0e-06,
129-
"rtol": 1.0e-06
128+
"atol": 1.2e-07,
129+
"rtol": 1.2e-07
130130
},
131131
{
132132
"op": "erf",
133133
"type": "f16",
134-
"atol": 1.0e-03,
135-
"rtol": 1.0e-03
134+
"atol": 1.0e-04,
135+
"rtol": 1.0e-04
136136
},
137137
{
138138
"op": "exp",
@@ -143,8 +143,8 @@
143143
{
144144
"op": "exp",
145145
"type": "f16",
146-
"atol": 1.0e-03,
147-
"rtol": 1.0e-03
146+
"atol": 1.0e-04,
147+
"rtol": 1.0e-04
148148
},
149149
{
150150
"op": "exp2",
@@ -168,8 +168,8 @@
168168
{
169169
"op": "expm1",
170170
"type": "f16",
171-
"atol": 1.0e-03,
172-
"rtol": 1.0e-03
171+
"atol": 1.0e-04,
172+
"rtol": 1.0e-04
173173
},
174174
{
175175
"op": "floor",
@@ -186,38 +186,38 @@
186186
{
187187
"op": "log",
188188
"type": "f32",
189-
"atol": 1.0e-03,
190-
"rtol": 1.0e-03
189+
"atol": 1.0e-04,
190+
"rtol": 1.0e-04
191191
},
192192
{
193193
"op": "log",
194194
"type": "f16",
195-
"atol": 1.0e-03,
196-
"rtol": 1.0e-03
195+
"atol": 1.0e-04,
196+
"rtol": 1.0e-04
197197
},
198198
{
199199
"op": "log1p",
200200
"type": "f32",
201-
"atol": 1.0e-06,
202-
"rtol": 1.0e-06
201+
"atol": 1.2e-07,
202+
"rtol": 1.2e-07
203203
},
204204
{
205205
"op": "log1p",
206206
"type": "f16",
207-
"atol": 1.0e-03,
208-
"rtol": 1.0e-03
207+
"atol": 1.0e-04,
208+
"rtol": 1.0e-04
209209
},
210210
{
211211
"op": "log2",
212212
"type": "f32",
213-
"atol": 1.0e-06,
214-
"rtol": 1.0e-06
213+
"atol": 1.2e-07,
214+
"rtol": 1.2e-07
215215
},
216216
{
217217
"op": "log2",
218218
"type": "f16",
219-
"atol": 1.0e-03,
220-
"rtol": 1.0e-03
219+
"atol": 1.0e-04,
220+
"rtol": 1.0e-04
221221
},
222222
{
223223
"op": "round",
@@ -246,8 +246,8 @@
246246
{
247247
"op": "rsqrt",
248248
"type": "f32",
249-
"atol": 1.0e-06,
250-
"rtol": 1.0e-06
249+
"atol": 1.2e-07,
250+
"rtol": 1.2e-07
251251
},
252252
{
253253
"op": "rsqrt",
@@ -264,8 +264,8 @@
264264
{
265265
"op": "sin",
266266
"type": "f16",
267-
"atol": 1.0e-03,
268-
"rtol": 1.0e-03
267+
"atol": 1.0e-04,
268+
"rtol": 1.0e-04
269269
},
270270
{
271271
"op": "sinh",
@@ -285,14 +285,14 @@
285285
{
286286
"op": "sqrt",
287287
"type": "f32",
288-
"atol": 1.0e-06,
289-
"rtol": 1.0e-06
288+
"atol": 1.2e-07,
289+
"rtol": 1.2e-07
290290
},
291291
{
292292
"op": "sqrt",
293293
"type": "f16",
294-
"atol": 1.0e-03,
295-
"rtol": 1.0e-03
294+
"atol": 1.0e-04,
295+
"rtol": 1.0e-04
296296
},
297297
{
298298
"op": "tan",
@@ -315,8 +315,8 @@
315315
{
316316
"op": "tanh",
317317
"type": "f16",
318-
"atol": 1.0e-03,
319-
"rtol": 1.0e-03
318+
"atol": 1.0e-04,
319+
"rtol": 1.0e-04
320320
},
321321
{
322322
"op": "atan2",
@@ -327,8 +327,8 @@
327327
{
328328
"op": "atan2",
329329
"type": "f16",
330-
"atol": 1.0e-03,
331-
"rtol": 1.0e-03
330+
"atol": 1.0e-04,
331+
"rtol": 1.0e-04
332332
},
333333
{
334334
"op": "powf",

0 commit comments

Comments
 (0)