Skip to content

Commit 569b02b

Browse files
authored
[Wave] Cleanup enable_scheduling_barriers, dump_generated_mlir copypaste from the test files (#46)
Set everything in one place. --------- Signed-off-by: Ivan Butygin <[email protected]>
1 parent 51bf969 commit 569b02b

20 files changed

+94
-377
lines changed

tests/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,21 @@ def perf_filename_iree(dump_perf_path, request):
5353
return os.path.join(dump_perf_path, "iree_" + request.node.name + ".json")
5454

5555

56+
@pytest.fixture(scope="function", autouse=True)
57+
def set_mlir_filename(request):
58+
option = request.config.getoption("--dump-mlir-files-path")
59+
if not option:
60+
return
61+
62+
import iree.turbine.kernel.wave.utils.run_utils as run_utils
63+
64+
run_utils.dump_generated_mlir = True
65+
run_utils.dump_generated_mlir_file = os.path.join(
66+
option,
67+
"mlir_" + request.node.name + ".mlir",
68+
)
69+
70+
5671
def pytest_addoption(parser):
5772
parser.addoption(
5873
"--run-e2e", action="store_true", default=False, help="run e2e tests"
@@ -78,6 +93,12 @@ def pytest_addoption(parser):
7893
default=0,
7994
help="Distribute over N gpu devices when running with pytest-xdist",
8095
)
96+
parser.addoption(
97+
"--dump-mlir-files-path",
98+
action="store",
99+
default=None,
100+
help="save mlir files into provided directory, filename based on current test name",
101+
)
81102

82103

83104
def pytest_configure(config):

tests/kernel/wave/attention/alibi_attention_test.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
set_default_run_config,
1717
)
1818
from wave_lang.kernel.wave.utils.torch_utils import (
19-
device_arange,
20-
device_full,
2119
device_randn,
2220
device_zeros,
2321
to_default_device,
@@ -30,14 +28,11 @@
3028
AttentionShape,
3129
)
3230
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
33-
import os
3431
from torch.testing import assert_close
3532
from ..common.utils import (
3633
require_e2e,
37-
enable_scheduling_barriers,
3834
)
39-
from ..common.shapes import get_test_shapes
40-
from typing import List, Optional, Tuple
35+
from typing import Optional, Tuple
4136

4237
shapes = [(128, 128, 128, 128, 128, 128)]
4338

@@ -144,7 +139,6 @@ def test_alibi_attention(
144139
subs=hyperparams,
145140
canonicalize=True,
146141
run_bench=run_bench,
147-
use_scheduling_barriers=enable_scheduling_barriers,
148142
benchmark_batch_size=10,
149143
benchmark_repetitions=3,
150144
benchmark_results_file=perf_filename_tk,

tests/kernel/wave/attention/backward_attention_test.py

Lines changed: 5 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
from torch.nn import functional as F
1010
import math
11-
import wave_lang.kernel as tk
1211
import wave_lang.kernel.lang as tkl
1312
import wave_lang.kernel.wave as tkw
1413
from wave_lang.kernel.lang.global_symbols import *
@@ -29,8 +28,6 @@
2928
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
3029
from wave_lang.kernel.wave.constraints import MMAType
3130
from ..common.utils import (
32-
dump_generated_mlir,
33-
enable_scheduling_barriers,
3431
expensive_test,
3532
require_e2e,
3633
)
@@ -1142,7 +1139,6 @@ def testAttentionForward(mfma_variant: MMAType, shape: tuple[int, ...]):
11421139
hyperparams.update(get_default_scheduling_params())
11431140
options = WaveCompileOptions(
11441141
subs=hyperparams,
1145-
use_scheduling_barriers=enable_scheduling_barriers,
11461142
run_bench=False,
11471143
waves_per_eu=2,
11481144
denorm_fp_math_f32="preserve-sign",
@@ -1154,13 +1150,7 @@ def testAttentionForward(mfma_variant: MMAType, shape: tuple[int, ...]):
11541150
lse = device_zeros(batch, q_seq_len, dtype=torch.float16)
11551151
s = device_zeros(batch, q_seq_len, kv_seq_len)
11561152

1157-
asm_fwd = attention_fwd(q, k, v.transpose(-1, -2), s, o, lse)
1158-
1159-
if dump_generated_mlir:
1160-
filename = f"out/wave_attention_fwd_{'x'.join(map(str, shape))}.mlir"
1161-
with open(filename, "w") as f:
1162-
f.write(asm_fwd)
1163-
print(f"IR dumped to {filename}")
1153+
attention_fwd(q, k, v.transpose(-1, -2), s, o, lse)
11641154

11651155
assert_close(s, s_ref, **cmp_params)
11661156
# Can't check P, since we don't actually compute the "real" thing in the
@@ -1209,7 +1199,6 @@ def testAttentionBackward(mfma_variant: MMAType, shape: tuple[int, ...]):
12091199
hyperparams.update(get_default_scheduling_params())
12101200
options = WaveCompileOptions(
12111201
subs=hyperparams,
1212-
use_scheduling_barriers=enable_scheduling_barriers,
12131202
run_bench=False,
12141203
waves_per_eu=2,
12151204
denorm_fp_math_f32="preserve-sign",
@@ -1229,7 +1218,7 @@ def testAttentionBackward(mfma_variant: MMAType, shape: tuple[int, ...]):
12291218
dp = device_zeros(batch, q_seq_len, kv_seq_len, dtype=torch.float32)
12301219
dp_sub = device_zeros(batch, q_seq_len, kv_seq_len, dtype=torch.float16)
12311220

1232-
asm_bwd = attention_bwd(
1221+
attention_bwd(
12331222
q,
12341223
k,
12351224
v,
@@ -1247,12 +1236,6 @@ def testAttentionBackward(mfma_variant: MMAType, shape: tuple[int, ...]):
12471236
dp_sub,
12481237
)
12491238

1250-
if dump_generated_mlir:
1251-
filename = f"out/wave_attention_bwd_{'x'.join(map(str, shape))}.mlir"
1252-
with open(filename, "w") as f:
1253-
f.write(asm_bwd)
1254-
print(f"IR dumped to {filename}")
1255-
12561239
assert_close(s, s_ref, **cmp_params)
12571240
assert_close(p, p_ref, **cmp_params)
12581241

@@ -1305,7 +1288,6 @@ def testAttentionBackward_dv(mfma_variant: MMAType, shape: tuple[int, ...]):
13051288
hyperparams_dv.update(get_default_scheduling_params())
13061289
options = WaveCompileOptions(
13071290
subs=hyperparams_dv,
1308-
use_scheduling_barriers=enable_scheduling_barriers,
13091291
run_bench=False,
13101292
waves_per_eu=2,
13111293
denorm_fp_math_f32="preserve-sign",
@@ -1317,13 +1299,7 @@ def testAttentionBackward_dv(mfma_variant: MMAType, shape: tuple[int, ...]):
13171299
s = device_zeros(batch, q_seq_len, kv_seq_len, dtype=torch.float32)
13181300
p = device_zeros(batch, q_seq_len, kv_seq_len, dtype=torch.float16)
13191301

1320-
asm_bwd_dv = attention_bwd_dv(q, k, do, lse_ref, dv, s, p)
1321-
1322-
if dump_generated_mlir:
1323-
filename = f"out/wave_attention_bwd_dv_{'x'.join(map(str, shape))}.mlir"
1324-
with open(filename, "w") as f:
1325-
f.write(asm_bwd_dv)
1326-
print(f"IR dumped to {filename}")
1302+
attention_bwd_dv(q, k, do, lse_ref, dv, s, p)
13271303

13281304
assert_close(s, s_ref, **cmp_params)
13291305
assert_close(p, p_ref, **cmp_params)
@@ -1367,7 +1343,6 @@ def testAttentionBackward_dk(mfma_variant: MMAType, shape: tuple[int, ...]):
13671343
hyperparams_dk.update(get_default_scheduling_params())
13681344
options = WaveCompileOptions(
13691345
subs=hyperparams_dk,
1370-
use_scheduling_barriers=enable_scheduling_barriers,
13711346
run_bench=False,
13721347
waves_per_eu=2,
13731348
denorm_fp_math_f32="preserve-sign",
@@ -1383,7 +1358,7 @@ def testAttentionBackward_dk(mfma_variant: MMAType, shape: tuple[int, ...]):
13831358
dp = torch.zeros_like(s)
13841359
dp_sub = torch.zeros_like(p)
13851360

1386-
asm_bwd_dk = attention_bwd_dk(
1361+
attention_bwd_dk(
13871362
q,
13881363
k,
13891364
v,
@@ -1398,12 +1373,6 @@ def testAttentionBackward_dk(mfma_variant: MMAType, shape: tuple[int, ...]):
13981373
dp_sub,
13991374
)
14001375

1401-
if dump_generated_mlir:
1402-
filename = f"out/wave_attention_bwd_dk_{'x'.join(map(str, shape))}.mlir"
1403-
with open(filename, "w") as f:
1404-
f.write(asm_bwd_dk)
1405-
print(f"IR dumped to {filename}")
1406-
14071376
dp_sub_ref = (dp_ref - D.reshape((batch, q_seq_len, 1))).to(torch.float16)
14081377

14091378
assert_close(s, s_ref, **cmp_params)
@@ -1452,7 +1421,6 @@ def testAttentionBackward_dq(mfma_variant: MMAType, shape: tuple[int, ...]):
14521421
hyperparams_dq.update(get_default_scheduling_params())
14531422
options = WaveCompileOptions(
14541423
subs=hyperparams_dq,
1455-
use_scheduling_barriers=enable_scheduling_barriers,
14561424
run_bench=False,
14571425
waves_per_eu=2,
14581426
denorm_fp_math_f32="preserve-sign",
@@ -1469,7 +1437,7 @@ def testAttentionBackward_dq(mfma_variant: MMAType, shape: tuple[int, ...]):
14691437
dp = torch.zeros_like(s)
14701438
dp_sub = torch.zeros_like(p)
14711439

1472-
asm_bwd_dq = attention_bwd_dq(
1440+
attention_bwd_dq(
14731441
q,
14741442
k,
14751443
v,
@@ -1485,12 +1453,6 @@ def testAttentionBackward_dq(mfma_variant: MMAType, shape: tuple[int, ...]):
14851453
dp_sub,
14861454
)
14871455

1488-
if dump_generated_mlir:
1489-
filename = f"out/wave_attention_bwd_dq_{'x'.join(map(str, shape))}.mlir"
1490-
with open(filename, "w") as f:
1491-
f.write(asm_bwd_dq)
1492-
print(f"IR dumped to {filename}")
1493-
14941456
s_sub_ref = s_ref.to(torch.float16) - lse_ref.reshape((batch, q_seq_len, 1)).expand(
14951457
batch, q_seq_len, kv_seq_len
14961458
)

tests/kernel/wave/attention/chained_gemm_test.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import pytest
88
import torch
9-
import wave_lang.kernel as tk
109
import wave_lang.kernel.lang as tkl
1110
import wave_lang.kernel.wave as tkw
1211
from wave_lang.kernel.lang.global_symbols import *
@@ -27,14 +26,11 @@
2726
)
2827
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
2928
from wave_lang.kernel.wave.constraints import MMAType
30-
import os
3129
from torch.testing import assert_close
3230
from ..common.utils import (
3331
require_e2e,
3432
require_cdna3,
3533
param_bool,
36-
enable_scheduling_barriers,
37-
dump_generated_mlir,
3834
)
3935
from ..common.shapes import get_test_shapes
4036

@@ -145,7 +141,6 @@ def repeat(
145141
subs=hyperparams,
146142
canonicalize=True,
147143
run_bench=run_bench,
148-
use_scheduling_barriers=enable_scheduling_barriers,
149144
benchmark_batch_size=10,
150145
benchmark_repetitions=3,
151146
benchmark_results_file=perf_filename_tk,
@@ -157,13 +152,7 @@ def repeat(
157152
k = device_randn(batch, kv_seq_len, qk_head_dim, dtype=torch.float16)
158153
v = device_randn(batch, v_head_dim, kv_seq_len, dtype=torch.float16)
159154
output = device_zeros(batch, v_head_dim, q_seq_len, dtype=torch.float32)
160-
asm = chained_gemm(q, k, v, output)
161-
162-
if dump_generated_mlir:
163-
filename = f"wave_cgemm_{'x'.join(map(str, shape))}.mlir"
164-
with open(filename, "w") as f:
165-
f.write(asm)
166-
print(f"IR dumped to {filename}")
155+
chained_gemm(q, k, v, output)
167156

168157
iree_ref = device_zeros(batch, v_head_dim, q_seq_len, dtype=torch.float32)
169158
generate_iree_ref("chain_mmt", [q, k, v], [iree_ref], options)
@@ -291,7 +280,6 @@ def repeat(
291280
subs=hyperparams,
292281
canonicalize=True,
293282
run_bench=run_bench,
294-
use_scheduling_barriers=enable_scheduling_barriers,
295283
benchmark_batch_size=10,
296284
benchmark_repetitions=3,
297285
benchmark_results_file=perf_filename_tk,
@@ -303,12 +291,7 @@ def repeat(
303291
k = device_randn(batch, kv_seq_len, qk_head_dim, dtype=torch.float16)
304292
v = device_randn(batch, v_head_dim, kv_seq_len, dtype=torch.float16)
305293
output = device_zeros(batch, v_head_dim, q_seq_len, dtype=torch.float32)
306-
asm = chained_gemm_f8(q, k, v, output)
307-
308-
if dump_generated_mlir:
309-
filename = f"wave_cgemm_{'x'.join(map(str, shape))}.mlir"
310-
with open(filename, "w") as f:
311-
f.write(asm)
294+
chained_gemm_f8(q, k, v, output)
312295

313296
iree_ref = device_zeros(batch, v_head_dim, q_seq_len, dtype=torch.float32)
314297
generate_iree_ref("chain_mmt_f8", [q, k, v], [iree_ref], options)

0 commit comments

Comments
 (0)