3939"""
4040
4141import math
42+
4243import torch
44+ from utils import list_tests , parse_args , run_test
4345
46+ import wave_lang .kernel .lang as tkl
4447import wave_lang .kernel .wave as tkw
4548from wave_lang .kernel .lang .global_symbols import *
4649from wave_lang .kernel .lang .wave_types import *
4750from wave_lang .kernel .wave .compile import WaveCompileOptions , wave_compile
48- from wave_lang .kernel .wave .utils .run_utils import set_default_run_config
51+ from wave_lang .kernel .wave .schedules .attention_prefetch import (
52+ get_attention_prefetch_schedule ,
53+ )
4954from wave_lang .kernel .wave .scheduling .schedule import SchedulingType
50- from wave_lang .kernel .wave .utils .general_utils import get_default_scheduling_params
5155
5256# Import templates and schedules
5357from wave_lang .kernel .wave .templates .attention_common import AttentionShape
5458from wave_lang .kernel .wave .templates .tagged_attention import (
5559 get_tagged_bshd_attention_kernel ,
5660)
57- from wave_lang .kernel .wave .schedules .attention_prefetch import (
58- get_attention_prefetch_schedule ,
59- )
60-
61- from utils import parse_args , list_tests , run_test
61+ from wave_lang .kernel .wave .utils .general_utils import get_default_scheduling_params
62+ from wave_lang .kernel .wave .utils .run_utils import set_default_run_config
6263
6364
6465def test_attention_manual_schedule (is_debug = False ):
@@ -85,10 +86,10 @@ def test_attention_manual_schedule(is_debug=False):
8586 shape = AttentionShape (
8687 num_query_heads = 64 ,
8788 num_kv_heads = 64 ,
88- query_seq_len = 128 ,
89- head_size_kv = 64 ,
90- head_size = 64 ,
91- kv_seq_len = 256 ,
89+ query_seq_len = 16384 ,
90+ head_size_kv = 128 ,
91+ head_size = 128 ,
92+ kv_seq_len = 16384 ,
9293 )
9394 mfma_variant = (tkw .MMAType .F32_16x16x16_F16 ,) * 2
9495
@@ -102,6 +103,10 @@ def test_attention_manual_schedule(is_debug=False):
102103 )
103104 hyperparams .update (get_default_scheduling_params ())
104105
106+ # set the unroll factor
107+ UNROLL_FACTOR = tkl .sym .UNROLL_FACTOR
108+ hyperparams [UNROLL_FACTOR ] = 4
109+
105110 # Get the prefetch schedule
106111 attention_prefetch_schedule = get_attention_prefetch_schedule ()
107112
@@ -111,6 +116,17 @@ def test_attention_manual_schedule(is_debug=False):
111116 schedule = SchedulingType .MANUAL ,
112117 use_global_to_shared = True , # Enable GatherToLDS
113118 print_ir_after = "all" if is_debug else [],
119+ use_buffer_ops = True ,
120+ postprocess = """
121+ module attributes {transform.with_named_sequence} {
122+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
123+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
124+ transform.loop.unroll %0 { factor = %%UNROLL_FACTOR%% } : !transform.any_op
125+ transform.yield
126+ }
127+ }
128+ """ ,
129+ linearize_shared_access = False , # This impacts the VGPR spills
114130 )
115131
116132 options = set_default_run_config (options )
0 commit comments