diff --git a/lib/Plugins/CMakeLists.txt b/lib/Plugins/CMakeLists.txt index 593335d222..1987910922 100644 --- a/lib/Plugins/CMakeLists.txt +++ b/lib/Plugins/CMakeLists.txt @@ -35,9 +35,16 @@ foreach( plugin ${TRITON_PLUGIN_PASSES} ) "$<$:-undefined dynamic_lookup>" ) - set_target_properties(${plugin} PROPERTIES + # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python + # build. It is empty if building directly from the root + # CMakeLists.txt file. Therefore if not building from Python just + # use the default CMake shared lib path otherwise this causes a hard + # build error + if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + set_target_properties(${plugin} PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../plugins") + endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) target_compile_options(${plugin} PRIVATE -fvisibility=hidden) endforeach() diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 122b76f827..033813e612 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6762,4 +6762,4 @@ def kernel(X, Y, Z, RANK: tl.constexpr, TRANS_A: tl.constexpr, TRANS_B: tl.const d = a.to(torch.float32) @ b.to(torch.float32) - assert torch.equal(c, d) + assert torch.allclose(c, d, rtol=1e-3, atol=1e-2) diff --git a/python/triton/tools/ragged_tma.py b/python/triton/tools/ragged_tma.py index 728dfcd42b..c7730e182a 100644 --- a/python/triton/tools/ragged_tma.py +++ b/python/triton/tools/ragged_tma.py @@ -12,7 +12,7 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0): of potentially unequal size. The load_ragged and store_ragged device functions can be used to read - and write from subarrays T[batch_offset : batch_offset + batch_size] + and write from subarrays T[slice_off : slice_off + slice_size] with hardware bounds-checking preventing any sort of leakage outside the subarray. """ @@ -46,22 +46,22 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0): @triton.jit -def to_ragged_indices(batch_offset, batch_size, row): +def to_ragged_indices(slice_off, slice_size, row): """ Helper function for load_ragged and store_ragged. """ billion = 0x40000000 # == 2**30 - x = billion - batch_size + row - y = batch_offset + batch_size + x = billion - slice_size + row + y = slice_off + slice_size return billion, y, x @triton.jit -def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0): +def load_ragged(TMA, slice_off, slice_size, coords, ragged_dim: tl.constexpr = 0): """ - Read from a subarray T[batch_offset : batch_offset + batch_size] with + Read from a subarray T[slice_off : slice_off + slice_size] with hardware bounds-checking, where reading outside the subarray gives zeros. Coords should be an appropriately-sized list of integers, just like in @@ -70,16 +70,16 @@ def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor") - c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim]) + c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim]) data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:]) data = tl.reshape(data, data.shape[2:]) return data @triton.jit -def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0): +def store_ragged(TMA, slice_off, slice_size, coords, data, ragged_dim: tl.constexpr = 0): """ - Write to a subarray T[batch_offset : batch_offset + batch_size] with + Write to a subarray T[slice_off : slice_off + slice_size] with hardware bounds-checking, where writes outside the subarray are masked correctly. @@ -87,15 +87,15 @@ def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.con TMA.store(). """ - c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim]) + c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim]) data = tl.reshape(data, [1, 1] + data.shape) TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data) @triton.jit -def atomic_add_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0): +def atomic_add_ragged(TMA, slice_off, slice_size, coords, data, ragged_dim: tl.constexpr = 0): """ - Atomic add into a subarray T[batch_offset : batch_offset + batch_size] with + Atomic add into a subarray T[slice_off : slice_off + slice_size] with hardware bounds-checking, where adds outside the subarray are masked correctly. @@ -103,6 +103,6 @@ def atomic_add_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: t TMA.atomic_add(). """ - c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim]) + c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim]) data = tl.reshape(data, [1, 1] + data.shape) TMA.atomic_add([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data) diff --git a/python/triton_kernels/bench/bench_mlp.py b/python/triton_kernels/bench/bench_mlp.py index ca3a977a3e..ff891c901f 100644 --- a/python/triton_kernels/bench/bench_mlp.py +++ b/python/triton_kernels/bench/bench_mlp.py @@ -7,7 +7,7 @@ import triton_kernels import triton_kernels.roofline as roofline import triton_kernels.swiglu -from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation +from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation from triton_kernels.target_info import get_cdna_version import distributed as triton_dist from triton_kernels.tensor_details import layout @@ -71,7 +71,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d input_x = torch.randn((batch // DP, dim1), device=dev) expt_assignment = triton_dist.create_expt_assignment(EP, n_expts_tot, torch.device(dev)) - triton_dist.initialize_matmul_ogs(batch, dim1, dim2, n_expts_act, n_expts_tot, input_x.dtype) + triton_dist.initialize_matmul(batch, dim1, dim2, n_expts_act, n_expts_tot, input_x.dtype) # run layer fpath = Path(tempfile.mktemp()) @@ -80,7 +80,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d xg = input_x.to(wg.dtype if n_expts_tot > 1 else input_x.dtype) for i in range(100): if n_expts_tot > 1: # sparse - logits = matmul_ogs(xg, wg, bg, precision_config=pcg) + logits = matmul(xg, wg, bg, precision_config=pcg) x, rdata, gather_indx, scatter_indx, metadata = triton_dist.routing(input_x, logits, n_expts_act, EP=EP, TP=TP, expt_assignment=expt_assignment, mode="ep_sharding") @@ -88,9 +88,8 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d x = triton_dist.all_gather(input_x, dim=0) rdata, gather_indx, scatter_indx, metadata = None, None, None, None if x.nelement() > 0: - x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) - x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx, - precision_config=pc2) + x = matmul(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) + x = matmul(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx, precision_config=pc2) x = triton_dist.reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment) proton.finalize() return roofline.parse_profile(fpath.with_suffix(".hatchet"), useful_op_regex=".*matmul.*") diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index e425e56ac1..1a5dea995e 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -10,12 +10,11 @@ import triton_kernels import triton_kernels.swiglu from triton_kernels.reduce import reduce -from triton_kernels.matmul_ogs import RoutingData, GatherIndx, ScatterIndx from triton_kernels.topk import topk -from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation +from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq from triton_kernels.tensor_details import layout -from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata +from triton_kernels.tensor import RaggedTensorMetadata, make_ragged_tensor_metadata, remap_ragged_tensor_metadata from triton_kernels.distributed import make_expt_dict_uniform, make_expt_assignment, convert_dp_to_ep, convert_ep_to_dp, ExptAssignment, symm_mem_pool from bench_utils import quantize_weight @@ -40,7 +39,7 @@ def create_expt_assignment(EP: int, n_expts_tot: int, device: torch.device) -> O return make_expt_assignment(EP, n_expts_tot, expt_dict, device) -def initialize_matmul_ogs( +def initialize_matmul( batch: int, dim1: int, dim2: int, @@ -52,7 +51,7 @@ def initialize_matmul_ogs( return world_size = dist.get_world_size() device = torch.cuda.current_device() - symm_mem_pool.initialize_matmul_ogs( + symm_mem_pool.initialize_matmul( n_tokens_global=batch, d_input=dim1, d_model=dim2, @@ -146,8 +145,7 @@ def routing( TP: int = 1, expt_assignment: Optional[ExptAssignment] = None, mode: Optional[str] = None, -) -> Tuple[torch.Tensor, RoutingData, GatherIndx, ScatterIndx, Optional[ReduceScatterMetadata]]: - n_expts_tot = logits.shape[-1] +) -> Tuple[torch.Tensor, RaggedTensorMetadata, torch.Tensor, torch.Tensor, Optional[ReduceScatterMetadata]]: if _is_distributed_launch() and mode: if mode == "ep_sharding": if not expt_assignment: @@ -170,15 +168,13 @@ def routing( logits_global_metadata = make_ragged_tensor_metadata(expt_sizes, dispatch_indx.shape[0]) x = convert_dp_to_ep(x, expt_assignment, active_indx, dispatch_indx) logits_local_metadata = remap_ragged_tensor_metadata(logits_global_metadata, expt_map) - gate_scal = logits_global.vals.flatten()[combine_indx] - rdata = RoutingData(gate_scal, expt_sizes, n_expts_tot // EP, n_expts_act, logits_local_metadata) reduce_scatter_metadata = ReduceScatterMetadata( mode=mode, active_indx=active_indx, dispatch_indx=dispatch_indx, combine_indx=combine_indx, ) - return x, rdata, None, None, reduce_scatter_metadata + return x, logits_local_metadata, None, None, reduce_scatter_metadata else: raise NotImplementedError(f"Distributed routing mode {mode} is not implemented yet.") else: @@ -186,13 +182,10 @@ def routing( logits = topk(logits, n_expts_act, y_indx=y_indx, apply_softmax=not sm_first) dispatch_indx = logits.mask_metadata.row_sorted_indx combine_indx = logits.mask_metadata.col_sorted_indx - ragged_batch_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0]) - gate_scal = logits.vals.flatten()[combine_indx] - routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act, - ragged_batch_metadata) - gather_indx = GatherIndx(combine_indx, dispatch_indx) - scatter_indx = ScatterIndx(dispatch_indx, combine_indx) - return x, routing_data, gather_indx, scatter_indx, None + ragged_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0]) + gather_indx = combine_indx // n_expts_act + scatter_indx = combine_indx + return x, ragged_metadata, gather_indx, scatter_indx, None def gather_ep(rank, world_size, param, TP, EP): @@ -276,14 +269,14 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac w1_full = w2_full = w1_flex_full = w2_flex_full = w1_scale_full = w2_scale_full = None # precision configs - pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), weight_scale=wg_scale) + pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), b_mx_scale=wg_scale) act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2), (1.0, 1.0)) - pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), weight_scale=w1_scale) - pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), weight_scale=w2_scale) + pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), b_mx_scale=w1_scale) + pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), b_mx_scale=w2_scale) if rank == 0: - pc1_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex_full), weight_scale=w1_scale_full) - pc2_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex_full), weight_scale=w2_scale_full) + pc1_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex_full), b_mx_scale=w1_scale_full) + pc2_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex_full), b_mx_scale=w2_scale_full) else: pc1_full = pc2_full = None @@ -296,7 +289,7 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac xd = torch.randn((batch // world_size, dim1), device=dev).to(dtype_map[x_dtype]) x0 = all_gather(xd, dim=0) expt_assignment = create_expt_assignment(EP, n_expts_tot, torch.device(dev)) - symm_mem_pool.initialize_matmul_ogs( + symm_mem_pool.initialize_matmul( n_tokens_global=batch, d_input=dim1, d_model=dim2, @@ -312,25 +305,25 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac def single(x): xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype) if n_expts_tot > 1: - logits = matmul_ogs(xg, wg, bg, precision_config=pcg) + logits = matmul(xg, wg, bg, precision_config=pcg) x, rdata, gi, si, _ = routing(x, logits, n_expts_act) else: rdata = gi = si = None - x = matmul_ogs(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act) - return matmul_ogs(x, w2_full, b2_full, rdata, scatter_indx=si, precision_config=pc2_full) + x = matmul(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act) + return matmul(x, w2_full, b2_full, rdata, scatter_indx=si, precision_config=pc2_full) # distributed pass def distributed(x): xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype) if n_expts_tot > 1: # sparse - logits = matmul_ogs(xg, wg, bg, precision_config=pcg) + logits = matmul(xg, wg, bg, precision_config=pcg) x, rdata, gi, si, metadata = routing(x, logits, n_expts_act, EP=EP, TP=TP, expt_assignment=expt_assignment, mode="ep_sharding") else: # dense x = all_gather(x, dim=0) rdata = gi = si = metadata = None - x = matmul_ogs(x, w1, b1, rdata, gather_indx=gi, precision_config=pc1, fused_activation=act) - x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=si, precision_config=pc2) + x = matmul(x, w1, b1, rdata, gather_indx=gi, precision_config=pc1, fused_activation=act) + x = matmul(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=si, precision_config=pc2) x = reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment) # gather the result from all GPUs, just for verification return all_gather(x, dim=0) diff --git a/python/triton_kernels/tests/test_distributed.py b/python/triton_kernels/tests/test_distributed.py index 44463f2059..e764bf0ee6 100644 --- a/python/triton_kernels/tests/test_distributed.py +++ b/python/triton_kernels/tests/test_distributed.py @@ -9,7 +9,7 @@ from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_dict_random, make_expt_assignment, symm_mem_pool from triton_kernels.reduce import reduce from triton_kernels.topk import topk -from triton_kernels.matmul_ogs import matmul_ogs, RoutingData, GatherIndx, ScatterIndx +from triton_kernels.matmul import matmul from triton_kernels.target_info import is_hip from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata import pytest @@ -122,17 +122,18 @@ def routing(logits, n_expts_act, all_gather=False, y_indx=None): dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx combine_indx = sparse_logits.mask_metadata.col_sorted_indx ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]) - gate_scal = sparse_logits.vals.flatten()[combine_indx] - routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, logits.shape[-1], n_expts_act, - ragged_batch_metadata) - gather_idx = GatherIndx(combine_indx, dispatch_indx) - scatter_idx = ScatterIndx(dispatch_indx, combine_indx) - return routing_data, gather_idx, scatter_idx, sparse_logits.indx + gather_idx = torch.div(combine_indx, n_expts_act, rounding_mode="trunc") + scatter_idx = combine_indx + return ragged_batch_metadata, gather_idx, scatter_idx, sparse_logits.indx def mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act, y_indx=None): rdata, combine_indx, dispatch_indx, _ = routing(l_global, n_expts_act, y_indx=y_indx) - y_global = matmul_ogs(x_global, w_global, b_global, rdata, gather_indx=combine_indx, scatter_indx=dispatch_indx) + y_global = matmul(x_global, w_global, b_global, rdata, gather_indx=combine_indx, scatter_indx=dispatch_indx) + y_mask = (dispatch_indx != -1).view(y_global.shape[-2] // n_expts_act, n_expts_act, 1) + y_global = y_global.view(y_global.shape[-2] // n_expts_act, n_expts_act, -1) + y_mask = y_mask.expand_as(y_global) + y_global, _ = reduce(y_global, dim=1, mask=y_mask) return y_global @@ -153,9 +154,7 @@ def mixture_of_expt_epsharded(x_dp_local, l_dp_local, w_ep_local, b_ep_local, ex y_ep_local = convert_dp_to_ep(x_dp_local, expt_assignment, active_indx, dispatch_indx) y_ep_local_metadata = remap_ragged_tensor_metadata(x_global_metadata, expt_map) # matrix multiply - # TODO: clean-up API. `RoutingData` should not exist; we should be passing `y_ep_local_metadata`. - rdata_ep_local = RoutingData(None, expt_sizes, w_ep_local.shape[0], n_expts_act, y_ep_local_metadata) - y_ep_local = matmul_ogs(y_ep_local, w_ep_local, b_ep_local, rdata_ep_local) + y_ep_local = matmul(y_ep_local, w_ep_local, b_ep_local, a_ragged_metadata=y_ep_local_metadata) # convert x from expert-sorted, ep-local to token-sorted, dp-local y_dp_local = convert_ep_to_dp(y_ep_local, expt_assignment, active_indx, combine_indx) # weighted average of the output token from experts @@ -208,7 +207,7 @@ def run_mixture(): y_indx=y_indx_global, ) - symm_mem_pool.initialize_matmul_ogs( + symm_mem_pool.initialize_matmul( n_tokens_global=n_tokens_global, d_input=d_model, d_model=d_model, diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index f74a2af059..0a2b70141c 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -1,169 +1,38 @@ # isort: off # fmt: off -from dataclasses import dataclass, fields, replace +from dataclasses import dataclass, fields import itertools import pytest import torch from typing import Union import triton # matmul utilities -import triton_kernels.matmul_ogs_details.opt_flags as opt_flags -from triton_kernels.matmul_ogs import FlexCtx, RoutingData, InnerRoutingData, PrecisionConfig, FusedActivation, FnSpecs, FnName, Epilogue -from triton_kernels.matmul_ogs import GatherIndx, ScatterIndx -from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch -from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig -from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4, make_ragged_tensor_metadata -from triton_kernels.tensor_details import layout -from triton_kernels.topk import topk +import triton_kernels.matmul_details.opt_flags as opt_flags +from triton_kernels.matmul import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs, FnName, Epilogue +from triton_kernels.matmul import matmul_set_idle_sms, matmul, matmul_torch # numerics utilities from triton_kernels.numerics import InFlexData, OutFlexData -from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp, quantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE +from triton_kernels.numerics_details.mxfp import upcast_from_mxfp, quantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE # testing utilities -from triton_kernels.testing import assert_close, compute_actual_scale +from triton_kernels.testing import assert_close, make_random_tensor # target-specific utilities -from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4 - -# --------------- -# initialize data -# --------------- - - -def alloc_rand(shape, device, dtype, requires_grad=True): - if dtype.itemsize == 1: - tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) - return tmp.to(dtype).requires_grad_(requires_grad) - return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) - - -def alloc_rand_like(x): - return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad) - - -def init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter, device="cuda"): - logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True) - sparse_logits = topk(logits, n_expts_act) - dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx - combine_indx = sparse_logits.mask_metadata.col_sorted_indx - ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]) - routing_data = RoutingData(None, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act, ragged_batch_metadata) - gather_idx = GatherIndx(combine_indx, dispatch_indx) if do_gather else None - scatter_idx = ScatterIndx(dispatch_indx, combine_indx) if do_scatter else None - return m, routing_data, gather_idx, scatter_idx - - -def init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode, act_dtype, weight_dtype, - has_y_gammas, requires_grad=True, device="cuda", - inner_expt_opt=None, padding_block_k=None): - torch.manual_seed(0) - assert mode in {'batched', "plain", 'ragged'} - if inner_expt_opt is not None: - assert gindx is None and sindx is None - # Rotate tensor shapes (dw = x.T @ dy) - m, k = k, m * n_expts_act + padding_block_k * n_expts_tot - in_m = m - else: - in_m = m * (n_expts_act if gindx is None else 1) - shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k) - shape_batch = tuple() if (mode == "plain" or inner_expt_opt is not None) else (n_expts_tot, ) - x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad) - w = alloc_rand(shape_batch + (k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad) - bias = alloc_rand(shape_batch + (n, ), device=device, dtype=torch.float32, requires_grad=requires_grad) - gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) - gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) - gs0 = gs0.detach().requires_grad_(requires_grad) - gs1 = gs1.detach().requires_grad_(requires_grad) - if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2): - gs0 = None - gs1 = None - if is_cuda() and "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10: - w = w.transpose(-1, -2).contiguous().transpose(-1, -2) - - def _apply_padding_and_fill_unused_part_with_nan(t, is_padded): - nan_val = float("nan") - if t.element_size() == 1: - t = t.view(torch.int8) - nan_val = 127 - - start = 0 - if is_padded: - for this_expt_nrows in rdata.expt_hist.tolist(): - end = start + this_expt_nrows - padding_end = start + triton.cdiv(this_expt_nrows, padding_block_k) * padding_block_k - t[end:padding_end, :] = 0 - start = padding_end - assert start <= t.shape[0] - t[start:, :] = nan_val - else: - n_actual_rows = rdata.expt_hist.sum().item() - if n_actual_rows + padding_block_k < t.shape[0]: - t[n_actual_rows+padding_block_k:, :] = nan_val - - if inner_expt_opt is not None: - bias = None - _apply_padding_and_fill_unused_part_with_nan(x.T, "pad_x" in inner_expt_opt) - _apply_padding_and_fill_unused_part_with_nan(w, "pad_w" in inner_expt_opt) - - return x, w, bias, gs0, gs1 - +from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4, is_xpu +from triton_kernels.swiglu import swiglu, swiglu_fn +from triton_kernels.swiglu import PrecisionConfig as SwiGLUPrecisionConfig # --------------- # numerics stuff # --------------- +class DType: -def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, mode, n_expts_tot=1, expt_is_inner=False, device="cuda"): - weight_use_flexpoint = weight_dtype.itemsize == 1 and not weight_mxfp - # flexpoint - make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) + - ([val0] - if n_expts_tot % 2 else []), dtype=torch.float32, device=device) - make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device=device) - make = lambda val0, val1, is_tensor: make_tensor(val0, val1) if is_tensor else make_scalar(val0) - - in_flex_data = lambda scale, use_flex: InFlexData(dtype=out_dtype, scale=make_scalar(scale) - ) if use_flex else InFlexData() - flex_ctx = FlexCtx( - lhs_data=in_flex_data(1.25, act_use_flexpoint), - rhs_data=InFlexData( - dtype=weight_dtype, - scale=make(1.50, 1.25, not expt_is_inner), - ) if weight_use_flexpoint else InFlexData(), - out_data=OutFlexData( - dtype=out_dtype, - expected_scale=make_scalar(4.00), - actual_scale=make_scalar(0), - checksum_scale=None, - ) if act_use_flexpoint else OutFlexData(), - ) - return PrecisionConfig(flex_ctx=flex_ctx, acc_scale=2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0, - out_dtype=out_dtype) - - -def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config): - flex_ctx = precision_config.flex_ctx - - def apply(x, scale): - if scale is None: - x = x.clone() - elif scale.numel() == 1: - x = x.float() * scale - else: - assert x.ndim == 3 - assert scale.numel() == x.shape[0] - x = x.float() * scale[:, None, None] - return x.detach().requires_grad_() - - return ( - apply(x_tri, flex_ctx.lhs_data.scale), - apply(w_tri, flex_ctx.rhs_data.scale), - None if bias_tri is None else apply(bias_tri, None), - None if gs0_tri is None else apply(gs0_tri, None), - None if gs1_tri is None else apply(gs1_tri, None), - ) - + def __init__(self, dtype_str): + self.has_global_scale = dtype_str.startswith("float8") + self.has_mx_scale = dtype_str.startswith("mx") + to_torch_dtype = lambda name: torch.uint8 if name == "float4_e2m1" else getattr(torch, name) + self.torch_dtype = to_torch_dtype(dtype_str.strip("mx")) + self.is_mxfloat4 = self.has_mx_scale and "float4" in dtype_str -def dtype_str_to_torch(dtype_str: str) -> torch.dtype: - return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str) # Scope to ensure that the opt_flags_constraints are reset after the test @@ -173,6 +42,22 @@ def opt_flags_scope(request): opt_flags.reset_opt_flags_constraints() +def make_constraints(block_m, split_k, is_persistent, epilogue_subtile, hbm_swizzling, weight_dtype_str): + constraints = { + "block_m": block_m, + "split_k": split_k, + "is_persistent": is_persistent, + "epilogue_subtile": epilogue_subtile, + } + if is_hip() and hbm_swizzling and "float4" in weight_dtype_str: + # Minimum block size to satisfy scale preshuffling + constraints.update({ + "block_m": 32, + "block_n": 32, + "block_k": 256 + }) + return constraints + # --------------- # unit tests # --------------- @@ -186,121 +71,116 @@ class Case: mode: str act_dtype_str: str weight_dtype_str: str - n_expts_tot: int = 1 - n_expts_act: int = 1 + n_slices: int = None split_k: int = 1 hbm_swizzling: bool = False epilogue_subtile: Union[int, None] = None - x_transpose: bool = False - w_transpose: bool = False - y_transpose: bool = False + a_transpose: bool = False + b_transpose: bool = False + c_transpose: bool = False colmajor_mxfp_weight: bool = True - + swiglu_opts: tuple[float, float] = None + + def __post_init__(self): + if self.n_slices is None: + self.n_slices = 1 if self.mode == "plain" else 10 + +def _build_test_op_cases(): + test_cases = [] + # zero-sized + test_cases.extend([ + Case(m, n, k, mode, "float16", "float16") + for mode in ("ragged", "batched") + for (m, n, k) in ((0, 5, 7), (5, 0, 7), (5, 7, 0)) + ]) + odd_shape1 = (727, 577, 859) + odd_shape2 = (720, 576, 768) + even_shape = (768, 512, 1024) + # canonical float16 + test_cases.extend([ + Case(*shape, mode, "float16", "float16", split_k=split_k) + for shape in [odd_shape1, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5] + ]) + # native float8 + test_cases.extend([ + Case(*shape, mode, "float8_e5m2", "float8_e5m2", split_k=split_k) + for shape in [odd_shape1, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5] + ]) + test_cases.extend([ + Case(*even_shape, "ragged", "float8_e5m2", "float8_e5m2", epilogue_subtile=val) + for val in (1, 2, 4) + ]) + # bfloat16 x mx + for shape in [odd_shape2, even_shape]: + test_cases.extend([ + Case(*shape, "plain", "bfloat16", "mxfloat4_e2m1"), + Case(*shape, "plain", "bfloat16", "mxfloat4_e2m1", hbm_swizzling=True), + Case(*shape, "batched", "bfloat16", "mxfloat4_e2m1"), + Case(*shape, "batched", "bfloat16", "mxfloat4_e2m1", hbm_swizzling=True), + Case(*shape, "ragged", "bfloat16", "mxfloat4_e2m1"), + Case(*shape, "ragged", "bfloat16", "mxfloat4_e2m1", hbm_swizzling=True), + Case(*shape, "ragged", "bfloat16", "mxfloat4_e2m1", split_k=9), + Case(*shape, "ragged", "bfloat16", "mxfloat4_e2m1", split_k=9, hbm_swizzling=True), + Case(*shape, "ragged", "bfloat16", "mxfloat8_e4m3fn"), + Case(*shape, "ragged", "bfloat16", "mxfloat8_e4m3fn", hbm_swizzling=True) + ]) + # float8 x mxfloat + test_cases.extend([ + Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", hbm_swizzling=True), + Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1", hbm_swizzling=True), + Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1"), + Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9), + Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9, hbm_swizzling=True), + Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn"), + Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1"), + Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn"), + ]) + # mxfloat x mxfloat + test_cases.extend([ + Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", hbm_swizzling=True), + Case(1024, 1024, 1024, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", split_k=9, hbm_swizzling=True), + Case(1024, 1024, 1024, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", split_k=9, colmajor_mxfp_weight=False), + Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"), + Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", hbm_swizzling=True), + Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"), + Case(1024, 1024, 1024, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", hbm_swizzling=True), + ]) + # amd-specific float8 + test_cases.extend([ + Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), + Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), + Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", split_k=2), + Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"), + ]) + # transposes / permutes + test_cases.extend([ + Case(320, 400, 400, "batched", "float16", "float16", + a_transpose=a_tr, b_transpose=b_tr, c_transpose=c_tr) + for a_tr, b_tr, c_tr in itertools.product((False, True), repeat=3) + ]) + test_cases.extend([ + Case(320, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", + a_transpose=False, b_transpose=True, c_transpose=False), + Case(320, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", + a_transpose=True, b_transpose=True, c_transpose=True), + ]) + # swiglu + test_cases.extend([ + Case(*shape, mode, "bfloat16", "bfloat16", split_k=split_k, swiglu_opts=(1.1, 1.4)) + for shape in [odd_shape2, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5] + ]) + test_cases.extend([ + Case(*even_shape, "ragged", "bfloat16", "bfloat16", epilogue_subtile=val, swiglu_opts=(1.1, 1.4)) + for val in (1, 2, 4) + ]) + + return test_cases @pytest.mark.parametrize( ", ".join(f.name for f in fields(Case)), [ - tuple(getattr(case, f.name) for f in fields(Case)) for case in [ - # Zero-sized args: - Case(0, 5, 7, "ragged", "float16", "float16"), - Case(5, 0, 7, "ragged", "float16", "float16"), - Case(5, 7, 0, "ragged", "float16", "float16"), - Case(0, 5, 7, "batched", "float16", "float16"), - Case(5, 0, 7, "batched", "float16", "float16"), - Case(5, 7, 0, "batched", "float16", "float16"), - # Non-mx types: - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4), - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), - Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1), - Case(16, 256, 256, "batched", "float16", "float16", 5, 1), - Case(16, 256, 256, "ragged", "float16", "float16", 3, 1), - Case(256, 256, 256, "ragged", "float16", "float16", 4, 1), - Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3), - Case(300, 400, 400, "batched", "float16", "float16", 5, 1), - Case(300, 400, 400, "ragged", "float16", "float16"), - Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"), - Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1), - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1), - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2), - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4), - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2), - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2), - Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1), - Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2), - Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9), - Case(16, 16, 1000, "batched", "float16", "float16", 5, 1, split_k=None), - Case(16, 16, 1000, "batched", "float8_e5m2", "float8_e5m2", 5, 1, split_k=None), - Case(16, 16, 2048, "batched", "float8_e5m2", "float8_e5m2", 6, 1, split_k=5), - # mx types: - Case(1, 1024, 1024, "plain", "bfloat16", "mxfloat8_e4m3fn", 1, 1), - Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1), - Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), - Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True, epilogue_subtile=4), - Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1), - Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), - Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), - Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), - Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9), - Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), - Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4), - Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), - Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4), - Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), - # Cover (N or K) % 128 == 64 (https://github.com/triton-lang/triton/pull/7203) - Case(1, 1472, 1472, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4), - Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), - Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), - Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), - Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1), - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9), - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2), - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), - Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4), - Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), - Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4), - Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), - Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4), - Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), - Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), - Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False), - Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), - Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), - Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1), - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9), - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, colmajor_mxfp_weight=False), - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2), - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), - Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), - Case(300, 512, 512, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), - Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), - Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4), - Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), - Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4), - Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), - # AMD - Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), - Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1), - Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2), - Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2), - Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"), - Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1), - Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2), - ] + [ - Case(320, 400, 400, mode, dtype, dtype, n_expts_tot, n_expts_act, - x_transpose=x_transpose, w_transpose=w_transpose, y_transpose=y_transpose) - for (mode, n_expts_tot, n_expts_act) in ( - ("batched", 1, 1), - ("ragged", 8, 4), - ("ragged", 32, 4), - ) - for dtype in ("float16", "float8_e5m2") - for x_transpose in (False, True) - for w_transpose in (False, True) - for y_transpose in (False, True) - ] + tuple(getattr(case, f.name) for f in fields(Case)) + for case in _build_test_op_cases() ], ) @pytest.mark.parametrize("block_m", [16, 128]) @@ -308,36 +188,34 @@ class Case: (False, False, None), (True, False, None), (False, True, None), - (False, True, None), - (True, True, None), (True, True, None), - (False, False, "pad_w"), - (False, False, "pad_x"), + (False, False, "pad_b"), + (False, False, "pad_a"), ]) -@pytest.mark.parametrize("has_y_gammas", [False, True]) +@pytest.mark.parametrize("do_gamma", [False, True]) @pytest.mark.parametrize("is_persistent", [False, True]) -def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, - n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, - x_transpose, w_transpose, y_transpose, - device, opt_flags_scope): +def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, n_slices, + mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + a_transpose, b_transpose, c_transpose, + swiglu_opts, device, opt_flags_scope): # We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to # the frame that called pytest.skip, including all the tensors, leading to OOM. skip_message = None try: - _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, - n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, - x_transpose, w_transpose, y_transpose, - device, opt_flags_scope) + _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, n_slices, + mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + a_transpose, b_transpose, c_transpose, + swiglu_opts, device, opt_flags_scope) except pytest.skip.Exception as e: skip_message = str(e) if skip_message is not None: pytest.skip(skip_message) -def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, - n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, - x_transpose, w_transpose, y_transpose, - device, opt_flags_scope): +def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, n_slices, + mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + a_transpose, b_transpose, c_transpose, + swiglu_opts, device, opt_flags_scope): # TODO: remove when Triton FP8 supports proper RTNE if is_cuda(): if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9: @@ -347,6 +225,8 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm if weight_dtype_str.startswith("mx"): if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10: pytest.skip("float8 x mx not supported with cuda capability < 10") + if swiglu_opts is not None and do_gamma: + pytest.skip("NYI: swiglu and gamma not supported together") elif is_hip(): if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4(): @@ -357,9 +237,16 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm pytest.skip("NYI: mx x mx not tested on AMD GPU") if is_persistent: pytest.skip("NYI: Persistent kernel not supported on AMD GPU") + # FIXME: this works on nvidia; looks like some sort of bug on AMD? + if do_gamma and swiglu_opts is not None: + pytest.skip("NYI: gamma and swiglu not supported together on AMD GPU") if split_k is not None and split_k > 1: pytest.skip("splitK hasn't been fully tested on AMD GPU.") + elif is_xpu(): + if swiglu_opts is not None and do_gamma: + pytest.xfail("NYI: swiglu and gamma not supported together") + if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3(): pytest.xfail("float8_e4m3fnuz only tested on AMD CDNA3 Platform") @@ -380,11 +267,11 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm if expt_is_inner: if mode != "ragged": pytest.skip("inner_expt_opt only meaningful with ragged") - if "mx" in act_dtype_str and inner_expt_opt != "pad_x": - pytest.skip("inner_expt_opt and act mx only supported with pad_x") + if "mx" in act_dtype_str and inner_expt_opt != "pad_a": + pytest.skip("inner_expt_opt and act mx only supported with pad_a") if "mx" in weight_dtype_str: - if inner_expt_opt != "pad_w": - pytest.skip("inner_expt_opt and weight mx only supported with pad_w") + if inner_expt_opt != "pad_b": + pytest.skip("inner_expt_opt and weight mx only supported with pad_b") if is_persistent and not hbm_swizzling: pytest.skip("FIXME: Fatal Python error: Aborted") if is_hip(): @@ -392,407 +279,138 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm pytest.skip("FIXME: failed to translate module to LLVM IR") if hbm_swizzling: pytest.skip("NYI: nner_expt_opt and HBM swizzling") + if not colmajor_mxfp_weight: + if is_cuda() and torch.cuda.get_device_capability()[0] < 10: + pytest.skip("transposed mxfp weight not supported with cuda capability < 10") + if block_m == 16: + pytest.skip("PassManager::run failed from Triton compiler") + # TODO: should construct the test case differently rather than overriding here + if "float8" in weight_dtype_str and is_cuda() and torch.cuda.get_device_capability()[0] < 10: + b_transpose = True - # launch metadata for batched / mx types may not work yet. torch.manual_seed(0) - block_k = None - if is_cuda() and is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10: - # Override block_k for testing correctness. The default is temporarily 128 for - # performance reasons which doesn't work with persistent matmul. - # TODO: revisit when Triton is better for H100 + MXFP4 - block_k = 256 - - constraints = { - "block_m": block_m, - "block_k": block_k, - "split_k": split_k, - "is_persistent": is_persistent, - "epilogue_subtile": epilogue_subtile, - } - - if is_hip() and hbm_swizzling and "float4" in weight_dtype_str: - # Minimum block size to satisfy scale preshuffling - constraints.update({ - "block_m": 32, - "block_n": 32, - "block_k": 256 - }) - + # set opt flags constraints + constraints = make_constraints(block_m, split_k, is_persistent, epilogue_subtile, hbm_swizzling, weight_dtype_str) opt_flags.update_opt_flags_constraints(constraints) - weight_mxfp = weight_dtype_str.startswith("mx") - weight_mxfp4 = weight_mxfp and "float4" in weight_dtype_str - if weight_mxfp: - weight_dtype_str = weight_dtype_str[2:] - act_mxfp8 = act_dtype_str.startswith("mx") - act_is_float8 = act_dtype_str.startswith("float8") - if act_mxfp8: - act_dtype_str = act_dtype_str[2:] - quantize_mxfp8_spec = FnSpecs( - FnName.QUANTIZE_MXFP8.name, quantize_mxfp8_fn, (), () - ) - - test_bwd = False - weight_dtype = dtype_str_to_torch(weight_dtype_str) - act_dtype = dtype_str_to_torch(act_dtype_str) - precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp, - mode, n_expts_tot, expt_is_inner, device=device) - # precision_opt.x_pad_trans_requires_flexpoint = False - if mode == "ragged": - m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter, - device=device) - else: - rdata = gindx = sindx = None - - padding_block_k = 32 - if hbm_swizzling: - if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 10: - # Blackwell scale swizzling constraint - # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py#L45 - padding_block_k = 128 - elif not is_persistent: - padding_block_k = 64 - x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, - mode, torch.bfloat16 if act_mxfp8 else act_dtype, # - torch.bfloat16 if weight_mxfp else weight_dtype, - has_y_gammas, requires_grad=test_bwd, device=device, - inner_expt_opt=inner_expt_opt, padding_block_k=padding_block_k) - x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt) - - if x_transpose: - x_tri = x_tri.detach().transpose(-1, -2).contiguous().transpose(-1, -2).requires_grad_(test_bwd) - if w_transpose: - w_tri = w_tri.detach().transpose(-1, -2).contiguous().transpose(-1, -2).requires_grad_(test_bwd) - if y_transpose: - if mode == "batched": - yT_shape = (n_expts_tot, n, x_tri.shape[-2]) - elif expt_is_inner: - yT_shape = (n_expts_tot, n, k) - elif sindx is not None: - yT_shape = (n, m) - else: - n_rows = x_tri.shape[-2] if gindx is None else gindx.dst_indx.shape[0] - yT_shape = (n, n_rows) - y_tri_in = torch.empty(yT_shape, dtype=act_dtype, device=device).transpose(-1, -2) - else: - y_tri_in = None - - if w_tri.shape[0] == 1 and mode != "batched": - # Test the case when weight has dim 2, i.e., shape (K, N). - w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd) - w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd) - - if weight_mxfp: - mx_axis = w_tri.ndim - 2 - # compute layouts - w_layout, w_layout_opts = layout.StridedLayout, dict() - w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict() - if hbm_swizzling and weight_mxfp4: - w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis) - w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( - mx_axis=mx_axis, num_warps=8) - # downcast to mxfp - w_tri_orig = w_tri - if colmajor_mxfp_weight: - w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) - w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis) - w_tri_dtype = FP4 if weight_mxfp4 else weight_dtype - w_tri = wrap_torch_tensor(w_tri, w_tri_dtype) - w_scale_tri = wrap_torch_tensor(w_scale_tri) - # convert layouts - w_tri = convert_layout(w_tri, w_layout, **w_layout_opts) - w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts) - else: - if is_cuda() and torch.cuda.get_device_capability()[0] < 10: - pytest.skip("transposed mxfp weight not supported with cuda capability < 10") - if block_m == 16: - pytest.skip("PassManager::run failed from Triton compiler") - # TODO: swizzling for rowmajor - - # A typical use case is we already quantized col-major weight, - # and we want matmul with its transposed row-major weight w/o - # requantization. - - # put abs_max of each 32x32 block to diagonal so scales of transposed agree - w_ndim = w_tri.ndim - if w_ndim == 2: - w_tri = w_tri.unsqueeze(0) - BLOCK_SIZE = int(MXFP_BLOCK_SIZE) - for e, i, j in itertools.product(range(w_tri.shape[0]), range(0, w_tri.shape[1], BLOCK_SIZE), range(0, w_tri.shape[2], BLOCK_SIZE)): - i_end = min(i+BLOCK_SIZE, w_tri.shape[1]) - j_end = min(j+BLOCK_SIZE, w_tri.shape[2]) - block = w_tri[e, i:i_end, j:j_end] - m_abs = block.abs().max() - i_len = i_end - i - j_len = j_end - j - min_len = min(i_len, j_len) - signs = torch.randint(0, 2, (max(i_len, j_len),), device=w_tri.device) * 2 - 1 - block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs - if j_len > i_len: - block[i_len - 1, i_len:] = signs[min_len:] * m_abs - elif i_len > j_len: - block[j_len:, j_len - 1] = signs[min_len:] * m_abs - if w_ndim == 2: - w_tri = w_tri.squeeze(0) - - # matmul with rowmajor weight expects scale is separately - # constructed (not much additional memory needed). - _, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) - # reuse quantized value from colmajor - w_tri_rowmajor, w_scale_tri_rowmajor = downcast_to_mxfp(w_tri.mT.contiguous(), weight_dtype, axis=mx_axis) - w_ref = upcast_from_mxfp(w_tri_rowmajor, w_scale_tri_rowmajor, torch.bfloat16, axis=mx_axis).mT.contiguous() - w_tri = w_tri_rowmajor.data.mT - - def _pad_and_block(x: torch.Tensor) -> torch.Tensor: - x = torch.nn.functional.pad(x, (0, x.shape[-1] % BLOCK_SIZE), mode="replicate") - return x.view(*x.shape[:-1], x.shape[-1] // BLOCK_SIZE, BLOCK_SIZE) - - # check if generated scale is transpose-invariant as intended construction - # [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)] - w_scale_tri_blocked = _pad_and_block(w_scale_tri) - w_scale_tri_sampled = w_scale_tri_blocked[..., 0:1] - # [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)] - w_scale_tri_rowmajor_blocked = _pad_and_block(w_scale_tri_rowmajor) - w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked[..., 0:1] - assert torch.equal(w_scale_tri_sampled.expand_as(w_scale_tri_blocked), w_scale_tri_blocked) - assert torch.equal(w_scale_tri_rowmajor_sampled.expand_as(w_scale_tri_rowmajor_blocked), w_scale_tri_rowmajor_blocked) - assert torch.equal(w_scale_tri_sampled.squeeze(-1), w_scale_tri_rowmajor_sampled.squeeze(-1).mT) - - precision_opt.weight_scale = w_scale_tri + a_dtype = DType(act_dtype_str) + b_dtype = DType(weight_dtype_str) + c_dtype = DType(act_dtype_str) + + # --- create conditionals --- + do_bias = inner_expt_opt is None + do_gather = do_gather and mode != "batched" + do_scatter = do_scatter and mode != "batched" + + # --- create inputs --- + a, a_scales, a_ragged_metadata = make_random_tensor( + shape=(m, k), + n_slices = n_slices, + dtype = a_dtype, + device = device, + ragged_dim = None if mode != "ragged" else 1 if expt_is_inner else 0, + mxfp_dim = -1 if a_dtype.has_mx_scale else None, + transpose = a_transpose, + ragged_padding = inner_expt_opt is not None and "pad_a" in inner_expt_opt, + squeeze_batch_dim = mode == "plain", + ) + b, b_scale_tri, b_ragged_metadata = make_random_tensor( + shape=(k, n), + n_slices = n_slices, + dtype = b_dtype, + device = device, + ragged_dim = None if mode != "ragged" or inner_expt_opt is None else 0, + mxfp_dim = -2 if b_dtype.has_mx_scale else None, + transpose = b_transpose, + ragged_padding = inner_expt_opt is not None and "pad_b" in inner_expt_opt, + squeeze_batch_dim = mode == "plain", + hbm_swizzling = hbm_swizzling, + is_mx_rowmajor = not colmajor_mxfp_weight, + ) + gather_indx = None if not do_gather else torch.randint(0, max(m, 1), (m, ), dtype=torch.int32, device=device) + scatter_indx = None if not do_scatter else torch.randperm(m, dtype=torch.int32, device=device) + bias = None if not do_bias else torch.randn(b.shape[:-2] + b.shape[-1:], dtype=torch.float32, device=device) + gammas = None if not do_gamma else 2**torch.randint(-5, 0, (m, ), dtype=torch.float32, device=device) + + # --- create fused activation --- + fused_activation = None + if swiglu_opts is not None: + fused_activation = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2), swiglu_opts) + + # --- initialize output --- + c_shape = (n_slices,) if mode == "batched" or inner_expt_opt is not None else tuple() # batch dim + c_shape += (scatter_indx.shape[0] if do_scatter else a.shape[-2],) # row dim + c_shape += (b.shape[-1] // (1 if fused_activation is None else fused_activation.specs.reduction_n) ,) # col dim + c = torch.empty(c_shape, dtype=c_dtype.torch_dtype, device=device) + if c_transpose: + c = c.mT.contiguous().mT + + # --- create precision config --- + wrap_list = lambda vals: torch.tensor(vals, dtype=torch.float32, device=device) + flex_a = InFlexData(c_dtype.torch_dtype, wrap_list([1.25])) if c_dtype.has_global_scale else InFlexData() + flex_b = InFlexData(b_dtype.torch_dtype, wrap_list([1.25])) if b_dtype.has_global_scale else InFlexData() + flex_c = OutFlexData(c_dtype.torch_dtype, wrap_list([4.00]), wrap_list([0]), None) if c_dtype.has_global_scale else OutFlexData() + precision_opt = PrecisionConfig( + flex_ctx=FlexCtx(flex_a, flex_b, flex_c), + acc_scale=2.0 if c_dtype.has_global_scale or b_dtype.has_global_scale else 1.0, + out_dtype=c_dtype.torch_dtype, + a_mx_scale=a_scales, + b_mx_scale=b_scale_tri, + ) + + # --- create epilogue --- epilogue = None - if act_mxfp8: - x_tri, x_mx_scales_tri = downcast_to_mxfp(x_tri, act_dtype, axis=-1) - x_ref = upcast_from_mxfp(x_tri, x_mx_scales_tri, torch.bfloat16, axis=-1) - is_input_batched = x_tri.ndim == 3 - y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape - n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0] - y_shape = (y_shape[0], n_rows, w_tri_orig.shape[-1]) - if sindx is None or mode == "batched": - if not is_input_batched: - y_shape = (y_shape[1], y_shape[2]) - else: - y_shape = (n_rows // rdata.n_expts_act, y_shape[-1]) - y_scale_shape = y_shape[:-1] + (triton.cdiv(y_shape[-1], MXFP_BLOCK_SIZE),) - y_scale = torch.empty(y_scale_shape, dtype=torch.uint8, device=x_tri.device) - precision_opt = replace(precision_opt, act_scale=x_mx_scales_tri, out_scale=y_scale) - epilogue = Epilogue(quantize_mxfp8_spec, tuple(), tuple(), effective_itemsize=6.0) - else: - y_scale = None - - if mode == "batched": - rdata, gindx, sindx = None, None, None - flex = precision_opt.flex_ctx + if c_dtype.has_mx_scale: + c_scale_shape = c_shape[:-1] + (triton.cdiv(c_shape[-1], MXFP_BLOCK_SIZE),) + c_scale = torch.empty(c_scale_shape, dtype=torch.uint8, device=a.device) + precision_opt.c_mx_scale = c_scale + epilogue_spec = FnSpecs(FnName.QUANTIZE_MXFP8.name, quantize_mxfp8_fn, (), ()) + epilogue = Epilogue(epilogue_spec, tuple(), tuple(), effective_itemsize=6.0) - if expt_is_inner: - inner_routing_data = InnerRoutingData( - base=rdata, block_k=padding_block_k, - x_is_padded="pad_x" in inner_expt_opt, - w_is_padded="pad_w" in inner_expt_opt, - ) - rdata = None - else: - inner_routing_data = None - - # triton + + # --- triton implementation --- try: - tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, - gammas=gs1_ref, epilogue=epilogue, y=y_tri_in, - inner_routing_data=inner_routing_data) + tri_y = matmul(a, b, bias, + a_ragged_metadata, b_ragged_metadata, + gather_indx, scatter_indx, precision_opt, + gammas=gammas, epilogue=epilogue, c=c, + fused_activation=fused_activation) + if c_dtype.has_global_scale: + tri_y_scale = precision_opt.flex_ctx.out_data.actual_scale.clone() except (opt_flags.InapplicableConstraint, NotImplementedError) as e: pytest.xfail(f"inapplicable opt_flags constraint {e}") - if y_tri_in is not None: - assert tri_y.data_ptr() == y_tri_in.data_ptr() - assert tri_y.shape == y_tri_in.shape - assert tri_y.stride() == y_tri_in.stride() - # If split_k > 1, then the intermediate tensor is fp32. - sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1 - sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1 - y_scale = flex.out_data.expected_scale if act_is_float8 else 1 - - def round_x(x, idx): - return x.to(act_dtype).to(torch.float32) if sep_gather else x - - round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y - ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, # - rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref, - inner_routing_data=inner_routing_data, device=device) - - def scale(val, scal): - if scal is None: - return val - elif scal.numel() == 1: - return val / scal - else: - assert val.ndim == 3 - return val / scal[:, None, None] - - if act_mxfp8: - tri_y = upcast_from_mxfp(tri_y, precision_opt.out_scale, target_dtype=torch.bfloat16, axis=-1).to(ref_y.dtype) - ref_y_quant, ref_y_scale = downcast_to_mxfp_torch(ref_y, act_dtype, axis=-1) - ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1) - maxtol = 4e-1 - rmstol = 4e-2 - elif weight_mxfp4: - if act_is_float8: - maxtol = 8e-2 - else: - maxtol = 3e-2 - rmstol = None - else: - maxtol = None - rmstol = None - assert_close(scale(ref_y, flex.out_data.expected_scale), tri_y, maxtol=maxtol, rmstol=rmstol) - - if act_is_float8: - tri_y_scale = flex.out_data.actual_scale.clone() - ref_y_scale = compute_actual_scale(ref_y, tri_y.dtype, tri_y_scale.numel() > 1) + # --- torch implementation --- + ref_y = matmul_torch(a, b, bias, # + a_ragged_metadata, b_ragged_metadata, + gather_indx, scatter_indx, precision_opt, + gammas=gammas) + if swiglu_opts is not None: + ref_y = swiglu(ref_y, alpha=swiglu_opts[0], precision_config=SwiGLUPrecisionConfig(swiglu_opts[1])) + if c_dtype.has_global_scale: + ref_y_scale = precision_opt.flex_ctx.out_data.actual_scale.clone() + + # --- check results --- + if c_dtype.has_mx_scale: + tri_y = upcast_from_mxfp(tri_y, precision_opt.c_mx_scale, target_dtype=torch.bfloat16, axis=-1).to(ref_y.dtype) + ref_y = upcast_from_mxfp_torch(*downcast_to_mxfp_torch(ref_y, c_dtype.torch_dtype, axis=-1), target_dtype=ref_y.dtype, axis=-1) + maxtol, rmstol = None, None + if c_dtype.has_mx_scale: + maxtol, rmstol = 4e-1, 4e-2 + elif b_dtype.is_mxfloat4: + maxtol, rmstol = 3e-2, None + assert_close(ref_y, tri_y, maxtol=maxtol, rmstol=rmstol) + if c_dtype.has_global_scale: assert torch.all((ref_y_scale - tri_y_scale).abs() < 1e-10), \ f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}" -# Test that we don't use unsupported block sizes. -@pytest.mark.parametrize("m", [8, 16, 32, 64, 128]) -@pytest.mark.parametrize("n", [8, 16, 32, 64, 128]) -@pytest.mark.parametrize("k", [8, 16, 32, 64, 128]) -def test_small_batch_matmul(m, n, k, device): - if is_hip(): - pytest.skip("Not fully tested on AMD") - - if m * n * k > 16384: - pytest.xfail() - - BATCH_SIZE = 10000 - - def _make_tensor(shape, dtype, trans): - if trans: - shape = (shape[0], shape[2], shape[1]) - t = alloc_rand(shape, device, dtype) - return t.transpose(1, 2) if trans else t - - for x_transpose, w_transpose, bias, dtype in itertools.product( - (False, True), - (False, True), - (False, True), - (torch.float16, torch.bfloat16, torch.float8_e5m2), - ): - if is_cuda() and ( - torch.cuda.get_device_capability()[0] < 10 - and dtype is torch.float8_e5m2 - and (not w_transpose) - ): - continue # Not supported - - x = _make_tensor((BATCH_SIZE, m, k), dtype, x_transpose) - w = _make_tensor((BATCH_SIZE, k, n), dtype, w_transpose) - bias = _make_tensor((BATCH_SIZE, n), torch.float32, False) if bias else None - tri_y = matmul_ogs(x, w, bias) - - # ref_y = matmul_ogs_torch(x.float(), w.float(), bias) - - # This is faster than matmul_ogs_torch. - ref_y = torch.bmm(x.float(), w.float()) - if bias is not None: - ref_y += bias[:, None, :] - - assert_close( - ref_y, - tri_y, - maxtol=4e-1 if dtype is torch.float8_e5m2 else None, - rmstol=4e-2 if dtype is torch.float8_e5m2 else None, - ) - - def test_set_idle_sms(): if not is_cuda(): pytest.skip("Only supported on CUDA") - from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags + from triton_kernels.matmul_details.opt_flags import make_opt_flags num_idle_sms = 24 - matmul_ogs_set_idle_sms(num_idle_sms) + matmul_set_idle_sms(num_idle_sms) flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \ 1, 1024, 1024, 1024, None, True, False, 1, False, False, None) assert flags.idle_sms == num_idle_sms - - -@pytest.mark.parametrize("m, n, k, mode", [ - (1200, 704, 608, "ragged"), - (800, 800, 400, "batched"), -]) -@pytest.mark.parametrize("split_k", [1, 2]) -@pytest.mark.parametrize("do_gather, do_scatter", [ - (False, False), - (True, False), - (False, True), - (True, True), -]) -@pytest.mark.parametrize("is_persistent, epilogue_subtile", [ - (False, None), - (True, 1), - (True, 4), -]) -@pytest.mark.parametrize("swiglu_alpha, swiglu_limit", [ - (1.1, 1.4), - (1.0, 1.2), - (0.7, 1.0), -]) -def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, is_persistent, epilogue_subtile, - swiglu_alpha, swiglu_limit, device, opt_flags_scope): - torch.manual_seed(0) - constraints = { - "is_persistent": is_persistent, - "epilogue_subtile": epilogue_subtile, - "split_k": split_k, - } - n_expts_tot, n_expts_act = 1, 1 - opt_flags.update_opt_flags_constraints(constraints) - - weight_dtype, act_dtype = torch.float16, torch.float16 - if mode == "ragged": - m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter, - device=device) - else: - rdata = gindx = sindx = None - - precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, mode, n_expts_tot, device=device) - x, w, bias, _, _ = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode, - act_dtype, weight_dtype, False, requires_grad=False, device=device) - - if mode == "batched": - rdata, gindx, sindx = None, None, None - - try: - a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha, - precision_config=SwiGLUPrecisionConfig(swiglu_limit)) - b = matmul_ogs( - x, w, bias, rdata, gindx, sindx, precision_opt, - fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2), - (swiglu_alpha, swiglu_limit))) - except opt_flags.InapplicableConstraint: - pytest.xfail("inapplicable constraint") - - assert_close(a, b) - - -@pytest.mark.parametrize("m, n, k", [ - (320, 2**19, 0), - (4096, 4096, 0), -]) -@pytest.mark.parametrize("view_x_as_zero_cols", [False, True]) -def test_zero_reduction_dim(m, n, k, view_x_as_zero_cols, device): - torch.manual_seed(0) - - if view_x_as_zero_cols: - x = torch.randn(m, m, device=device, dtype=torch.bfloat16) - x = x[:0, :].transpose(-1, -2) - else: - x = torch.randn(m, k, device=device, dtype=torch.bfloat16) - w = torch.randn(k, n, device=device, dtype=torch.bfloat16) - bias = torch.randn(n, device=device, dtype=torch.float32) - - try: - tri_y = matmul_ogs(x, w, bias) - except opt_flags.InapplicableConstraint: - pytest.xfail("inapplicable constraint") - ref_y = matmul_ogs_torch(x, w, bias, round_x=lambda x, idx: x, round_y=lambda y: y, device=device) - - assert_close(ref_y, tri_y) diff --git a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py index 3ce12f8e7d..f48408daea 100644 --- a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py +++ b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py @@ -5,15 +5,15 @@ import torch -import triton_kernels.matmul_ogs_details.opt_flags as opt_flags +import triton_kernels.matmul_details.opt_flags as opt_flags class _DummyPrecisionConfig: def __init__(self): - self.weight_scale = None + self.b_mx_scale = None self.max_num_imprecise_acc = None - self.act_scale = None - self.out_scale = None + self.a_mx_scale = None + self.c_mx_scale = None self.enforce_bitwise_invariance = False diff --git a/python/triton_kernels/tests/test_reduce.py b/python/triton_kernels/tests/test_reduce.py index 7bae07391a..6ca8a09e22 100644 --- a/python/triton_kernels/tests/test_reduce.py +++ b/python/triton_kernels/tests/test_reduce.py @@ -60,11 +60,11 @@ def plus_a_reduce(x, a): @pytest.mark.parametrize("dim", [0, 1, 2]) def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn, device): is_hip = triton.runtime.driver.active.get_current_target().backend == "hip" - is_pre_h100 = torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0) + is_pre_h100 = device == "cuda" and torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0) if (is_hip or is_pre_h100) and "float8" in dtype_str: pytest.skip("float8 not supported on CUDA < 9.0") torch.manual_seed(0) - x = torch.randn((B, M, N), device=device, dtype=torch.float32) + x = torch.randn((B, M, N), device=device, dtype=torch.float32, requires_grad=True) x_mscale, x_flex = None, None y_flex_tri, y_flex_ref = None, None if is_mx := dtype_str.startswith("mx"): @@ -90,16 +90,25 @@ def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn, device): postprocess_fn_ref = lambda x: (x + 10).reshape([x.shape[0], x.shape[1] // 2, 2]).sum(dim=2) else: postprocess_fn_tri = postprocess_fn_ref = None - y_tri, y_tri_mxscale = reduce(x, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_tri, + # run forward pass + x_tri = x.clone().detach().requires_grad_(True) + x_ref = x.clone().detach().requires_grad_(True) + y_tri, y_tri_mxscale = reduce(x_tri, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_tri, postprocess_fn1=postprocess_fn_tri) - y_ref, y_ref_mxscale = reduce_torch(x, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_ref, + y_ref, y_ref_mxscale = reduce_torch(x_ref, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_ref, postprocess_fn1=postprocess_fn_ref) if is_mx: y_ref = upcast_from_mxfp_torch(y_ref, y_ref_mxscale, torch.float16, axis=-1) y_tri = upcast_from_mxfp_torch(y_tri, y_tri_mxscale, torch.float16, axis=-1) + assert torch.allclose(y_tri.float(), y_ref.float(), atol=1e-3, rtol=1e-3) if is_flex: torch.allclose(y_flex_tri.actual_scale, y_flex_ref.actual_scale, atol=1e-3, rtol=1e-3) - assert torch.allclose(y_tri.float(), y_ref.float(), atol=1e-3, rtol=1e-3) + run_bwd = postprocess_fn is None and "float8" not in dtype_str + if run_bwd: + dy = torch.randn_like(y_tri) + y_tri.backward(dy) + y_ref.backward(dy) + assert torch.allclose(x_tri.grad.float(), x_ref.grad.float(), atol=1e-3, rtol=1e-3) def bench_reduce(B: int = 4, M: int = 4096, N: int = 4096, *, dim: int = 0, dtype: torch.dtype = torch.float16, diff --git a/python/triton_kernels/triton_kernels/distributed.py b/python/triton_kernels/triton_kernels/distributed.py index 8102337a97..e27fb2a275 100644 --- a/python/triton_kernels/triton_kernels/distributed.py +++ b/python/triton_kernels/triton_kernels/distributed.py @@ -135,7 +135,7 @@ def _initialize( self._is_initialized = True - def initialize_matmul_ogs( + def initialize_matmul( self, n_tokens_global: int, d_input: int, diff --git a/python/triton_kernels/triton_kernels/matmul.py b/python/triton_kernels/triton_kernels/matmul.py new file mode 100644 index 0000000000..e6d6bb283c --- /dev/null +++ b/python/triton_kernels/triton_kernels/matmul.py @@ -0,0 +1,739 @@ +# isort: off +# fmt: off +from dataclasses import dataclass +import itertools +import torch +import triton +from enum import Enum, auto +import math +# utilities +from triton_kernels import target_info +from triton_kernels.numerics import InFlexData, OutFlexData +from triton_kernels.target_info import is_cuda +# details +from .matmul_details._matmul import _matmul +from .matmul_details._p_matmul import _p_matmul, get_per_device_per_stream_alloc_fn +from .numerics_details.mxfp import MXFP_BLOCK_SIZE +from .tensor_details.layout_details.strided import StridedLayout +from .matmul_details.opt_flags import make_opt_flags, update_opt_flags_constraints +from .specialize import FnSpecs, SpecializationModule, ClosureArg +from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor, RaggedTensorMetadata +from .reduce import reduce +from .reduce import PostprocessFn as ReducePostprocessFn +from .tensor_details.ragged_tensor import ragged_metadata_fields + + +@dataclass(frozen=True) +class FusedActivation: + specs: FnSpecs = FnSpecs.default() + fn_args: tuple[object] = tuple() + + +@dataclass(frozen=True) +class Epilogue: + specs: FnSpecs = FnSpecs.default() + fn_arg_values_matmul: tuple[object] = tuple() + fn_arg_values_finalize: tuple[object] = tuple() + effective_itemsize: float = None + +class FnName(Enum): + QUANTIZE_MXFP8 = auto() + + +@dataclass(frozen=True) +class FusedComm: + out_handles: torch.Tensor + scatter_shard_indx: torch.Tensor | None = None + reduce_rank: int = 0 + n_reduce_shards: int = 1 + +specializations = SpecializationModule("matmul", + kernels=[("_matmul", _matmul), ("_p_matmul", _p_matmul)], + closure_args={ + "epilogue": ClosureArg("EPILOGUE_FN", "epilogue_fn_args"), # + "activation": ClosureArg("ACTIVATION_FN", "activation_fn_args"), # + }, +) +# ----------------------------------------------------------------------------- +# Matrix Multiplication + Outer Gather/Scatter +# ----------------------------------------------------------------------------- + + +def can_overflow_int32(tensor: torch.Tensor): + max_int32 = (1 << 31) - 1 + offset = 0 + for i in range(tensor.ndim): + offset += (tensor.shape[i] - 1) * tensor.stride(i) + return offset > max_int32 + + +def should_upcast_indices(*args): + return any(tensor is not None and can_overflow_int32(tensor) for tensor in args) + + +# --------------------- +# Numerics +# --------------------- + +# fmt: off + +@dataclass(frozen=True) +class FlexCtx: + lhs_data: InFlexData = InFlexData() + rhs_data: InFlexData = InFlexData() + out_data: OutFlexData = OutFlexData() + acc_data: InFlexData = InFlexData() + +@dataclass +class PrecisionConfig: + max_num_imprecise_acc: int = None + allow_tf32: bool = True + flex_ctx: FlexCtx = FlexCtx() + acc_scale: int = 1.0 + flexpoint_saturate_inf: bool = False + report_quantization_err_fn: callable = None + a_mx_scale: Tensor | None = None + b_mx_scale: Tensor| None = None + c_mx_scale: Tensor | None = None + out_dtype: torch.dtype = None + enforce_bitwise_invariance: bool = False + + +# TODO: merge in opt_flags +def get_swap_xw(precision_config, opt_flags): + if target_info.cuda_capability_geq(10, 0): + return precision_config.b_mx_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent + + return False + +# --------------------- +# Allocation +# --------------------- + +@dataclass +class MatmulAllocation: + device: str + output: tuple[tuple[int], torch.dtype] + scratchpads: dict[str, tuple] + +def init_allocation(x, w, precision_config, fused_activation, + gather_indx, scatter_indx, batch_dim, + n_reduce_shards, opt_flags): + # ---- output ------ + N = w.shape[-1] + # by default - M is number of rows in the activations + M = x.shape[-2] + # if the activations are gathered, then M is number of gather indices + if gather_indx is not None: + M = gather_indx.shape[0] + if scatter_indx is not None: + M = scatter_indx.shape[0] + y_rows = M + y_rows *= n_reduce_shards + out_shape = (batch_dim, y_rows, N // fused_activation.specs.reduction_n) + out_dtype = precision_config.out_dtype or x.dtype + output = (out_shape, out_dtype) + # ---- scratchpad -----# + scratchpad = dict() + N_scratch = N // fused_activation.specs.reduction_n if opt_flags.split_k == 1 else N + if opt_flags.split_k > 1: + scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype + scratchpad["matmul"] = ((opt_flags.split_k, batch_dim, M, N_scratch), scratch_out_dtype) + if "matmul" in scratchpad and precision_config.c_mx_scale is not None: + assert batch_dim == 1, "batch_dim > 1 not supported yet" + scratchpad["mx_c_mx_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N_scratch, MXFP_BLOCK_SIZE)), torch.uint8) + return MatmulAllocation(x.device, output, scratchpad) + +def apply_allocation(allocation: MatmulAllocation, output): + ret = dict() + if output is None: + output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1]) + else: + if output.ndim == 2: + output = output[None, :, :] + assert output.shape == allocation.output[0] + ret["output"] = output[None, :, :] + ret["scratchpad"] = { + k: torch.empty(v[0], device=allocation.device, dtype=v[1]) + for k, v in allocation.scratchpads.items() + } + return ret + +# ----------------------------------------------------------------------------- +# Canonicalize +# ----------------------------------------------------------------------------- +# the `matmul` kernel can operate on 2D or 3D inputs depending on the mode being used +# we can canonicalize storages to make the implementation more uniform + +def _canonicalize_storage(storage, out_ndim, flex_data): + assert out_ndim >= storage.data.ndim + # Need to use as_strided instead of view because for a tensor with + # shape[-2] == 1 can have ambuiguity related to col-wise. Fo example, + # > t = torch.randn(2, 5, 1).mT + # > t_view = t.view(t.shape) + # > t.stride(), t_view.stride() + # ((5, 1, 1), (5, 5, 1)) + # Our check t_view is col-wise fails since t_view.stride(-2) != 1 + # This case is covered by (m, n, k) == (1000, 700, 2) in test_matmul.py + new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape) + new_storage_stride = [0] * (out_ndim - storage.data.ndim) + list(storage.data.stride()) + new_storage_data = storage.data.as_strided(new_storage_shape, new_storage_stride) + if flex_data is not None: + new_storage_data = flex_data.reinterpret(new_storage_data) + return Storage(new_storage_data, storage.layout) + + +# ----------------------------------------------------------------------------- +# Triton Implementation +# ----------------------------------------------------------------------------- + +def matmul_set_idle_sms(num_idle_sms): + """ + persistent kernels will leave `num_idle_sms` idle + """ + update_opt_flags_constraints({"idle_sms": num_idle_sms}) + +def matmul(a, b, bias, + a_ragged_metadata: RaggedTensorMetadata | None = None, + b_ragged_metadata: RaggedTensorMetadata | None = None, + gather_indx: torch.Tensor | None = None, + scatter_indx: torch.Tensor | None = None, + precision_config: PrecisionConfig | None = None, + betas: torch.Tensor | None = None, + gammas: torch.Tensor | None = None, + out_alpha: float | None = None, + c: torch.Tensor | None = None, + fused_comm: FusedComm | None = None, + fused_activation: FusedActivation | None = None, + epilogue: Epilogue | None = None, + c_acc_in: torch.Tensor | None = None, +): + """ + Y[:, :] = 0. + for e in num_experts: + Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :]) + + matmul can be optionally fused with all gather or scatter at the end for the output. When fused_comm is specified, the m-th row of the output will be stored to (m * n_reduce_shards + reduce_rank) -th row + of each rank id in range [scatter_shard_indx[m] * n_reduce_shards, (scatter_shard_indx[m] + 1) * n_reduce_shards) if scatter_shard_indx is not None, otherwise the output will be all gathered across all reduce ranks. + When scatter_shard_indx is specified, the caller should ensure that the indices of different shards do not conflict. + + The output buffer for fused comm should be pre-allocated and passed in via fused_comm.out_handles, which contains ipc handles to the output tensors, each with shape (n_rows * n_reduce_shards, n_cols). + """ + is_input_batched = a.ndim == 3 + if is_input_batched: + assert gather_indx is None, "gather not supported in batched mode" + assert scatter_indx is None, "scatter not supported in batched mode" + assert b_ragged_metadata is None, "w cannot be ragged in batched mode" + assert a_ragged_metadata is None, "x cannot be ragged in batched mode" + assert fused_comm is None, "fused comm is not supported in batched mode" + assert b.ndim == 3 and b.shape[0] == a.shape[0] + if b_ragged_metadata is not None: + assert gather_indx is None + assert scatter_indx is None + # canonicalize inputs + if precision_config is None: + precision_config = PrecisionConfig() + if fused_activation is None: + fused_activation = FusedActivation(FnSpecs.default(), tuple()) + if epilogue is None: + epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False) + n_slices = max(1, b.shape[0]) if a_ragged_metadata is None else a_ragged_metadata.n_slices + # unpack scales + b_scale = precision_config.b_mx_scale + b_has_mx = b_scale is not None + is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(b.dtype) == 8 + if is_hopper_fp8: assert b.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10" + if not isinstance(b, Tensor): + # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real + dtype = FP4 if b.dtype == torch.uint8 else b.dtype + b = wrap_torch_tensor(b, dtype=dtype) + if b_has_mx and is_cuda() and (torch.cuda.get_device_capability()[0] < 10 or b.storage.layout is not None and not isinstance(b.storage.layout, StridedLayout)): + assert b.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)" + if b_scale is not None and not isinstance(b_scale, Tensor): + b_scale = Tensor(b_scale) + if b_scale is not None: + b_scale.storage.data = b_scale.data.view(torch.uint8) + b_scale.dtype = torch.uint8 + a_scale = precision_config.a_mx_scale + a_has_mx = a_scale is not None + if a_has_mx: assert a.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp" + if a_scale is not None and not isinstance(a_scale, Tensor): + a_scale = Tensor(a_scale) + if not isinstance(a, Tensor): + a = Tensor(a, dtype=a.dtype) + a_transpose = a.stride(-1) != 1 + # determine shapes + has_gather = gather_indx is not None + has_scatter = scatter_indx is not None + is_a_ragged = a_ragged_metadata is not None + is_b_ragged = b_ragged_metadata is not None + is_c_ragged = is_a_ragged and b_ragged_metadata is None + ragged_dimension = "K" if is_b_ragged else "M" if is_a_ragged else None + M = a.shape[-2] if gather_indx is None else gather_indx.shape[0] + if ragged_dimension == "K": + batch_size = b_ragged_metadata.n_slices + elif ragged_dimension is None and b.ndim == 3: + batch_size = b.shape[0] + else: + batch_size = 1 + if c_acc_in is not None: + c_acc_is_c = c_acc_in.data_ptr() == c.data_ptr() and c_acc_in.stride() == c.stride() + else: + c_acc_is_c = None + K = a.shape[-1] + K_W, N = b.shape[-2:] + if a.ndim == 3 and b.ndim == 3: + assert a.shape[0] == b.shape[0] + # compute optimization flags + out_dtype = precision_config.out_dtype or a.dtype + can_use_tma = ( + a.numel() > 0 and a.storage.is_tma_compliant() and + b.numel() > 0 and b.storage.is_tma_compliant() and + (b_scale is None or b_scale.storage.is_tma_compliant()) and + (ragged_dimension != "M" or a.stride(-1) == 1) and + # Currently we don't support tma if y is column major; may revisit later if this becomes an issue. + (c is None or c.stride(-1) == 1) and + (c_acc_in is None or c_acc_is_c) and + # if ragged dimension is K, w must be either padded or row major to ensure alignment + (ragged_dimension != "K" or b.stride(-1) == 1 or b_ragged_metadata.slice_sizes_divisibility is not None) + ) + if b_scale is not None and isinstance(b_scale.storage.layout, StridedLayout) and b_scale.storage.data.stride()[-1] != 1: + # In this case, we need to transpose b_scale. Then the reduction dim + # becomes the last dim that will be divided by 32. This to be a multiple + # of 16 to be TMA-compliant requires block_k to be a multiple of 512, + # which is too big. + can_use_tma = False + has_gather_tma = has_gather and target_info.has_tma_gather() + # hopper w/ mxfp4 doesn't support TMA + can_use_tma = can_use_tma and is_cuda() and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(b.dtype) != 4) + can_use_split_k = scatter_indx is None and not a_has_mx and not b_has_mx and ragged_dimension != "K" + block_k = None + if ragged_dimension == "K": + block_k = a_ragged_metadata.slice_sizes_divisibility or b_ragged_metadata.slice_sizes_divisibility + opt_flags = make_opt_flags(out_dtype, a.dtype, b.dtype, precision_config, + batch_size, M, N, b.shape[-2], a_ragged_metadata, + can_use_tma, can_use_split_k, epilogue.effective_itemsize, + a_transpose, c_acc_in is not None, + block_k = block_k, + ) + # there seems to be a bug on A100 + # pytest -vs test_matmul.py::test_op[False-False-False-False-pad_b-16-768-512-1024-ragged-float16-float16-10-1-False-None-False-False-False-True-None] + if ragged_dimension == "K" and is_cuda() and torch.cuda.get_device_capability()[0] < 9: + opt_flags.num_stages = 1 + if ragged_dimension == "K": + a_has_tma = opt_flags.is_persistent and (a.stride(-1) != 1 or (a_ragged_metadata.slice_sizes_divisibility is not None)) + # If TMA is used, limit is handled automatically, so we can pretend K is "even". + # (For unpadded input, we assume that the first block_k unused rows are zero-filled, + # when routing_data.expt_hist.sum() is less than K or K_W.) + if opt_flags.is_persistent: + even_K = a_has_tma or (a_ragged_metadata.slice_sizes_divisibility is not None) + else: + even_K = a_ragged_metadata.slice_sizes_divisibility is not None and b_ragged_metadata.slice_sizes_divisibility is not None + else: + batch_size = b.shape[0] if a_ragged_metadata is None and b.ndim == 3 else 1 + assert K == K_W + a_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather) + even_K = (K % opt_flags.block_k == 0) + if b_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp(): + raise NotImplementedError("Must use non-persistent kernel for simulated MXFP") + if b_scale is not None and b_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp(): + raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP") + # fused activation + matmul_fused_activation = fused_activation + reduce_fused_activation = FusedActivation() + if opt_flags.split_k > 1: + matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation + # allocate output/scratchpad memory + allocation = init_allocation(a, b, precision_config, fused_activation, + gather_indx, scatter_indx, batch_size, + fused_comm.n_reduce_shards if fused_comm is not None else 1, + opt_flags) + memory = apply_allocation(allocation, c) + # early exit + if batch_size * M * N == 0: + ret = memory["output"].squeeze(0) + if not is_input_batched: + ret = ret.squeeze(0) + return ret + # TMA descriptors require a global memory allocation + if opt_flags.is_persistent: + triton.set_allocator(get_per_device_per_stream_alloc_fn(a.device)) + # Intermediate tensors and postprocess kernels for each situation + has_scratchpad = "matmul" in memory["scratchpad"] + # Canonical output tensor (matmul scratchpad if present, otherwise final output tensor) + out_matmul = memory["scratchpad"].get("matmul", memory["output"]) + out_matmul_flex = OutFlexData() if out_matmul.dtype == torch.float32 else precision_config.flex_ctx.out_data + # Unified mx-scale pointer; when scratchpad exists, prefer its mx buffer + out_matmul_scale = precision_config.c_mx_scale + if out_matmul_scale is not None: + out_matmul_scale = out_matmul_scale.data.view(torch.uint8) + if has_scratchpad and "mx_c_mx_scale" in memory["scratchpad"]: + out_matmul_scale = memory["scratchpad"]["mx_c_mx_scale"] + out_matmul_has_mx = out_matmul_scale is not None and out_matmul.element_size() == 1 + # matrix multiplication + flex = precision_config.flex_ctx + bias_stride = None if bias is None else bias.stride(0) + # moe metadata + expt_data_w = tuple([None] * 6) if ragged_dimension != "K" else ragged_metadata_fields(b_ragged_metadata, opt_flags.block_k) + expt_data_x = tuple([None] * 6) if ragged_dimension is None else ragged_metadata_fields(a_ragged_metadata, opt_flags.block_m if ragged_dimension == "M" else opt_flags.block_k) + # spmd grid + grid_m = triton.cdiv(M, opt_flags.block_m) + if ragged_dimension == "M": + grid_m = a_ragged_metadata.n_blocks(a_ragged_metadata.n_slices, M, opt_flags.block_m) + grid_n = triton.cdiv(N, opt_flags.block_n) + max_grid = batch_size * grid_m * grid_n * opt_flags.split_k + grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid + # canonicalize storage + has_scatter_tma = scatter_indx is not None and target_info.has_tma_gather() + c = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if has_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:])) + a_storage = _canonicalize_storage(a.storage, 2 if has_gather_tma else 3, flex.lhs_data) + b_storage = _canonicalize_storage(b.storage, 3, flex.rhs_data) + c_storage = _canonicalize_storage(c.storage, 2 if has_scatter_tma else 3, flex.out_data) + # create tma descriptor for x + if c_acc_in is not None: + assert opt_flags.split_k == 1, "c_acc_in + split_k is not supported." + assert scatter_indx is None, "c_acc_in + scatter is not supported." + if c_acc_in.ndim == 2: + c_acc_in = c_acc_in.unsqueeze(0) + assert c_acc_in.shape == out_matmul.shape[-3:] + c_acc_strides = c_acc_in.stride() + else: + c_acc_strides = (None, None, None) + + a_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k] + a_tma_mode = None if not a_has_tma else "ragged" if ragged_dimension == "M" and not has_gather_tma else "dense" + a_tensor_or_tma = a_storage.make_tma(a_tma_block_size, a_tma_mode) if a_has_tma else a_storage.data + # create tma descriptor for y + c_has_tma = ( + opt_flags.is_persistent and (scatter_indx is None or has_scatter_tma) + and (c_acc_in is None or c_acc_is_c) + ) + block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.specs.reduction_n + c_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n] + c_tma_mode = None if not c_has_tma else "ragged" if is_c_ragged and not has_scatter_tma else "dense" + c_tensor_or_tma = c_storage.make_tma(c_tma_block_size, c_tma_mode) if c_has_tma else c_storage.data + # create tma descriptor for w + b_has_tma = opt_flags.is_persistent + b_tensor_or_tma = b_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if b_has_tma else b_storage.data + # create tma descriptor for w_scale + b_scale_has_tma = opt_flags.is_persistent and b_scale is not None + b_transpose = b_storage.data.stride()[-2] == 1 + if b_scale_has_tma: + scale_block_k = opt_flags.block_k // int(MXFP_BLOCK_SIZE) + b_scale_storage = b_scale.storage + b_scale_tma_block_size = [opt_flags.block_n, scale_block_k] if b_transpose else [scale_block_k, opt_flags.block_n] + if isinstance(b_scale.storage.layout, StridedLayout): + b_scale_storage = _canonicalize_storage(b_scale.storage, 3, None) + b_scale_tma_block_size = [1] + b_scale_tma_block_size + b_scale_tensor_or_tma = b_scale_storage.make_tma(b_scale_tma_block_size, "dense", is_scale=True) + else: + b_scale_tensor_or_tma = b_scale + # canonicalize strides + a_strides = [0]*(3 - a_storage.data.ndim) + list(a_storage.data.stride()) + a_scale_strides = a_scale.stride() if a_has_mx else (None, None, None) + a_scale_strides = (0, ) * (3 - len(a_scale_strides)) + a_scale_strides + b_scale_strides = b_scale.stride() if b_has_mx and not b_scale_has_tma else (None, None, None) + b_scale_strides = (0, ) * (3 - len(b_scale_strides)) + b_scale_strides + out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None) + out_matmul_scale_strides = (0, ) * (4 - len(out_matmul_scale_strides)) + out_matmul_scale_strides + # launch kernel + kernels = specializations.get(epilogue=epilogue.specs, activation=matmul_fused_activation.specs) + # When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed + # (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose + # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs. + # w_transpose = w_storage.data.stride()[-1] != 1 + fused_comm_kwargs = { + "pYPtrs": fused_comm.out_handles, + "ScatterShardIndx": fused_comm.scatter_shard_indx, + "reduce_rank": fused_comm.reduce_rank, + "n_reduce_shards": fused_comm.n_reduce_shards, + } if fused_comm is not None else {} + n_valid_slices = b_tensor_or_tma.shape[0] if ragged_dimension == "M" else n_slices + (kernels._p_matmul if opt_flags.is_persistent else kernels._matmul)[(grid,)]( + c_tensor_or_tma, c_storage.data, *out_matmul.stride(), + *((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex), + *out_matmul_scale_strides[-4:], + a_tensor_or_tma, a_storage.data, *a_strides, a_transpose, + flex.lhs_data.scale, + None if a_scale is None else a_scale.data.view(torch.uint8), *a_scale_strides, + b_tensor_or_tma, b_storage.data, *b_storage.data.stride(), b_transpose, + flex.rhs_data.scale, + b_scale_tensor_or_tma, *b_scale_strides, + flex.acc_data.reinterpret(c_acc_in), *c_acc_strides, + flex.acc_data.scale, c_acc_is_c, + bias, bias_stride, + None if ragged_dimension == "M" else a.shape[-2], + N, K, K_W, + betas, gammas, + gather_indx, + scatter_indx, + None if scatter_indx is None else scatter_indx.shape[0], + ragged_dimension, + *expt_data_x, + *expt_data_w, + batch_size, grid_m, grid_n, + out_alpha, + *matmul_fused_activation.fn_args, matmul_fused_activation.specs.reduction_n, + *epilogue.fn_arg_values_matmul, + n_valid_slices, + precision_config.max_num_imprecise_acc, + precision_config.allow_tf32, + precision_config.flexpoint_saturate_inf, + flex.rhs_data.is_per_batch, + out_matmul_flex.is_per_batch, + flex.acc_data.is_per_batch, + opt_flags.block_m, + opt_flags.block_n, + opt_flags.block_k, + opt_flags.group_m, + XCD_SWIZZLE=opt_flags.xcd_swizzle, + SWIZZLE_MX_VALUE=b.storage.layout.name, + SWIZZLE_MX_SCALE=None if b_scale is None else b_scale.storage.layout.name, + EPILOGUE_SUBTILE=opt_flags.epilogue_subtile, + SPLIT_K=opt_flags.split_k, + EVEN_K=even_K, + W_CACHE_MODIFIER=opt_flags.w_cache_modifier, + TOKENS_PER_EXPT_FOR_ANNOTATION=None if a_ragged_metadata is None else a_ragged_metadata.expected_slice_size, + num_warps=opt_flags.num_warps, + num_stages=opt_flags.num_stages, + arch=opt_flags.arch, + UPCAST_INDICES=should_upcast_indices(a, b, out_matmul), + X_TMA_MODE=a_tma_mode, + Y_TMA_MODE=c_tma_mode, + SWAP_XW=get_swap_xw(precision_config, opt_flags), + IS_EPILOGUE_QUANT_MXFP8=epilogue.specs.name == FnName.QUANTIZE_MXFP8.name, + NUM_SMS = grid if opt_flags.is_persistent else 0, + **fused_comm_kwargs, + **opt_flags.target_kernel_kwargs) + + assert not (opt_flags.split_k > 1 and scatter_indx is not None) + out_final_mx_scale = None + if opt_flags.split_k > 1: + assert not out_matmul_has_mx + postprocess_fn1 = ReducePostprocessFn(specs=reduce_fused_activation.specs, fn_args=reduce_fused_activation.fn_args) + postprocess_fn2 = ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize) + c, y_mx_scale = reduce( + x = out_matmul.view(out_matmul.shape[0], -1, out_matmul.shape[-1]), + dim = 0, + # output data/metadata + y = memory["output"].view(-1, memory["output"].shape[-1]), + y_dtype = memory["output"].dtype, + y_flex = precision_config.flex_ctx.out_data, + y_flex_saturate_inf = precision_config.flexpoint_saturate_inf, + y_has_mx = precision_config.c_mx_scale is not None, + # fused functions + postprocess_fn1 = postprocess_fn1, + postprocess_fn2 = postprocess_fn2, + ) + y_shape = out_matmul.shape[1:-1] + (out_matmul.shape[-1] // reduce_fused_activation.specs.reduction_n,) + out_final = c.view(*y_shape) + if y_mx_scale is not None: + out_final_mx_scale = y_mx_scale.view(out_matmul.shape[-2], triton.cdiv(out_matmul.shape[-1], 32)) + else: + out_final = out_matmul.squeeze(0) + out_final_mx_scale = out_matmul_scale + + if not (is_input_batched or b_ragged_metadata is not None): + out_final = out_final.squeeze(0) + if out_final_mx_scale is not None: + precision_config.c_mx_scale = out_final_mx_scale + return out_final + +# ----------------------------------------------------------------------------- +# Reference Implementation +# ----------------------------------------------------------------------------- + +def apply_precision(x_tri, w_tri, precision_config): + from .tensor import convert_layout + from .tensor_details import layout + from .numerics_details.mxfp import upcast_from_mxfp + + flex_ctx = precision_config.flex_ctx + + def apply(x, scale): + if scale is None: + return x.clone() + return x.float() * scale + + if precision_config.a_mx_scale is not None: + mx_axis = x_tri.storage.data.ndim -1 + x_tri = convert_layout(x_tri, layout.StridedLayout) + x_tri_scale = convert_layout(precision_config.a_mx_scale, layout.StridedLayout) + x_ref = upcast_from_mxfp(x_tri.storage.data, x_tri_scale.storage.data, torch.bfloat16, axis=mx_axis) + else: + x_ref = apply(x_tri, flex_ctx.lhs_data.scale) + + if precision_config.b_mx_scale is not None: + mx_axis = w_tri.storage.data.ndim - 2 + w_tri = convert_layout(w_tri, layout.StridedLayout) + w_tri_scale = convert_layout(precision_config.b_mx_scale, layout.StridedLayout) + w_ref = upcast_from_mxfp(w_tri.storage.data, w_tri_scale.storage.data, torch.bfloat16, axis=mx_axis) + else: + w_ref = apply(w_tri, flex_ctx.rhs_data.scale) + + return ( + x_ref, w_ref, + ) + + +def scale(val, scal): + if scal is None: + return val + elif scal.numel() == 1: + return val / scal + else: + assert val.ndim == 3 + return val / scal[:, None, None] + +def compute_actual_scale(x, dtype, per_batch_scale=False): + from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 + max_finite = { + torch.float8_e5m2: MAX_FINITE_FLOAT8E5, + torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV, + torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8, + }[dtype] + maxvals = x.abs().amax(dim=tuple(range(1, x.ndim))) if per_batch_scale else x.abs().max() + return maxvals / max_finite + + +def matmul_torch(a, b, bias, + a_ragged_metadata: RaggedTensorMetadata | None = None, + b_ragged_metadata: RaggedTensorMetadata | None = None, + gather_indx: torch.Tensor = None, + scatter_indx: torch.Tensor = None, + precision_config: PrecisionConfig = None, + betas = None, + gammas = None, + round_x = None, round_y = None, + ): + a, b = apply_precision(a, b, precision_config) + + if b_ragged_metadata is not None: + n_expts_tot = b_ragged_metadata.slice_sizes.shape[0] + m, n = a.shape[-2], b.shape[-1] + out = torch.zeros((n_expts_tot, m, n), dtype=torch.float32, device=a.device) + x_slice_offs = a_ragged_metadata.slice_offs + w_slice_offs = b_ragged_metadata.slice_offs + for expt in range(n_expts_tot): + k = int(b_ragged_metadata.slice_sizes[expt].item()) + if k == 0: + continue + x_start = int(x_slice_offs[expt].item()) + w_start = int(w_slice_offs[expt].item()) + x_slice = a[:, x_start:x_start + k] + w_slice = b[w_start:w_start + k, :] + out_expt = matmul_torch( + x_slice, w_slice, None, None, + None, None, None, PrecisionConfig(), + betas, gammas, + round_x, round_y, + ) + out[expt] = out_expt.to(out.dtype) + actual_scale = precision_config.flex_ctx.out_data.actual_scale + if actual_scale is not None: + actual_scale.copy_(compute_actual_scale(out, precision_config.out_dtype)) + return scale(out, precision_config.flex_ctx.out_data.expected_scale) + + is_input_batched = a.ndim == 3 + assert a.dtype.itemsize > 1 + assert b.dtype.itemsize > 1 + if is_input_batched: + assert gather_indx is None, "gather not supported in batched mode" + assert scatter_indx is None, "scatter not supported in batched mode" + assert b.ndim == 3 and b.shape[0] == a.shape[0] + if round_x is None: + round_x = lambda x, idx: x + if round_y is None: + round_y = lambda x: x + if bias is not None and bias.ndim == 1: + bias = bias.view(1, *bias.shape) + if b.ndim == 2: + b = b.view(1, *b.shape) + if a.ndim == 2: + a = a.view(1, *a.shape) + # memory offsets + if a_ragged_metadata is not None and not is_input_batched: + sizes = a_ragged_metadata.slice_sizes + off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32) + off[1:] = torch.cumsum(sizes, 0) + offs = list(itertools.pairwise(off)) + else: + offs = [[0, a.shape[1]] for _ in range(b.shape[0])] + # compute + n_rows = a.shape[1] if gather_indx is None else gather_indx.shape[0] + y = torch.zeros((a.shape[0], n_rows, b.shape[-1]), device=a.device, dtype=a.dtype) + for i, (lo, hi) in enumerate(offs): + if gather_indx is None: + idx = torch.arange(lo, hi, device=a.device) + else: + idx = gather_indx[lo:hi] + batch = i if is_input_batched else 0 + out = torch.matmul(round_x(a[batch, idx, :], torch.arange(lo, hi, device=a.device)).float(), + b[i].float()) + if bias is not None: + out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None] + if gammas is not None: + out *= gammas[lo:hi, None] + y[batch, lo:hi, :] = round_y(out) + if not is_input_batched: + y = y.view(y.shape[1], y.shape[2]) + if scatter_indx is None: + out = y + else: + out = torch.zeros((scatter_indx.shape[0], y.shape[-1]), dtype=y.dtype, device=a.device) + msk = scatter_indx != -1 + out[scatter_indx[msk], :] = y[msk, :] + actual_scale = precision_config.flex_ctx.out_data.actual_scale + if actual_scale is not None: + actual_scale.copy_(compute_actual_scale(out, precision_config.out_dtype)) + return scale(out, precision_config.flex_ctx.out_data.expected_scale) + + +def post_matmul_comm_torch(y: torch.Tensor, rank: int, n_reduce_shards: int, + world_size: int, + scatter_shard_indx: torch.Tensor | None = None, +): + """ + Reference implementation of post matmul communication. + + y: the local matmul output + rank: the global rank + n_reduce_shards: the number of reduce shards + world_size: the world size + scatter_shard_indx: the shard indices for the scatter. None if all gather. + + Output shape: + (batch_size, n_rows, n_cols) -> (batch_size, n_rows * n_reduce_shards, n_cols) if batched, otherwise + (n_rows, n_cols) -> (n_rows * n_reduce_shards, n_cols) + """ + from torch import distributed as dist + # if n_reduce_shards == 1: + # return y + + ys = [torch.empty_like(y) for _ in range(world_size)] + dist.all_gather(ys, y) + out_shape = (*y.shape[:-2], y.shape[-2] * n_reduce_shards, y.shape[-1]) + + if scatter_shard_indx is None: + # all gather + assert n_reduce_shards == world_size + return torch.cat(ys, dim=-1).reshape(out_shape) + else: + # Note: when multiple ranks scatter to the same destination, the result is undefined. + scatter_shard_indx_global = torch.empty((world_size, *scatter_shard_indx.shape), device=scatter_shard_indx.device, dtype=scatter_shard_indx.dtype) + dist.all_gather([scatter_shard_indx_global[i] for i in range(world_size)], scatter_shard_indx) + + assert len(out_shape) == 2, "batched mode not supported" + result = torch.zeros(out_shape, device=y.device, dtype=y.dtype) + reduce_shard_id = rank // n_reduce_shards + + for i in range(world_size // n_reduce_shards): + scatter_mask = scatter_shard_indx_global[i * n_reduce_shards, :] == reduce_shard_id + for j in range(n_reduce_shards): + out_slice = result.as_strided( + (result.shape[0] // n_reduce_shards, result.shape[1]), + (result.stride(0) * n_reduce_shards, result.stride(1)), + storage_offset=j * result.stride(0), + ) + out_slice[scatter_mask, :] = ys[i * n_reduce_shards + j][scatter_mask, :] + return result diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py b/python/triton_kernels/triton_kernels/matmul_details/_common.py similarity index 63% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py rename to python/triton_kernels/triton_kernels/matmul_details/_common.py index 951be4c053..4c04b1aafc 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py +++ b/python/triton_kernels/triton_kernels/matmul_details/_common.py @@ -53,100 +53,80 @@ def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr): @triton.jit -def _load_tile_attrs( - tile_id, - num_tiles, - unpadded_m, - grid_n, - M, - K, - ExptData, - ExptHist, - ExptOffs, - ExptTileOffs, - EXPT_IS_INNER: tl.constexpr, - X_IS_PADDED: tl.constexpr, - W_IS_PADDED: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, - PACKED_BLOCK_K_W: tl.constexpr, - SPLIT_K: tl.constexpr, - GROUP_M: tl.constexpr, - XCD_SWIZZLE: tl.constexpr, - SWIZZLE_MX_VALUE: tl.constexpr, -): - # unpack and swizzle program ids - pid_emnk = tile_id +def compute_pids(block_id, grid_m, grid_n, num_blocks, XCD_SWIZZLE: tl.constexpr, GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr): + pid_zmnk = block_id if XCD_SWIZZLE != 1: - pid_emnk = xcd_swizzle(pid_emnk, num_tiles, XCD_SWIZZLE) - pid_e = pid_emnk // (unpadded_m * grid_n * SPLIT_K) - pid_mnk = pid_emnk % (unpadded_m * grid_n * SPLIT_K) + pid_zmnk = xcd_swizzle(pid_zmnk, num_blocks, XCD_SWIZZLE) + pid_z = pid_zmnk // (grid_m * grid_n * SPLIT_K) + pid_mnk = pid_zmnk % (grid_m * grid_n * SPLIT_K) if SPLIT_K > 1: pid_k = pid_mnk % SPLIT_K pid_mn = pid_mnk // SPLIT_K else: pid_k: tl.constexpr = 0 pid_mn = pid_mnk - pid_m, pid_n = swizzle2d(pid_mn, unpadded_m, grid_n, GROUP_M) + pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M) + return pid_z, pid_m, pid_n, pid_k - # unpack expert data - if EXPT_IS_INNER: - # pid_e indicates expert ID: experts are laid sequentially along the K dimension + +@triton.jit +def compute_offsets( + pid_z, + pid_m, + pid_k, + XBlockSchedule, + XSliceOffs, + X_SLICE_SIZE_DIVISIBILITY: tl.constexpr, + WBlockSchedule, + WSliceOffs, + W_SLICE_SIZE_DIVISIBILITY: tl.constexpr, + RAGGED_DIMENSION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K_X: tl.constexpr, + PACKED_BLOCK_K_W: tl.constexpr, + SPLIT_K: tl.constexpr, +): + if RAGGED_DIMENSION == "K": + # pid_z indicates slice ID: experts are laid sequentially along the K dimension # (i.e., we have columns for expert 0, and then expert 1, and then so on). # pid_k is meaningless (always zero). - tl.static_assert(X_IS_PADDED or W_IS_PADDED, "At least one input must be padded!") + tl.static_assert(X_SLICE_SIZE_DIVISIBILITY is not None or \ + W_SLICE_SIZE_DIVISIBILITY is not None, + "At least one input must be padded!") tl.static_assert(SPLIT_K == 1, "Not supported yet") - tl.static_assert(M is not None) - expt_id, pid_z, pid_z_out, start_m, block_id, eM = 0, 0, pid_e, 0, pid_m, M - k_tiles = tl.cdiv(tl.load(ExptHist + pid_e), BLOCK_K) - padded_start_off_raw = tl.load(ExptTileOffs + pid_e) - padded_start_off = padded_start_off_raw * BLOCK_K - unpadded_start_off = tl.load(ExptOffs + pid_e) - off_k_x = padded_start_off if X_IS_PADDED else unpadded_start_off - # K_W is only used for non-TMA kernel (W bound is handled by TMA on TMA kernel). - if W_IS_PADDED: - off_k_w = padded_start_off_raw * PACKED_BLOCK_K_W - K_W = tl.load(ExptTileOffs + pid_e + 1) * PACKED_BLOCK_K_W + off_x_k = tl.load(XSliceOffs + pid_z) + off_w_k = tl.load(WSliceOffs + pid_z) + if PACKED_BLOCK_K_W >= BLOCK_K_X: + off_w_k = off_w_k * (PACKED_BLOCK_K_W // BLOCK_K_X) else: - off_k_w = unpadded_start_off - K_W = tl.load(ExptOffs + pid_e + 1) + off_w_k = off_w_k // (BLOCK_K_X // PACKED_BLOCK_K_W) + off_x_m = BLOCK_M * pid_m + off_w_z, off_x_z, off_x_slice = 0, 0, 0 + off_y_z = pid_z + elif RAGGED_DIMENSION == "M": + off_x_k = pid_k * BLOCK_K_X + off_w_k = pid_k * PACKED_BLOCK_K_W + block_schedule = tl.load(XBlockSchedule + pid_m) + off_w_z = block_schedule & 0x0000FFFF + block_id = block_schedule >> 16 + off_x_slice = tl.load(XSliceOffs + off_w_z) + off_x_z, off_y_z = 0, 0 + off_x_m = BLOCK_M * block_id else: - off_k_x = pid_k * BLOCK_K - off_k_w = pid_k * PACKED_BLOCK_K_W - if PACKED_BLOCK_K_W >= BLOCK_K: - K_W = K * (PACKED_BLOCK_K_W // BLOCK_K) - else: - K_W = K // (BLOCK_K // PACKED_BLOCK_K_W) - if SWIZZLE_MX_VALUE == "HOPPER_VALUE": - K_W = tl.cdiv(K_W, 128) * 128 - k_tiles = tl.cdiv(K - off_k_x, BLOCK_K * SPLIT_K) - if ExptData is None: - tl.static_assert(M is not None) - expt_id, pid_z, pid_z_out, start_m, block_id, eM = pid_e, pid_e, pid_e, 0, pid_m, M - else: - tl.static_assert(M is None) - expt_data = tl.load(ExptData + pid_m) - expt_id = expt_data & 0x0000FFFF - block_id = expt_data >> 16 - eM = tl.load(ExptHist + expt_id) - start_m = tl.load(ExptOffs + expt_id) - pid_z, pid_z_out = 0, 0 - - off_m = BLOCK_M * block_id - + tl.static_assert(RAGGED_DIMENSION is None) + off_x_k = pid_k * BLOCK_K_X + off_w_k = pid_k * PACKED_BLOCK_K_W + off_w_z, off_x_z, off_y_z, off_x_slice = pid_z, pid_z, pid_z, 0 + off_x_m = BLOCK_M * pid_m return ( - expt_id, - pid_z, - pid_z_out, - start_m, - eM, - off_m, - pid_n, - k_tiles, - pid_k, - off_k_x, - off_k_w, - K_W, + off_w_z, + off_x_z, + off_y_z, + off_x_slice, + off_x_m, + off_x_k, + off_w_k, ) @@ -192,40 +172,36 @@ def matmul_launch_metadata(grid, kernel, args): ret = dict() M, N, K = args["M"], args["N"], args["K"] Y, X, W = args["YPtr"], args["XPtr"], args["WPtr"] - tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION") - hist = args["ExptHist"] + expected_slice_sizes = args.get("X_EXPECTED_SLICE_SIZE") + slice_sizes = args["XSliceSizes"] batch_size = args.get("batch_size", 1) - expt_is_inner = args["EXPT_IS_INNER"] - if hist is not None: + if slice_sizes is not None: # If annotation is given, use that to generate name for profiling. - if tokens_per_expt is not None: - n_rows = f"{tokens_per_expt}*" + if expected_slice_sizes is not None: + n_rows = f"{expected_slice_sizes}*" elif launch_metadata_allow_sync(): - n_rows = int(hist.float().mean()) + n_rows = int(slice_sizes.float().mean()) else: n_rows = "unknown" - if launch_metadata_allow_sync(): - n_tokens = float(hist.sum()) - n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum() - elif tokens_per_expt is not None: - n_tokens = tokens_per_expt * args["N_EXPTS_TOT"] + n_tokens = float(slice_sizes.sum()) + n_w_bytes = (W.numel() * W.element_size() // slice_sizes.numel()) * (slice_sizes > 0).sum() + elif expected_slice_sizes is not None: + n_tokens = expected_slice_sizes * args["N_SLICES"] # This may not be totally correct (e.g., we might not be using all experts) # but it's better than nothing. n_w_bytes = W.numel() * W.element_size() else: n_tokens = None n_w_bytes = 0 - # If annotation is given, use that to generate name for profiling. - tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION") - n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows + n_rows = f"{expected_slice_sizes}*" if expected_slice_sizes is not None else n_rows else: n_tokens = None n_w_bytes = W.numel() * W.element_size() - if expt_is_inner: + if args["RAGGED_DIMENSION"] == "K": K = None if n_tokens is None else int(n_tokens) - repr = lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}" + repr = lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(slice_sizes)}({s}) = {n_rows}" nbits = X.dtype.itemsize * 8 batch_repr = "" if batch_size > 1: @@ -235,20 +211,21 @@ def matmul_launch_metadata(grid, kernel, args): if ep_subtile is not None and ep_subtile > 1: ret["name"] += f" ep/{ep_subtile}" - if hist is not None and n_tokens is None: + if slice_sizes is not None and n_tokens is None: return ret # Don't fill metadata because we can't compute them properly. fM = M if M is not None else n_tokens - ret[f"flops{nbits}"] = 2.0 * fM * N * K * (1 if expt_is_inner else batch_size) + Z = 1 if args["RAGGED_DIMENSION"] == "K" else batch_size + ret[f"flops{nbits}"] = 2.0 * fM * N * K * Z # sindx = args.get("WriteBackIndx", None) n_x_bytes = X.numel() * X.element_size() n_y_bytes = Y.numel() * Y.element_size() - if hist is not None: + if slice_sizes is not None: assert n_tokens is not None n_read_rows = n_tokens - if expt_is_inner: + if args["RAGGED_DIMENSION"] == "K": n_x_bytes = n_read_rows * X.shape[-2] * X.element_size() # Here, we're computing dW = X.T@dY, so "W" is actually dY and "Y" is actually dW. n_y_bytes = Y.numel() * Y.element_size() * (2 if args["OutAcc"] is not None else 1) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_details/_matmul.py similarity index 85% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py rename to python/triton_kernels/triton_kernels/matmul_details/_matmul.py index d3f26b9299..b165534afe 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_details/_matmul.py @@ -9,36 +9,19 @@ from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE from ._common import ( - _load_tile_attrs, + compute_offsets, get_scaled_dot_format_string, make_matmul_repr, matmul_launch_metadata, - swizzle2d, - xcd_swizzle, threadfence_system, + compute_pids, ) -@triton.jit -def _zero_masked_rows( - pid_m, pid_n, - Y, stride_y_m, stride_y_n, - N, - ScatterSrcIndx, num_idxs, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M) - offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N) - src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0) - YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n - mask_n = offs_n < N - mask = (src_idx == -1)[:, None] & mask_n[None, :] - tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask) - - -_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2]) +_matmul_repr = make_matmul_repr("_matmul", [0, 1, 2]) @triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"], - repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata) -def _matmul_ogs( + repr=_matmul_repr, launch_metadata=matmul_launch_metadata) +def _matmul( Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n, YExpectedScale, YActualScale, YChecksumScale, stride_y_mx_k, stride_y_mx_z, stride_y_mx_m, stride_y_mx_n, @@ -54,14 +37,11 @@ def _matmul_ogs( M, N, K, K_W, # shapes # expt data Betas, Gammas, - GatherIndx, GatherDstIndx, # GatherDstIndx is only used for launch metadata. - ScatterSrcIndx, num_idxs, + GatherIndx, WriteBackIndx, writeback_size, - ExptHist, ExptOffs, ExptTileOffs, ExptData, - EXPT_IS_INNER: tl.constexpr, - X_IS_PADDED: tl.constexpr, - W_IS_PADDED: tl.constexpr, - ExptHistMax, + RAGGED_DIMENSION: tl.constexpr, + XSliceSizes, XSliceOffs, XBlockOffs, XBlockSchedule, X_EXPECTED_SLICE_SIZE: tl.constexpr, X_SLICE_SIZES_DIVISIBILITY: tl.constexpr, + WSliceSizes, WSliceOffs, WBlockOffs, WBlockSchedule, W_EXPECTED_SLICE_SIZE: tl.constexpr, _W_SLICE_SIZES_DIVISIBILITY: tl.constexpr, # true grid size batch_size, grid_m, grid_n, # Out scale @@ -81,7 +61,6 @@ def _matmul_ogs( # optimization config BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr, - INIT_OUTPUT_TO_ZERO: tl.constexpr, # One of ["HOPPER", "BLACKWELL", None] SWIZZLE_MX_VALUE: tl.constexpr, # One of ["HOPPER", "BLACKWELL", None] @@ -123,12 +102,14 @@ def _matmul_ogs( tl.assume(grid_m >= 0) tl.assume(grid_n >= 0) + + w_type: tl.constexpr = W.dtype.element_ty is_x_microscaled: tl.constexpr = XMxScale is not None is_w_microscaled: tl.constexpr = WMxScale is not None + is_w_mxfp4: tl.constexpr = w_type == tl.uint8 and is_w_microscaled + MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE if is_w_microscaled: - w_type: tl.constexpr = W.dtype.element_ty - is_mxfp4: tl.constexpr = w_type == tl.uint8 tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5), "mx_weight_ptr must be uint8 or fp8") tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") @@ -137,7 +118,7 @@ def _matmul_ogs( # TODO: refactor if/else when triton front end improves if SWIZZLE_MX_VALUE == "HOPPER_VALUE": - tl.static_assert(is_mxfp4, "Only mxfp4 is supported for HOPPER swizzling") + tl.static_assert(is_w_mxfp4, "Only mxfp4 is supported for HOPPER swizzling") tl.static_assert(not is_x_microscaled) # We have pack 2 fp4 values in a byte but we divide the dimension by 2 # when swizzling @@ -146,7 +127,7 @@ def _matmul_ogs( W_N_DIVISOR: tl.constexpr = 4 else: # We have pack 2 fp4 values in a byte - W_K_DIVISOR: tl.constexpr = 2 if is_mxfp4 else 1 + W_K_DIVISOR: tl.constexpr = 2 if is_w_mxfp4 else 1 W_K_MULTIPLIER: tl.constexpr = 1 W_N_DIVISOR: tl.constexpr = 1 @@ -177,52 +158,71 @@ def _matmul_ogs( tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") is_out_microscaled: tl.constexpr = stride_y_mx_z is not None + if _W_SLICE_SIZES_DIVISIBILITY is None: + W_SLICE_SIZES_DIVISIBILITY: tl.constexpr = 1 + else: + if PACKED_BLOCK_K_W > BLOCK_K: + W_SLICE_SIZES_DIVISIBILITY: tl.constexpr = _W_SLICE_SIZES_DIVISIBILITY * (PACKED_BLOCK_K_W // BLOCK_K) + else: + W_SLICE_SIZES_DIVISIBILITY: tl.constexpr = _W_SLICE_SIZES_DIVISIBILITY // (BLOCK_K // PACKED_BLOCK_K_W) + OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N yN = N // ACTIVATION_REDUCTION_N pid = tl.program_id(0) - if ExptTileOffs is not None and (not EXPT_IS_INNER): - # Determine how much padding there is on the expert data. This allows us to - # know the true grid size and avoid processing padding tiles. - padding_m = grid_m - tl.load(ExptTileOffs + N_EXPTS_TOT) + if RAGGED_DIMENSION == "M": + padding_m = grid_m - tl.load(XBlockOffs + N_EXPTS_TOT) else: padding_m: tl.constexpr = 0 - HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32 unpadded_m = grid_m - padding_m tl.assume(unpadded_m >= 0) total_actual_tiles = batch_size * unpadded_m * grid_n * SPLIT_K - # set masked out rows to 0 - # We are tiling Y here, so the tiling is independent of matmul (where we - # tile X & W and scatter to different rows of Y). - # TODO: refactor (same code in _p_matmul_ogs) - if HAS_FUSED_SCATTER and INIT_OUTPUT_TO_ZERO: - tl.device_assert(batch_size == 1) - pid_mnk = pid - if XCD_SWIZZLE != 1: - pid_mnk = xcd_swizzle(pid_mnk, grid_m * grid_n * SPLIT_K, XCD_SWIZZLE) - pid_k = pid_mnk % SPLIT_K - pid_mn = pid_mnk // SPLIT_K - pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M) - _zero_masked_rows(pid_m, pid_n, - Y + pid_k.to(index_type) * stride_y_k, stride_y_m, stride_y_n, - yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N) - if padding_m > 0 and pid >= total_actual_tiles: return + pid_s, pid_m, pid_n, pid_k = compute_pids(pid, unpadded_m, grid_n, total_actual_tiles, XCD_SWIZZLE, GROUP_M, SPLIT_K) + loop_k = tl.multiple_of(tl.load(XSliceSizes + pid_s), X_SLICE_SIZES_DIVISIBILITY) if RAGGED_DIMENSION == "K" else K + ( expt_id, start_z, start_z_out, - start_m, eM, off_m, pid_n, - k_tiles, pid_k, off_k_x, off_k_w, K_W, - ) = _load_tile_attrs(pid, total_actual_tiles, unpadded_m, grid_n, - M, K, ExptData, ExptHist, ExptOffs, ExptTileOffs, - EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED, - BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K, - GROUP_M, XCD_SWIZZLE, SWIZZLE_MX_VALUE) + start_m, off_m, + off_k_x, off_k_w + ) = compute_offsets( + pid_s, pid_m, pid_k, + XBlockSchedule, XSliceOffs, X_SLICE_SIZES_DIVISIBILITY, + WBlockSchedule, WSliceOffs, W_SLICE_SIZES_DIVISIBILITY, + RAGGED_DIMENSION, + BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K + ) + if X_SLICE_SIZES_DIVISIBILITY is not None: + off_k_x = off_k_x // X_SLICE_SIZES_DIVISIBILITY * X_SLICE_SIZES_DIVISIBILITY + if W_SLICE_SIZES_DIVISIBILITY is not None: + off_k_w = off_k_w // W_SLICE_SIZES_DIVISIBILITY * W_SLICE_SIZES_DIVISIBILITY + + + if RAGGED_DIMENSION == "M": + eM = tl.multiple_of(tl.load(XSliceSizes + expt_id), X_SLICE_SIZES_DIVISIBILITY) + else: + eM = M + + + if RAGGED_DIMENSION == "K": + K_W = tl.multiple_of(tl.load(WSliceOffs + pid_s + 1), W_SLICE_SIZES_DIVISIBILITY) + if PACKED_BLOCK_K_W > BLOCK_K: + K_W = K_W * (PACKED_BLOCK_K_W // BLOCK_K) + else: + K_W = K_W // (BLOCK_K // PACKED_BLOCK_K_W) + K_X = tl.multiple_of(tl.load(XSliceOffs + pid_s + 1), X_SLICE_SIZES_DIVISIBILITY) + else: + K_W = K * (PACKED_BLOCK_K_W // BLOCK_K) if PACKED_BLOCK_K_W >= BLOCK_K else K // (BLOCK_K // PACKED_BLOCK_K_W) + K_X = K + + loop_k = tl.multiple_of(tl.load(XSliceSizes + pid_s), X_SLICE_SIZES_DIVISIBILITY) if RAGGED_DIMENSION == "K" else K - off_k_x + k_tiles = tl.cdiv(loop_k, BLOCK_K * SPLIT_K) # For split-k, advance to the output k slice if SPLIT_K > 1: @@ -246,6 +246,7 @@ def _matmul_ogs( offs_k = off_k_x + tl.arange(0, BLOCK_K) XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k + # TODO: refactor if/else when triton front end improves if is_w_microscaled: WMxScale += expt_id * stride_w_mx_e @@ -262,6 +263,7 @@ def _matmul_ogs( # TODO: support non W_TRANSPOSE with Hopper swizzling tl.static_assert(W_TRANSPOSE) n_warps: tl.constexpr = tl.extra.cuda.num_warps() + tl.static_assert(n_warps == 8) tl.static_assert(BLOCK_N % (2 * n_warps * 2 * 8) == 0) tl.static_assert(MX_SCALE_BLOCK_K % 2 == 0) PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32 @@ -281,7 +283,6 @@ def _matmul_ogs( offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N) # K dimension must be the last dimension for the scales - tl.static_assert(not EXPT_IS_INNER or W_IS_PADDED) offs_k_scale = off_k_w // PACKED_BLOCK_K_W * PACKED_MX_BLOCK + tl.arange(0, PACKED_MX_BLOCK) WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n else: @@ -309,27 +310,28 @@ def _matmul_ogs( WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n) # compute output acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - x_k_limit = K + BLOCK_K * SPLIT_K + x_k_limit = K_X + BLOCK_K * SPLIT_K w_k_limit = K_W + PACKED_BLOCK_K_W * SPLIT_K + for ki in range(k_tiles): x_k_limit -= BLOCK_K * SPLIT_K w_k_limit -= PACKED_BLOCK_K_W * SPLIT_K if EVEN_K: - mask_k = tl.full([BLOCK_K], True, dtype=tl.int1) + mask_k_x = tl.full([BLOCK_K], True, dtype=tl.int1) mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1) if is_w_microscaled and SWIZZLE_MX_SCALE is None: mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1) if is_x_microscaled: mask_x_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1) else: - mask_k = offs_k < x_k_limit + mask_k_x = offs_k < x_k_limit mask_k_w = offs_w_k < w_k_limit if is_w_microscaled and SWIZZLE_MX_SCALE is None: - mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < x_k_limit + mask_k_scale = offs_k_scale * MX_PACK_DIVISOR // 2 < w_k_limit if is_x_microscaled: mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < x_k_limit - x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0) + x = tl.load(XPtrs, mask=mask_k_x[None, :], other=0.0) w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER) if is_w_microscaled: x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype) @@ -348,6 +350,7 @@ def _matmul_ogs( elif SWIZZLE_MX_SCALE == "HOPPER_SCALE": # Handshake with the swizzling code num_warps: tl.constexpr = tl.extra.cuda.num_warps() + tl.static_assert(num_warps == 8) w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps) elif SWIZZLE_MX_SCALE == "CDNA4_SCALE": w_scales = unswizzle_mx_scale_cdna4(tl.load(WMxScalePtrs), BLOCK_N, MX_SCALE_BLOCK_K) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py similarity index 80% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py rename to python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py index 3a38b32cc1..ccb7a4a1a5 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py @@ -14,12 +14,12 @@ ) from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE from ._common import ( - _load_tile_attrs, + compute_offsets, get_scaled_dot_format_string, make_matmul_repr, matmul_launch_metadata, - swizzle2d, threadfence_system, + compute_pids, ) @@ -44,10 +44,10 @@ def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask): return (offs, mask) -_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2]) +_matmul_repr = make_matmul_repr("_p_matmul", [0, 1, 2]) @triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"], - repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata) -def _p_matmul_ogs( + repr=_matmul_repr, launch_metadata=matmul_launch_metadata) +def _p_matmul( Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n, YExpectedScale, YActualScale, YChecksumScale, stride_y_mx_k, stride_y_mx_z, stride_y_mx_m, stride_y_mx_n, @@ -63,14 +63,11 @@ def _p_matmul_ogs( M, N, K, K_W, # shapes # expt data Betas, Gammas, - GatherIndx, GatherDstIndx, # GatherDstIndx is only used for launch metadata. - ScatterSrcIndx, num_idxs, + GatherIndx, WriteBackIndx, writeback_size, - ExptHist, ExptOffs, ExptTileOffs, ExptData, - EXPT_IS_INNER: tl.constexpr, - X_IS_PADDED: tl.constexpr, - W_IS_PADDED: tl.constexpr, - ExptHistMax, + RAGGED_DIMENSION: tl.constexpr, + XSliceSizes, XSliceOffs, XBlockOffs, XBlockSchedule, X_EXPECTED_SLICE_SIZE: tl.constexpr, X_SLICE_SIZES_DIVISIBILITY: tl.constexpr, + WSliceSizes, WSliceOffs, WBlockOffs, WBlockSchedule, W_EXPECTED_SLICE_SIZE: tl.constexpr, W_SLICE_SIZES_DIVISIBILITY: tl.constexpr, # true grid size batch_size, grid_m, grid_n, # Out scale @@ -80,7 +77,7 @@ def _p_matmul_ogs( # epilogue transform EPILOGUE_FN: tl.constexpr, epilogue_fn_args, # MoE config - N_EXPTS_TOT: tl.constexpr, + N_SLICES: tl.constexpr, # precision config MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr, FLEXPOINT_SATURATE_INF: tl.constexpr, @@ -90,7 +87,6 @@ def _p_matmul_ogs( # optimization config BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr, - INIT_OUTPUT_TO_ZERO: tl.constexpr, # NYI: Must be None SWIZZLE_MX_VALUE: tl.constexpr, # One of ["BLACKWELL", None] @@ -142,12 +138,10 @@ def _p_matmul_ogs( tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") is_out_microscaled: tl.constexpr = stride_y_mx_z is not None - if ExptTileOffs is not None and (not EXPT_IS_INNER): - # Determine how much padding there is on the expert data. This allows us to - # know the true grid size and avoid processing padding tiles. - padding_m = grid_m - tl.load(ExptTileOffs + N_EXPTS_TOT) + if RAGGED_DIMENSION == "M": + useful_grid_m = tl.load(XBlockOffs + N_SLICES) else: - padding_m: tl.constexpr = 0 + useful_grid_m = grid_m index_type: tl.constexpr = tl.int64 @@ -157,12 +151,9 @@ def _p_matmul_ogs( USE_GATHER_TMA: tl.constexpr = HAS_GATHER and X_TMA_MODE == "dense" USE_SCATTER_TMA: tl.constexpr = HAS_SCATTER and Y_TMA_MODE == "dense" - if EXPT_IS_INNER: + if RAGGED_DIMENSION == "K": tl.static_assert((OutAcc is None) or Y_ACC_IS_Y, "Using differernt y_acc is not supported with TMA kernel.") - tl.static_assert( - not (HAS_SCATTER or USE_GATHER_TMA or USE_SCATTER_TMA), - "Cannot be used with EXPT_IS_INNER" - ) + tl.static_assert(not (HAS_SCATTER or USE_GATHER_TMA or USE_SCATTER_TMA), "Cannot be used with RAGGED_DIMENSION == 'K'") if EPILOGUE_SUBTILE is None: SUBTILE_FACTOR: tl.constexpr = 1 @@ -172,28 +163,7 @@ def _p_matmul_ogs( OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N yN = N // ACTIVATION_REDUCTION_N - # set masked out rows to 0 - if HAS_SCATTER and INIT_OUTPUT_TO_ZERO: - # Iterate with reversed pids so that later pids will get more tiles if the number of - # tiles isn't evenly divisible by the number of SMs. - # The main loop after this iterates in the forward direction such that earlier - # pids get more tiles if the number of tiles isn't evenly divisible. - # This helps balance the work across the SMs. - for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS): - pid_k = pid_mnk % SPLIT_K - pid_mn = pid_mnk // SPLIT_K - pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M) - - z = tl.zeros([BLOCK_M, BLOCK_N // ACTIVATION_REDUCTION_N], dtype=tl.float32) - offs_m = z.shape[0] * pid_m + tl.arange(0, z.shape[0]) - offs_n = z.shape[1] * pid_n + tl.arange(0, z.shape[1]) - src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0) - YPtrs = YPtr + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n - mask_n = offs_n < yN - mask = (src_idx == -1)[:, None] & mask_n[None, :] - tl.store(YPtrs + pid_k * stride_y_k, z, mask=mask) - - num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K + num_blocks = batch_size * useful_grid_m * grid_n * SPLIT_K # If true, do not share loop-carried variables between the prologue and the # epilogue to enable better pipelining with mmav5 @@ -211,55 +181,71 @@ def _p_matmul_ogs( DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_w_microscaled and BLOCK_M * BLOCK_N >= 128 * 256 - for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=True): - expt_id, start_z, start_z_out, start_m, eM, off_m, pid_n, k_tiles, pid_k, off_k_x0, off_k_w0, _ = _load_tile_attrs( - tile_id, num_tiles, grid_m - padding_m, grid_n, - M, K, ExptData, ExptHist, ExptOffs, ExptTileOffs, - EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED, - BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K, - GROUP_M, XCD_SWIZZLE, SWIZZLE_MX_VALUE) - off_n = BLOCK_N * pid_n + for block_id in tl.range(tl.program_id(0), num_blocks, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=True): - # Base pointers and offsets. - if X_TMA_MODE is None: - XBase = X + start_z.to(index_type) * stride_x_z - offs_x_k = (off_k_x0.to(index_type) + tl.arange(0, BLOCK_K))[None, :] * stride_x_k + pid_z, pid_m, pid_n, pid_k = compute_pids(block_id, useful_grid_m, grid_n, num_blocks, XCD_SWIZZLE, GROUP_M, SPLIT_K) + + # ------------------------------------------------------------ + # prologue + # ------------------------------------------------------------ + off_w_z, off_x_z, off_y_z, slice_off_m, off_m, off_k_x0, off_k_w0 = compute_offsets( + pid_z, pid_m, pid_k, + XBlockSchedule, XSliceOffs, X_SLICE_SIZES_DIVISIBILITY, + WBlockSchedule, WSliceOffs, W_SLICE_SIZES_DIVISIBILITY, + RAGGED_DIMENSION, + BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K + ) + # TODO: if RAGGED_DIMENSION == "M" + if RAGGED_DIMENSION == "M": + shape_m = tl.load(XSliceSizes + off_w_z) + else: + shape_m = M + + off_n = BLOCK_N * pid_n + + # ---- offset x ------ if USE_GATHER_TMA: offs_m = off_m + tl.arange(0, BLOCK_M) - mask_m = offs_m < eM - if ExptData is None: - offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m) + mask_m = offs_m < shape_m + if XBlockSchedule is None: + offs_x_m = tl.load(GatherIndx + slice_off_m.to(index_type) + offs_m, mask=mask_m) # Bump rows to account for the Z offset. - offs_x_m += start_z * (stride_x_z // stride_x_m) + offs_x_m += off_x_z * (stride_x_z // stride_x_m) offs_x_m = tl.where(mask_m, offs_x_m, -1) else: - offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m, other=-1) - elif X_TMA_MODE is None or is_x_microscaled: + offs_x_m = tl.load(GatherIndx + slice_off_m.to(index_type) + offs_m, mask=mask_m, other=-1) + if X_TMA_MODE is None: + XBase = X + off_x_z.to(index_type) * stride_x_z offs_m = off_m + tl.arange(0, BLOCK_M) - offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M) + offs_m = tl.max_contiguous(tl.multiple_of(offs_m % shape_m, BLOCK_M), BLOCK_M) # no needs to bounds-check here because `offs_m` wraps around M dim if GatherIndx is not None: tl.static_assert(HAS_GATHER) - offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) + offs_m = tl.load(GatherIndx + slice_off_m.to(index_type) + offs_m) offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m + offs_x_k = (off_k_x0.to(index_type) + tl.arange(0, BLOCK_K))[None, :] * stride_x_k if is_x_microscaled: - XMxScalePtrs = XMxScale + start_z.to(index_type) * stride_x_mx_z + offs_m = off_m + tl.arange(0, BLOCK_M) + XMxScalePtrs = XMxScale + off_x_z.to(index_type) * stride_x_mx_z if GatherIndx is None: - XMxScalePtrs += start_m * stride_x_mx_m + XMxScalePtrs += slice_off_m * stride_x_mx_m offs_k_scale = off_k_x0 // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K) XMxScalePtrs += (offs_x_m if USE_GATHER_TMA else offs_m).to(index_type)[:, None] * stride_x_mx_m XMxScalePtrs += offs_k_scale.to(index_type)[None, :] * stride_x_mx_k - else: - XMxScalePtrs = None acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32) + # ------------------------------------------------------------ + # inner loop + # ------------------------------------------------------------ + loop_k = tl.load(XSliceSizes + pid_z) if RAGGED_DIMENSION == "K" else K - off_k_x0 + k_tiles = tl.cdiv(loop_k, BLOCK_K * SPLIT_K) loop_bound = tl.maximum(k_tiles, 1) tl.assume(loop_bound > 0) # Currently necessary for the compiler to flatten the loop properly. for ki in tl.range(loop_bound, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER): - if EXPT_IS_INNER and ki >= k_tiles: + if RAGGED_DIMENSION == "K" and ki >= k_tiles: # Tile #ki does not exist: use out-of-bound indices to mask all loads. off_k_x = K off_k_w = K_W @@ -272,13 +258,13 @@ def _p_matmul_ogs( x = X.gather(offs_x_m, off_k_x) elif X_TMA_MODE == "dense": if X_TRANSPOSE: - x = X.load([start_z, off_k_x, start_m + off_m]) + x = X.load([off_x_z, off_k_x, slice_off_m + off_m]) x = x.reshape(BLOCK_K, BLOCK_M).T else: - x = X.load([start_z, start_m + off_m, off_k_x]) + x = X.load([off_x_z, slice_off_m + off_m, off_k_x]) x = x.reshape(BLOCK_M, BLOCK_K) elif X_TMA_MODE == "ragged": - x = load_ragged(X, start_m, eM, [start_z, off_m, off_k_x], ragged_dim=1) + x = load_ragged(X, slice_off_m, shape_m, [off_x_z, off_m, off_k_x], ragged_dim=1) x = x.reshape(BLOCK_M, BLOCK_K) else: tl.static_assert(X_TMA_MODE is None) @@ -293,37 +279,39 @@ def _p_matmul_ogs( else: x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0) + # --- load x_scale --- + x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype) + if is_x_microscaled: + off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_PACK_DIVISOR) + if EVEN_K: + mask_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1) + else: + mask_k_scale = off_k_mx + tl.arange(0, MX_SCALE_BLOCK_K) < tl.cdiv(K, MX_PACK_DIVISOR) + mask_m = off_m + tl.arange(0, BLOCK_M) < shape_m + x_scales = tl.load(XMxScalePtrs, mask=mask_k_scale[None, :] & mask_m[:, None], other=0.0) + elif x_format == "fp16" or x_format == "bf16": + x_scales: tl.constexpr = None + else: + x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8) + # --- load w --- if W_TRANSPOSE: - w = tl.reshape(W.load([expt_id, off_n, off_k_w]), W.block_shape[1:]).T + w = tl.reshape(W.load([off_w_z, off_n, off_k_w]), W.block_shape[1:]).T else: - w = tl.reshape(W.load([expt_id, off_k_w, off_n]), W.block_shape[1:]) + w = tl.reshape(W.load([off_w_z, off_k_w, off_n]), W.block_shape[1:]) # --- load w_scale --- + w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype) if is_w_microscaled: - x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype) - w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype) off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_PACK_DIVISOR) - - if is_x_microscaled: - if EVEN_K: - mask_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1) - else: - mask_k_scale = off_k_mx + tl.arange(0, MX_SCALE_BLOCK_K) < tl.cdiv(K, MX_PACK_DIVISOR) - mask_m = off_m + tl.arange(0, BLOCK_M) < eM - x_scales = tl.load(XMxScalePtrs, mask=mask_k_scale[None, :] & mask_m[:, None], other=0.0) - elif x_format == "fp16" or x_format == "bf16": - x_scales: tl.constexpr = None - else: - x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8) tl.static_assert(MX_PACK_DIVISOR % W_PACK_DIVISOR == 0) if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE": - flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128) + flattened_expt_n_idx = off_w_z * ((N + 127) // 128) + (off_n // 128) w_scales = WMxScale.load([0, flattened_expt_n_idx, off_k_mx // 4, 0, 0]) w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1])) w_scales = unswizzle_mx_scale_bw(w_scales) else: - w_scales = WMxScale.load([expt_id, off_k_mx, off_n]) + w_scales = WMxScale.load([off_w_z, off_k_mx, off_n]) w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T # --- update accumulator --- @@ -332,25 +320,35 @@ def _p_matmul_ogs( acc = tl.dot_scaled(w.T, w_scales, w_format, x.T, x_scales, x_format, acc=acc, fast_math=True) else: acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True) - if is_x_microscaled: - XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k else: if SWAP_XW: acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) else: acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) + if is_x_microscaled: + XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k + + # ------------------------------------------------------------ + # epilogue + # ------------------------------------------------------------ if INDEPENDENT_EPILOGUE: tile_id1 += NUM_SMS - expt_id1, _, start_z1, start_m1, eM1, off_m1, pid_n1, _, pid_k1, _, _, _ = _load_tile_attrs( - tile_id1, num_tiles, grid_m - padding_m, grid_n, - M, K, ExptData, ExptHist, ExptOffs, ExptTileOffs, - EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED, - BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K, - GROUP_M, XCD_SWIZZLE, SWIZZLE_MX_VALUE) + pid_s1, pid_m1, pid_n1, pid_k1 = compute_pids(tile_id1, useful_grid_m, grid_n, num_blocks, XCD_SWIZZLE, GROUP_M, SPLIT_K) + expt_id1, _, start_z1, start_m1, off_m1, _, _ = compute_offsets( + pid_z, pid_m, pid_k, + XBlockSchedule, XSliceOffs, X_SLICE_SIZES_DIVISIBILITY, + WBlockSchedule, WSliceOffs, W_SLICE_SIZES_DIVISIBILITY, + RAGGED_DIMENSION, + BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K + ) off_n1 = pid_n1 * BLOCK_N + if RAGGED_DIMENSION == "M": + eM1 = tl.load(XSliceSizes + expt_id1) + else: + eM1 = M else: - tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z_out, start_m, eM + tile_id1, expt_id1, start_z1, start_m1, eM1 = block_id, off_w_z, off_y_z, slice_off_m, shape_m off_m1, off_n1, pid_k1 = off_m, off_n, pid_k offs_m = off_m1 + tl.arange(0, BLOCK_M) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py similarity index 88% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py rename to python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index 52fce7b630..5c08d238b5 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -54,7 +54,7 @@ def make_default_opt_flags_intel( m, n, k, - routing_data, + ragged_metadata, can_use_persistent_tma, can_use_split_k, enforce_bitwise_invariance, @@ -65,13 +65,13 @@ def make_default_opt_flags_intel( ): constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() - # tokens per expert - if routing_data is None: - tokens_per_expt = m - elif routing_data.expected_tokens_per_expt is None: - tokens_per_expt = max(1, m // routing_data.n_expts_tot) + # tokens per slice + if ragged_metadata is None: + slice_size = m + elif ragged_metadata.expected_slice_size is None: + slice_size = max(1, m // ragged_metadata.n_slices) else: - tokens_per_expt = routing_data.expected_tokens_per_expt + slice_size = ragged_metadata.expected_slice_size # pid swizzling group_m = 8 xcd_swizzle = 1 @@ -81,7 +81,7 @@ def make_default_opt_flags_intel( elif enforce_bitwise_invariance: block_m = 128 else: - block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + block_m = max(16, min(triton.next_power_of_2(slice_size), 128)) # block n block_n = opt_flags_intel.compute_block_n(n) # is_persistent @@ -136,7 +136,7 @@ def make_default_opt_flags_amd( m, n, k, - routing_data, + ragged_metadata, can_use_persistent_tma, can_use_split_k, enforce_bitwise_invariance, @@ -147,13 +147,13 @@ def make_default_opt_flags_amd( ): constraints_supported = ["block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() - # tokens per expert - if routing_data is None: - tokens_per_expt = m - elif routing_data.expected_tokens_per_expt is None: - tokens_per_expt = max(1, m // routing_data.n_expts_tot) + # tokens per slice + if ragged_metadata is None: + slice_size = m + elif ragged_metadata.expected_slice_size is None: + slice_size = max(1, m // ragged_metadata.n_slices) else: - tokens_per_expt = routing_data.expected_tokens_per_expt + slice_size = ragged_metadata.expected_slice_size is_cdna4 = get_cdna_version() == 4 # block_m @@ -161,15 +161,15 @@ def make_default_opt_flags_amd( block_m = constraints["block_m"] elif enforce_bitwise_invariance: block_m = 256 if is_cdna4 else 128 - elif tokens_per_expt >= 512 and n >= 2048: + elif slice_size >= 512 and n >= 2048: block_m = 256 if is_cdna4 else 128 elif is_cdna4 and m >= 512: block_m = 128 else: - block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64)) + block_m = max(32, min(triton.next_power_of_2(slice_size), 64)) - if routing_data is not None: - grid_m = routing_data.n_blocks(m, block_m) + if ragged_metadata is not None: + grid_m = ragged_metadata.n_blocks(ragged_metadata.n_slices, m, block_m) else: grid_m = triton.cdiv(m, block_m) # group_m: @@ -204,9 +204,13 @@ def make_default_opt_flags_amd( if epilogue_subtile is None: epilogue_subtile = 1 + # prevents OutOfSharedMemoryError for mxfp8 on CDNA3 + if get_cdna_version() == 3 and bitwidth(rhs_dtype) == 8 and precision_config.b_mx_scale is not None: + num_stages = 1 + # specific configs for F16 x MXFP4 on CDNA4 # Note that these configs will exceed LDS usage with async copy enabled - if is_cdna4 and bitwidth(lhs_dtype) == 16 and bitwidth(rhs_dtype) == 4 and precision_config.weight_scale is not None: + if is_cdna4 and bitwidth(lhs_dtype) == 16 and bitwidth(rhs_dtype) == 4 and precision_config.b_mx_scale is not None: split_k = 1 if m <= 1024: target_kernel_kwargs["waves_per_eu"] = 3 @@ -268,11 +272,11 @@ def make_default_opt_flags_nvidia( assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None or batch_size > 1: - tokens_per_expt = m - elif routing_data.expected_tokens_per_expt is None: - tokens_per_expt = max(1, m // routing_data.n_expts_tot) + slice_size = m + elif routing_data.expected_slice_size is None: + slice_size = max(1, m // routing_data.n_slices) else: - tokens_per_expt = routing_data.expected_tokens_per_expt + slice_size = routing_data.expected_slice_size # pid swizzling group_m = 8 xcd_swizzle = 1 @@ -282,14 +286,14 @@ def make_default_opt_flags_nvidia( elif enforce_bitwise_invariance: block_m = 128 else: - if tokens_per_expt <= 64 and routing_data is not None and routing_data.expt_hist is not None: + if slice_size <= 64 and routing_data is not None and routing_data.slice_sizes is not None: # Ragged and likely memory bound; set the block size higher to minimize loading weights more than once. - if lhs_dtype == torch.bfloat16 and rhs_dtype == FP4 and tokens_per_expt >= 16 and torch.cuda.get_device_capability()[0] >= 10: - block_m = max(16, min(triton.next_power_of_2(8 * tokens_per_expt), 128)) + if lhs_dtype == torch.bfloat16 and rhs_dtype == FP4 and slice_size >= 16 and torch.cuda.get_device_capability()[0] >= 10: + block_m = max(16, min(triton.next_power_of_2(8 * slice_size), 128)) else: - block_m = max(16, min(triton.next_power_of_2(2 * tokens_per_expt), 64)) + block_m = max(16, min(triton.next_power_of_2(2 * slice_size), 64)) else: - block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + block_m = max(16, min(triton.next_power_of_2(slice_size), 128)) # block n arch = None block_n, block_n_tma = opt_flags_nvidia.compute_block_n(n, arch, precision_config) @@ -298,7 +302,7 @@ def make_default_opt_flags_nvidia( n_sms = torch.cuda.get_device_properties(0).multi_processor_count tiles_per_sm = grid_size_tma / n_sms supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9) - requires_persistent = (get_layout(precision_config.act_scale) is not None or get_layout(precision_config.weight_scale) is not None) and target_info.has_native_mxfp() + requires_persistent = (get_layout(precision_config.a_mx_scale) is not None or get_layout(precision_config.b_mx_scale) is not None) and target_info.has_native_mxfp() if constraints.get("is_persistent", None) is not None: is_persistent = constraints["is_persistent"] elif requires_persistent: @@ -313,7 +317,7 @@ def make_default_opt_flags_nvidia( block_n = block_n_tma if is_persistent else block_n # block k block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in) - if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and tokens_per_expt > 1 and lhs_dtype != torch.bfloat16: + if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and slice_size > 1 and lhs_dtype != torch.bfloat16: # Swap block_n and block_k for mxfp4 weights so that block_k is a full cacheline, so long as K is sufficiently large. # TODO: swizzle the HBM layout of the weights instead block_n, block_k = block_k, block_n @@ -414,7 +418,7 @@ def make_opt_flags( m, n, k, - routing_data, + ragged_metadata, can_use_persistent_tma, can_use_split_k, epilogue_effective_itemsize, @@ -439,7 +443,7 @@ def make_opt_flags( opt_flags_constraints = opt_flags_constraints.copy() opt_flags_constraints.update(block_k=block_k, split_k=1) args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, batch_size, m, n, k, - routing_data, can_use_persistent_tma, can_use_split_k, + ragged_metadata, can_use_persistent_tma, can_use_split_k, enforce_bitwise_invariance, epilogue_effective_itemsize, x_transpose, has_y_acc_in, opt_flags_constraints] backend = triton.runtime.driver.active.get_current_target().backend diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py similarity index 92% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py rename to python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py index ffe06c333f..593688a74d 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py @@ -28,6 +28,6 @@ def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precisi # TODO: block_k = 128 seems to work better for now. # perhaps due to increased number of k loops to pipeline - if precision_config.weight_scale is not None and get_cdna_version() != 4: + if precision_config.b_mx_scale is not None and get_cdna_version() != 4: block_k = 128 return block_n, block_k diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_intel.py similarity index 91% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py rename to python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_intel.py index 6c3a6ebc3f..586cc01095 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_intel.py @@ -4,7 +4,7 @@ def compute_grid_size(routing_data, m, n, block_m, block_n): if routing_data is not None: - grid_m = routing_data.n_blocks(m, block_m) + grid_m = routing_data.n_blocks(routing_data.n_slices, m, block_m) else: grid_m = triton.cdiv(m, block_m) grid_n = (n + block_n - 1) // block_n @@ -19,7 +19,7 @@ def compute_block_n(n: int): def compute_block_k(k: int | None, is_persistent: bool, precision_config): if k is not None: block_k = max(32, min(128, triton.next_power_of_2(k))) - has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None + has_mx_weight_scale = precision_config is not None and precision_config.b_mx_scale is not None if is_persistent and has_mx_weight_scale: block_k = min(block_k, 128) return block_k diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py similarity index 92% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py rename to python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py index 2cc0f3d41d..7998b56089 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py @@ -8,7 +8,7 @@ def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n): if routing_data is not None and batch_size == 1: - grid_m = routing_data.n_blocks(m, block_m) + grid_m = routing_data.n_blocks(routing_data.n_slices, m, block_m) else: grid_m = triton.cdiv(m, block_m) grid_n = (n + block_n - 1) // block_n @@ -17,10 +17,10 @@ def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n): def compute_block_n(n: int, arch, precision_config): # block_n: - layout = get_layout(precision_config.weight_scale) + layout = get_layout(precision_config.b_mx_scale) if isinstance(layout, HopperMXScaleLayout): if layout.num_warps in [4, 8]: - # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py#L265 + # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/matmul_details/_matmul.py#L265 block_n = 2 * layout.num_warps * 2 * 8 return block_n, block_n elif precision_config.max_num_imprecise_acc is None and n > 128: @@ -41,7 +41,7 @@ def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_d elif k is not None: min_block_k = 32 if is_persistent or lhs_width != 16 or rhs_width != 16 else 16 block_k = max(min_block_k, min(triton.next_power_of_2(k), block_k)) - has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None + has_mx_weight_scale = precision_config is not None and precision_config.b_mx_scale is not None if has_native_mxfp and is_persistent and has_mx_weight_scale: block_k = min(block_k, 128) if has_y_acc_in and lhs_width == rhs_width == 16 and not target_info.cuda_capability_geq(10, 0): @@ -62,7 +62,7 @@ def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int: def compute_num_warps(block_m, block_n, is_persistent: bool, precision_config): - layout = get_layout(precision_config.weight_scale) + layout = get_layout(precision_config.b_mx_scale) if isinstance(layout, HopperMXScaleLayout): return layout.num_warps return max(block_m * block_n // 4096, 4 if is_persistent else 1) @@ -90,7 +90,7 @@ def compute_num_stages( device_props = torch.cuda.get_device_properties(0) smem_capacity = device_props.shared_memory_per_block_optin has_native_mxfp = target_info.cuda_capability_geq(10, 0) - if has_native_mxfp and getattr(precision_config, "weight_scale", None) is not None: + if has_native_mxfp and getattr(precision_config, "b_mx_scale", None) is not None: if rhs_dtype == FP4: # 4-bit e2m1 weights are padded 2x # https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory @@ -114,7 +114,7 @@ def compute_num_stages( smem_capacity -= int((block_m + 4) * acc_block_n * acc_size) if x_transpose: smem_capacity -= block_m * block_k * lhs_dtype.itemsize - if precision_config.weight_scale is not None: + if precision_config.b_mx_scale is not None: # mx scales stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE)) elif has_native_mxfp: diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py deleted file mode 100644 index 281d68f4c6..0000000000 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ /dev/null @@ -1,821 +0,0 @@ -# isort: off -# fmt: off -from dataclasses import dataclass, field -import itertools -import torch -import triton -from enum import Enum, auto -import math -# utilities -from triton_kernels import target_info -from triton_kernels.numerics import InFlexData, OutFlexData -from triton_kernels.target_info import is_cuda -# details -from .matmul_ogs_details._matmul_ogs import _matmul_ogs -from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn -from .numerics_details.mxfp import MXFP_BLOCK_SIZE -from .tensor_details.layout_details.strided import StridedLayout -from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints -from .specialize import FnSpecs, SpecializationModule, ClosureArg -from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor, RaggedTensorMetadata -from .reduce import reduce -from .reduce import PostprocessFn as ReducePostprocessFn - - -@dataclass -class GatherIndx: - """ - Indices for an operation that performs: - Y = X[src_idx, :] - """ - # array such that `dst_idx[src_idx] = arange(0, N)` - src_indx: torch.Tensor - dst_indx: torch.Tensor - - -@dataclass -class ScatterIndx: - """ - Indices for an operation that performs: - Y[dst_idx, :] = X - """ - # array such that `dst_idx[src_idx] = arange(0, N)` - src_indx: torch.Tensor - dst_indx: torch.Tensor - -@dataclass -class RoutingData: - gate_scal: torch.Tensor = field() - expt_hist: torch.Tensor = field() - n_expts_tot: int = field() - n_expts_act: int = field() - expt_data: RaggedTensorMetadata = None - - # Used to make perf annotation cleaner: when we use expert sharding, we can - # use this to tell the "expected" number of local tokens per expert, because - # the actual number can vary per each input. - expected_tokens_per_expt: int = field(default=None) - - def n_blocks(self, n_rows, block_m): - if n_rows <= self.n_expts_tot: - return n_rows - else: - return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1 - -@dataclass(frozen=True) -class FusedActivation: - specs: FnSpecs = FnSpecs.default() - fn_args: tuple[object] = tuple() - - -@dataclass(frozen=True) -class Epilogue: - specs: FnSpecs = FnSpecs.default() - fn_arg_values_matmul: tuple[object] = tuple() - fn_arg_values_finalize: tuple[object] = tuple() - effective_itemsize: float = None - -class FnName(Enum): - QUANTIZE_MXFP8 = auto() - - -@dataclass(frozen=True) -class FusedComm: - out_handles: torch.Tensor - scatter_shard_indx: torch.Tensor | None = None - reduce_rank: int = 0 - n_reduce_shards: int = 1 - -specializations = SpecializationModule("matmul_ogs", - kernels=[("_matmul_ogs", _matmul_ogs), ("_p_matmul_ogs", _p_matmul_ogs)], - closure_args={ - "epilogue": ClosureArg("EPILOGUE_FN", "epilogue_fn_args"), # - "activation": ClosureArg("ACTIVATION_FN", "activation_fn_args"), # - }, -) -# ----------------------------------------------------------------------------- -# Matrix Multiplication + Outer Gather/Scatter -# ----------------------------------------------------------------------------- - - -def can_overflow_int32(tensor: torch.Tensor): - max_int32 = (1 << 31) - 1 - offset = 0 - for i in range(tensor.ndim): - offset += (tensor.shape[i] - 1) * tensor.stride(i) - return offset > max_int32 - - -def should_upcast_indices(*args): - return any(tensor is not None and can_overflow_int32(tensor) for tensor in args) - - -# This supports computing dw for ragged matmul. Note the correspondence: -# fwd pass: y = matmul_ogs(x, w, ...) -# bwd pass (dw): dw = matmul_ogs(x.T, dy, ...) -# -# Thus, "our" x, w, and y (as seen by matmul_ogs) correspond to x.T, dy, dw, respectively. -# To avoid confusion, now we'll stick to "x, w, y" terminology. -# -# Assume that y.shape == (N_EXPTS, M, N), x.shape == (M, K), w.shape = (K_W, N). -# -# To make things feasible, we require that x and w satisfy the following condition: -# (1) We don't support gather/scatter indices: in x, all columns for expt #0 are grouped at -# the leftmost part, followed by expt #1, and so on. Ditto for w (top to bottom). -# (2) At least one of x and w are padded: each expert uses a multiple of block_k columns -# (or rows), and unused values are filled with zero. -# (3) No inf or nan are allowed in x or w (except for the final padding - see below). -# This is because we use "multiplying by padded zero region" in lieu of masking. -# (4) The number of actually used columns/rows equals self.base.expt_hist.sum() and may be -# less than K or K_W. In this case, the final "unused" values can be left uninitialized. -# However, if x or w is unpadded, the first block_k columns/rows of the unused part must -# not contain nan or inf. -# -# For example, assume N_EXPTS == 5, block_k == 32, and expt_hist == [60, 33, 0, 32, 25]. -# -# if unpadded if padded -# ----------- --------- -# x: expt #0: x[:, :60] x[:, :60] -# x[:, 60:64] - zero padded -# expt #1: x[:, 60:93] x[:, 64:97] -# x[:, 97:128] - zero padded -# expt #3: x[:, 93:125] x[:, 128:160] -# expt #4: x[:, 125:150] x[:, 160:185] -# x[:, 185:192] - zero padded -# x[:, 150:min(182, K)] - must not contain inf/nan -# -# x[:, 182:] x[:, 192:] - unused (may contain garbage, including inf/nan) -# -# w is the same, except that rows columns are flipped. -@dataclass -class InnerRoutingData: - base: RoutingData | None = None - block_k: int | None = None - x_is_padded: bool = False - w_is_padded: bool = False - - # Return value contains: ExptHist, ExptOffs, ExptTileOffs, ExptData, - # EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED, ExptHistMax - @staticmethod - def make_kernel_args(data, block_m): - if isinstance(data, RoutingData): - expt_data, block = data.expt_data, block_m - args = (False, False, False, None) - elif isinstance(data, InnerRoutingData): - expt_data, block = data.base.expt_data, data.block_k - args = ( - True, data.x_is_padded, data.w_is_padded, expt_data.slice_sizes.max() - ) - elif data is None: - expt_data = None - else: - assert None - - if expt_data is None: - return (None, None, None, None, False, False, False, None) - - return ( - expt_data.slice_sizes, - expt_data.slice_offs, - expt_data.block_offs(block), - expt_data.block_schedule(block), - ) + args - - -# --------------------- -# Numerics -# --------------------- - -# fmt: off - -@dataclass(frozen=True) -class FlexCtx: - lhs_data: InFlexData = InFlexData() - rhs_data: InFlexData = InFlexData() - out_data: OutFlexData = OutFlexData() - acc_data: InFlexData = InFlexData() - -@dataclass -class PrecisionConfig: - max_num_imprecise_acc: int = None - allow_tf32: bool = True - flex_ctx: FlexCtx = FlexCtx() - acc_scale: int = 1.0 - flexpoint_saturate_inf: bool = False - report_quantization_err_fn: callable = None - act_scale: Tensor | None = None - weight_scale: Tensor| None = None - out_scale: Tensor | None = None - out_dtype: torch.dtype = None - enforce_bitwise_invariance: bool = False - - -# TODO: merge in opt_flags -def get_swap_xw(precision_config, opt_flags): - if target_info.cuda_capability_geq(10, 0): - return precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent - return False - -# --------------------- -# Allocation -# --------------------- - -@dataclass -class MatmulAllocation: - device: str - output: tuple[tuple[int], torch.dtype] - scratchpads: dict[str, tuple] - -def init_allocation(x, w, precision_config, fused_activation, - routing_data, gather_indx, scatter_indx, inner_routing_data, - n_reduce_shards, opt_flags): - # ---- output ------ - N = w.shape[-1] - # by default - M is number of rows in the activations - M = x.shape[-2] - # if the activations are gathered, then M is number of gather indices - if gather_indx is not None: - M = gather_indx.src_indx.shape[0] - if scatter_indx is not None: - M = scatter_indx.src_indx.shape[0] - if scatter_indx is None: - y_rows = M - else: - y_rows = M // routing_data.n_expts_act - y_rows *= n_reduce_shards - if inner_routing_data is not None: - batch_dim = inner_routing_data.base.n_expts_tot - else: - batch_dim = x.shape[0] if x.ndim == 3 else 1 - out_shape = (batch_dim, y_rows, N // fused_activation.specs.reduction_n) - out_dtype = precision_config.out_dtype or x.dtype - output = (out_shape, out_dtype) - # ---- scratchpad -----# - scratchpad = dict() - N_scratch = N // fused_activation.specs.reduction_n if opt_flags.split_k == 1 else N - if opt_flags.split_k > 1 or (scatter_indx is not None and (not is_cuda() or routing_data.n_expts_act > 1)): - scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype - scratchpad["matmul"] = ((opt_flags.split_k, batch_dim, M, N_scratch), scratch_out_dtype) - if "matmul" in scratchpad and precision_config.out_scale is not None: - assert batch_dim == 1, "batch_dim > 1 not supported yet" - scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N_scratch, MXFP_BLOCK_SIZE)), torch.uint8) - return MatmulAllocation(x.device, output, scratchpad) - -def apply_allocation(allocation: MatmulAllocation, output): - ret = dict() - if output is None: - output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1]) - else: - if output.ndim == 2: - output = output[None, :, :] - assert output.shape == allocation.output[0] - ret["output"] = output[None, :, :] - ret["scratchpad"] = { - k: torch.empty(v[0], device=allocation.device, dtype=v[1]) - for k, v in allocation.scratchpads.items() - } - return ret - -# ----------------------------------------------------------------------------- -# Canonicalize -# ----------------------------------------------------------------------------- -# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used -# we can canonicalize storages to make the implementation more uniform - -def _canonicalize_storage(storage, out_ndim, flex_data): - assert out_ndim >= storage.data.ndim - # Need to use as_strided instead of view because for a tensor with - # shape[-2] == 1 can have ambuiguity related to col-wise. Fo example, - # > t = torch.randn(2, 5, 1).mT - # > t_view = t.view(t.shape) - # > t.stride(), t_view.stride() - # ((5, 1, 1), (5, 5, 1)) - # Our check t_view is col-wise fails since t_view.stride(-2) != 1 - # This case is covered by (m, n, k) == (1000, 700, 2) in test_matmul.py - new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape) - new_storage_stride = [0] * (out_ndim - storage.data.ndim) + list(storage.data.stride()) - new_storage_data = storage.data.as_strided(new_storage_shape, new_storage_stride) - if flex_data is not None: - new_storage_data = flex_data.reinterpret(new_storage_data) - return Storage(new_storage_data, storage.layout) - - -# ----------------------------------------------------------------------------- -# Triton Implementation -# ----------------------------------------------------------------------------- - -def matmul_ogs_set_idle_sms(num_idle_sms): - """ - persistent kernels will leave `num_idle_sms` idle - """ - update_opt_flags_constraints({"idle_sms": num_idle_sms}) - -def matmul_ogs(x, w, bias, - routing_data: RoutingData | None = None, - gather_indx: GatherIndx | None = None, - scatter_indx: ScatterIndx | None = None, - precision_config: PrecisionConfig | None = None, - betas: torch.Tensor | None = None, - gammas: torch.Tensor | None = None, - out_alpha: float | None = None, - y: torch.Tensor | None = None, - fused_comm: FusedComm | None = None, - fused_activation: FusedActivation | None = None, - epilogue: Epilogue | None = None, - y_acc_in: torch.Tensor | None = None, - inner_routing_data: InnerRoutingData | None = None, -): - """ - Y[:, :] = 0. - for e in num_experts: - Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :]) - - matmul can be optionally fused with all gather or scatter at the end for the output. When fused_comm is specified, the m-th row of the output will be stored to (m * n_reduce_shards + reduce_rank) -th row - of each rank id in range [scatter_shard_indx[m] * n_reduce_shards, (scatter_shard_indx[m] + 1) * n_reduce_shards) if scatter_shard_indx is not None, otherwise the output will be all gathered across all reduce ranks. - When scatter_shard_indx is specified, the caller should ensure that the indices of different shards do not conflict. - - The output buffer for fused comm should be pre-allocated and passed in via fused_comm.out_handles, which contains ipc handles to the output tensors, each with shape (n_rows * n_reduce_shards, n_cols). - """ - is_input_batched = x.ndim == 3 - if is_input_batched: - assert gather_indx is None, "gather not supported in batched mode" - assert scatter_indx is None, "scatter not supported in batched mode" - assert routing_data is None, "routing not supported in batched mode" - assert inner_routing_data is None, "routing not supported in batched mode" - assert fused_comm is None, "fused comm is not supported in batched mode" - assert w.ndim == 3 and w.shape[0] == x.shape[0] - if inner_routing_data is not None: - assert routing_data is None - assert gather_indx is None - assert scatter_indx is None - routing_data = RoutingData( - None, None, inner_routing_data.base.n_expts_tot, 1, - expected_tokens_per_expt=inner_routing_data.base.expected_tokens_per_expt, - ) - # canonicalize inputs - if precision_config is None: - precision_config = PrecisionConfig() - if fused_activation is None: - fused_activation = FusedActivation(FnSpecs.default(), tuple()) - if epilogue is None: - epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False) - if routing_data is None: - routing_data = RoutingData(None, None, max(1, w.shape[0]), 1) - # unpack scales - w_scale = precision_config.weight_scale - w_has_mx = w_scale is not None - is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8 - if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10" - if not isinstance(w, Tensor): - # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real - dtype = FP4 if w.dtype == torch.uint8 else w.dtype - w = wrap_torch_tensor(w, dtype=dtype) - if w_has_mx and is_cuda() and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)): - assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)" - if w_scale is not None and not isinstance(w_scale, Tensor): - w_scale = Tensor(w_scale) - if w_scale is not None: - w_scale.storage.data = w_scale.data.view(torch.uint8) - w_scale.dtype = torch.uint8 - x_scale = precision_config.act_scale - x_has_mx = x_scale is not None - if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp" - if x_scale is not None and not isinstance(x_scale, Tensor): - x_scale = Tensor(x_scale) - if not isinstance(x, Tensor): - x = Tensor(x, dtype=x.dtype) - x_transpose = x.stride(-1) != 1 - # determine shapes - has_gather = gather_indx is not None - has_scatter = scatter_indx is not None - is_ragged = routing_data.expt_hist is not None - M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0] - if inner_routing_data is not None: - batch_size = inner_routing_data.base.n_expts_tot - else: - batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1 - if y_acc_in is not None: - y_acc_is_y = y_acc_in.data_ptr() == y.data_ptr() and y_acc_in.stride() == y.stride() - else: - y_acc_is_y = None - K = x.shape[-1] - K_W, N = w.shape[-2:] - if x.ndim == 3 and w.ndim == 3: - assert x.shape[0] == w.shape[0] - # compute optimization flags - out_dtype = precision_config.out_dtype or x.dtype - can_use_tma = ( - x.numel() > 0 and x.storage.is_tma_compliant() and - w.numel() > 0 and w.storage.is_tma_compliant() and - (w_scale is None or w_scale.storage.is_tma_compliant()) and - (not is_ragged or x.stride(-1) == 1) and - # Currently we don't support tma if y is column major; may revisit later if this becomes an issue. - (y is None or y.stride(-1) == 1) and - (y_acc_in is None or y_acc_is_y) and - # If we use inner_routing_data, w must be either padded or row major, otherwise we get - # unaligned access. - (inner_routing_data is None or w.stride(-1) == 1 or inner_routing_data.w_is_padded) - ) - if w_scale is not None and isinstance(w_scale.storage.layout, StridedLayout) and w_scale.storage.data.stride()[-1] != 1: - # In this case, we need to transpose w_scale. Then the reduction dim - # becomes the last dim that will be divided by 32. This to be a multiple - # of 16 to be TMA-compliant requires block_k to be a multiple of 512, - # which is too big. - can_use_tma = False - has_gather_tma = has_gather and target_info.has_tma_gather() - # hopper w/ mxfp4 doesn't support TMA - can_use_tma = can_use_tma and is_cuda() and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4) - can_use_split_k = scatter_indx is None and not x_has_mx and not w_has_mx - opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config, - batch_size, M, N, w.shape[-2], routing_data, - can_use_tma, can_use_split_k, epilogue.effective_itemsize, - x_transpose, y_acc_in is not None, - inner_routing_data.block_k if inner_routing_data is not None else None, - ) - if inner_routing_data is not None: - assert opt_flags.block_k == inner_routing_data.block_k - assert opt_flags.split_k == 1 - batch_size = inner_routing_data.base.n_expts_tot - # For unpadded (row major) x, we cannot use tma because memory access isn't aligned. - x_has_tma = opt_flags.is_persistent and (x.stride(-1) != 1 or inner_routing_data.x_is_padded) - # If TMA is used, limit is handled automatically, so we can pretend K is "even". - # (For unpadded input, we assume that the first block_k unused rows are zero-filled, - # when routing_data.expt_hist.sum() is less than K or K_W.) - if opt_flags.is_persistent: - even_K = x_has_tma or inner_routing_data.x_is_padded - else: - even_K = inner_routing_data.x_is_padded and inner_routing_data.w_is_padded - else: - batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1 - assert K == K_W - x_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather) - even_K = (K % opt_flags.block_k == 0) - if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp(): - raise NotImplementedError("Must use non-persistent kernel for simulated MXFP") - if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp(): - raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP") - # fused activation - matmul_fused_activation = fused_activation - reduce_fused_activation = FusedActivation() - if opt_flags.split_k > 1: - matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation - # allocate output/scratchpad memory - allocation = init_allocation(x, w, precision_config, fused_activation, - routing_data, gather_indx, scatter_indx, inner_routing_data, fused_comm.n_reduce_shards if fused_comm is not None else 1, opt_flags) - memory = apply_allocation(allocation, y) - # early exit - if batch_size * M * N == 0: - ret = memory["output"].squeeze(0) - if not is_input_batched: - ret = ret.squeeze(0) - return ret - # TMA descriptors require a global memory allocation - if opt_flags.is_persistent: - triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device)) - # Intermediate tensors and postprocess kernels for each situation - has_scratchpad = "matmul" in memory["scratchpad"] - # Canonical output tensor (matmul scratchpad if present, otherwise final output tensor) - out_matmul = memory["scratchpad"].get("matmul", memory["output"]) - out_matmul_flex = OutFlexData() if out_matmul.dtype == torch.float32 else precision_config.flex_ctx.out_data - # Unified mx-scale pointer; when scratchpad exists, prefer its mx buffer - out_matmul_scale = precision_config.out_scale - if out_matmul_scale is not None: - out_matmul_scale = out_matmul_scale.data.view(torch.uint8) - if has_scratchpad and "mx_out_scale" in memory["scratchpad"]: - out_matmul_scale = memory["scratchpad"]["mx_out_scale"] - out_matmul_has_mx = out_matmul_scale is not None and out_matmul.element_size() == 1 - # matrix multiplication - flex = precision_config.flex_ctx - bias_stride = None if bias is None else bias.stride(0) - num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0] - # moe metadata - block_m = opt_flags.block_m - expt_data_args = InnerRoutingData.make_kernel_args(inner_routing_data or routing_data, block_m) - # spmd grid - grid_m = triton.cdiv(M, opt_flags.block_m) - if routing_data.expt_data is not None: - grid_m = routing_data.n_blocks(M, opt_flags.block_m) - grid_n = triton.cdiv(N, opt_flags.block_n) - max_grid = batch_size * grid_m * grid_n * opt_flags.split_k - grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid - # canonicalize storage - has_scatter_tma = scatter_indx is not None and target_info.has_tma_gather() - y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if has_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:])) - x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data) - w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data) - y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data) - # create tma descriptor for x - if y_acc_in is not None: - assert opt_flags.split_k == 1, "y_acc_in + split_k is not supported." - assert scatter_indx is None, "y_acc_in + scatter is not supported." - if y_acc_in.ndim == 2: - y_acc_in = y_acc_in.unsqueeze(0) - assert y_acc_in.shape == out_matmul.shape[-3:] - y_acc_strides = y_acc_in.stride() - else: - y_acc_strides = (None, None, None) - - x_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k] - x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense" - x_tensor_or_tma = x_storage.make_tma(x_tma_block_size, x_tma_mode) if x_has_tma else x_storage.data - # create tma descriptor for y - y_has_tma = ( - opt_flags.is_persistent and (scatter_indx is None or has_scatter_tma) - and (y_acc_in is None or y_acc_is_y) - ) - block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.specs.reduction_n - y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n] - y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense" - y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data - # create tma descriptor for w - w_has_tma = opt_flags.is_persistent - w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data - # create tma descriptor for w_scale - w_scale_has_tma = opt_flags.is_persistent and w_scale is not None - # When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed - # (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose - # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs. - # w_transpose = w_storage.data.stride()[-1] != 1 - w_transpose = w_storage.data.stride()[-2] == 1 - if w_scale_has_tma: - w_scale_storage = w_scale.storage - scale_block_k = opt_flags.block_k // int(MXFP_BLOCK_SIZE) - # cancel out the transpose done inside make_tma since - # BlackwellMXScaleLayout.swizzle_block_shape expects block_shape[1] is - # the reduction dimension. - w_scale_tma_block_size = [opt_flags.block_n, scale_block_k] if w_transpose and w_scale.storage.layout.name == "BLACKWELL_SCALE" else [scale_block_k, opt_flags.block_n] - if isinstance(w_scale.storage.layout, StridedLayout): - assert w_scale_storage.data.stride()[-1] == 1, "w_scale should be contiguous with StridedLayout" - w_scale_storage = _canonicalize_storage(w_scale.storage, 3, None) - w_scale_tma_block_size = [1] + w_scale_tma_block_size - w_scale_tensor_or_tma = w_scale_storage.make_tma(w_scale_tma_block_size, "dense", is_scale=True) - else: - w_scale_tensor_or_tma = w_scale - # canonicalize strides - x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride()) - x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None) - x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides - w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None) - w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides - out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None) - out_matmul_scale_strides = (0, ) * (4 - len(out_matmul_scale_strides)) + out_matmul_scale_strides - # launch kernel - kernels = specializations.get(epilogue=epilogue.specs, activation=matmul_fused_activation.specs) - if gather_indx is not None: - gather_src_indx = torch.div(gather_indx.src_indx, routing_data.n_expts_act, rounding_mode='trunc') - fused_comm_kwargs = { - "pYPtrs": fused_comm.out_handles, - "ScatterShardIndx": fused_comm.scatter_shard_indx, - "reduce_rank": fused_comm.reduce_rank, - "n_reduce_shards": fused_comm.n_reduce_shards, - } if fused_comm is not None else {} - # if routing_data.n_expts_act > 1: - # y_storage.data.view(torch.uint8).zero_() - (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)]( - y_tensor_or_tma, y_storage.data, *out_matmul.stride(), - *((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex), - *out_matmul_scale_strides[-4:], - x_tensor_or_tma, x_storage.data, *x_strides, x_transpose, - flex.lhs_data.scale, - None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides, - w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose, - flex.rhs_data.scale, - w_scale_tensor_or_tma, *w_scale_strides, - flex.acc_data.reinterpret(y_acc_in), *y_acc_strides, - flex.acc_data.scale, y_acc_is_y, - bias, bias_stride, - x.shape[-2] if routing_data.expt_hist is None else None, - N, K, K_W, - betas, gammas, - None if gather_indx is None else gather_src_indx, - None if gather_indx is None else gather_indx.dst_indx, # Only for launch_metadata - None if scatter_indx is None else scatter_indx.src_indx, - num_indx, - None if scatter_indx is None else scatter_indx.dst_indx, - None if scatter_indx is None else scatter_indx.dst_indx.shape[0], - *expt_data_args, - batch_size, grid_m, grid_n, - out_alpha, - *matmul_fused_activation.fn_args, matmul_fused_activation.specs.reduction_n, - *epilogue.fn_arg_values_matmul, - routing_data.n_expts_tot, - precision_config.max_num_imprecise_acc, - precision_config.allow_tf32, - precision_config.flexpoint_saturate_inf, - flex.rhs_data.is_per_batch, - out_matmul_flex.is_per_batch, - flex.acc_data.is_per_batch, - opt_flags.block_m, - opt_flags.block_n, - opt_flags.block_k, - opt_flags.group_m, - INIT_OUTPUT_TO_ZERO=routing_data.n_expts_act == 1, - XCD_SWIZZLE=opt_flags.xcd_swizzle, - SWIZZLE_MX_VALUE=w.storage.layout.name, - SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name, - EPILOGUE_SUBTILE=opt_flags.epilogue_subtile, - SPLIT_K=opt_flags.split_k, - EVEN_K=even_K, - W_CACHE_MODIFIER=opt_flags.w_cache_modifier, - TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt, - num_warps=opt_flags.num_warps, - num_stages=opt_flags.num_stages, - arch=opt_flags.arch, - UPCAST_INDICES=should_upcast_indices(x, w, out_matmul), - X_TMA_MODE=x_tma_mode, - Y_TMA_MODE=y_tma_mode, - SWAP_XW=get_swap_xw(precision_config, opt_flags), - IS_EPILOGUE_QUANT_MXFP8=epilogue.specs.name == FnName.QUANTIZE_MXFP8.name, - NUM_SMS = grid if opt_flags.is_persistent else 0, - **fused_comm_kwargs, - **opt_flags.target_kernel_kwargs) - - assert not (opt_flags.split_k > 1 and scatter_indx is not None) - out_final_mx_scale = None - if opt_flags.split_k > 1: - assert not out_matmul_has_mx - postprocess_fn1 = ReducePostprocessFn(specs=reduce_fused_activation.specs, fn_args=reduce_fused_activation.fn_args) - postprocess_fn2 = None if has_scatter else ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize) - y, y_mx_scale = reduce( - x = out_matmul.view(out_matmul.shape[0], -1, out_matmul.shape[-1]), - dim = 0, - # output data/metadata - y = memory["output"].view(-1, memory["output"].shape[-1]), - y_dtype = memory["output"].dtype, - y_flex = precision_config.flex_ctx.out_data, - y_flex_saturate_inf = precision_config.flexpoint_saturate_inf, - y_has_mx = precision_config.out_scale is not None, - # fused functions - postprocess_fn1 = postprocess_fn1, - postprocess_fn2 = postprocess_fn2, - ) - y_shape = out_matmul.shape[1:-1] + (out_matmul.shape[-1] // reduce_fused_activation.specs.reduction_n,) - out_matmul = y.view(*y_shape).unsqueeze(0) - if y_mx_scale is not None: - out_final_mx_scale = y_mx_scale.view(out_matmul.shape[-2], triton.cdiv(out_matmul.shape[-1], 32)) - # TODO: change `matmul_ogs` semantics and move this to another op! - if scatter_indx is not None and (not is_cuda() or routing_data.n_expts_act > 1): # Matmul ogs kernel fuses scatter already, so only need for n_exps_act > 1. - mask = (scatter_indx.src_indx != -1).view(out_matmul.shape[-2]//routing_data.n_expts_act, routing_data.n_expts_act, 1) - out_matmul = out_matmul.view(out_matmul.shape[-2]//routing_data.n_expts_act, routing_data.n_expts_act, -1) - mask = mask.expand_as(out_matmul) - out_matmul_scale_shape = out_matmul.shape[:-1] + (triton.cdiv(out_matmul.shape[-1], 32),) - postprocess_fn = ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize) - x_flex = InFlexData(dtype=out_matmul_flex.dtype, scale=out_matmul_flex.expected_scale) - out_final, out_final_mx_scale = reduce(out_matmul, dim=1, postprocess_fn2=postprocess_fn, x_flex=x_flex, # - mask=mask, - y=memory["output"].squeeze(0).squeeze(0), - x_mxscale=out_matmul_scale.view(*out_matmul_scale_shape) if out_matmul_has_mx else None, - y_has_mx=precision_config.out_scale is not None, - y_flex=precision_config.flex_ctx.out_data, - y_flex_saturate_inf=precision_config.flexpoint_saturate_inf, - ) - out_final = out_final.unsqueeze(0) - else: - out_final = out_matmul.squeeze(0) - - if not (is_input_batched or inner_routing_data is not None): - out_final = out_final.squeeze(0) - if out_final_mx_scale is not None: - precision_config.out_scale = out_final_mx_scale - return out_final - -# ----------------------------------------------------------------------------- -# Reference Implementation -# ----------------------------------------------------------------------------- - -def matmul_ogs_torch(x, w, bias, - routing_data: RoutingData = None, - gather_indx: GatherIndx = None, - scatter_indx: ScatterIndx = None, - precision_config: PrecisionConfig = None, - betas = None, - gammas = None, - inner_routing_data: InnerRoutingData | None = None, - round_x = None, round_y = None, - device: str = "cuda"): - if inner_routing_data is not None: - assert bias is None, "Not supported yet" - m, n = x.shape[-2], w.shape[-1] - block_k = inner_routing_data.block_k - n_expts_tot = inner_routing_data.base.n_expts_tot - out = torch.zeros((n_expts_tot, m, n), dtype=torch.float32, device=x.device) - start_x = start_w = 0 - for expt in range(n_expts_tot): - k = inner_routing_data.base.expt_hist[expt].item() - if k > 0: - out[expt] = matmul_ogs_torch( - x[:, start_x:start_x+k], w[start_w:start_w+k, :], None, - None, None, None, None, betas, gammas, None, round_x, round_y, device - ) - padded_k = triton.cdiv(k, block_k) * block_k - start_x += padded_k if inner_routing_data.x_is_padded else k - start_w += padded_k if inner_routing_data.w_is_padded else k - return out - - is_input_batched = x.ndim == 3 - assert x.dtype.itemsize > 1 - assert w.dtype.itemsize > 1 - if is_input_batched: - assert gather_indx is None, "gather not supported in batched mode" - assert scatter_indx is None, "scatter not supported in batched mode" - assert routing_data is None, "routing not supported in batched mode" - assert w.ndim == 3 and w.shape[0] == x.shape[0] - if round_x is None: - round_x = lambda x, idx: x - if round_y is None: - round_y = lambda x: x - if bias is not None and bias.ndim == 1: - bias = bias.view(1, *bias.shape) - if w.ndim == 2: - w = w.view(1, *w.shape) - if x.ndim == 2: - x = x.view(1, *x.shape) - if routing_data is None: - routing_data = RoutingData(None, None, w.shape[0], 1) - n_expts_act = routing_data.n_expts_act - # memory offsets - if routing_data.n_expts_tot > 1 and not is_input_batched: - sizes = routing_data.expt_hist - off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32) - off[1:] = torch.cumsum(sizes, 0) - offs = list(itertools.pairwise(off)) - else: - offs = [[0, x.shape[1]] for _ in range(w.shape[0])] - # compute - n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0] - y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype) - for i, (lo, hi) in enumerate(offs): - if gather_indx is None: - idx = torch.arange(lo, hi, device=x.device) - else: - idx = gather_indx.src_indx[lo:hi] // n_expts_act - batch = i if is_input_batched else 0 - out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(), - w[i].float()) - if bias is not None: - out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None] - if gammas is not None: - out *= gammas[lo:hi, None] - y[batch, lo:hi, :] = round_y(out) - if not is_input_batched: - y = y.view(y.shape[1], y.shape[2]) - if scatter_indx is None: - return y - # accumulate output from all experts - n_rows = y.shape[0] // n_expts_act - out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device) - for i, (lo, hi) in enumerate(offs): - dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act - msk = dst_idx != -1 - out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float() - return out - - -def post_matmul_comm_torch(y: torch.Tensor, rank: int, n_reduce_shards: int, - world_size: int, - scatter_shard_indx: torch.Tensor | None = None, -): - """ - Reference implementation of post matmul communication. - - y: the local matmul output - rank: the global rank - n_reduce_shards: the number of reduce shards - world_size: the world size - scatter_shard_indx: the shard indices for the scatter. None if all gather. - - Output shape: - (batch_size, n_rows, n_cols) -> (batch_size, n_rows * n_reduce_shards, n_cols) if batched, otherwise - (n_rows, n_cols) -> (n_rows * n_reduce_shards, n_cols) - """ - from torch import distributed as dist - # if n_reduce_shards == 1: - # return y - - ys = [torch.empty_like(y) for _ in range(world_size)] - dist.all_gather(ys, y) - out_shape = (*y.shape[:-2], y.shape[-2] * n_reduce_shards, y.shape[-1]) - - if scatter_shard_indx is None: - # all gather - assert n_reduce_shards == world_size - return torch.cat(ys, dim=-1).reshape(out_shape) - else: - # Note: when multiple ranks scatter to the same destination, the result is undefined. - scatter_shard_indx_global = torch.empty((world_size, *scatter_shard_indx.shape), device=scatter_shard_indx.device, dtype=scatter_shard_indx.dtype) - dist.all_gather([scatter_shard_indx_global[i] for i in range(world_size)], scatter_shard_indx) - - assert len(out_shape) == 2, "batched mode not supported" - result = torch.zeros(out_shape, device=y.device, dtype=y.dtype) - reduce_shard_id = rank // n_reduce_shards - - for i in range(world_size // n_reduce_shards): - scatter_mask = scatter_shard_indx_global[i * n_reduce_shards, :] == reduce_shard_id - for j in range(n_reduce_shards): - out_slice = result.as_strided( - (result.shape[0] // n_reduce_shards, result.shape[1]), - (result.stride(0) * n_reduce_shards, result.stride(1)), - storage_offset=j * result.stride(0), - ) - out_slice[scatter_mask, :] = ys[i * n_reduce_shards + j][scatter_mask, :] - return result diff --git a/python/triton_kernels/triton_kernels/reduce.py b/python/triton_kernels/triton_kernels/reduce.py index 4d64173a9d..f9bf58be02 100644 --- a/python/triton_kernels/triton_kernels/reduce.py +++ b/python/triton_kernels/triton_kernels/reduce.py @@ -16,29 +16,30 @@ class PostprocessFn: @triton.jit -def _reduce(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x tensor (input) - XMx, stride_xmxr, stride_xmx0, stride_xmx1, # x mx scale - Y, stride_y0: tl.int64, stride_y1, # y tensor (output) - YMx, stride_ymx0, stride_ymx1, # y mx scale - Mask, stride_mr, stride_m0, stride_m1, # mask tensor - Scale, stride_sr, stride_s0, stride_s1, # scale tensor - K, S0, X_S1, Y_S1, # shape (K = reduction dim; S0, IN_S1 = input dims, OUT_S1 = output dims) - POSTPROCESS_FN1: tl.constexpr, postprocess_fn1_args, # - POSTPROCESS_FN2: tl.constexpr, postprocess_fn2_args, # - XFlex, # x flex (global) scale - YFlexExpected, YFlexActual, YFlexChecksum, Y_FLEX_SATURATE_INF: tl.constexpr, # y flex (global) scale - IS_MASK_NONE: tl.constexpr, # - BROADCAST_R: tl.constexpr, # - BROADCAST_S0: tl.constexpr, # - BROADCAST_S1: tl.constexpr, # - IS_SCALE_NONE: tl.constexpr, # - SCALE_BROADCAST_R: tl.constexpr, # - SCALE_BROADCAST_S0: tl.constexpr, # - SCALE_BROADCAST_S1: tl.constexpr, # - BLOCK_S0: tl.constexpr, # - BLOCK_X_S1: tl.constexpr, # - BLOCK_Y_S1: tl.constexpr, # - ): +def _reduce_forward(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x tensor (input) + XMx, stride_xmxr, stride_xmx0, stride_xmx1, # x mx scale + Y, stride_y0: tl.int64, stride_y1, # y tensor (output) + YMx, stride_ymx0, stride_ymx1, # y mx scale + Mask, stride_mr, stride_m0, stride_m1, # mask tensor + Scale, stride_sr, stride_s0, stride_s1, # scale tensor + K, S0, X_S1, Y_S1, # shape (K = reduction dim; S0, IN_S1 = input dims, OUT_S1 = output dims) + POSTPROCESS_FN1: tl.constexpr, postprocess_fn1_args, # + POSTPROCESS_FN2: tl.constexpr, postprocess_fn2_args, # + XFlex, # x flex (global) scale + YFlexExpected, YFlexActual, YFlexChecksum, + Y_FLEX_SATURATE_INF: tl.constexpr, # y flex (global) scale + IS_MASK_NONE: tl.constexpr, # + BROADCAST_R: tl.constexpr, # + BROADCAST_S0: tl.constexpr, # + BROADCAST_S1: tl.constexpr, # + IS_SCALE_NONE: tl.constexpr, # + SCALE_BROADCAST_R: tl.constexpr, # + SCALE_BROADCAST_S0: tl.constexpr, # + SCALE_BROADCAST_S1: tl.constexpr, # + BLOCK_S0: tl.constexpr, # + BLOCK_X_S1: tl.constexpr, # + BLOCK_Y_S1: tl.constexpr, # + ): pid_s0 = tl.program_id(0) pid_s1 = tl.program_id(1) tl.static_assert(BLOCK_X_S1 % 32 == 0) @@ -95,9 +96,9 @@ def _reduce(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x tensor tl.store(y_ptrs, y, mask=valid_s0[:, None] & valid_y_s1[None, :]) -specializations = SpecializationModule( - "reduce", - kernels=[("_reduce", _reduce)], +forward_specializations = SpecializationModule( + "reduce_forward", + kernels=[("_reduce_forward", _reduce_forward)], closure_args={ "postprocess_fn1": ClosureArg("POSTPROCESS_FN1", "postprocess_fn1_args"), "postprocess_fn2": ClosureArg("POSTPROCESS_FN2", "postprocess_fn2_args"), @@ -105,7 +106,7 @@ def _reduce(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x tensor ) -def reduce( +def reduce_forward( x: torch.Tensor, dim: int, mask: Optional[torch.Tensor] = None, @@ -216,8 +217,8 @@ def reduce( grid = (triton.cdiv(S0, BLOCK_S0), triton.cdiv(Y_S1, BLOCK_Y_S1)) mask_arg = mask if mask is not None else None scale_arg = scale if scale is not None else None - reduce_kernel = specializations.get(postprocess_fn1=postprocess_fn1.specs, - postprocess_fn2=postprocess_fn2.specs)._reduce + reduce_kernel = forward_specializations.get(postprocess_fn1=postprocess_fn1.specs, + postprocess_fn2=postprocess_fn2.specs)._reduce_forward reduce_kernel[grid]( x_flex.reinterpret(x), stride_xr, stride_x0, stride_x1, # x_mxscale, stride_xmxr, stride_xmx0, stride_xmx1, # @@ -245,6 +246,311 @@ def reduce( return y, y_mxscale +# ------------------------------------------------------------ + + +@triton.jit +def _reduce_backward( + dY, + stride_y0: tl.int64, + stride_y1, # upstream grad (S0, Y_S1) + dX, + stride_xr: tl.int64, + stride_x0: tl.int64, + stride_x1, # grad wrt X (K, S0, X_S1) in the chosen layout + XMx, + stride_xmxr, + stride_xmx0, + stride_xmx1, # input micro-scales (optional) + Mask, + stride_mr, + stride_m0, + stride_m1, # mask (optional) + Scale, + stride_sr, + stride_s0, + stride_s1, # scale (optional) + K, + S0, + X_S1, + Y_S1, # shapes + XFlex, # global input flex scale (scalar device buffer) + IS_MASK_NONE: tl.constexpr, + BROADCAST_R: tl.constexpr, + BROADCAST_S0: tl.constexpr, + BROADCAST_S1: tl.constexpr, + IS_SCALE_NONE: tl.constexpr, + SCALE_BROADCAST_R: tl.constexpr, + SCALE_BROADCAST_S0: tl.constexpr, + SCALE_BROADCAST_S1: tl.constexpr, + REDUCTION_N: tl.constexpr, # maps X_S1 -> Y_S1 (grouped sum in fwd) + BLOCK_S0: tl.constexpr, + BLOCK_X_S1: tl.constexpr, +): + # Tile over (S0, X_S1). We loop over the reduction K dimension. + pid_s0 = tl.program_id(0) + pid_s1 = tl.program_id(1) + + tl.static_assert(BLOCK_X_S1 % 32 == 0) + BLOCK_X_SMX1: tl.constexpr = BLOCK_X_S1 // 32 + + offs_s0 = pid_s0 * BLOCK_S0 + tl.arange(0, BLOCK_S0) + offs_x_s1 = pid_s1 * BLOCK_X_S1 + tl.arange(0, BLOCK_X_S1) + offs_x_smx1 = pid_s1 * BLOCK_X_SMX1 + tl.arange(0, BLOCK_X_SMX1) + + valid_s0 = offs_s0 < S0 + valid_x_s1 = offs_x_s1 < X_S1 + valid_in_smx1 = offs_x_smx1 < tl.cdiv(X_S1, 32) + + # Map X_S1 positions to their Y_S1 group index (grouped-sum fwd) + offs_y_from_x = offs_x_s1 // REDUCTION_N + valid_y_from_x = offs_y_from_x < Y_S1 + + # Load upstream grad; broadcasting over the REDUCTION_N group happens via indexing. + dy_ptrs = dY + offs_s0[:, None] * stride_y0 + offs_y_from_x[None, :] * stride_y1 + dy = tl.load(dy_ptrs, mask=valid_s0[:, None] & valid_y_from_x[None, :], other=0.0).to(tl.float32) + + # Global flex scale (scalar) + x_flex_scale = load_scale(XFlex) + + # Loop over the reduced dimension + for k in tl.range(0, K, num_stages=2): + g = dy + # Multiply by input micro-scale per group of 32 lanes if present + if XMx is not None: + xmx_ptrs = XMx + k * stride_xmxr + offs_s0[:, None] * stride_xmx0 + offs_x_smx1[None, :] * stride_xmx1 + xmx = tl.load(xmx_ptrs, mask=valid_s0[:, None] & valid_in_smx1[None, :], other=0) + xmx = (xmx.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + g = (g.reshape([BLOCK_S0, BLOCK_X_S1 // 32, 32]) * xmx[:, :, None]).reshape([BLOCK_S0, BLOCK_X_S1]) + # Multiply by global input flex scale + g = g * x_flex_scale + # Multiply by per-element Scale if provided + if not IS_SCALE_NONE: + k_term_s = 0 if SCALE_BROADCAST_R else (k * stride_sr) + s0_term_s = 0 if SCALE_BROADCAST_S0 else (offs_s0[:, None] * stride_s0) + s1_term_s = 0 if SCALE_BROADCAST_S1 else (offs_x_s1[None, :] * stride_s1) + s_ptrs = Scale + k_term_s + s0_term_s + s1_term_s + s = tl.load(s_ptrs, mask=valid_s0[:, None] & valid_x_s1[None, :], other=1) + g = g * s + # Apply mask if provided + if not IS_MASK_NONE: + k_term = 0 if BROADCAST_R else (k * stride_mr) + s0_term = 0 if BROADCAST_S0 else (offs_s0[:, None] * stride_m0) + s1_term = 0 if BROADCAST_S1 else (offs_x_s1[None, :] * stride_m1) + m_ptrs = Mask + k_term + s0_term + s1_term + m = tl.load(m_ptrs, mask=valid_s0[:, None] & valid_x_s1[None, :], other=1) + g = tl.where(m != 0, g, 0.0) + # + dx_ptrs = dX + k * stride_xr + offs_s0[:, None] * stride_x0 + offs_x_s1[None, :] * stride_x1 + tl.store(dx_ptrs, g, mask=valid_s0[:, None] & valid_x_s1[None, :]) + + +def reduce_backward( + dy: torch.Tensor, + x_shape: tuple[int, int, int], + dim: int, + *, + mask: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + x_mxscale: Optional[torch.Tensor], + x_flex: Optional[InFlexData], + postprocess_fn1: Optional[PostprocessFn], + x_strides: tuple[int, int, int], + x_mx_strides: Optional[tuple[int, int, int]], + mask_strides: Optional[tuple[int, int, int]], + scale_strides: Optional[tuple[int, int, int]], + dx: torch.Tensor, +): + # Shapes/axes handling mirrors `reduce(...)` + if dim < 0: + dim += 3 + dims = (0, 1, 2) + nonred = tuple(d for d in dims if d != dim) + + S0, X_S1 = x_shape[nonred[0]], x_shape[nonred[1]] + K = x_shape[dim] + + # Postprocess grouping (grouped sum). Default is identity (1). + reduction_n = (postprocess_fn1.specs.reduction_n if postprocess_fn1 is not None else FnSpecs.default().reduction_n) + Y_S1 = X_S1 // reduction_n + assert dy.shape == (S0, Y_S1), f"dY shape {dy.shape} mismatch with (S0={S0}, Y_S1={Y_S1})" + + # Strides for dX must match the element size of the tensor passed to the kernel. + # If we reinterpret the dtype (e.g., flex/float8), use the reinterpreted view's strides. + dx_view = x_flex.reinterpret(dx) + dx_str0, dx_str1, dx_str2 = dx_view.stride() + stride_xr = (dx_str0 if dim == 0 else (dx_str1 if dim == 1 else dx_str2)) + stride_x0 = (dx_str0 if nonred[0] == 0 else (dx_str1 if nonred[0] == 1 else dx_str2)) + stride_x1 = (dx_str0 if nonred[1] == 0 else (dx_str1 if nonred[1] == 1 else dx_str2)) + stride_xmxr = stride_xmx0 = stride_xmx1 = 0 + if x_mxscale is not None: + stride_xmxr, stride_xmx0, stride_xmx1 = x_mx_strides + + if mask is not None: + mstr0, mstr1, mstr2 = mask_strides + stride_mr = (mstr0 if dim == 0 else (mstr1 if dim == 1 else mstr2)) + stride_m0 = (mstr0 if nonred[0] == 0 else (mstr1 if nonred[0] == 1 else mstr2)) + stride_m1 = (mstr0 if nonred[1] == 0 else (mstr1 if nonred[1] == 1 else mstr2)) + else: + stride_mr = stride_m0 = stride_m1 = 0 + + if scale is not None: + sstr0, sstr1, sstr2 = scale_strides + stride_sr = (sstr0 if dim == 0 else (sstr1 if dim == 1 else sstr2)) + stride_s0 = (sstr0 if nonred[0] == 0 else (sstr1 if nonred[0] == 1 else sstr2)) + stride_s1 = (sstr0 if nonred[1] == 0 else (sstr1 if nonred[1] == 1 else sstr2)) + else: + stride_sr = stride_s0 = stride_s1 = 0 + + # Launch configuration mirrors forward (but we tile over X_S1, not Y_S1) + BLOCK_S0 = 64 + BLOCK_X_S1 = 128 + grid = (triton.cdiv(S0, BLOCK_S0), triton.cdiv(X_S1, BLOCK_X_S1)) + + _reduce_backward[grid]( + dy, + dy.stride(0), + dy.stride(1), + dx_view, + stride_xr, + stride_x0, + stride_x1, + x_mxscale, + stride_xmxr, + stride_xmx0, + stride_xmx1, + mask, + stride_mr, + stride_m0, + stride_m1, + scale, + stride_sr, + stride_s0, + stride_s1, + K, + S0, + X_S1, + Y_S1, + x_flex.scale, + IS_MASK_NONE=(mask is None), + BROADCAST_R=(stride_mr == 0), + BROADCAST_S0=(stride_m0 == 0), + BROADCAST_S1=(stride_m1 == 0), + IS_SCALE_NONE=(scale is None), + SCALE_BROADCAST_R=(stride_sr == 0), + SCALE_BROADCAST_S0=(stride_s0 == 0), + SCALE_BROADCAST_S1=(stride_s1 == 0), + REDUCTION_N=reduction_n, + BLOCK_S0=BLOCK_S0, + BLOCK_X_S1=BLOCK_X_S1, + num_warps=4, + ) + + +# ------------------------------------------------------------ + +backward_specializations = SpecializationModule( + "reduce_backward", + kernels=[("_reduce_backward", _reduce_backward)], + closure_args={ + "postprocess_fn1": ClosureArg("POSTPROCESS_FN1", "postprocess_fn1_args"), + "postprocess_fn2": ClosureArg("POSTPROCESS_FN2", "postprocess_fn2_args"), + }, +) + + +class _ReduceAutograd(torch.autograd.Function): + + @staticmethod + def forward(ctx, x: torch.Tensor, dim: int, mask: Optional[torch.Tensor], scale: Optional[torch.Tensor], + x_mxscale: Optional[torch.Tensor], x_flex: Optional[InFlexData], y_dtype: Optional[torch.dtype], + y_flex: Optional[OutFlexData], y_flex_saturate_inf: bool, y_has_mx: Optional[bool], + y: Optional[torch.Tensor], postprocess_fn1: Optional[PostprocessFn], + postprocess_fn2: Optional[PostprocessFn]): + # Run your existing Triton forward + y, y_mx = reduce_forward( + x=x, + dim=dim, + mask=mask, + scale=scale, + x_mxscale=x_mxscale, + x_flex=x_flex, + y_dtype=y_dtype, + y_flex=y_flex, + y_flex_saturate_inf=y_flex_saturate_inf, + y_has_mx=y_has_mx, + y=y, + postprocess_fn1=postprocess_fn1, + postprocess_fn2=postprocess_fn2, + ) + + # Save everything needed for backward (no tensors are modified) + ctx.dim = dim + ctx.x_shape = tuple(x.shape) + ctx.x_dtype = x.dtype + ctx.device = x.device + ctx.mask = mask + ctx.scale = scale + ctx.x_mxscale = x_mxscale + ctx.x_flex = x_flex if x_flex is not None else InFlexData() + ctx.postprocess_fn1 = postprocess_fn1 if postprocess_fn1 is not None else PostprocessFn() + ctx.x_strides = tuple(x.stride()) + ctx.x_mx_strides = tuple(x_mxscale.stride()) if x_mxscale is not None else None + ctx.mask_strides = tuple(mask.stride()) if mask is not None else None + ctx.scale_strides = tuple(scale.stride()) if scale is not None else None + ctx.y_has_mx = bool(y_mx is not None) + + return y, y_mx + + @staticmethod + def backward(ctx, grad_y: torch.Tensor, grad_y_mxscale: Optional[torch.Tensor] = None): + # We do not support grads through MX-quantized outputs (no torch compute in bwd) + if ctx.y_has_mx: + raise NotImplementedError("Backward with y_mxscale (MX-quantized outputs) is not supported.") + + # Allocate grad for x; (no torch compute) + dx = torch.empty(ctx.x_shape, dtype=ctx.x_dtype, device=grad_y.device) + + reduce_backward( + dy=grad_y, + x_shape=ctx.x_shape, + dim=ctx.dim, + mask=ctx.mask, + scale=ctx.scale, + x_mxscale=ctx.x_mxscale, + x_flex=ctx.x_flex, + postprocess_fn1=ctx.postprocess_fn1, + x_strides=ctx.x_strides, + x_mx_strides=ctx.x_mx_strides, + mask_strides=ctx.mask_strides, + scale_strides=ctx.scale_strides, + dx=dx, + ) + return dx, None, None, None, None, None, None, None, None, None, None, None, None + + +def reduce( + x: torch.Tensor, + dim: int, + mask: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + x_mxscale: Optional[torch.Tensor] = None, + x_flex: Optional[InFlexData] = InFlexData(), + y: Optional[torch.Tensor] = None, + y_dtype: Optional[torch.dtype] = None, + y_flex: Optional[OutFlexData] = OutFlexData(), + y_flex_saturate_inf: bool = False, + y_has_mx: Optional[bool] = None, + postprocess_fn1: Optional[PostprocessFn] = None, + postprocess_fn2: Optional[PostprocessFn] = None, +): + return _ReduceAutograd.apply(x, dim, mask, scale, x_mxscale, x_flex, y_dtype, y_flex, # + y_flex_saturate_inf, y_has_mx, y, postprocess_fn1, postprocess_fn2) + + +# ------------------------------------------------------------ + + def compute_actual_scale(x, dtype, per_batch_scale=False): max_finite = { torch.float8_e5m2: MAX_FINITE_FLOAT8E5, diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py index f3e9582a4c..9656fd0f08 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py @@ -13,11 +13,13 @@ class HopperMXScaleLayout(Layout): name: str = "HOPPER_SCALE" def __init__(self, shape, mx_axis, num_warps=8) -> None: + if mx_axis is not None and mx_axis < 0: + mx_axis += len(shape) assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2" super().__init__(shape) self.mx_axis = mx_axis self.num_warps = num_warps - *self.leading_shape, _, _ = shape + *self.leading_shape, self.M, self.K = shape def _maybe_mT(self, data): if self.mx_axis == len(self.leading_shape): @@ -25,6 +27,7 @@ def _maybe_mT(self, data): return data def swizzle_data(self, data): + assert data.shape == (*self.leading_shape, self.M, self.K) data = self._maybe_mT(data).contiguous() *batch, M, K = data.shape SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8 @@ -59,7 +62,7 @@ def unswizzle_data(self, data): data = data.permute(*perm) data = data.reshape(*batch, M * 32, K // 32) data = self._maybe_mT(data) - return data + return data[..., :self.M, :self.K] def swizzle_block_shape(self, block_shape): return block_shape @@ -70,6 +73,8 @@ def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constex """ Triton inverse of swizzle_mxfp4_scale_hopper """ + if mx_axis is not None and mx_axis < 0: + mx_axis += len(x.shape) tl.static_assert(len(x.shape) == 2, "NYI") # implementation assumes mxfp data is packed along the last dimension x = x.trans() if mx_axis == 0 else x diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py index 118388c275..9128037bf7 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py @@ -92,7 +92,8 @@ class HopperMXValueLayout(Layout): def __init__(self, shape, mx_axis, mma_version=3): super().__init__(shape) - assert mx_axis in range(len(shape)) + if mx_axis < 0: + mx_axis += len(shape) self.mx_axis = mx_axis self.mma_version = mma_version *self.leading_shape, self.K, self.N, = shape diff --git a/python/triton_kernels/triton_kernels/tensor_details/ragged_tensor.py b/python/triton_kernels/triton_kernels/tensor_details/ragged_tensor.py index 277141d95a..9456fb6bf6 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/ragged_tensor.py +++ b/python/triton_kernels/triton_kernels/tensor_details/ragged_tensor.py @@ -45,6 +45,10 @@ class RaggedTensorMetadata: # NOTE 2: because the size of `block_schedule_data[k]` is data-dependent, we pad it with -1s # up to an user-provided upper bound block_schedule_data: torch.Tensor + # expected slice size (for heuristics) + expected_slice_size: int | None = None + # divisibility hint for values in `slice_sizes` + slice_sizes_divisibility: int = None def __post_init__(self): assert self.block_offs_data.shape[0] == len(RaggedTensorMetadata.block_sizes()) @@ -56,6 +60,10 @@ def __post_init__(self): if self.slice_offs is not None: assert self.slice_offs.dtype == torch.int32 + @property + def n_slices(self): + return self.slice_sizes.shape[0] + def block_offs(self, block_size): return self.block_offs_data[RaggedTensorMetadata.block_sizes().index(block_size)] @@ -63,10 +71,14 @@ def block_schedule(self, block_size): return self.block_schedule_data[RaggedTensorMetadata.block_sizes().index(block_size)] @staticmethod - def max_n_tiles(n_slices, n_total_rows): + def n_blocks(n_slices, n_total_rows, block_size): if n_total_rows <= n_slices: return n_total_rows - return n_slices - 1 - ((n_slices - n_total_rows - 1) // min(RaggedTensorMetadata.block_sizes())) + return n_slices - 1 - ((n_slices - n_total_rows - 1) // block_size) + + @staticmethod + def max_n_blocks(n_slices, n_total_rows): + return RaggedTensorMetadata.n_blocks(n_slices, n_total_rows, min(RaggedTensorMetadata.block_sizes())) @staticmethod def block_sizes_log2(): @@ -77,6 +89,11 @@ def block_sizes(): return [2**x for x in RaggedTensorMetadata.block_sizes_log2()] +def ragged_metadata_fields(metadata, block_size): + return (metadata.slice_sizes, metadata.slice_offs, metadata.block_offs(block_size), + metadata.block_schedule(block_size), metadata.expected_slice_size, metadata.slice_sizes_divisibility or 1) + + # utilities # --------------------------------------------------------- # @@ -171,7 +188,7 @@ def make_ragged_tensor_metadata(slice_sizes, n_total_rows): MEMSET_BLOCK = 512 dtype = torch.int32 device = slice_sizes.device - max_n_blocks = RaggedTensorMetadata.max_n_tiles(n_slices, n_total_rows) + max_n_blocks = RaggedTensorMetadata.max_n_blocks(n_slices, n_total_rows) slice_offs_combined, _ = empty_aligned((block_size_num + 1, n_slices + 1), dtype, device, MEMSET_BLOCK) block_schedule_data, n_memset_elts = empty_aligned((block_size_num, max_n_blocks), dtype, device, MEMSET_BLOCK) slice_offs, block_offs_data = slice_offs_combined[0], slice_offs_combined[1:] @@ -200,7 +217,7 @@ def make_ragged_tensor_metadata(slice_sizes, n_total_rows): def make_ragged_tensor_metadata_torch(slice_sizes, n_total_rows): assert slice_sizes.ndim == 1 n_slices = slice_sizes.shape[0] - max_n_blocks = RaggedTensorMetadata.max_n_tiles(n_slices, n_total_rows) + max_n_blocks = RaggedTensorMetadata.max_n_blocks(n_slices, n_total_rows) # offset for each experts device = slice_sizes.device slice_offs = torch.cumsum(slice_sizes, dim=0) @@ -226,11 +243,11 @@ def _build_schedule(block_off, n_blocks): block_offs = dict() block_pid_map = dict() for block_size in RaggedTensorMetadata.block_sizes(): - n_tiles = (slice_sizes + block_size - 1) // block_size - block = torch.cumsum(n_tiles, dim=0) + n_blocks = (slice_sizes + block_size - 1) // block_size + block = torch.cumsum(n_blocks, dim=0) block = torch.cat((torch.zeros(1, device=device), block)).int() block_offs[block_size] = block - block_pid_map[block_size] = _build_schedule(block, n_tiles) + block_pid_map[block_size] = _build_schedule(block, n_blocks) block_offs = torch.stack(list(block_offs.values())) block_pid_map = torch.stack(list(block_pid_map.values())) return RaggedTensorMetadata(slice_sizes, slice_offs, block_offs, block_pid_map) diff --git a/python/triton_kernels/triton_kernels/testing.py b/python/triton_kernels/triton_kernels/testing.py index fd7169a656..0b49a79dc3 100644 --- a/python/triton_kernels/triton_kernels/testing.py +++ b/python/triton_kernels/triton_kernels/testing.py @@ -5,6 +5,11 @@ import sys import torch from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 +from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4, make_ragged_tensor_metadata +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, MXFP_BLOCK_SIZE +from triton_kernels.tensor_details import layout +import itertools +from dataclasses import replace def assert_equal(ref, tri): @@ -194,3 +199,131 @@ def compute_actual_scale(x, dtype, per_batch_scale=False): }[dtype] maxvals = x.abs().amax(dim=tuple(range(1, x.ndim))) if per_batch_scale else x.abs().max() return maxvals / max_finite + + +# --- create tensor --- + + +def normalize_blocks(x, BLOCK_SIZE=None): + if BLOCK_SIZE is None: + BLOCK_SIZE = int(MXFP_BLOCK_SIZE) + x_ndim = x.ndim + if x_ndim == 2: + x = x.unsqueeze(0) + for e, i, j in itertools.product(range(x.shape[0]), range(0, x.shape[1], BLOCK_SIZE), + range(0, x.shape[2], BLOCK_SIZE)): + i_end = min(i + BLOCK_SIZE, x.shape[1]) + j_end = min(j + BLOCK_SIZE, x.shape[2]) + block = x[e, i:i_end, j:j_end] + m_abs = block.abs().max() + i_len = i_end - i + j_len = j_end - j + min_len = min(i_len, j_len) + signs = torch.randint(0, 2, (max(i_len, j_len), ), device=x.device) * 2 - 1 + block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs + if j_len > i_len: + block[i_len - 1, i_len:] = signs[min_len:] * m_abs + elif i_len > j_len: + block[j_len:, j_len - 1] = signs[min_len:] * m_abs + if x_ndim == 2: + x = x.squeeze(0) + return x + + +def alloc_rand(shape, device, dtype, requires_grad=False): + if dtype.itemsize == 1: + tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) + return tmp.to(dtype).requires_grad_(requires_grad) + ret = torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) + ret = normalize_blocks(ret) + return ret + + +def make_slice_sizes(n_slices, total_size, device="cuda"): + torch.manual_seed(0) + dtype = torch.int32 + if total_size < 0: + raise ValueError("total_size must be non-negative") + if n_slices <= 0: + return torch.zeros((0, ), dtype=dtype, device=device) + if total_size == 0: + return torch.zeros((n_slices, ), dtype=dtype, device=device) + # always set one slice size to zero + probs = torch.ones(n_slices, device=device) / n_slices + if n_slices > 1: + probs[2] += probs[1] + probs[1] = 0. + assignments = torch.multinomial(probs, total_size, replacement=True) + counts = torch.bincount(assignments, minlength=n_slices).to(dtype) + assert counts.sum().item() == total_size + assert len(counts) == n_slices + return counts + + +def pad_rows_to_multiples(A, indices, multiple=128, pad_value=float('nan')): + """ + Insert padding so that each row A[i] (for i in indices) + appears at an output row index that is a multiple of `multiple`. + """ + D = A.size(1) + out = [] + for i_cur, i_next in zip(indices[:-1], indices[1:]): + size = (i_next - i_cur) + size_padded = ((size + multiple - 1) // multiple) * multiple + cur = torch.full((size_padded, D), pad_value, dtype=A.dtype, device=A.device) + cur[:size, :] = A[i_cur:i_next, :] + out.append(cur) + return torch.vstack(out) + + +def pad_ragged_tensor(x, x_ragged_metadata, hbm_swizzling, transpose): + multiple = 128 if hbm_swizzling else 64 + if transpose: + y = pad_rows_to_multiples(x.T, x_ragged_metadata.slice_offs, multiple=multiple, pad_value=0).T.contiguous() + else: + y = pad_rows_to_multiples(x, x_ragged_metadata.slice_offs, multiple=multiple, pad_value=0).contiguous() + + y_ragged_metadata = replace(x_ragged_metadata, slice_offs=x_ragged_metadata.block_offs(multiple) * multiple, + slice_sizes_divisibility=multiple) + return y, y_ragged_metadata + + +def make_random_tensor(shape, n_slices, ragged_dim, ragged_padding, device, dtype, mxfp_dim, transpose, + squeeze_batch_dim, hbm_swizzling=False, is_mx_rowmajor=False): + # allocate buffer + buffer_shape = ((n_slices, ) if ragged_dim is None else tuple()) + shape + buffer_dtype = torch.bfloat16 if dtype.has_mx_scale else dtype.torch_dtype + buffer = alloc_rand(buffer_shape, device=device, dtype=buffer_dtype) + if squeeze_batch_dim: + buffer = buffer.squeeze(0) + # handle raggedness + ragged_metadata = None + if ragged_dim is not None: + slice_sizes = make_slice_sizes(n_slices, shape[ragged_dim], device=device) + ragged_metadata = make_ragged_tensor_metadata(slice_sizes, shape[ragged_dim]) + if ragged_padding: + buffer, ragged_metadata = pad_ragged_tensor(buffer, ragged_metadata, hbm_swizzling, ragged_dim == 1) + # handle transpose + if transpose: + buffer = buffer.mT.contiguous().mT + # handle mxfp + scales = None + if mxfp_dim is not None: + assert dtype.has_mx_scale + buffer_dtype = dtype.torch_dtype + if is_mx_rowmajor: + scales = downcast_to_mxfp(buffer, buffer_dtype, axis=mxfp_dim)[1] + buffer = downcast_to_mxfp(buffer.mT.contiguous(), buffer_dtype, axis=mxfp_dim)[0].mT + else: + buffer, scales = downcast_to_mxfp(buffer, buffer_dtype, axis=mxfp_dim) + buffer = wrap_torch_tensor(buffer, FP4 if dtype.is_mxfloat4 else buffer_dtype) + scales = wrap_torch_tensor(scales) + if dtype.is_mxfloat4 and hbm_swizzling and not is_mx_rowmajor: + # convert buffer to swizzled hbm layout + buffer_layout, buffer_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mxfp_dim) + buffer = convert_layout(buffer, buffer_layout, **buffer_layout_opts) + # convert scales to swizzled hbm layout + scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=mxfp_dim, num_warps=8) + scales = convert_layout(scales, scale_layout, **scale_layout_opts) + return buffer, scales, ragged_metadata diff --git a/scripts/skiplist/default/triton_kernels.txt b/scripts/skiplist/default/triton_kernels.txt index 7870a886fd..e69de29bb2 100644 --- a/scripts/skiplist/default/triton_kernels.txt +++ b/scripts/skiplist/default/triton_kernels.txt @@ -1,17 +0,0 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/5074 -tests/test_matmul.py::test_op[False-False-False-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] diff --git a/scripts/skiplist/xe2/triton_kernels.txt b/scripts/skiplist/xe2/triton_kernels.txt index 108cf912ba..e69de29bb2 100644 --- a/scripts/skiplist/xe2/triton_kernels.txt +++ b/scripts/skiplist/xe2/triton_kernels.txt @@ -1,49 +0,0 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/5074 -tests/test_matmul.py::test_op[False-False-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-False-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-False-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-True-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-True-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]