Skip to content

Commit 4f7b27e

Browse files
authored
Updated 6.4 schedule to improve wave FAv3 perf (#788)
* added unrolling * updated priority order * tested with reduced VGPRs spills * reordered prefetch compute in cluster 1 and 3. --------- Signed-off-by: root <[email protected]> Signed-off-by: xintin <[email protected]>
1 parent dab1e9f commit 4f7b27e

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

examples/python/6.4_schedule.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,27 @@
3939
"""
4040

4141
import math
42+
4243
import torch
44+
from utils import list_tests, parse_args, run_test
4345

46+
import wave_lang.kernel.lang as tkl
4447
import wave_lang.kernel.wave as tkw
4548
from wave_lang.kernel.lang.global_symbols import *
4649
from wave_lang.kernel.lang.wave_types import *
4750
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
48-
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
51+
from wave_lang.kernel.wave.schedules.attention_prefetch import (
52+
get_attention_prefetch_schedule,
53+
)
4954
from wave_lang.kernel.wave.scheduling.schedule import SchedulingType
50-
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
5155

5256
# Import templates and schedules
5357
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
5458
from wave_lang.kernel.wave.templates.tagged_attention import (
5559
get_tagged_bshd_attention_kernel,
5660
)
57-
from wave_lang.kernel.wave.schedules.attention_prefetch import (
58-
get_attention_prefetch_schedule,
59-
)
60-
61-
from utils import parse_args, list_tests, run_test
61+
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
62+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
6263

6364

6465
def test_attention_manual_schedule(is_debug=False):
@@ -85,10 +86,10 @@ def test_attention_manual_schedule(is_debug=False):
8586
shape = AttentionShape(
8687
num_query_heads=64,
8788
num_kv_heads=64,
88-
query_seq_len=128,
89-
head_size_kv=64,
90-
head_size=64,
91-
kv_seq_len=256,
89+
query_seq_len=16384,
90+
head_size_kv=128,
91+
head_size=128,
92+
kv_seq_len=16384,
9293
)
9394
mfma_variant = (tkw.MMAType.F32_16x16x16_F16,) * 2
9495

@@ -102,6 +103,10 @@ def test_attention_manual_schedule(is_debug=False):
102103
)
103104
hyperparams.update(get_default_scheduling_params())
104105

106+
# set the unroll factor
107+
UNROLL_FACTOR = tkl.sym.UNROLL_FACTOR
108+
hyperparams[UNROLL_FACTOR] = 4
109+
105110
# Get the prefetch schedule
106111
attention_prefetch_schedule = get_attention_prefetch_schedule()
107112

@@ -111,6 +116,17 @@ def test_attention_manual_schedule(is_debug=False):
111116
schedule=SchedulingType.MANUAL,
112117
use_global_to_shared=True, # Enable GatherToLDS
113118
print_ir_after="all" if is_debug else [],
119+
use_buffer_ops=True,
120+
postprocess="""
121+
module attributes {transform.with_named_sequence} {
122+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
123+
%0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
124+
transform.loop.unroll %0 { factor = %%UNROLL_FACTOR%% } : !transform.any_op
125+
transform.yield
126+
}
127+
}
128+
""",
129+
linearize_shared_access=False, # This impacts the VGPR spills
114130
)
115131

116132
options = set_default_run_config(options)

wave_lang/kernel/wave/schedules/attention_prefetch.py

100644100755
Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,45 +240,41 @@ def attention_prefetch_schedule():
240240
# - Cluster 2: PV computation + softmax0
241241
# - Cluster 3: V data movement + local load K
242242
clusters = [
243-
# Cluster 0: QK computation and softmax1
243+
# Cluster 0: QK computation and softmax1 (high priority through barrier)
244244
tkw.cluster(
245245
[
246246
tkw.SetWavePrio(1),
247247
mma_qk_init_kernel,
248248
mma_qk_kernel,
249-
tkw.SetWavePrio(0),
250249
*softmax1_ops_kernel,
251250
tkw.WorkgroupBarrier(),
252-
tkw.SchedulingBarrier([]),
251+
tkw.SetWavePrio(0), # Lower priority after barrier
253252
],
254253
),
255254
# Cluster 1: K data movement (global_to_shared_k) + local load V
256255
tkw.cluster(
257256
[
258-
global_to_shared_k_kernel,
259257
shared_load_v_kernel,
258+
global_to_shared_k_kernel,
260259
tkw.WorkgroupBarrier(),
261-
tkw.SchedulingBarrier([]),
262260
],
263261
),
264-
# Cluster 2: PV computation and softmax0
262+
# Cluster 2: PV computation and softmax0 (high priority through barrier)
265263
tkw.cluster(
266264
[
267265
tkw.SetWavePrio(1),
268266
mma_pv_kernel,
269-
tkw.SetWavePrio(0),
270267
*softmax0_ops_kernel,
271268
tkw.WorkgroupBarrier(),
272-
tkw.SchedulingBarrier([]),
269+
tkw.SetWavePrio(0),
273270
],
274271
),
275272
# Cluster 3: V data movement (global_to_shared_v) + local load K
276273
tkw.cluster(
277274
[
278-
global_to_shared_v_kernel,
279275
shared_load_k_kernel,
276+
global_to_shared_v_kernel,
280277
tkw.WorkgroupBarrier(),
281-
tkw.SchedulingBarrier([]),
282278
],
283279
),
284280
]

0 commit comments

Comments
 (0)