Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions examples/python/6.4_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,27 @@
"""

import math

import torch
from utils import list_tests, parse_args, run_test

import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.lang.wave_types import *
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
from wave_lang.kernel.wave.schedules.attention_prefetch import (
get_attention_prefetch_schedule,
)
from wave_lang.kernel.wave.scheduling.schedule import SchedulingType
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params

# Import templates and schedules
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
from wave_lang.kernel.wave.templates.tagged_attention import (
get_tagged_bshd_attention_kernel,
)
from wave_lang.kernel.wave.schedules.attention_prefetch import (
get_attention_prefetch_schedule,
)

from utils import parse_args, list_tests, run_test
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config


def test_attention_manual_schedule(is_debug=False):
Expand All @@ -85,10 +86,10 @@ def test_attention_manual_schedule(is_debug=False):
shape = AttentionShape(
num_query_heads=64,
num_kv_heads=64,
query_seq_len=128,
head_size_kv=64,
head_size=64,
kv_seq_len=256,
query_seq_len=16384,
head_size_kv=128,
head_size=128,
kv_seq_len=16384,
)
mfma_variant = (tkw.MMAType.F32_16x16x16_F16,) * 2

Expand All @@ -102,6 +103,10 @@ def test_attention_manual_schedule(is_debug=False):
)
hyperparams.update(get_default_scheduling_params())

# set the unroll factor
UNROLL_FACTOR = tkl.sym.UNROLL_FACTOR
hyperparams[UNROLL_FACTOR] = 4

# Get the prefetch schedule
attention_prefetch_schedule = get_attention_prefetch_schedule()

Expand All @@ -111,6 +116,17 @@ def test_attention_manual_schedule(is_debug=False):
schedule=SchedulingType.MANUAL,
use_global_to_shared=True, # Enable GatherToLDS
print_ir_after="all" if is_debug else [],
use_buffer_ops=True,
postprocess="""
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.loop.unroll %0 { factor = %%UNROLL_FACTOR%% } : !transform.any_op
transform.yield
}
}
""",
linearize_shared_access=False, # This impacts the VGPR spills
)

options = set_default_run_config(options)
Expand Down
16 changes: 6 additions & 10 deletions wave_lang/kernel/wave/schedules/attention_prefetch.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -240,45 +240,41 @@ def attention_prefetch_schedule():
# - Cluster 2: PV computation + softmax0
# - Cluster 3: V data movement + local load K
clusters = [
# Cluster 0: QK computation and softmax1
# Cluster 0: QK computation and softmax1 (high priority through barrier)
tkw.cluster(
[
tkw.SetWavePrio(1),
mma_qk_init_kernel,
mma_qk_kernel,
tkw.SetWavePrio(0),
*softmax1_ops_kernel,
tkw.WorkgroupBarrier(),
tkw.SchedulingBarrier([]),
tkw.SetWavePrio(0), # Lower priority after barrier
],
),
# Cluster 1: K data movement (global_to_shared_k) + local load V
tkw.cluster(
[
global_to_shared_k_kernel,
shared_load_v_kernel,
global_to_shared_k_kernel,
tkw.WorkgroupBarrier(),
tkw.SchedulingBarrier([]),
],
),
# Cluster 2: PV computation and softmax0
# Cluster 2: PV computation and softmax0 (high priority through barrier)
tkw.cluster(
[
tkw.SetWavePrio(1),
mma_pv_kernel,
tkw.SetWavePrio(0),
*softmax0_ops_kernel,
tkw.WorkgroupBarrier(),
tkw.SchedulingBarrier([]),
tkw.SetWavePrio(0),
],
),
# Cluster 3: V data movement (global_to_shared_v) + local load K
tkw.cluster(
[
global_to_shared_v_kernel,
shared_load_k_kernel,
global_to_shared_v_kernel,
tkw.WorkgroupBarrier(),
tkw.SchedulingBarrier([]),
],
),
]
Expand Down
Loading