Skip to content

Commit bceb332

Browse files
authored
[asm] Add dynamic shapes support to C++ ASM backend (#1087)
1 parent 7160aa0 commit bceb332

23 files changed

+1369
-105
lines changed

tests/kernel/wave/asm/test_waveasm_e2e.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
import pytest
5050

51-
from tests.kernel.common.utils import require_cdna4
51+
from tests.kernel.common.utils import param_bool, require_cdna4
5252
from wave_lang.kernel.wave.asm.waveasm_e2e import (
5353
WaveASMCompiler,
5454
capture_wave_kernel_info,
@@ -1319,6 +1319,8 @@ def _dbuf_mxfp4_helper(
13191319
compiler,
13201320
backend,
13211321
dump_asm,
1322+
dynamic_dims=False,
1323+
use_buffer_ops=True,
13221324
):
13231325
"""Shared helper for double-buffered MXFP4 scheduled GEMM tests.
13241326
@@ -1349,6 +1351,8 @@ def _dbuf_mxfp4_helper(
13491351
from wave_lang.kernel.wave.utils.mxfp_utils import (
13501352
generate_gemm_afp4wfp4_inputs,
13511353
torchScaledGemmMXFP4,
1354+
b_preshuffle,
1355+
e8m0_shuffle,
13521356
)
13531357

13541358
# Get tagged kernel + options (same as 7.1_schedule.py)
@@ -1359,8 +1363,9 @@ def _dbuf_mxfp4_helper(
13591363
shape,
13601364
block,
13611365
wave_shape=(1, 4),
1366+
reorder_workgroups=not dynamic_dims,
13621367
)
1363-
schedule = get_mxfp4_asymmetric_schedule()
1368+
schedule = get_mxfp4_asymmetric_schedule(is_bscale_shuffled=True)
13641369
else:
13651370
gemm, options = get_tagged_mxfp4_gemm(
13661371
shape,
@@ -1373,8 +1378,24 @@ def _dbuf_mxfp4_helper(
13731378
options.backend = "asm"
13741379
options.wave_runtime = True
13751380
options.compile_to_mlir = False
1381+
options.use_buffer_ops = use_buffer_ops
13761382
options = set_default_run_config(options)
13771383

1384+
import wave_lang.kernel.lang as tkl
1385+
1386+
M = tkl.sym.M
1387+
N = tkl.sym.N
1388+
m, n, k = shape
1389+
1390+
dynamic_symbols = []
1391+
dynamic_values = {}
1392+
if dynamic_dims:
1393+
dynamic_symbols = [M, N]
1394+
dynamic_values = {M: m, N: n}
1395+
del options.subs[M]
1396+
del options.subs[N]
1397+
options.dynamic_symbols = dynamic_symbols
1398+
13781399
# Generate MXFP4 inputs and reference output
13791400
x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape)
13801401
torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales)
@@ -1384,7 +1405,9 @@ def _dbuf_mxfp4_helper(
13841405
c = torch.zeros(shape[0], shape[1], dtype=torch.float32).cuda()
13851406

13861407
# Capture MLIR with schedule applied
1387-
kernel_info = capture_wave_kernel_info(options, gemm, schedule=schedule)
1408+
kernel_info = capture_wave_kernel_info(
1409+
options, gemm, schedule=schedule, dynamic_values=dynamic_values
1410+
)
13881411

13891412
# Verify MLIR contains scaled_mfma operation
13901413
assert (
@@ -1424,8 +1447,10 @@ def _dbuf_mxfp4_helper(
14241447

14251448
# Execute on GPU
14261449
# Kernel signature: (a, a_scale, b, b_scale, c)
1427-
# For preshuffle B: transform B data and B scales to preshuffled layout
1450+
# For preshuffle B: transform all inputs to match kernel expectations.
1451+
# a_scale_preshuffle=True (default) means a_scales must also be shuffled.
14281452
if num_waves <= 4:
1453+
x_scales = e8m0_shuffle(x_scales).contiguous()
14291454
w_input = b_preshuffle(w.T.contiguous()).contiguous()
14301455
w_scales_input = e8m0_shuffle(w_scales).contiguous()
14311456
else:
@@ -1439,6 +1464,7 @@ def _dbuf_mxfp4_helper(
14391464
block=block_size,
14401465
shared_memory_bytes=lds_size,
14411466
func_name=kernel_name,
1467+
dynamic_dims=[dynamic_values[s] for s in dynamic_symbols],
14421468
)
14431469

14441470
# Numerical correctness validation (same tolerance as existing MXFP4 test)
@@ -1453,25 +1479,26 @@ def _dbuf_mxfp4_helper(
14531479
)
14541480

14551481

1456-
@pytest.mark.xfail(
1457-
reason="Asymmetric schedule with wave_shape=(1,4) requires ~323 VGPRs, "
1458-
"exceeding the 256 hardware encoding limit. Needs LDS scale layout "
1459-
"fix or spilling to resolve.",
1460-
)
1461-
def test_dbuf_4wave_mxfp4_gemm_cpp_backend(compiler, backend, dump_asm):
1482+
@param_bool("dynamic_dims", "dyn")
1483+
@param_bool("use_buffer_ops", "bufops")
1484+
def test_dbuf_4wave_mxfp4_gemm_cpp_backend(
1485+
dynamic_dims, use_buffer_ops, compiler, backend, dump_asm
1486+
):
14621487
"""End-to-end test for asymmetric MXFP4 GEMM with 4 waves.
14631488
1464-
Uses get_mxfp4_asymmetric_schedule() with wave_shape=(1,4) and
1465-
B direct from global (no LDS).
1489+
Uses get_mxfp4_asymmetric_schedule() with wave_shape=(1,4),
1490+
preshuffle B, and block=(128,256,256) matching 7.1_schedule.py.
14661491
"""
14671492
_dbuf_mxfp4_helper(
14681493
shape=(1024, 1024, 8192),
1469-
block=(256, 256, 256),
1494+
block=(128, 256, 256),
14701495
num_waves=4,
14711496
use_stagger=False,
14721497
compiler=compiler,
14731498
backend=backend,
14741499
dump_asm=dump_asm,
1500+
dynamic_dims=dynamic_dims,
1501+
use_buffer_ops=use_buffer_ops,
14751502
)
14761503

14771504

wave_lang/kernel/wave/asm/waveasm_e2e.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import tempfile
3434
from dataclasses import dataclass
3535
from pathlib import Path
36-
from typing import Optional, List, Tuple
36+
from typing import Dict, Optional, List, Tuple
3737

3838
import torch
3939

@@ -436,7 +436,12 @@ def capture_wave_mlir(options, kernel_func) -> str:
436436
return mlir_text
437437

438438

439-
def capture_wave_kernel_info(options, kernel_func, schedule=None) -> CapturedKernelInfo:
439+
def capture_wave_kernel_info(
440+
options,
441+
kernel_func,
442+
schedule=None,
443+
dynamic_values: Optional[Dict] = None,
444+
) -> CapturedKernelInfo:
440445
"""
441446
Capture MLIR and kernel launch info from Wave compilation.
442447
@@ -447,6 +452,9 @@ def capture_wave_kernel_info(options, kernel_func, schedule=None) -> CapturedKer
447452
options: WaveCompileOptions
448453
kernel_func: Decorated wave kernel function
449454
schedule: Optional WaveSchedule to apply during compilation
455+
dynamic_values: Optional dict mapping dynamic symbols to their concrete
456+
values. Used for grid computation when symbols are not in
457+
options.subs (i.e. truly dynamic shapes).
450458
451459
Returns:
452460
CapturedKernelInfo with all launch information
@@ -517,13 +525,19 @@ def capture_wave_kernel_info(options, kernel_func, schedule=None) -> CapturedKer
517525
dynamic_syms = list(getattr(options, "dynamic_symbols", None) or [])
518526
grid_symbols = list(kernel_func.bound_scalar_symbols.keys()) + dynamic_syms
519527
grid_values = []
528+
dv = dynamic_values or {}
520529
for sym in grid_symbols:
521-
if sym not in options.subs:
530+
if sym in options.subs:
531+
grid_values.append(options.subs[sym])
532+
elif sym in dv:
533+
grid_values.append(dv[sym])
534+
else:
522535
raise ValueError(
523-
f"Grid symbol {sym} not found in options.subs. "
524-
f"Available: {list(options.subs.keys())}"
536+
f"Grid symbol {sym} not found in options.subs or "
537+
f"dynamic_values. "
538+
f"Available subs: {list(options.subs.keys())}, "
539+
f"dynamic_values: {list(dv.keys())}"
525540
)
526-
grid_values.append(options.subs[sym])
527541
grid = launch_info.grid(grid_values)
528542
grid = tuple(int(x) for x in grid)
529543

@@ -617,6 +631,7 @@ def run_with_wave_runtime(
617631
block: Tuple[int, int, int],
618632
shared_memory_bytes: int = 0,
619633
func_name: str = "isolated_benchmark",
634+
dynamic_dims: Optional[List[int]] = None,
620635
):
621636
"""
622637
Execute a compiled GPU binary using wave_runtime.
@@ -629,6 +644,8 @@ def run_with_wave_runtime(
629644
block: Block dimensions (x, y, z)
630645
shared_memory_bytes: Shared memory size
631646
func_name: Function name in the binary (default: "isolated_benchmark")
647+
dynamic_dims: Optional list of concrete values for dynamic dimension
648+
symbols, passed as additional kernel arguments.
632649
"""
633650
import wave_runtime
634651

@@ -660,11 +677,13 @@ def run_with_wave_runtime(
660677
kern_args = [tensor.data_ptr() for tensor in all_tensors]
661678
kernel_args = wave_runtime.Int64Vector(kern_args)
662679

680+
dyn_dims = wave_runtime.Int64Vector(dynamic_dims or [])
681+
663682
# Prepare dynamic stride arguments
664683
stride_args = get_dynamic_stride_args(all_tensors)
665684

666685
# Launch
667-
wave_runtime.launch(kernel_launch_info, kernel_args, [], [], stride_args)
686+
wave_runtime.launch(kernel_launch_info, kernel_args, dyn_dims, [], stride_args)
668687

669688
# Sync
670689
torch.cuda.synchronize()

waveasm/include/waveasm/Dialect/WaveASMOps.td

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,14 @@ def WaveASM_V_LSHL_OR_B32 : VALUTernaryOp<"v_lshl_or_b32">;
591591
def WaveASM_V_LSHL_ADD_U32 : VALUTernaryOp<"v_lshl_add_u32">;
592592

593593
// Conditional mask and lane operations
594-
def WaveASM_V_CNDMASK_B32 : VALUTernaryOp<"v_cndmask_b32">;
594+
// V_CNDMASK_B32 implicitly reads VCC, so it must NOT have Pure or
595+
// ArithmeticOp traits (which would make CSE treat two instances with
596+
// identical explicit operands as equivalent even when VCC differs).
597+
def WaveASM_V_CNDMASK_B32 : WAVEASMOp<"v_cndmask_b32", []> {
598+
let arguments = (ins WaveASM_VALUSrc:$src0, WaveASM_VALUSrc:$src1, WaveASM_VALUSrc:$src2);
599+
let results = (outs WaveASM_AnyVGPR:$dst);
600+
let assemblyFormat = "$src0 `,` $src1 `,` $src2 attr-dict `:` type($src0) `,` type($src1) `,` type($src2) `->` type($dst)";
601+
}
595602

596603
// Lane read operations (VGPR -> SGPR)
597604
def WaveASM_V_READLANE_B32 : WAVEASMOp<"v_readlane_b32", [Pure]> {
@@ -908,6 +915,28 @@ def WaveASM_S_MOV_B32_M0 : WAVEASMOp<"s_mov_b32_m0", [WaveASM_SpecialRegOp]> {
908915
let assemblyFormat = "$src attr-dict `:` type($src)";
909916
}
910917

918+
def WaveASM_S_AND_SAVEEXEC_B64 : WAVEASMOp<"s_and_saveexec_b64", [WaveASM_SpecialRegOp]> {
919+
let summary = "Save exec to dst, then AND exec with VCC (implicit)";
920+
let description = [{
921+
dst = exec; exec &= vcc.
922+
VCC is read implicitly (set by a preceding V_CMP).
923+
Used for conditional execution: lanes where VCC is 0 become inactive.
924+
The saved exec is restored later via s_mov_b64_exec.
925+
}];
926+
let results = (outs WaveASM_AnySGPR:$dst);
927+
let assemblyFormat = "attr-dict `->` type($dst)";
928+
}
929+
930+
def WaveASM_S_MOV_B64_EXEC : WAVEASMOp<"s_mov_b64_exec", [WaveASM_SpecialRegOp]> {
931+
let summary = "Restore exec from saved SGPR pair";
932+
let description = [{
933+
exec = src.
934+
Used to restore exec after a conditional execution region.
935+
}];
936+
let arguments = (ins WaveASM_AnySGPR:$src);
937+
let assemblyFormat = "$src attr-dict `:` type($src)";
938+
}
939+
911940
//===----------------------------------------------------------------------===//
912941
// VMEM Atomic Instructions
913942
//===----------------------------------------------------------------------===//

waveasm/include/waveasm/Transforms/Passes.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,29 @@ def WAVEASMScalePackElimination : Pass<"waveasm-scale-pack-elimination"> {
223223
let dependentDialects = ["::waveasm::WaveASMDialect"];
224224
}
225225

226+
//===----------------------------------------------------------------------===//
227+
// Extract Scalarization Pass
228+
//===----------------------------------------------------------------------===//
229+
230+
def WAVEASMExtractScalarization
231+
: Pass<"waveasm-extract-scalarization"> {
232+
let summary = "Scalarize vector.extract from broadcast+dense-const patterns";
233+
let description = [{
234+
Pre-translation pass that rewrites
235+
vector.extract[k]( index_cast?( select?( addi(broadcast(x), dense<[...]>) )))
236+
into scalar operations: arith.addi %x, dense[k], with an optional scalar
237+
arith.select if the original chain included one.
238+
239+
This eliminates non-splat dense vector constants before the WaveASM
240+
translator runs, so translation handlers only see ordinary scalar IR.
241+
}];
242+
243+
let dependentDialects = [
244+
"::mlir::arith::ArithDialect",
245+
"::mlir::vector::VectorDialect"
246+
];
247+
}
248+
226249
//===----------------------------------------------------------------------===//
227250
// Memory Offset Optimization Pass
228251
//===----------------------------------------------------------------------===//

waveasm/include/waveasm/Transforms/TranslateFromMLIR.h

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,18 @@ class TranslationContext {
341341
int64_t srdBaseIndex; // SGPR index for SRD (e.g., 8 for s[8:11])
342342
};
343343

344+
/// Information about a pending scalar kernel argument load (index, i32, etc.)
345+
struct PendingScalarArg {
346+
mlir::Value blockArg; // The MLIR block argument
347+
int64_t argIndex; // Position in function signature
348+
};
349+
344350
/// Queue an SRD setup for a binding
345351
void queueSRDSetup(mlir::Value memref, int64_t argIndex, int64_t bufferSize);
346352

353+
/// Queue a scalar argument load from the kernarg buffer
354+
void queueScalarArgLoad(mlir::Value blockArg, int64_t argIndex);
355+
347356
/// Emit all pending SRD setup instructions (called at start of kernel body)
348357
void emitSRDPrologue();
349358

@@ -398,8 +407,10 @@ class TranslationContext {
398407
/// Update buffer size for a pending SRD (called when we see reinterpret_cast)
399408
void updateSRDBufferSize(mlir::Value memref, int64_t bufferSize);
400409

401-
/// Get the number of kernel arguments (based on pending SRD count)
402-
size_t getNumKernelArgs() const { return pendingSRDs.size(); }
410+
/// Get the number of kernel arguments (bindings + scalar args)
411+
size_t getNumKernelArgs() const {
412+
return pendingSRDs.size() + pendingScalarArgs.size();
413+
}
403414

404415
//===--------------------------------------------------------------------===//
405416
// Split Vector Result Tracking
@@ -514,6 +525,30 @@ class TranslationContext {
514525
return ldsBaseOffsetMap.contains(memref);
515526
}
516527

528+
//===--------------------------------------------------------------------===//
529+
// Dynamic Stride Tracking (for memref.reinterpret_cast with runtime strides)
530+
//===--------------------------------------------------------------------===//
531+
532+
/// Store a dynamic (runtime) stride value for a memref dimension.
533+
/// \p strideValue is the mapped WaveASM SSA value holding the element stride.
534+
void setDynamicStride(mlir::Value memref, unsigned dim,
535+
mlir::Value strideValue) {
536+
dynamicStrideMap[memref][dim] = strideValue;
537+
}
538+
539+
/// Get the dynamic stride value for a memref dimension.
540+
/// Returns nullopt if the stride is static.
541+
std::optional<mlir::Value> getDynamicStride(mlir::Value memref,
542+
unsigned dim) const {
543+
auto it = dynamicStrideMap.find(memref);
544+
if (it == dynamicStrideMap.end())
545+
return std::nullopt;
546+
auto dimIt = it->second.find(dim);
547+
if (dimIt == it->second.end())
548+
return std::nullopt;
549+
return dimIt->second;
550+
}
551+
517552
/// Track a pending per-workgroup SRD base adjustment for a linearized memref
518553
struct PendingSRDBaseAdjust {
519554
mlir::Value elementOffset;
@@ -690,6 +725,9 @@ class TranslationContext {
690725

691726
llvm::DenseMap<mlir::Value, PendingSRDBaseAdjust> pendingSRDBaseAdjustMap;
692727
llvm::SmallVector<PendingSRD, 4> pendingSRDs;
728+
llvm::SmallVector<PendingScalarArg, 2> pendingScalarArgs;
729+
llvm::DenseMap<mlir::Value, llvm::DenseMap<unsigned, mlir::Value>>
730+
dynamicStrideMap;
693731
llvm::StringMap<mlir::Value> exprCache;
694732
int64_t nextSRDIndex =
695733
-1; // Will be computed lazily, starts after user+system SGPRs
@@ -739,7 +777,8 @@ struct VOffsetResult {
739777
VOffsetResult computeVOffsetFromIndices(mlir::MemRefType memrefType,
740778
mlir::ValueRange indices,
741779
TranslationContext &ctx,
742-
mlir::Location loc);
780+
mlir::Location loc,
781+
mlir::Value base = nullptr);
743782

744783
/// Emit inline SRD base adjustment for per-workgroup buffer addressing.
745784
/// Allocates a new SRD (5 SGPRs: base pair, hi/lo temporaries, offset temp),

waveasm/lib/Transforms/AssemblyEmitter.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,18 @@ std::optional<std::string> KernelGenerator::generateOp(Operation *op) {
471471
return result;
472472
})
473473

474+
.Case<S_AND_SAVEEXEC_B64>(
475+
[&](S_AND_SAVEEXEC_B64 saveOp) -> std::optional<std::string> {
476+
std::string dst = resolveValue(saveOp.getDst());
477+
return " s_and_saveexec_b64 " + dst + ", vcc";
478+
})
479+
480+
.Case<S_MOV_B64_EXEC>(
481+
[&](S_MOV_B64_EXEC restoreOp) -> std::optional<std::string> {
482+
std::string src = resolveValue(restoreOp.getSrc());
483+
return " s_mov_b64 exec, " + src;
484+
})
485+
474486
.Case<S_BRANCH>([&](S_BRANCH branchOp) {
475487
return std::string(" s_branch ") +
476488
branchOp.getTarget().getRootReference().str();

0 commit comments

Comments
 (0)