4848
4949import pytest
5050
51- from tests .kernel .common .utils import require_cdna4
51+ from tests .kernel .common .utils import param_bool , require_cdna4
5252from wave_lang .kernel .wave .asm .waveasm_e2e import (
5353 WaveASMCompiler ,
5454 capture_wave_kernel_info ,
@@ -1319,6 +1319,8 @@ def _dbuf_mxfp4_helper(
13191319 compiler ,
13201320 backend ,
13211321 dump_asm ,
1322+ dynamic_dims = False ,
1323+ use_buffer_ops = True ,
13221324):
13231325 """Shared helper for double-buffered MXFP4 scheduled GEMM tests.
13241326
@@ -1349,6 +1351,8 @@ def _dbuf_mxfp4_helper(
13491351 from wave_lang .kernel .wave .utils .mxfp_utils import (
13501352 generate_gemm_afp4wfp4_inputs ,
13511353 torchScaledGemmMXFP4 ,
1354+ b_preshuffle ,
1355+ e8m0_shuffle ,
13521356 )
13531357
13541358 # Get tagged kernel + options (same as 7.1_schedule.py)
@@ -1359,8 +1363,9 @@ def _dbuf_mxfp4_helper(
13591363 shape ,
13601364 block ,
13611365 wave_shape = (1 , 4 ),
1366+ reorder_workgroups = not dynamic_dims ,
13621367 )
1363- schedule = get_mxfp4_asymmetric_schedule ()
1368+ schedule = get_mxfp4_asymmetric_schedule (is_bscale_shuffled = True )
13641369 else :
13651370 gemm , options = get_tagged_mxfp4_gemm (
13661371 shape ,
@@ -1373,8 +1378,24 @@ def _dbuf_mxfp4_helper(
13731378 options .backend = "asm"
13741379 options .wave_runtime = True
13751380 options .compile_to_mlir = False
1381+ options .use_buffer_ops = use_buffer_ops
13761382 options = set_default_run_config (options )
13771383
1384+ import wave_lang .kernel .lang as tkl
1385+
1386+ M = tkl .sym .M
1387+ N = tkl .sym .N
1388+ m , n , k = shape
1389+
1390+ dynamic_symbols = []
1391+ dynamic_values = {}
1392+ if dynamic_dims :
1393+ dynamic_symbols = [M , N ]
1394+ dynamic_values = {M : m , N : n }
1395+ del options .subs [M ]
1396+ del options .subs [N ]
1397+ options .dynamic_symbols = dynamic_symbols
1398+
13781399 # Generate MXFP4 inputs and reference output
13791400 x , w , x_scales , w_scales = generate_gemm_afp4wfp4_inputs (shape )
13801401 torch_out = torchScaledGemmMXFP4 (x , w , x_scales , w_scales )
@@ -1384,7 +1405,9 @@ def _dbuf_mxfp4_helper(
13841405 c = torch .zeros (shape [0 ], shape [1 ], dtype = torch .float32 ).cuda ()
13851406
13861407 # Capture MLIR with schedule applied
1387- kernel_info = capture_wave_kernel_info (options , gemm , schedule = schedule )
1408+ kernel_info = capture_wave_kernel_info (
1409+ options , gemm , schedule = schedule , dynamic_values = dynamic_values
1410+ )
13881411
13891412 # Verify MLIR contains scaled_mfma operation
13901413 assert (
@@ -1424,8 +1447,10 @@ def _dbuf_mxfp4_helper(
14241447
14251448 # Execute on GPU
14261449 # Kernel signature: (a, a_scale, b, b_scale, c)
1427- # For preshuffle B: transform B data and B scales to preshuffled layout
1450+ # For preshuffle B: transform all inputs to match kernel expectations.
1451+ # a_scale_preshuffle=True (default) means a_scales must also be shuffled.
14281452 if num_waves <= 4 :
1453+ x_scales = e8m0_shuffle (x_scales ).contiguous ()
14291454 w_input = b_preshuffle (w .T .contiguous ()).contiguous ()
14301455 w_scales_input = e8m0_shuffle (w_scales ).contiguous ()
14311456 else :
@@ -1439,6 +1464,7 @@ def _dbuf_mxfp4_helper(
14391464 block = block_size ,
14401465 shared_memory_bytes = lds_size ,
14411466 func_name = kernel_name ,
1467+ dynamic_dims = [dynamic_values [s ] for s in dynamic_symbols ],
14421468 )
14431469
14441470 # Numerical correctness validation (same tolerance as existing MXFP4 test)
@@ -1453,25 +1479,26 @@ def _dbuf_mxfp4_helper(
14531479 )
14541480
14551481
1456- @pytest .mark .xfail (
1457- reason = "Asymmetric schedule with wave_shape=(1,4) requires ~323 VGPRs, "
1458- "exceeding the 256 hardware encoding limit. Needs LDS scale layout "
1459- "fix or spilling to resolve." ,
1460- )
1461- def test_dbuf_4wave_mxfp4_gemm_cpp_backend (compiler , backend , dump_asm ):
1482+ @param_bool ("dynamic_dims" , "dyn" )
1483+ @param_bool ("use_buffer_ops" , "bufops" )
1484+ def test_dbuf_4wave_mxfp4_gemm_cpp_backend (
1485+ dynamic_dims , use_buffer_ops , compiler , backend , dump_asm
1486+ ):
14621487 """End-to-end test for asymmetric MXFP4 GEMM with 4 waves.
14631488
1464- Uses get_mxfp4_asymmetric_schedule() with wave_shape=(1,4) and
1465- B direct from global (no LDS) .
1489+ Uses get_mxfp4_asymmetric_schedule() with wave_shape=(1,4),
1490+ preshuffle B, and block=(128,256,256) matching 7.1_schedule.py .
14661491 """
14671492 _dbuf_mxfp4_helper (
14681493 shape = (1024 , 1024 , 8192 ),
1469- block = (256 , 256 , 256 ),
1494+ block = (128 , 256 , 256 ),
14701495 num_waves = 4 ,
14711496 use_stagger = False ,
14721497 compiler = compiler ,
14731498 backend = backend ,
14741499 dump_asm = dump_asm ,
1500+ dynamic_dims = dynamic_dims ,
1501+ use_buffer_ops = use_buffer_ops ,
14751502 )
14761503
14771504
0 commit comments