88import torch
99from torch .nn import functional as F
1010import math
11- import wave_lang .kernel as tk
1211import wave_lang .kernel .lang as tkl
1312import wave_lang .kernel .wave as tkw
1413from wave_lang .kernel .lang .global_symbols import *
2928from wave_lang .kernel .wave .compile import WaveCompileOptions , wave_compile
3029from wave_lang .kernel .wave .constraints import MMAType
3130from ..common .utils import (
32- dump_generated_mlir ,
33- enable_scheduling_barriers ,
3431 expensive_test ,
3532 require_e2e ,
3633)
@@ -1142,7 +1139,6 @@ def testAttentionForward(mfma_variant: MMAType, shape: tuple[int, ...]):
11421139 hyperparams .update (get_default_scheduling_params ())
11431140 options = WaveCompileOptions (
11441141 subs = hyperparams ,
1145- use_scheduling_barriers = enable_scheduling_barriers ,
11461142 run_bench = False ,
11471143 waves_per_eu = 2 ,
11481144 denorm_fp_math_f32 = "preserve-sign" ,
@@ -1154,13 +1150,7 @@ def testAttentionForward(mfma_variant: MMAType, shape: tuple[int, ...]):
11541150 lse = device_zeros (batch , q_seq_len , dtype = torch .float16 )
11551151 s = device_zeros (batch , q_seq_len , kv_seq_len )
11561152
1157- asm_fwd = attention_fwd (q , k , v .transpose (- 1 , - 2 ), s , o , lse )
1158-
1159- if dump_generated_mlir :
1160- filename = f"out/wave_attention_fwd_{ 'x' .join (map (str , shape ))} .mlir"
1161- with open (filename , "w" ) as f :
1162- f .write (asm_fwd )
1163- print (f"IR dumped to { filename } " )
1153+ attention_fwd (q , k , v .transpose (- 1 , - 2 ), s , o , lse )
11641154
11651155 assert_close (s , s_ref , ** cmp_params )
11661156 # Can't check P, since we don't actually compute the "real" thing in the
@@ -1209,7 +1199,6 @@ def testAttentionBackward(mfma_variant: MMAType, shape: tuple[int, ...]):
12091199 hyperparams .update (get_default_scheduling_params ())
12101200 options = WaveCompileOptions (
12111201 subs = hyperparams ,
1212- use_scheduling_barriers = enable_scheduling_barriers ,
12131202 run_bench = False ,
12141203 waves_per_eu = 2 ,
12151204 denorm_fp_math_f32 = "preserve-sign" ,
@@ -1229,7 +1218,7 @@ def testAttentionBackward(mfma_variant: MMAType, shape: tuple[int, ...]):
12291218 dp = device_zeros (batch , q_seq_len , kv_seq_len , dtype = torch .float32 )
12301219 dp_sub = device_zeros (batch , q_seq_len , kv_seq_len , dtype = torch .float16 )
12311220
1232- asm_bwd = attention_bwd (
1221+ attention_bwd (
12331222 q ,
12341223 k ,
12351224 v ,
@@ -1247,12 +1236,6 @@ def testAttentionBackward(mfma_variant: MMAType, shape: tuple[int, ...]):
12471236 dp_sub ,
12481237 )
12491238
1250- if dump_generated_mlir :
1251- filename = f"out/wave_attention_bwd_{ 'x' .join (map (str , shape ))} .mlir"
1252- with open (filename , "w" ) as f :
1253- f .write (asm_bwd )
1254- print (f"IR dumped to { filename } " )
1255-
12561239 assert_close (s , s_ref , ** cmp_params )
12571240 assert_close (p , p_ref , ** cmp_params )
12581241
@@ -1305,7 +1288,6 @@ def testAttentionBackward_dv(mfma_variant: MMAType, shape: tuple[int, ...]):
13051288 hyperparams_dv .update (get_default_scheduling_params ())
13061289 options = WaveCompileOptions (
13071290 subs = hyperparams_dv ,
1308- use_scheduling_barriers = enable_scheduling_barriers ,
13091291 run_bench = False ,
13101292 waves_per_eu = 2 ,
13111293 denorm_fp_math_f32 = "preserve-sign" ,
@@ -1317,13 +1299,7 @@ def testAttentionBackward_dv(mfma_variant: MMAType, shape: tuple[int, ...]):
13171299 s = device_zeros (batch , q_seq_len , kv_seq_len , dtype = torch .float32 )
13181300 p = device_zeros (batch , q_seq_len , kv_seq_len , dtype = torch .float16 )
13191301
1320- asm_bwd_dv = attention_bwd_dv (q , k , do , lse_ref , dv , s , p )
1321-
1322- if dump_generated_mlir :
1323- filename = f"out/wave_attention_bwd_dv_{ 'x' .join (map (str , shape ))} .mlir"
1324- with open (filename , "w" ) as f :
1325- f .write (asm_bwd_dv )
1326- print (f"IR dumped to { filename } " )
1302+ attention_bwd_dv (q , k , do , lse_ref , dv , s , p )
13271303
13281304 assert_close (s , s_ref , ** cmp_params )
13291305 assert_close (p , p_ref , ** cmp_params )
@@ -1367,7 +1343,6 @@ def testAttentionBackward_dk(mfma_variant: MMAType, shape: tuple[int, ...]):
13671343 hyperparams_dk .update (get_default_scheduling_params ())
13681344 options = WaveCompileOptions (
13691345 subs = hyperparams_dk ,
1370- use_scheduling_barriers = enable_scheduling_barriers ,
13711346 run_bench = False ,
13721347 waves_per_eu = 2 ,
13731348 denorm_fp_math_f32 = "preserve-sign" ,
@@ -1383,7 +1358,7 @@ def testAttentionBackward_dk(mfma_variant: MMAType, shape: tuple[int, ...]):
13831358 dp = torch .zeros_like (s )
13841359 dp_sub = torch .zeros_like (p )
13851360
1386- asm_bwd_dk = attention_bwd_dk (
1361+ attention_bwd_dk (
13871362 q ,
13881363 k ,
13891364 v ,
@@ -1398,12 +1373,6 @@ def testAttentionBackward_dk(mfma_variant: MMAType, shape: tuple[int, ...]):
13981373 dp_sub ,
13991374 )
14001375
1401- if dump_generated_mlir :
1402- filename = f"out/wave_attention_bwd_dk_{ 'x' .join (map (str , shape ))} .mlir"
1403- with open (filename , "w" ) as f :
1404- f .write (asm_bwd_dk )
1405- print (f"IR dumped to { filename } " )
1406-
14071376 dp_sub_ref = (dp_ref - D .reshape ((batch , q_seq_len , 1 ))).to (torch .float16 )
14081377
14091378 assert_close (s , s_ref , ** cmp_params )
@@ -1452,7 +1421,6 @@ def testAttentionBackward_dq(mfma_variant: MMAType, shape: tuple[int, ...]):
14521421 hyperparams_dq .update (get_default_scheduling_params ())
14531422 options = WaveCompileOptions (
14541423 subs = hyperparams_dq ,
1455- use_scheduling_barriers = enable_scheduling_barriers ,
14561424 run_bench = False ,
14571425 waves_per_eu = 2 ,
14581426 denorm_fp_math_f32 = "preserve-sign" ,
@@ -1469,7 +1437,7 @@ def testAttentionBackward_dq(mfma_variant: MMAType, shape: tuple[int, ...]):
14691437 dp = torch .zeros_like (s )
14701438 dp_sub = torch .zeros_like (p )
14711439
1472- asm_bwd_dq = attention_bwd_dq (
1440+ attention_bwd_dq (
14731441 q ,
14741442 k ,
14751443 v ,
@@ -1485,12 +1453,6 @@ def testAttentionBackward_dq(mfma_variant: MMAType, shape: tuple[int, ...]):
14851453 dp_sub ,
14861454 )
14871455
1488- if dump_generated_mlir :
1489- filename = f"out/wave_attention_bwd_dq_{ 'x' .join (map (str , shape ))} .mlir"
1490- with open (filename , "w" ) as f :
1491- f .write (asm_bwd_dq )
1492- print (f"IR dumped to { filename } " )
1493-
14941456 s_sub_ref = s_ref .to (torch .float16 ) - lse_ref .reshape ((batch , q_seq_len , 1 )).expand (
14951457 batch , q_seq_len , kv_seq_len
14961458 )
0 commit comments