From d8119d51f39a25313e295a7b8ba5940283a08d37 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 23 Nov 2025 13:27:07 -0800 Subject: [PATCH 1/4] [TRITON_KERNELS][BW-BREAKING] matrix multiplication refactor (#8765) --- python/triton/tools/ragged_tma.py | 26 +- python/triton_kernels/bench/bench_mlp.py | 11 +- python/triton_kernels/bench/distributed.py | 51 +- .../triton_kernels/tests/test_distributed.py | 23 +- python/triton_kernels/tests/test_matmul.py | 928 +++++------------- .../test_opt_flags_split_k.py | 8 +- python/triton_kernels/tests/test_reduce.py | 17 +- .../triton_kernels/distributed.py | 2 +- .../{matmul_ogs.py => matmul.py} | 682 ++++++------- .../_common.py | 179 ++-- .../_matmul.py} | 139 +-- .../_p_matmul.py} | 200 ++-- .../opt_flags.py | 54 +- .../opt_flags_details/opt_flags_amd.py | 2 +- .../opt_flags_details/opt_flags_nvidia.py | 14 +- .../triton_kernels/triton_kernels/reduce.py | 364 ++++++- .../layout_details/hopper_scale.py | 9 +- .../layout_details/hopper_value.py | 3 +- .../tensor_details/ragged_tensor.py | 31 +- .../triton_kernels/triton_kernels/testing.py | 133 +++ 20 files changed, 1426 insertions(+), 1450 deletions(-) rename python/triton_kernels/triton_kernels/{matmul_ogs.py => matmul.py} (50%) rename python/triton_kernels/triton_kernels/{matmul_ogs_details => matmul_details}/_common.py (63%) rename python/triton_kernels/triton_kernels/{matmul_ogs_details/_matmul_ogs.py => matmul_details/_matmul.py} (85%) rename python/triton_kernels/triton_kernels/{matmul_ogs_details/_p_matmul_ogs.py => matmul_details/_p_matmul.py} (80%) rename python/triton_kernels/triton_kernels/{matmul_ogs_details => matmul_details}/opt_flags.py (88%) rename python/triton_kernels/triton_kernels/{matmul_ogs_details => matmul_details}/opt_flags_details/opt_flags_amd.py (92%) rename python/triton_kernels/triton_kernels/{matmul_ogs_details => matmul_details}/opt_flags_details/opt_flags_nvidia.py (92%) diff --git a/python/triton/tools/ragged_tma.py b/python/triton/tools/ragged_tma.py index 728dfcd42b..c7730e182a 100644 --- a/python/triton/tools/ragged_tma.py +++ b/python/triton/tools/ragged_tma.py @@ -12,7 +12,7 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0): of potentially unequal size. The load_ragged and store_ragged device functions can be used to read - and write from subarrays T[batch_offset : batch_offset + batch_size] + and write from subarrays T[slice_off : slice_off + slice_size] with hardware bounds-checking preventing any sort of leakage outside the subarray. """ @@ -46,22 +46,22 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0): @triton.jit -def to_ragged_indices(batch_offset, batch_size, row): +def to_ragged_indices(slice_off, slice_size, row): """ Helper function for load_ragged and store_ragged. """ billion = 0x40000000 # == 2**30 - x = billion - batch_size + row - y = batch_offset + batch_size + x = billion - slice_size + row + y = slice_off + slice_size return billion, y, x @triton.jit -def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0): +def load_ragged(TMA, slice_off, slice_size, coords, ragged_dim: tl.constexpr = 0): """ - Read from a subarray T[batch_offset : batch_offset + batch_size] with + Read from a subarray T[slice_off : slice_off + slice_size] with hardware bounds-checking, where reading outside the subarray gives zeros. Coords should be an appropriately-sized list of integers, just like in @@ -70,16 +70,16 @@ def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor") - c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim]) + c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim]) data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:]) data = tl.reshape(data, data.shape[2:]) return data @triton.jit -def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0): +def store_ragged(TMA, slice_off, slice_size, coords, data, ragged_dim: tl.constexpr = 0): """ - Write to a subarray T[batch_offset : batch_offset + batch_size] with + Write to a subarray T[slice_off : slice_off + slice_size] with hardware bounds-checking, where writes outside the subarray are masked correctly. @@ -87,15 +87,15 @@ def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.con TMA.store(). """ - c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim]) + c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim]) data = tl.reshape(data, [1, 1] + data.shape) TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data) @triton.jit -def atomic_add_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0): +def atomic_add_ragged(TMA, slice_off, slice_size, coords, data, ragged_dim: tl.constexpr = 0): """ - Atomic add into a subarray T[batch_offset : batch_offset + batch_size] with + Atomic add into a subarray T[slice_off : slice_off + slice_size] with hardware bounds-checking, where adds outside the subarray are masked correctly. @@ -103,6 +103,6 @@ def atomic_add_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: t TMA.atomic_add(). """ - c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim]) + c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim]) data = tl.reshape(data, [1, 1] + data.shape) TMA.atomic_add([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data) diff --git a/python/triton_kernels/bench/bench_mlp.py b/python/triton_kernels/bench/bench_mlp.py index ca3a977a3e..ff891c901f 100644 --- a/python/triton_kernels/bench/bench_mlp.py +++ b/python/triton_kernels/bench/bench_mlp.py @@ -7,7 +7,7 @@ import triton_kernels import triton_kernels.roofline as roofline import triton_kernels.swiglu -from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation +from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation from triton_kernels.target_info import get_cdna_version import distributed as triton_dist from triton_kernels.tensor_details import layout @@ -71,7 +71,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d input_x = torch.randn((batch // DP, dim1), device=dev) expt_assignment = triton_dist.create_expt_assignment(EP, n_expts_tot, torch.device(dev)) - triton_dist.initialize_matmul_ogs(batch, dim1, dim2, n_expts_act, n_expts_tot, input_x.dtype) + triton_dist.initialize_matmul(batch, dim1, dim2, n_expts_act, n_expts_tot, input_x.dtype) # run layer fpath = Path(tempfile.mktemp()) @@ -80,7 +80,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d xg = input_x.to(wg.dtype if n_expts_tot > 1 else input_x.dtype) for i in range(100): if n_expts_tot > 1: # sparse - logits = matmul_ogs(xg, wg, bg, precision_config=pcg) + logits = matmul(xg, wg, bg, precision_config=pcg) x, rdata, gather_indx, scatter_indx, metadata = triton_dist.routing(input_x, logits, n_expts_act, EP=EP, TP=TP, expt_assignment=expt_assignment, mode="ep_sharding") @@ -88,9 +88,8 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d x = triton_dist.all_gather(input_x, dim=0) rdata, gather_indx, scatter_indx, metadata = None, None, None, None if x.nelement() > 0: - x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) - x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx, - precision_config=pc2) + x = matmul(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) + x = matmul(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx, precision_config=pc2) x = triton_dist.reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment) proton.finalize() return roofline.parse_profile(fpath.with_suffix(".hatchet"), useful_op_regex=".*matmul.*") diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index e425e56ac1..1a5dea995e 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -10,12 +10,11 @@ import triton_kernels import triton_kernels.swiglu from triton_kernels.reduce import reduce -from triton_kernels.matmul_ogs import RoutingData, GatherIndx, ScatterIndx from triton_kernels.topk import topk -from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation +from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq from triton_kernels.tensor_details import layout -from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata +from triton_kernels.tensor import RaggedTensorMetadata, make_ragged_tensor_metadata, remap_ragged_tensor_metadata from triton_kernels.distributed import make_expt_dict_uniform, make_expt_assignment, convert_dp_to_ep, convert_ep_to_dp, ExptAssignment, symm_mem_pool from bench_utils import quantize_weight @@ -40,7 +39,7 @@ def create_expt_assignment(EP: int, n_expts_tot: int, device: torch.device) -> O return make_expt_assignment(EP, n_expts_tot, expt_dict, device) -def initialize_matmul_ogs( +def initialize_matmul( batch: int, dim1: int, dim2: int, @@ -52,7 +51,7 @@ def initialize_matmul_ogs( return world_size = dist.get_world_size() device = torch.cuda.current_device() - symm_mem_pool.initialize_matmul_ogs( + symm_mem_pool.initialize_matmul( n_tokens_global=batch, d_input=dim1, d_model=dim2, @@ -146,8 +145,7 @@ def routing( TP: int = 1, expt_assignment: Optional[ExptAssignment] = None, mode: Optional[str] = None, -) -> Tuple[torch.Tensor, RoutingData, GatherIndx, ScatterIndx, Optional[ReduceScatterMetadata]]: - n_expts_tot = logits.shape[-1] +) -> Tuple[torch.Tensor, RaggedTensorMetadata, torch.Tensor, torch.Tensor, Optional[ReduceScatterMetadata]]: if _is_distributed_launch() and mode: if mode == "ep_sharding": if not expt_assignment: @@ -170,15 +168,13 @@ def routing( logits_global_metadata = make_ragged_tensor_metadata(expt_sizes, dispatch_indx.shape[0]) x = convert_dp_to_ep(x, expt_assignment, active_indx, dispatch_indx) logits_local_metadata = remap_ragged_tensor_metadata(logits_global_metadata, expt_map) - gate_scal = logits_global.vals.flatten()[combine_indx] - rdata = RoutingData(gate_scal, expt_sizes, n_expts_tot // EP, n_expts_act, logits_local_metadata) reduce_scatter_metadata = ReduceScatterMetadata( mode=mode, active_indx=active_indx, dispatch_indx=dispatch_indx, combine_indx=combine_indx, ) - return x, rdata, None, None, reduce_scatter_metadata + return x, logits_local_metadata, None, None, reduce_scatter_metadata else: raise NotImplementedError(f"Distributed routing mode {mode} is not implemented yet.") else: @@ -186,13 +182,10 @@ def routing( logits = topk(logits, n_expts_act, y_indx=y_indx, apply_softmax=not sm_first) dispatch_indx = logits.mask_metadata.row_sorted_indx combine_indx = logits.mask_metadata.col_sorted_indx - ragged_batch_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0]) - gate_scal = logits.vals.flatten()[combine_indx] - routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act, - ragged_batch_metadata) - gather_indx = GatherIndx(combine_indx, dispatch_indx) - scatter_indx = ScatterIndx(dispatch_indx, combine_indx) - return x, routing_data, gather_indx, scatter_indx, None + ragged_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0]) + gather_indx = combine_indx // n_expts_act + scatter_indx = combine_indx + return x, ragged_metadata, gather_indx, scatter_indx, None def gather_ep(rank, world_size, param, TP, EP): @@ -276,14 +269,14 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac w1_full = w2_full = w1_flex_full = w2_flex_full = w1_scale_full = w2_scale_full = None # precision configs - pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), weight_scale=wg_scale) + pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), b_mx_scale=wg_scale) act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2), (1.0, 1.0)) - pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), weight_scale=w1_scale) - pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), weight_scale=w2_scale) + pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), b_mx_scale=w1_scale) + pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), b_mx_scale=w2_scale) if rank == 0: - pc1_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex_full), weight_scale=w1_scale_full) - pc2_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex_full), weight_scale=w2_scale_full) + pc1_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex_full), b_mx_scale=w1_scale_full) + pc2_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex_full), b_mx_scale=w2_scale_full) else: pc1_full = pc2_full = None @@ -296,7 +289,7 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac xd = torch.randn((batch // world_size, dim1), device=dev).to(dtype_map[x_dtype]) x0 = all_gather(xd, dim=0) expt_assignment = create_expt_assignment(EP, n_expts_tot, torch.device(dev)) - symm_mem_pool.initialize_matmul_ogs( + symm_mem_pool.initialize_matmul( n_tokens_global=batch, d_input=dim1, d_model=dim2, @@ -312,25 +305,25 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac def single(x): xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype) if n_expts_tot > 1: - logits = matmul_ogs(xg, wg, bg, precision_config=pcg) + logits = matmul(xg, wg, bg, precision_config=pcg) x, rdata, gi, si, _ = routing(x, logits, n_expts_act) else: rdata = gi = si = None - x = matmul_ogs(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act) - return matmul_ogs(x, w2_full, b2_full, rdata, scatter_indx=si, precision_config=pc2_full) + x = matmul(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act) + return matmul(x, w2_full, b2_full, rdata, scatter_indx=si, precision_config=pc2_full) # distributed pass def distributed(x): xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype) if n_expts_tot > 1: # sparse - logits = matmul_ogs(xg, wg, bg, precision_config=pcg) + logits = matmul(xg, wg, bg, precision_config=pcg) x, rdata, gi, si, metadata = routing(x, logits, n_expts_act, EP=EP, TP=TP, expt_assignment=expt_assignment, mode="ep_sharding") else: # dense x = all_gather(x, dim=0) rdata = gi = si = metadata = None - x = matmul_ogs(x, w1, b1, rdata, gather_indx=gi, precision_config=pc1, fused_activation=act) - x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=si, precision_config=pc2) + x = matmul(x, w1, b1, rdata, gather_indx=gi, precision_config=pc1, fused_activation=act) + x = matmul(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=si, precision_config=pc2) x = reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment) # gather the result from all GPUs, just for verification return all_gather(x, dim=0) diff --git a/python/triton_kernels/tests/test_distributed.py b/python/triton_kernels/tests/test_distributed.py index c27e024d89..a48fcd6c48 100644 --- a/python/triton_kernels/tests/test_distributed.py +++ b/python/triton_kernels/tests/test_distributed.py @@ -9,7 +9,7 @@ from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_dict_random, make_expt_assignment, symm_mem_pool from triton_kernels.reduce import reduce from triton_kernels.topk import topk -from triton_kernels.matmul_ogs import matmul_ogs, RoutingData, GatherIndx, ScatterIndx +from triton_kernels.matmul import matmul from triton_kernels.target_info import is_hip from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata import pytest @@ -123,17 +123,18 @@ def routing(logits, n_expts_act, all_gather=False, y_indx=None): dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx combine_indx = sparse_logits.mask_metadata.col_sorted_indx ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]) - gate_scal = sparse_logits.vals.flatten()[combine_indx] - routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, logits.shape[-1], n_expts_act, - ragged_batch_metadata) - gather_idx = GatherIndx(combine_indx, dispatch_indx) - scatter_idx = ScatterIndx(dispatch_indx, combine_indx) - return routing_data, gather_idx, scatter_idx, sparse_logits.indx + gather_idx = torch.div(combine_indx, n_expts_act, rounding_mode="trunc") + scatter_idx = combine_indx + return ragged_batch_metadata, gather_idx, scatter_idx, sparse_logits.indx def mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act, y_indx=None): rdata, combine_indx, dispatch_indx, _ = routing(l_global, n_expts_act, y_indx=y_indx) - y_global = matmul_ogs(x_global, w_global, b_global, rdata, gather_indx=combine_indx, scatter_indx=dispatch_indx) + y_global = matmul(x_global, w_global, b_global, rdata, gather_indx=combine_indx, scatter_indx=dispatch_indx) + y_mask = (dispatch_indx != -1).view(y_global.shape[-2] // n_expts_act, n_expts_act, 1) + y_global = y_global.view(y_global.shape[-2] // n_expts_act, n_expts_act, -1) + y_mask = y_mask.expand_as(y_global) + y_global, _ = reduce(y_global, dim=1, mask=y_mask) return y_global @@ -154,9 +155,7 @@ def mixture_of_expt_epsharded(x_dp_local, l_dp_local, w_ep_local, b_ep_local, ex y_ep_local = convert_dp_to_ep(x_dp_local, expt_assignment, active_indx, dispatch_indx) y_ep_local_metadata = remap_ragged_tensor_metadata(x_global_metadata, expt_map) # matrix multiply - # TODO: clean-up API. `RoutingData` should not exist; we should be passing `y_ep_local_metadata`. - rdata_ep_local = RoutingData(None, expt_sizes, w_ep_local.shape[0], n_expts_act, y_ep_local_metadata) - y_ep_local = matmul_ogs(y_ep_local, w_ep_local, b_ep_local, rdata_ep_local) + y_ep_local = matmul(y_ep_local, w_ep_local, b_ep_local, a_ragged_metadata=y_ep_local_metadata) # convert x from expert-sorted, ep-local to token-sorted, dp-local y_dp_local = convert_ep_to_dp(y_ep_local, expt_assignment, active_indx, combine_indx) # weighted average of the output token from experts @@ -209,7 +208,7 @@ def run_mixture(): y_indx=y_indx_global, ) - symm_mem_pool.initialize_matmul_ogs( + symm_mem_pool.initialize_matmul( n_tokens_global=n_tokens_global, d_input=d_model, d_model=d_model, diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 664de12bdd..6b50ca8a11 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -1,169 +1,38 @@ # isort: off # fmt: off -from dataclasses import dataclass, fields, replace +from dataclasses import dataclass, fields import itertools import pytest import torch from typing import Union import triton # matmul utilities -import triton_kernels.matmul_ogs_details.opt_flags as opt_flags -from triton_kernels.matmul_ogs import FlexCtx, RoutingData, InnerRoutingData, PrecisionConfig, FusedActivation, FnSpecs, FnName, Epilogue -from triton_kernels.matmul_ogs import GatherIndx, ScatterIndx -from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch -from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig -from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4, make_ragged_tensor_metadata -from triton_kernels.tensor_details import layout -from triton_kernels.topk import topk +import triton_kernels.matmul_details.opt_flags as opt_flags +from triton_kernels.matmul import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs, FnName, Epilogue +from triton_kernels.matmul import matmul_set_idle_sms, matmul, matmul_torch # numerics utilities from triton_kernels.numerics import InFlexData, OutFlexData -from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp, quantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE +from triton_kernels.numerics_details.mxfp import upcast_from_mxfp, quantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE # testing utilities -from triton_kernels.testing import assert_close, compute_actual_scale +from triton_kernels.testing import assert_close, make_random_tensor # target-specific utilities from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4 - -# --------------- -# initialize data -# --------------- - - -def alloc_rand(shape, device, dtype, requires_grad=True): - if dtype.itemsize == 1: - tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) - return tmp.to(dtype).requires_grad_(requires_grad) - return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) - - -def alloc_rand_like(x): - return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad) - - -def init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter, device="cuda"): - logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True) - sparse_logits = topk(logits, n_expts_act) - dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx - combine_indx = sparse_logits.mask_metadata.col_sorted_indx - ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]) - routing_data = RoutingData(None, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act, ragged_batch_metadata) - gather_idx = GatherIndx(combine_indx, dispatch_indx) if do_gather else None - scatter_idx = ScatterIndx(dispatch_indx, combine_indx) if do_scatter else None - return m, routing_data, gather_idx, scatter_idx - - -def init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode, act_dtype, weight_dtype, - has_y_gammas, requires_grad=True, device="cuda", - inner_expt_opt=None, padding_block_k=None): - torch.manual_seed(0) - assert mode in {'batched', "plain", 'ragged'} - if inner_expt_opt is not None: - assert gindx is None and sindx is None - # Rotate tensor shapes (dw = x.T @ dy) - m, k = k, m * n_expts_act + padding_block_k * n_expts_tot - in_m = m - else: - in_m = m * (n_expts_act if gindx is None else 1) - shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k) - shape_batch = tuple() if (mode == "plain" or inner_expt_opt is not None) else (n_expts_tot, ) - x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad) - w = alloc_rand(shape_batch + (k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad) - bias = alloc_rand(shape_batch + (n, ), device=device, dtype=torch.float32, requires_grad=requires_grad) - gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) - gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) - gs0 = gs0.detach().requires_grad_(requires_grad) - gs1 = gs1.detach().requires_grad_(requires_grad) - if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2): - gs0 = None - gs1 = None - if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10: - w = w.transpose(-1, -2).contiguous().transpose(-1, -2) - - def _apply_padding_and_fill_unused_part_with_nan(t, is_padded): - nan_val = float("nan") - if t.element_size() == 1: - t = t.view(torch.int8) - nan_val = 127 - - start = 0 - if is_padded: - for this_expt_nrows in rdata.expt_hist.tolist(): - end = start + this_expt_nrows - padding_end = start + triton.cdiv(this_expt_nrows, padding_block_k) * padding_block_k - t[end:padding_end, :] = 0 - start = padding_end - assert start <= t.shape[0] - t[start:, :] = nan_val - else: - n_actual_rows = rdata.expt_hist.sum().item() - if n_actual_rows + padding_block_k < t.shape[0]: - t[n_actual_rows+padding_block_k:, :] = nan_val - - if inner_expt_opt is not None: - bias = None - _apply_padding_and_fill_unused_part_with_nan(x.T, "pad_x" in inner_expt_opt) - _apply_padding_and_fill_unused_part_with_nan(w, "pad_w" in inner_expt_opt) - - return x, w, bias, gs0, gs1 - +from triton_kernels.swiglu import swiglu, swiglu_fn +from triton_kernels.swiglu import PrecisionConfig as SwiGLUPrecisionConfig # --------------- # numerics stuff # --------------- +class DType: -def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, mode, n_expts_tot=1, expt_is_inner=False, device="cuda"): - weight_use_flexpoint = weight_dtype.itemsize == 1 and not weight_mxfp - # flexpoint - make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) + - ([val0] - if n_expts_tot % 2 else []), dtype=torch.float32, device=device) - make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device=device) - make = lambda val0, val1, is_tensor: make_tensor(val0, val1) if is_tensor else make_scalar(val0) - - in_flex_data = lambda scale, use_flex: InFlexData(dtype=out_dtype, scale=make_scalar(scale) - ) if use_flex else InFlexData() - flex_ctx = FlexCtx( - lhs_data=in_flex_data(1.25, act_use_flexpoint), - rhs_data=InFlexData( - dtype=weight_dtype, - scale=make(1.50, 1.25, not expt_is_inner), - ) if weight_use_flexpoint else InFlexData(), - out_data=OutFlexData( - dtype=out_dtype, - expected_scale=make_scalar(4.00), - actual_scale=make_scalar(0), - checksum_scale=None, - ) if act_use_flexpoint else OutFlexData(), - ) - return PrecisionConfig(flex_ctx=flex_ctx, acc_scale=2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0, - out_dtype=out_dtype) - - -def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config): - flex_ctx = precision_config.flex_ctx - - def apply(x, scale): - if scale is None: - x = x.clone() - elif scale.numel() == 1: - x = x.float() * scale - else: - assert x.ndim == 3 - assert scale.numel() == x.shape[0] - x = x.float() * scale[:, None, None] - return x.detach().requires_grad_() - - return ( - apply(x_tri, flex_ctx.lhs_data.scale), - apply(w_tri, flex_ctx.rhs_data.scale), - None if bias_tri is None else apply(bias_tri, None), - None if gs0_tri is None else apply(gs0_tri, None), - None if gs1_tri is None else apply(gs1_tri, None), - ) - + def __init__(self, dtype_str): + self.has_global_scale = dtype_str.startswith("float8") + self.has_mx_scale = dtype_str.startswith("mx") + to_torch_dtype = lambda name: torch.uint8 if name == "float4_e2m1" else getattr(torch, name) + self.torch_dtype = to_torch_dtype(dtype_str.strip("mx")) + self.is_mxfloat4 = self.has_mx_scale and "float4" in dtype_str -def dtype_str_to_torch(dtype_str: str) -> torch.dtype: - return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str) # Scope to ensure that the opt_flags_constraints are reset after the test @@ -173,6 +42,22 @@ def opt_flags_scope(request): opt_flags.reset_opt_flags_constraints() +def make_constraints(block_m, split_k, is_persistent, epilogue_subtile, hbm_swizzling, weight_dtype_str): + constraints = { + "block_m": block_m, + "split_k": split_k, + "is_persistent": is_persistent, + "epilogue_subtile": epilogue_subtile, + } + if is_hip() and hbm_swizzling and "float4" in weight_dtype_str: + # Minimum block size to satisfy scale preshuffling + constraints.update({ + "block_m": 32, + "block_n": 32, + "block_k": 256 + }) + return constraints + # --------------- # unit tests # --------------- @@ -186,121 +71,116 @@ class Case: mode: str act_dtype_str: str weight_dtype_str: str - n_expts_tot: int = 1 - n_expts_act: int = 1 + n_slices: int = None split_k: int = 1 hbm_swizzling: bool = False epilogue_subtile: Union[int, None] = None - x_transpose: bool = False - w_transpose: bool = False - y_transpose: bool = False + a_transpose: bool = False + b_transpose: bool = False + c_transpose: bool = False colmajor_mxfp_weight: bool = True - + swiglu_opts: tuple[float, float] = None + + def __post_init__(self): + if self.n_slices is None: + self.n_slices = 1 if self.mode == "plain" else 10 + +def _build_test_op_cases(): + test_cases = [] + # zero-sized + test_cases.extend([ + Case(m, n, k, mode, "float16", "float16") + for mode in ("ragged", "batched") + for (m, n, k) in ((0, 5, 7), (5, 0, 7), (5, 7, 0)) + ]) + odd_shape1 = (727, 577, 859) + odd_shape2 = (720, 576, 768) + even_shape = (768, 512, 1024) + # canonical float16 + test_cases.extend([ + Case(*shape, mode, "float16", "float16", split_k=split_k) + for shape in [odd_shape1, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5] + ]) + # native float8 + test_cases.extend([ + Case(*shape, mode, "float8_e5m2", "float8_e5m2", split_k=split_k) + for shape in [odd_shape1, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5] + ]) + test_cases.extend([ + Case(*even_shape, "ragged", "float8_e5m2", "float8_e5m2", epilogue_subtile=val) + for val in (1, 2, 4) + ]) + # bfloat16 x mx + for shape in [odd_shape2, even_shape]: + test_cases.extend([ + Case(*shape, "plain", "bfloat16", "mxfloat4_e2m1"), + Case(*shape, "plain", "bfloat16", "mxfloat4_e2m1", hbm_swizzling=True), + Case(*shape, "batched", "bfloat16", "mxfloat4_e2m1"), + Case(*shape, "batched", "bfloat16", "mxfloat4_e2m1", hbm_swizzling=True), + Case(*shape, "ragged", "bfloat16", "mxfloat4_e2m1"), + Case(*shape, "ragged", "bfloat16", "mxfloat4_e2m1", hbm_swizzling=True), + Case(*shape, "ragged", "bfloat16", "mxfloat4_e2m1", split_k=9), + Case(*shape, "ragged", "bfloat16", "mxfloat4_e2m1", split_k=9, hbm_swizzling=True), + Case(*shape, "ragged", "bfloat16", "mxfloat8_e4m3fn"), + Case(*shape, "ragged", "bfloat16", "mxfloat8_e4m3fn", hbm_swizzling=True) + ]) + # float8 x mxfloat + test_cases.extend([ + Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", hbm_swizzling=True), + Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1", hbm_swizzling=True), + Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1"), + Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9), + Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9, hbm_swizzling=True), + Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn"), + Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1"), + Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn"), + ]) + # mxfloat x mxfloat + test_cases.extend([ + Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", hbm_swizzling=True), + Case(1024, 1024, 1024, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", split_k=9, hbm_swizzling=True), + Case(1024, 1024, 1024, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", split_k=9, colmajor_mxfp_weight=False), + Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"), + Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", hbm_swizzling=True), + Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"), + Case(1024, 1024, 1024, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", hbm_swizzling=True), + ]) + # amd-specific float8 + test_cases.extend([ + Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), + Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), + Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", split_k=2), + Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"), + ]) + # transposes / permutes + test_cases.extend([ + Case(320, 400, 400, "batched", "float16", "float16", + a_transpose=a_tr, b_transpose=b_tr, c_transpose=c_tr) + for a_tr, b_tr, c_tr in itertools.product((False, True), repeat=3) + ]) + test_cases.extend([ + Case(320, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", + a_transpose=False, b_transpose=True, c_transpose=False), + Case(320, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", + a_transpose=True, b_transpose=True, c_transpose=True), + ]) + # swiglu + test_cases.extend([ + Case(*shape, mode, "bfloat16", "bfloat16", split_k=split_k, swiglu_opts=(1.1, 1.4)) + for shape in [odd_shape2, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5] + ]) + test_cases.extend([ + Case(*even_shape, "ragged", "bfloat16", "bfloat16", epilogue_subtile=val, swiglu_opts=(1.1, 1.4)) + for val in (1, 2, 4) + ]) + + return test_cases @pytest.mark.parametrize( ", ".join(f.name for f in fields(Case)), [ - tuple(getattr(case, f.name) for f in fields(Case)) for case in [ - # Zero-sized args: - Case(0, 5, 7, "ragged", "float16", "float16"), - Case(5, 0, 7, "ragged", "float16", "float16"), - Case(5, 7, 0, "ragged", "float16", "float16"), - Case(0, 5, 7, "batched", "float16", "float16"), - Case(5, 0, 7, "batched", "float16", "float16"), - Case(5, 7, 0, "batched", "float16", "float16"), - # Non-mx types: - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4), - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), - Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1), - Case(16, 256, 256, "batched", "float16", "float16", 5, 1), - Case(16, 256, 256, "ragged", "float16", "float16", 3, 1), - Case(256, 256, 256, "ragged", "float16", "float16", 4, 1), - Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3), - Case(300, 400, 400, "batched", "float16", "float16", 5, 1), - Case(300, 400, 400, "ragged", "float16", "float16"), - Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"), - Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1), - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1), - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2), - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4), - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2), - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2), - Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1), - Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2), - Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9), - Case(16, 16, 1000, "batched", "float16", "float16", 5, 1, split_k=None), - Case(16, 16, 1000, "batched", "float8_e5m2", "float8_e5m2", 5, 1, split_k=None), - Case(16, 16, 2048, "batched", "float8_e5m2", "float8_e5m2", 6, 1, split_k=5), - # mx types: - Case(1, 1024, 1024, "plain", "bfloat16", "mxfloat8_e4m3fn", 1, 1), - Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1), - Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), - Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True, epilogue_subtile=4), - Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1), - Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), - Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), - Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), - Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9), - Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), - Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4), - Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), - Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4), - Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), - # Cover (N or K) % 128 == 64 (https://github.com/triton-lang/triton/pull/7203) - Case(1, 1472, 1472, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4), - Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), - Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), - Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), - Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1), - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9), - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2), - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), - Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4), - Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), - Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4), - Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), - Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4), - Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), - Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), - Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False), - Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), - Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), - Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1), - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9), - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, colmajor_mxfp_weight=False), - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2), - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), - Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), - Case(300, 512, 512, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), - Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), - Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4), - Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), - Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4), - Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), - # AMD - Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), - Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1), - Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2), - Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2), - Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"), - Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1), - Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2), - ] + [ - Case(320, 400, 400, mode, dtype, dtype, n_expts_tot, n_expts_act, - x_transpose=x_transpose, w_transpose=w_transpose, y_transpose=y_transpose) - for (mode, n_expts_tot, n_expts_act) in ( - ("batched", 1, 1), - ("ragged", 8, 4), - ("ragged", 32, 4), - ) - for dtype in ("float16", "float8_e5m2") - for x_transpose in (False, True) - for w_transpose in (False, True) - for y_transpose in (False, True) - ] + tuple(getattr(case, f.name) for f in fields(Case)) + for case in _build_test_op_cases() ], ) @pytest.mark.parametrize("block_m", [16, 128]) @@ -308,36 +188,34 @@ class Case: (False, False, None), (True, False, None), (False, True, None), - (False, True, None), - (True, True, None), (True, True, None), - (False, False, "pad_w"), - (False, False, "pad_x"), + (False, False, "pad_b"), + (False, False, "pad_a"), ]) -@pytest.mark.parametrize("has_y_gammas", [False, True]) +@pytest.mark.parametrize("do_gamma", [False, True]) @pytest.mark.parametrize("is_persistent", [False, True]) -def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, - n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, - x_transpose, w_transpose, y_transpose, - device, opt_flags_scope): +def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, n_slices, + mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + a_transpose, b_transpose, c_transpose, + swiglu_opts, device, opt_flags_scope): # We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to # the frame that called pytest.skip, including all the tensors, leading to OOM. skip_message = None try: - _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, - n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, - x_transpose, w_transpose, y_transpose, - device, opt_flags_scope) + _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, n_slices, + mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + a_transpose, b_transpose, c_transpose, + swiglu_opts, device, opt_flags_scope) except pytest.skip.Exception as e: skip_message = str(e) if skip_message is not None: pytest.skip(skip_message) -def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, - n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, - x_transpose, w_transpose, y_transpose, - device, opt_flags_scope): +def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, n_slices, + mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + a_transpose, b_transpose, c_transpose, + swiglu_opts, device, opt_flags_scope): # TODO: remove when Triton FP8 supports proper RTNE if is_cuda(): if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9: @@ -347,6 +225,8 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm if weight_dtype_str.startswith("mx"): if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10: pytest.skip("float8 x mx not supported with cuda capability < 10") + if swiglu_opts is not None and do_gamma: + pytest.skip("NYI: swiglu and gamma not supported together") elif is_hip(): if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4(): @@ -357,6 +237,9 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm pytest.skip("NYI: mx x mx not tested on AMD GPU") if is_persistent: pytest.skip("NYI: Persistent kernel not supported on AMD GPU") + # FIXME: this works on nvidia; looks like some sort of bug on AMD? + if do_gamma and swiglu_opts is not None: + pytest.skip("NYI: gamma and swiglu not supported together on AMD GPU") if split_k is not None and split_k > 1: pytest.skip("splitK hasn't been fully tested on AMD GPU.") @@ -379,11 +262,11 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm if expt_is_inner: if mode != "ragged": pytest.skip("inner_expt_opt only meaningful with ragged") - if "mx" in act_dtype_str and inner_expt_opt != "pad_x": - pytest.skip("inner_expt_opt and act mx only supported with pad_x") + if "mx" in act_dtype_str and inner_expt_opt != "pad_a": + pytest.skip("inner_expt_opt and act mx only supported with pad_a") if "mx" in weight_dtype_str: - if inner_expt_opt != "pad_w": - pytest.skip("inner_expt_opt and weight mx only supported with pad_w") + if inner_expt_opt != "pad_b": + pytest.skip("inner_expt_opt and weight mx only supported with pad_b") if is_persistent and not hbm_swizzling: pytest.skip("FIXME: Fatal Python error: Aborted") if is_hip(): @@ -391,407 +274,138 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm pytest.skip("FIXME: failed to translate module to LLVM IR") if hbm_swizzling: pytest.skip("NYI: nner_expt_opt and HBM swizzling") + if not colmajor_mxfp_weight: + if torch.cuda.get_device_capability()[0] < 10: + pytest.skip("transposed mxfp weight not supported with cuda capability < 10") + if block_m == 16: + pytest.skip("PassManager::run failed from Triton compiler") + # TODO: should construct the test case differently rather than overriding here + if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 10: + b_transpose = True - # launch metadata for batched / mx types may not work yet. torch.manual_seed(0) - block_k = None - if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10: - # Override block_k for testing correctness. The default is temporarily 128 for - # performance reasons which doesn't work with persistent matmul. - # TODO: revisit when Triton is better for H100 + MXFP4 - block_k = 256 - - constraints = { - "block_m": block_m, - "block_k": block_k, - "split_k": split_k, - "is_persistent": is_persistent, - "epilogue_subtile": epilogue_subtile, - } - - if is_hip() and hbm_swizzling and "float4" in weight_dtype_str: - # Minimum block size to satisfy scale preshuffling - constraints.update({ - "block_m": 32, - "block_n": 32, - "block_k": 256 - }) - + # set opt flags constraints + constraints = make_constraints(block_m, split_k, is_persistent, epilogue_subtile, hbm_swizzling, weight_dtype_str) opt_flags.update_opt_flags_constraints(constraints) - weight_mxfp = weight_dtype_str.startswith("mx") - weight_mxfp4 = weight_mxfp and "float4" in weight_dtype_str - if weight_mxfp: - weight_dtype_str = weight_dtype_str[2:] - act_mxfp8 = act_dtype_str.startswith("mx") - act_is_float8 = act_dtype_str.startswith("float8") - if act_mxfp8: - act_dtype_str = act_dtype_str[2:] - quantize_mxfp8_spec = FnSpecs( - FnName.QUANTIZE_MXFP8.name, quantize_mxfp8_fn, (), () - ) - - test_bwd = False - weight_dtype = dtype_str_to_torch(weight_dtype_str) - act_dtype = dtype_str_to_torch(act_dtype_str) - precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp, - mode, n_expts_tot, expt_is_inner, device=device) - # precision_opt.x_pad_trans_requires_flexpoint = False - if mode == "ragged": - m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter, - device=device) - else: - rdata = gindx = sindx = None - - padding_block_k = 32 - if hbm_swizzling: - if torch.cuda.get_device_capability()[0] >= 10: - # Blackwell scale swizzling constraint - # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py#L45 - padding_block_k = 128 - elif not is_persistent: - padding_block_k = 64 - x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, - mode, torch.bfloat16 if act_mxfp8 else act_dtype, # - torch.bfloat16 if weight_mxfp else weight_dtype, - has_y_gammas, requires_grad=test_bwd, device=device, - inner_expt_opt=inner_expt_opt, padding_block_k=padding_block_k) - x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt) - - if x_transpose: - x_tri = x_tri.detach().transpose(-1, -2).contiguous().transpose(-1, -2).requires_grad_(test_bwd) - if w_transpose: - w_tri = w_tri.detach().transpose(-1, -2).contiguous().transpose(-1, -2).requires_grad_(test_bwd) - if y_transpose: - if mode == "batched": - yT_shape = (n_expts_tot, n, x_tri.shape[-2]) - elif expt_is_inner: - yT_shape = (n_expts_tot, n, k) - elif sindx is not None: - yT_shape = (n, m) - else: - n_rows = x_tri.shape[-2] if gindx is None else gindx.dst_indx.shape[0] - yT_shape = (n, n_rows) - y_tri_in = torch.empty(yT_shape, dtype=act_dtype, device=device).transpose(-1, -2) - else: - y_tri_in = None - - if w_tri.shape[0] == 1 and mode != "batched": - # Test the case when weight has dim 2, i.e., shape (K, N). - w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd) - w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd) - - if weight_mxfp: - mx_axis = w_tri.ndim - 2 - # compute layouts - w_layout, w_layout_opts = layout.StridedLayout, dict() - w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict() - if hbm_swizzling and weight_mxfp4: - w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis) - w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( - mx_axis=mx_axis, num_warps=8) - # downcast to mxfp - w_tri_orig = w_tri - if colmajor_mxfp_weight: - w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) - w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis) - w_tri_dtype = FP4 if weight_mxfp4 else weight_dtype - w_tri = wrap_torch_tensor(w_tri, w_tri_dtype) - w_scale_tri = wrap_torch_tensor(w_scale_tri) - # convert layouts - w_tri = convert_layout(w_tri, w_layout, **w_layout_opts) - w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts) - else: - if torch.cuda.get_device_capability()[0] < 10: - pytest.skip("transposed mxfp weight not supported with cuda capability < 10") - if block_m == 16: - pytest.skip("PassManager::run failed from Triton compiler") - # TODO: swizzling for rowmajor - - # A typical use case is we already quantized col-major weight, - # and we want matmul with its transposed row-major weight w/o - # requantization. - - # put abs_max of each 32x32 block to diagonal so scales of transposed agree - w_ndim = w_tri.ndim - if w_ndim == 2: - w_tri = w_tri.unsqueeze(0) - BLOCK_SIZE = int(MXFP_BLOCK_SIZE) - for e, i, j in itertools.product(range(w_tri.shape[0]), range(0, w_tri.shape[1], BLOCK_SIZE), range(0, w_tri.shape[2], BLOCK_SIZE)): - i_end = min(i+BLOCK_SIZE, w_tri.shape[1]) - j_end = min(j+BLOCK_SIZE, w_tri.shape[2]) - block = w_tri[e, i:i_end, j:j_end] - m_abs = block.abs().max() - i_len = i_end - i - j_len = j_end - j - min_len = min(i_len, j_len) - signs = torch.randint(0, 2, (max(i_len, j_len),), device=w_tri.device) * 2 - 1 - block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs - if j_len > i_len: - block[i_len - 1, i_len:] = signs[min_len:] * m_abs - elif i_len > j_len: - block[j_len:, j_len - 1] = signs[min_len:] * m_abs - if w_ndim == 2: - w_tri = w_tri.squeeze(0) - - # matmul with rowmajor weight expects scale is separately - # constructed (not much additional memory needed). - _, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) - # reuse quantized value from colmajor - w_tri_rowmajor, w_scale_tri_rowmajor = downcast_to_mxfp(w_tri.mT.contiguous(), weight_dtype, axis=mx_axis) - w_ref = upcast_from_mxfp(w_tri_rowmajor, w_scale_tri_rowmajor, torch.bfloat16, axis=mx_axis).mT.contiguous() - w_tri = w_tri_rowmajor.data.mT - - def _pad_and_block(x: torch.Tensor) -> torch.Tensor: - x = torch.nn.functional.pad(x, (0, x.shape[-1] % BLOCK_SIZE), mode="replicate") - return x.view(*x.shape[:-1], x.shape[-1] // BLOCK_SIZE, BLOCK_SIZE) - - # check if generated scale is transpose-invariant as intended construction - # [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)] - w_scale_tri_blocked = _pad_and_block(w_scale_tri) - w_scale_tri_sampled = w_scale_tri_blocked[..., 0:1] - # [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)] - w_scale_tri_rowmajor_blocked = _pad_and_block(w_scale_tri_rowmajor) - w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked[..., 0:1] - assert torch.equal(w_scale_tri_sampled.expand_as(w_scale_tri_blocked), w_scale_tri_blocked) - assert torch.equal(w_scale_tri_rowmajor_sampled.expand_as(w_scale_tri_rowmajor_blocked), w_scale_tri_rowmajor_blocked) - assert torch.equal(w_scale_tri_sampled.squeeze(-1), w_scale_tri_rowmajor_sampled.squeeze(-1).mT) - - precision_opt.weight_scale = w_scale_tri + a_dtype = DType(act_dtype_str) + b_dtype = DType(weight_dtype_str) + c_dtype = DType(act_dtype_str) + + # --- create conditionals --- + do_bias = inner_expt_opt is None + do_gather = do_gather and mode != "batched" + do_scatter = do_scatter and mode != "batched" + + # --- create inputs --- + a, a_scales, a_ragged_metadata = make_random_tensor( + shape=(m, k), + n_slices = n_slices, + dtype = a_dtype, + device = device, + ragged_dim = None if mode != "ragged" else 1 if expt_is_inner else 0, + mxfp_dim = -1 if a_dtype.has_mx_scale else None, + transpose = a_transpose, + ragged_padding = inner_expt_opt is not None and "pad_a" in inner_expt_opt, + squeeze_batch_dim = mode == "plain", + ) + b, b_scale_tri, b_ragged_metadata = make_random_tensor( + shape=(k, n), + n_slices = n_slices, + dtype = b_dtype, + device = device, + ragged_dim = None if mode != "ragged" or inner_expt_opt is None else 0, + mxfp_dim = -2 if b_dtype.has_mx_scale else None, + transpose = b_transpose, + ragged_padding = inner_expt_opt is not None and "pad_b" in inner_expt_opt, + squeeze_batch_dim = mode == "plain", + hbm_swizzling = hbm_swizzling, + is_mx_rowmajor = not colmajor_mxfp_weight, + ) + gather_indx = None if not do_gather else torch.randint(0, max(m, 1), (m, ), dtype=torch.int32, device=device) + scatter_indx = None if not do_scatter else torch.randperm(m, dtype=torch.int32, device=device) + bias = None if not do_bias else torch.randn(b.shape[:-2] + b.shape[-1:], dtype=torch.float32, device=device) + gammas = None if not do_gamma else 2**torch.randint(-5, 0, (m, ), dtype=torch.float32, device=device) + + # --- create fused activation --- + fused_activation = None + if swiglu_opts is not None: + fused_activation = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2), swiglu_opts) + + # --- initialize output --- + c_shape = (n_slices,) if mode == "batched" or inner_expt_opt is not None else tuple() # batch dim + c_shape += (scatter_indx.shape[0] if do_scatter else a.shape[-2],) # row dim + c_shape += (b.shape[-1] // (1 if fused_activation is None else fused_activation.specs.reduction_n) ,) # col dim + c = torch.empty(c_shape, dtype=c_dtype.torch_dtype, device=device) + if c_transpose: + c = c.mT.contiguous().mT + + # --- create precision config --- + wrap_list = lambda vals: torch.tensor(vals, dtype=torch.float32, device=device) + flex_a = InFlexData(c_dtype.torch_dtype, wrap_list([1.25])) if c_dtype.has_global_scale else InFlexData() + flex_b = InFlexData(b_dtype.torch_dtype, wrap_list([1.25])) if b_dtype.has_global_scale else InFlexData() + flex_c = OutFlexData(c_dtype.torch_dtype, wrap_list([4.00]), wrap_list([0]), None) if c_dtype.has_global_scale else OutFlexData() + precision_opt = PrecisionConfig( + flex_ctx=FlexCtx(flex_a, flex_b, flex_c), + acc_scale=2.0 if c_dtype.has_global_scale or b_dtype.has_global_scale else 1.0, + out_dtype=c_dtype.torch_dtype, + a_mx_scale=a_scales, + b_mx_scale=b_scale_tri, + ) + + # --- create epilogue --- epilogue = None - if act_mxfp8: - x_tri, x_mx_scales_tri = downcast_to_mxfp(x_tri, act_dtype, axis=-1) - x_ref = upcast_from_mxfp(x_tri, x_mx_scales_tri, torch.bfloat16, axis=-1) - is_input_batched = x_tri.ndim == 3 - y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape - n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0] - y_shape = (y_shape[0], n_rows, w_tri_orig.shape[-1]) - if sindx is None or mode == "batched": - if not is_input_batched: - y_shape = (y_shape[1], y_shape[2]) - else: - y_shape = (n_rows // rdata.n_expts_act, y_shape[-1]) - y_scale_shape = y_shape[:-1] + (triton.cdiv(y_shape[-1], MXFP_BLOCK_SIZE),) - y_scale = torch.empty(y_scale_shape, dtype=torch.uint8, device=x_tri.device) - precision_opt = replace(precision_opt, act_scale=x_mx_scales_tri, out_scale=y_scale) - epilogue = Epilogue(quantize_mxfp8_spec, tuple(), tuple(), effective_itemsize=6.0) - else: - y_scale = None - - if mode == "batched": - rdata, gindx, sindx = None, None, None - flex = precision_opt.flex_ctx + if c_dtype.has_mx_scale: + c_scale_shape = c_shape[:-1] + (triton.cdiv(c_shape[-1], MXFP_BLOCK_SIZE),) + c_scale = torch.empty(c_scale_shape, dtype=torch.uint8, device=a.device) + precision_opt.c_mx_scale = c_scale + epilogue_spec = FnSpecs(FnName.QUANTIZE_MXFP8.name, quantize_mxfp8_fn, (), ()) + epilogue = Epilogue(epilogue_spec, tuple(), tuple(), effective_itemsize=6.0) - if expt_is_inner: - inner_routing_data = InnerRoutingData( - base=rdata, block_k=padding_block_k, - x_is_padded="pad_x" in inner_expt_opt, - w_is_padded="pad_w" in inner_expt_opt, - ) - rdata = None - else: - inner_routing_data = None - - # triton + + # --- triton implementation --- try: - tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, - gammas=gs1_ref, epilogue=epilogue, y=y_tri_in, - inner_routing_data=inner_routing_data) + tri_y = matmul(a, b, bias, + a_ragged_metadata, b_ragged_metadata, + gather_indx, scatter_indx, precision_opt, + gammas=gammas, epilogue=epilogue, c=c, + fused_activation=fused_activation) + if c_dtype.has_global_scale: + tri_y_scale = precision_opt.flex_ctx.out_data.actual_scale.clone() except (opt_flags.InapplicableConstraint, NotImplementedError) as e: pytest.skip(f"inapplicable opt_flags constraint {e}") - if y_tri_in is not None: - assert tri_y.data_ptr() == y_tri_in.data_ptr() - assert tri_y.shape == y_tri_in.shape - assert tri_y.stride() == y_tri_in.stride() - # If split_k > 1, then the intermediate tensor is fp32. - sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1 - sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1 - y_scale = flex.out_data.expected_scale if act_is_float8 else 1 - - def round_x(x, idx): - return x.to(act_dtype).to(torch.float32) if sep_gather else x - - round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y - ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, # - rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref, - inner_routing_data=inner_routing_data) - - def scale(val, scal): - if scal is None: - return val - elif scal.numel() == 1: - return val / scal - else: - assert val.ndim == 3 - return val / scal[:, None, None] - - if act_mxfp8: - tri_y = upcast_from_mxfp(tri_y, precision_opt.out_scale, target_dtype=torch.bfloat16, axis=-1).to(ref_y.dtype) - ref_y_quant, ref_y_scale = downcast_to_mxfp_torch(ref_y, act_dtype, axis=-1) - ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1) - maxtol = 4e-1 - rmstol = 4e-2 - elif weight_mxfp4: - if act_is_float8: - maxtol = 8e-2 - else: - maxtol = 3e-2 - rmstol = None - else: - maxtol = None - rmstol = None - assert_close(scale(ref_y, flex.out_data.expected_scale), tri_y, maxtol=maxtol, rmstol=rmstol) - - if act_is_float8: - tri_y_scale = flex.out_data.actual_scale.clone() - ref_y_scale = compute_actual_scale(ref_y, tri_y.dtype, tri_y_scale.numel() > 1) + # --- torch implementation --- + ref_y = matmul_torch(a, b, bias, # + a_ragged_metadata, b_ragged_metadata, + gather_indx, scatter_indx, precision_opt, + gammas=gammas) + if swiglu_opts is not None: + ref_y = swiglu(ref_y, alpha=swiglu_opts[0], precision_config=SwiGLUPrecisionConfig(swiglu_opts[1])) + if c_dtype.has_global_scale: + ref_y_scale = precision_opt.flex_ctx.out_data.actual_scale.clone() + + # --- check results --- + if c_dtype.has_mx_scale: + tri_y = upcast_from_mxfp(tri_y, precision_opt.c_mx_scale, target_dtype=torch.bfloat16, axis=-1).to(ref_y.dtype) + ref_y = upcast_from_mxfp_torch(*downcast_to_mxfp_torch(ref_y, c_dtype.torch_dtype, axis=-1), target_dtype=ref_y.dtype, axis=-1) + maxtol, rmstol = None, None + if c_dtype.has_mx_scale: + maxtol, rmstol = 4e-1, 4e-2 + elif b_dtype.is_mxfloat4: + maxtol, rmstol = 3e-2, None + assert_close(ref_y, tri_y, maxtol=maxtol, rmstol=rmstol) + if c_dtype.has_global_scale: assert torch.all((ref_y_scale - tri_y_scale).abs() < 1e-10), \ f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}" -# Test that we don't use unsupported block sizes. -@pytest.mark.parametrize("m", [8, 16, 32, 64, 128]) -@pytest.mark.parametrize("n", [8, 16, 32, 64, 128]) -@pytest.mark.parametrize("k", [8, 16, 32, 64, 128]) -def test_small_batch_matmul(m, n, k): - if is_hip(): - pytest.skip("Not fully tested on AMD") - - if m * n * k > 16384: - pytest.skip() - - BATCH_SIZE = 10000 - - def _make_tensor(shape, dtype, trans): - if trans: - shape = (shape[0], shape[2], shape[1]) - t = alloc_rand(shape, "cuda", dtype) - return t.transpose(1, 2) if trans else t - - for x_transpose, w_transpose, bias, dtype in itertools.product( - (False, True), - (False, True), - (False, True), - (torch.float16, torch.bfloat16, torch.float8_e5m2), - ): - if ( - torch.cuda.get_device_capability()[0] < 10 - and dtype is torch.float8_e5m2 - and (not w_transpose) - ): - continue # Not supported - - x = _make_tensor((BATCH_SIZE, m, k), dtype, x_transpose) - w = _make_tensor((BATCH_SIZE, k, n), dtype, w_transpose) - bias = _make_tensor((BATCH_SIZE, n), torch.float32, False) if bias else None - tri_y = matmul_ogs(x, w, bias) - - # ref_y = matmul_ogs_torch(x.float(), w.float(), bias) - - # This is faster than matmul_ogs_torch. - ref_y = torch.bmm(x.float(), w.float()) - if bias is not None: - ref_y += bias[:, None, :] - - assert_close( - ref_y, - tri_y, - maxtol=4e-1 if dtype is torch.float8_e5m2 else None, - rmstol=4e-2 if dtype is torch.float8_e5m2 else None, - ) - - def test_set_idle_sms(): if not is_cuda(): pytest.skip("Only supported on CUDA") - from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags + from triton_kernels.matmul_details.opt_flags import make_opt_flags num_idle_sms = 24 - matmul_ogs_set_idle_sms(num_idle_sms) + matmul_set_idle_sms(num_idle_sms) flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \ 1, 1024, 1024, 1024, None, True, False, 1, False, False, None) assert flags.idle_sms == num_idle_sms - - -@pytest.mark.parametrize("m, n, k, mode", [ - (1200, 704, 608, "ragged"), - (800, 800, 400, "batched"), -]) -@pytest.mark.parametrize("split_k", [1, 2]) -@pytest.mark.parametrize("do_gather, do_scatter", [ - (False, False), - (True, False), - (False, True), - (True, True), -]) -@pytest.mark.parametrize("is_persistent, epilogue_subtile", [ - (False, None), - (True, 1), - (True, 4), -]) -@pytest.mark.parametrize("swiglu_alpha, swiglu_limit", [ - (1.1, 1.4), - (1.0, 1.2), - (0.7, 1.0), -]) -def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, is_persistent, epilogue_subtile, - swiglu_alpha, swiglu_limit, device, opt_flags_scope): - torch.manual_seed(0) - constraints = { - "is_persistent": is_persistent, - "epilogue_subtile": epilogue_subtile, - "split_k": split_k, - } - n_expts_tot, n_expts_act = 1, 1 - opt_flags.update_opt_flags_constraints(constraints) - - weight_dtype, act_dtype = torch.float16, torch.float16 - if mode == "ragged": - m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter, - device=device) - else: - rdata = gindx = sindx = None - - precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, mode, n_expts_tot, device=device) - x, w, bias, _, _ = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode, - act_dtype, weight_dtype, False, requires_grad=False, device=device) - - if mode == "batched": - rdata, gindx, sindx = None, None, None - - try: - a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha, - precision_config=SwiGLUPrecisionConfig(swiglu_limit)) - b = matmul_ogs( - x, w, bias, rdata, gindx, sindx, precision_opt, - fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2), - (swiglu_alpha, swiglu_limit))) - except opt_flags.InapplicableConstraint: - pytest.skip("inapplicable constraint") - - assert_close(a, b) - - -@pytest.mark.parametrize("m, n, k", [ - (320, 2**19, 0), - (4096, 4096, 0), -]) -@pytest.mark.parametrize("view_x_as_zero_cols", [False, True]) -def test_zero_reduction_dim(m, n, k, view_x_as_zero_cols): - torch.manual_seed(0) - - if view_x_as_zero_cols: - x = torch.randn(m, m, device="cuda", dtype=torch.bfloat16) - x = x[:0, :].transpose(-1, -2) - else: - x = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) - w = torch.randn(k, n, device="cuda", dtype=torch.bfloat16) - bias = torch.randn(n, device="cuda", dtype=torch.float32) - - try: - tri_y = matmul_ogs(x, w, bias) - except opt_flags.InapplicableConstraint: - pytest.skip("inapplicable constraint") - ref_y = matmul_ogs_torch(x, w, bias, round_x=lambda x, idx: x, round_y=lambda y: y) - - assert_close(ref_y, tri_y) diff --git a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py index 3ce12f8e7d..f48408daea 100644 --- a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py +++ b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py @@ -5,15 +5,15 @@ import torch -import triton_kernels.matmul_ogs_details.opt_flags as opt_flags +import triton_kernels.matmul_details.opt_flags as opt_flags class _DummyPrecisionConfig: def __init__(self): - self.weight_scale = None + self.b_mx_scale = None self.max_num_imprecise_acc = None - self.act_scale = None - self.out_scale = None + self.a_mx_scale = None + self.c_mx_scale = None self.enforce_bitwise_invariance = False diff --git a/python/triton_kernels/tests/test_reduce.py b/python/triton_kernels/tests/test_reduce.py index 9f64ffb20f..bc0ffa86c6 100644 --- a/python/triton_kernels/tests/test_reduce.py +++ b/python/triton_kernels/tests/test_reduce.py @@ -65,7 +65,7 @@ def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn): pytest.skip("float8 not supported on CUDA < 9.0") torch.manual_seed(0) device = "cuda" - x = torch.randn((B, M, N), device=device, dtype=torch.float32) + x = torch.randn((B, M, N), device=device, dtype=torch.float32, requires_grad=True) x_mscale, x_flex = None, None y_flex_tri, y_flex_ref = None, None if is_mx := dtype_str.startswith("mx"): @@ -91,16 +91,25 @@ def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn): postprocess_fn_ref = lambda x: (x + 10).reshape([x.shape[0], x.shape[1] // 2, 2]).sum(dim=2) else: postprocess_fn_tri = postprocess_fn_ref = None - y_tri, y_tri_mxscale = reduce(x, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_tri, + # run forward pass + x_tri = x.clone().detach().requires_grad_(True) + x_ref = x.clone().detach().requires_grad_(True) + y_tri, y_tri_mxscale = reduce(x_tri, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_tri, postprocess_fn1=postprocess_fn_tri) - y_ref, y_ref_mxscale = reduce_torch(x, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_ref, + y_ref, y_ref_mxscale = reduce_torch(x_ref, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_ref, postprocess_fn1=postprocess_fn_ref) if is_mx: y_ref = upcast_from_mxfp_torch(y_ref, y_ref_mxscale, torch.float16, axis=-1) y_tri = upcast_from_mxfp_torch(y_tri, y_tri_mxscale, torch.float16, axis=-1) + assert torch.allclose(y_tri.float(), y_ref.float(), atol=1e-3, rtol=1e-3) if is_flex: torch.allclose(y_flex_tri.actual_scale, y_flex_ref.actual_scale, atol=1e-3, rtol=1e-3) - assert torch.allclose(y_tri.float(), y_ref.float(), atol=1e-3, rtol=1e-3) + run_bwd = postprocess_fn is None and "float8" not in dtype_str + if run_bwd: + dy = torch.randn_like(y_tri) + y_tri.backward(dy) + y_ref.backward(dy) + assert torch.allclose(x_tri.grad.float(), x_ref.grad.float(), atol=1e-3, rtol=1e-3) def bench_reduce(B: int = 4, M: int = 4096, N: int = 4096, *, dim: int = 0, dtype: torch.dtype = torch.float16, diff --git a/python/triton_kernels/triton_kernels/distributed.py b/python/triton_kernels/triton_kernels/distributed.py index 8102337a97..e27fb2a275 100644 --- a/python/triton_kernels/triton_kernels/distributed.py +++ b/python/triton_kernels/triton_kernels/distributed.py @@ -135,7 +135,7 @@ def _initialize( self._is_initialized = True - def initialize_matmul_ogs( + def initialize_matmul( self, n_tokens_global: int, d_input: int, diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul.py similarity index 50% rename from python/triton_kernels/triton_kernels/matmul_ogs.py rename to python/triton_kernels/triton_kernels/matmul.py index 24a18253d9..11088b64b1 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul.py @@ -1,6 +1,6 @@ # isort: off # fmt: off -from dataclasses import dataclass, field +from dataclasses import dataclass import itertools import torch import triton @@ -11,57 +11,18 @@ from triton_kernels.numerics import InFlexData, OutFlexData from triton_kernels.target_info import is_cuda # details -from .matmul_ogs_details._matmul_ogs import _matmul_ogs -from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn +from .matmul_details._matmul import _matmul +from .matmul_details._p_matmul import _p_matmul, get_per_device_per_stream_alloc_fn from .numerics_details.mxfp import MXFP_BLOCK_SIZE from .tensor_details.layout_details.strided import StridedLayout -from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints +from .matmul_details.opt_flags import make_opt_flags, update_opt_flags_constraints from .specialize import FnSpecs, SpecializationModule, ClosureArg from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor, RaggedTensorMetadata from .reduce import reduce from .reduce import PostprocessFn as ReducePostprocessFn +from .tensor_details.ragged_tensor import ragged_metadata_fields -@dataclass -class GatherIndx: - """ - Indices for an operation that performs: - Y = X[src_idx, :] - """ - # array such that `dst_idx[src_idx] = arange(0, N)` - src_indx: torch.Tensor - dst_indx: torch.Tensor - - -@dataclass -class ScatterIndx: - """ - Indices for an operation that performs: - Y[dst_idx, :] = X - """ - # array such that `dst_idx[src_idx] = arange(0, N)` - src_indx: torch.Tensor - dst_indx: torch.Tensor - -@dataclass -class RoutingData: - gate_scal: torch.Tensor = field() - expt_hist: torch.Tensor = field() - n_expts_tot: int = field() - n_expts_act: int = field() - expt_data: RaggedTensorMetadata = None - - # Used to make perf annotation cleaner: when we use expert sharding, we can - # use this to tell the "expected" number of local tokens per expert, because - # the actual number can vary per each input. - expected_tokens_per_expt: int = field(default=None) - - def n_blocks(self, n_rows, block_m): - if n_rows <= self.n_expts_tot: - return n_rows - else: - return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1 - @dataclass(frozen=True) class FusedActivation: specs: FnSpecs = FnSpecs.default() @@ -86,8 +47,8 @@ class FusedComm: reduce_rank: int = 0 n_reduce_shards: int = 1 -specializations = SpecializationModule("matmul_ogs", - kernels=[("_matmul_ogs", _matmul_ogs), ("_p_matmul_ogs", _p_matmul_ogs)], +specializations = SpecializationModule("matmul", + kernels=[("_matmul", _matmul), ("_p_matmul", _p_matmul)], closure_args={ "epilogue": ClosureArg("EPILOGUE_FN", "epilogue_fn_args"), # "activation": ClosureArg("ACTIVATION_FN", "activation_fn_args"), # @@ -110,78 +71,6 @@ def should_upcast_indices(*args): return any(tensor is not None and can_overflow_int32(tensor) for tensor in args) -# This supports computing dw for ragged matmul. Note the correspondence: -# fwd pass: y = matmul_ogs(x, w, ...) -# bwd pass (dw): dw = matmul_ogs(x.T, dy, ...) -# -# Thus, "our" x, w, and y (as seen by matmul_ogs) correspond to x.T, dy, dw, respectively. -# To avoid confusion, now we'll stick to "x, w, y" terminology. -# -# Assume that y.shape == (N_EXPTS, M, N), x.shape == (M, K), w.shape = (K_W, N). -# -# To make things feasible, we require that x and w satisfy the following condition: -# (1) We don't support gather/scatter indices: in x, all columns for expt #0 are grouped at -# the leftmost part, followed by expt #1, and so on. Ditto for w (top to bottom). -# (2) At least one of x and w are padded: each expert uses a multiple of block_k columns -# (or rows), and unused values are filled with zero. -# (3) No inf or nan are allowed in x or w (except for the final padding - see below). -# This is because we use "multiplying by padded zero region" in lieu of masking. -# (4) The number of actually used columns/rows equals self.base.expt_hist.sum() and may be -# less than K or K_W. In this case, the final "unused" values can be left uninitialized. -# However, if x or w is unpadded, the first block_k columns/rows of the unused part must -# not contain nan or inf. -# -# For example, assume N_EXPTS == 5, block_k == 32, and expt_hist == [60, 33, 0, 32, 25]. -# -# if unpadded if padded -# ----------- --------- -# x: expt #0: x[:, :60] x[:, :60] -# x[:, 60:64] - zero padded -# expt #1: x[:, 60:93] x[:, 64:97] -# x[:, 97:128] - zero padded -# expt #3: x[:, 93:125] x[:, 128:160] -# expt #4: x[:, 125:150] x[:, 160:185] -# x[:, 185:192] - zero padded -# x[:, 150:min(182, K)] - must not contain inf/nan -# -# x[:, 182:] x[:, 192:] - unused (may contain garbage, including inf/nan) -# -# w is the same, except that rows columns are flipped. -@dataclass -class InnerRoutingData: - base: RoutingData | None = None - block_k: int | None = None - x_is_padded: bool = False - w_is_padded: bool = False - - # Return value contains: ExptHist, ExptOffs, ExptTileOffs, ExptData, - # EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED, ExptHistMax - @staticmethod - def make_kernel_args(data, block_m): - if isinstance(data, RoutingData): - expt_data, block = data.expt_data, block_m - args = (False, False, False, None) - elif isinstance(data, InnerRoutingData): - expt_data, block = data.base.expt_data, data.block_k - args = ( - True, data.x_is_padded, data.w_is_padded, expt_data.slice_sizes.max() - ) - elif data is None: - expt_data = None - else: - assert None - - if expt_data is None: - return (None, None, None, None, False, False, False, None) - - return ( - expt_data.slice_sizes, - expt_data.slice_offs, - expt_data.block_offs(block), - expt_data.block_schedule(block), - ) + args - - # --------------------- # Numerics # --------------------- @@ -203,9 +92,9 @@ class PrecisionConfig: acc_scale: int = 1.0 flexpoint_saturate_inf: bool = False report_quantization_err_fn: callable = None - act_scale: Tensor | None = None - weight_scale: Tensor| None = None - out_scale: Tensor | None = None + a_mx_scale: Tensor | None = None + b_mx_scale: Tensor| None = None + c_mx_scale: Tensor | None = None out_dtype: torch.dtype = None enforce_bitwise_invariance: bool = False @@ -213,7 +102,8 @@ class PrecisionConfig: # TODO: merge in opt_flags def get_swap_xw(precision_config, opt_flags): if target_info.cuda_capability_geq(10, 0): - return precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent + return precision_config.b_mx_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent + return False # --------------------- @@ -227,7 +117,7 @@ class MatmulAllocation: scratchpads: dict[str, tuple] def init_allocation(x, w, precision_config, fused_activation, - routing_data, gather_indx, scatter_indx, inner_routing_data, + gather_indx, scatter_indx, batch_dim, n_reduce_shards, opt_flags): # ---- output ------ N = w.shape[-1] @@ -235,30 +125,23 @@ def init_allocation(x, w, precision_config, fused_activation, M = x.shape[-2] # if the activations are gathered, then M is number of gather indices if gather_indx is not None: - M = gather_indx.src_indx.shape[0] + M = gather_indx.shape[0] if scatter_indx is not None: - M = scatter_indx.src_indx.shape[0] - if scatter_indx is None: - y_rows = M - else: - y_rows = M // routing_data.n_expts_act + M = scatter_indx.shape[0] + y_rows = M y_rows *= n_reduce_shards - if inner_routing_data is not None: - batch_dim = inner_routing_data.base.n_expts_tot - else: - batch_dim = x.shape[0] if x.ndim == 3 else 1 out_shape = (batch_dim, y_rows, N // fused_activation.specs.reduction_n) out_dtype = precision_config.out_dtype or x.dtype output = (out_shape, out_dtype) # ---- scratchpad -----# scratchpad = dict() N_scratch = N // fused_activation.specs.reduction_n if opt_flags.split_k == 1 else N - if opt_flags.split_k > 1 or (scatter_indx is not None and (not is_cuda() or routing_data.n_expts_act > 1)): + if opt_flags.split_k > 1: scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype scratchpad["matmul"] = ((opt_flags.split_k, batch_dim, M, N_scratch), scratch_out_dtype) - if "matmul" in scratchpad and precision_config.out_scale is not None: + if "matmul" in scratchpad and precision_config.c_mx_scale is not None: assert batch_dim == 1, "batch_dim > 1 not supported yet" - scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N_scratch, MXFP_BLOCK_SIZE)), torch.uint8) + scratchpad["mx_c_mx_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N_scratch, MXFP_BLOCK_SIZE)), torch.uint8) return MatmulAllocation(x.device, output, scratchpad) def apply_allocation(allocation: MatmulAllocation, output): @@ -279,7 +162,7 @@ def apply_allocation(allocation: MatmulAllocation, output): # ----------------------------------------------------------------------------- # Canonicalize # ----------------------------------------------------------------------------- -# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used +# the `matmul` kernel can operate on 2D or 3D inputs depending on the mode being used # we can canonicalize storages to make the implementation more uniform def _canonicalize_storage(storage, out_ndim, flex_data): @@ -304,26 +187,26 @@ def _canonicalize_storage(storage, out_ndim, flex_data): # Triton Implementation # ----------------------------------------------------------------------------- -def matmul_ogs_set_idle_sms(num_idle_sms): +def matmul_set_idle_sms(num_idle_sms): """ persistent kernels will leave `num_idle_sms` idle """ update_opt_flags_constraints({"idle_sms": num_idle_sms}) -def matmul_ogs(x, w, bias, - routing_data: RoutingData | None = None, - gather_indx: GatherIndx | None = None, - scatter_indx: ScatterIndx | None = None, +def matmul(a, b, bias, + a_ragged_metadata: RaggedTensorMetadata | None = None, + b_ragged_metadata: RaggedTensorMetadata | None = None, + gather_indx: torch.Tensor | None = None, + scatter_indx: torch.Tensor | None = None, precision_config: PrecisionConfig | None = None, betas: torch.Tensor | None = None, gammas: torch.Tensor | None = None, out_alpha: float | None = None, - y: torch.Tensor | None = None, + c: torch.Tensor | None = None, fused_comm: FusedComm | None = None, fused_activation: FusedActivation | None = None, epilogue: Epilogue | None = None, - y_acc_in: torch.Tensor | None = None, - inner_routing_data: InnerRoutingData | None = None, + c_acc_in: torch.Tensor | None = None, ): """ Y[:, :] = 0. @@ -336,22 +219,17 @@ def matmul_ogs(x, w, bias, The output buffer for fused comm should be pre-allocated and passed in via fused_comm.out_handles, which contains ipc handles to the output tensors, each with shape (n_rows * n_reduce_shards, n_cols). """ - is_input_batched = x.ndim == 3 + is_input_batched = a.ndim == 3 if is_input_batched: assert gather_indx is None, "gather not supported in batched mode" assert scatter_indx is None, "scatter not supported in batched mode" - assert routing_data is None, "routing not supported in batched mode" - assert inner_routing_data is None, "routing not supported in batched mode" + assert b_ragged_metadata is None, "w cannot be ragged in batched mode" + assert a_ragged_metadata is None, "x cannot be ragged in batched mode" assert fused_comm is None, "fused comm is not supported in batched mode" - assert w.ndim == 3 and w.shape[0] == x.shape[0] - if inner_routing_data is not None: - assert routing_data is None + assert b.ndim == 3 and b.shape[0] == a.shape[0] + if b_ragged_metadata is not None: assert gather_indx is None assert scatter_indx is None - routing_data = RoutingData( - None, None, inner_routing_data.base.n_expts_tot, 1, - expected_tokens_per_expt=inner_routing_data.base.expected_tokens_per_expt, - ) # canonicalize inputs if precision_config is None: precision_config = PrecisionConfig() @@ -359,100 +237,106 @@ def matmul_ogs(x, w, bias, fused_activation = FusedActivation(FnSpecs.default(), tuple()) if epilogue is None: epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False) - if routing_data is None: - routing_data = RoutingData(None, None, max(1, w.shape[0]), 1) + n_slices = max(1, b.shape[0]) if a_ragged_metadata is None else a_ragged_metadata.n_slices # unpack scales - w_scale = precision_config.weight_scale - w_has_mx = w_scale is not None - is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8 - if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10" - if not isinstance(w, Tensor): + b_scale = precision_config.b_mx_scale + b_has_mx = b_scale is not None + is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(b.dtype) == 8 + if is_hopper_fp8: assert b.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10" + if not isinstance(b, Tensor): # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real - dtype = FP4 if w.dtype == torch.uint8 else w.dtype - w = wrap_torch_tensor(w, dtype=dtype) - if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)): - assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)" - if w_scale is not None and not isinstance(w_scale, Tensor): - w_scale = Tensor(w_scale) - if w_scale is not None: - w_scale.storage.data = w_scale.data.view(torch.uint8) - w_scale.dtype = torch.uint8 - x_scale = precision_config.act_scale - x_has_mx = x_scale is not None - if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp" - if x_scale is not None and not isinstance(x_scale, Tensor): - x_scale = Tensor(x_scale) - if not isinstance(x, Tensor): - x = Tensor(x, dtype=x.dtype) - x_transpose = x.stride(-1) != 1 + dtype = FP4 if b.dtype == torch.uint8 else b.dtype + b = wrap_torch_tensor(b, dtype=dtype) + if b_has_mx and (torch.cuda.get_device_capability()[0] < 10 or b.storage.layout is not None and not isinstance(b.storage.layout, StridedLayout)): + assert b.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)" + if b_scale is not None and not isinstance(b_scale, Tensor): + b_scale = Tensor(b_scale) + if b_scale is not None: + b_scale.storage.data = b_scale.data.view(torch.uint8) + b_scale.dtype = torch.uint8 + a_scale = precision_config.a_mx_scale + a_has_mx = a_scale is not None + if a_has_mx: assert a.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp" + if a_scale is not None and not isinstance(a_scale, Tensor): + a_scale = Tensor(a_scale) + if not isinstance(a, Tensor): + a = Tensor(a, dtype=a.dtype) + a_transpose = a.stride(-1) != 1 # determine shapes has_gather = gather_indx is not None has_scatter = scatter_indx is not None - is_ragged = routing_data.expt_hist is not None - M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0] - if inner_routing_data is not None: - batch_size = inner_routing_data.base.n_expts_tot + is_a_ragged = a_ragged_metadata is not None + is_b_ragged = b_ragged_metadata is not None + is_c_ragged = is_a_ragged and b_ragged_metadata is None + ragged_dimension = "K" if is_b_ragged else "M" if is_a_ragged else None + M = a.shape[-2] if gather_indx is None else gather_indx.shape[0] + if ragged_dimension == "K": + batch_size = b_ragged_metadata.n_slices + elif ragged_dimension is None and b.ndim == 3: + batch_size = b.shape[0] else: - batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1 - if y_acc_in is not None: - y_acc_is_y = y_acc_in.data_ptr() == y.data_ptr() and y_acc_in.stride() == y.stride() + batch_size = 1 + if c_acc_in is not None: + c_acc_is_c = c_acc_in.data_ptr() == c.data_ptr() and c_acc_in.stride() == c.stride() else: - y_acc_is_y = None - K = x.shape[-1] - K_W, N = w.shape[-2:] - if x.ndim == 3 and w.ndim == 3: - assert x.shape[0] == w.shape[0] + c_acc_is_c = None + K = a.shape[-1] + K_W, N = b.shape[-2:] + if a.ndim == 3 and b.ndim == 3: + assert a.shape[0] == b.shape[0] # compute optimization flags - out_dtype = precision_config.out_dtype or x.dtype + out_dtype = precision_config.out_dtype or a.dtype can_use_tma = ( - x.numel() > 0 and x.storage.is_tma_compliant() and - w.numel() > 0 and w.storage.is_tma_compliant() and - (w_scale is None or w_scale.storage.is_tma_compliant()) and - (not is_ragged or x.stride(-1) == 1) and + a.numel() > 0 and a.storage.is_tma_compliant() and + b.numel() > 0 and b.storage.is_tma_compliant() and + (b_scale is None or b_scale.storage.is_tma_compliant()) and + (ragged_dimension != "M" or a.stride(-1) == 1) and # Currently we don't support tma if y is column major; may revisit later if this becomes an issue. - (y is None or y.stride(-1) == 1) and - (y_acc_in is None or y_acc_is_y) and - # If we use inner_routing_data, w must be either padded or row major, otherwise we get - # unaligned access. - (inner_routing_data is None or w.stride(-1) == 1 or inner_routing_data.w_is_padded) + (c is None or c.stride(-1) == 1) and + (c_acc_in is None or c_acc_is_c) and + # if ragged dimension is K, w must be either padded or row major to ensure alignment + (ragged_dimension != "K" or b.stride(-1) == 1 or b_ragged_metadata.slice_sizes_divisibility is not None) ) - if w_scale is not None and isinstance(w_scale.storage.layout, StridedLayout) and w_scale.storage.data.stride()[-1] != 1: - # In this case, we need to transpose w_scale. Then the reduction dim + if b_scale is not None and isinstance(b_scale.storage.layout, StridedLayout) and b_scale.storage.data.stride()[-1] != 1: + # In this case, we need to transpose b_scale. Then the reduction dim # becomes the last dim that will be divided by 32. This to be a multiple # of 16 to be TMA-compliant requires block_k to be a multiple of 512, # which is too big. can_use_tma = False has_gather_tma = has_gather and target_info.has_tma_gather() # hopper w/ mxfp4 doesn't support TMA - can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4) - can_use_split_k = scatter_indx is None and not x_has_mx and not w_has_mx - opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config, - batch_size, M, N, w.shape[-2], routing_data, + can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(b.dtype) != 4) + can_use_split_k = scatter_indx is None and not a_has_mx and not b_has_mx and ragged_dimension != "K" + block_k = None + if ragged_dimension == "K": + block_k = a_ragged_metadata.slice_sizes_divisibility or b_ragged_metadata.slice_sizes_divisibility + opt_flags = make_opt_flags(out_dtype, a.dtype, b.dtype, precision_config, + batch_size, M, N, b.shape[-2], a_ragged_metadata, can_use_tma, can_use_split_k, epilogue.effective_itemsize, - x_transpose, y_acc_in is not None, - inner_routing_data.block_k if inner_routing_data is not None else None, + a_transpose, c_acc_in is not None, + block_k = block_k, ) - if inner_routing_data is not None: - assert opt_flags.block_k == inner_routing_data.block_k - assert opt_flags.split_k == 1 - batch_size = inner_routing_data.base.n_expts_tot - # For unpadded (row major) x, we cannot use tma because memory access isn't aligned. - x_has_tma = opt_flags.is_persistent and (x.stride(-1) != 1 or inner_routing_data.x_is_padded) + # there seems to be a bug on A100 + # pytest -vs test_matmul.py::test_op[False-False-False-False-pad_b-16-768-512-1024-ragged-float16-float16-10-1-False-None-False-False-False-True-None] + if ragged_dimension == "K" and torch.cuda.get_device_capability()[0] < 9: + opt_flags.num_stages = 1 + if ragged_dimension == "K": + a_has_tma = opt_flags.is_persistent and (a.stride(-1) != 1 or (a_ragged_metadata.slice_sizes_divisibility is not None)) # If TMA is used, limit is handled automatically, so we can pretend K is "even". # (For unpadded input, we assume that the first block_k unused rows are zero-filled, # when routing_data.expt_hist.sum() is less than K or K_W.) if opt_flags.is_persistent: - even_K = x_has_tma or inner_routing_data.x_is_padded + even_K = a_has_tma or (a_ragged_metadata.slice_sizes_divisibility is not None) else: - even_K = inner_routing_data.x_is_padded and inner_routing_data.w_is_padded + even_K = a_ragged_metadata.slice_sizes_divisibility is not None and b_ragged_metadata.slice_sizes_divisibility is not None else: - batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1 + batch_size = b.shape[0] if a_ragged_metadata is None and b.ndim == 3 else 1 assert K == K_W - x_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather) + a_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather) even_K = (K % opt_flags.block_k == 0) - if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp(): + if b_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp(): raise NotImplementedError("Must use non-persistent kernel for simulated MXFP") - if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp(): + if b_scale is not None and b_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp(): raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP") # fused activation matmul_fused_activation = fused_activation @@ -460,9 +344,11 @@ def matmul_ogs(x, w, bias, if opt_flags.split_k > 1: matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation # allocate output/scratchpad memory - allocation = init_allocation(x, w, precision_config, fused_activation, - routing_data, gather_indx, scatter_indx, inner_routing_data, fused_comm.n_reduce_shards if fused_comm is not None else 1, opt_flags) - memory = apply_allocation(allocation, y) + allocation = init_allocation(a, b, precision_config, fused_activation, + gather_indx, scatter_indx, batch_size, + fused_comm.n_reduce_shards if fused_comm is not None else 1, + opt_flags) + memory = apply_allocation(allocation, c) # early exit if batch_size * M * N == 0: ret = memory["output"].squeeze(0) @@ -471,134 +357,125 @@ def matmul_ogs(x, w, bias, return ret # TMA descriptors require a global memory allocation if opt_flags.is_persistent: - triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device)) + triton.set_allocator(get_per_device_per_stream_alloc_fn(a.device)) # Intermediate tensors and postprocess kernels for each situation has_scratchpad = "matmul" in memory["scratchpad"] # Canonical output tensor (matmul scratchpad if present, otherwise final output tensor) out_matmul = memory["scratchpad"].get("matmul", memory["output"]) out_matmul_flex = OutFlexData() if out_matmul.dtype == torch.float32 else precision_config.flex_ctx.out_data # Unified mx-scale pointer; when scratchpad exists, prefer its mx buffer - out_matmul_scale = precision_config.out_scale + out_matmul_scale = precision_config.c_mx_scale if out_matmul_scale is not None: out_matmul_scale = out_matmul_scale.data.view(torch.uint8) - if has_scratchpad and "mx_out_scale" in memory["scratchpad"]: - out_matmul_scale = memory["scratchpad"]["mx_out_scale"] + if has_scratchpad and "mx_c_mx_scale" in memory["scratchpad"]: + out_matmul_scale = memory["scratchpad"]["mx_c_mx_scale"] out_matmul_has_mx = out_matmul_scale is not None and out_matmul.element_size() == 1 # matrix multiplication flex = precision_config.flex_ctx bias_stride = None if bias is None else bias.stride(0) - num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0] # moe metadata - block_m = opt_flags.block_m - expt_data_args = InnerRoutingData.make_kernel_args(inner_routing_data or routing_data, block_m) + expt_data_w = tuple([None] * 6) if ragged_dimension != "K" else ragged_metadata_fields(b_ragged_metadata, opt_flags.block_k) + expt_data_x = tuple([None] * 6) if ragged_dimension is None else ragged_metadata_fields(a_ragged_metadata, opt_flags.block_m if ragged_dimension == "M" else opt_flags.block_k) # spmd grid grid_m = triton.cdiv(M, opt_flags.block_m) - if routing_data.expt_data is not None: - grid_m = routing_data.n_blocks(M, opt_flags.block_m) + if ragged_dimension == "M": + grid_m = a_ragged_metadata.n_blocks(a_ragged_metadata.n_slices, M, opt_flags.block_m) grid_n = triton.cdiv(N, opt_flags.block_n) max_grid = batch_size * grid_m * grid_n * opt_flags.split_k grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid # canonicalize storage has_scatter_tma = scatter_indx is not None and target_info.has_tma_gather() - y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if has_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:])) - x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data) - w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data) - y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data) + c = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if has_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:])) + a_storage = _canonicalize_storage(a.storage, 2 if has_gather_tma else 3, flex.lhs_data) + b_storage = _canonicalize_storage(b.storage, 3, flex.rhs_data) + c_storage = _canonicalize_storage(c.storage, 2 if has_scatter_tma else 3, flex.out_data) # create tma descriptor for x - if y_acc_in is not None: - assert opt_flags.split_k == 1, "y_acc_in + split_k is not supported." - assert scatter_indx is None, "y_acc_in + scatter is not supported." - if y_acc_in.ndim == 2: - y_acc_in = y_acc_in.unsqueeze(0) - assert y_acc_in.shape == out_matmul.shape[-3:] - y_acc_strides = y_acc_in.stride() + if c_acc_in is not None: + assert opt_flags.split_k == 1, "c_acc_in + split_k is not supported." + assert scatter_indx is None, "c_acc_in + scatter is not supported." + if c_acc_in.ndim == 2: + c_acc_in = c_acc_in.unsqueeze(0) + assert c_acc_in.shape == out_matmul.shape[-3:] + c_acc_strides = c_acc_in.stride() else: - y_acc_strides = (None, None, None) + c_acc_strides = (None, None, None) - x_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k] - x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense" - x_tensor_or_tma = x_storage.make_tma(x_tma_block_size, x_tma_mode) if x_has_tma else x_storage.data + a_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k] + a_tma_mode = None if not a_has_tma else "ragged" if ragged_dimension == "M" and not has_gather_tma else "dense" + a_tensor_or_tma = a_storage.make_tma(a_tma_block_size, a_tma_mode) if a_has_tma else a_storage.data # create tma descriptor for y - y_has_tma = ( + c_has_tma = ( opt_flags.is_persistent and (scatter_indx is None or has_scatter_tma) - and (y_acc_in is None or y_acc_is_y) + and (c_acc_in is None or c_acc_is_c) ) block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.specs.reduction_n - y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n] - y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense" - y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data + c_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n] + c_tma_mode = None if not c_has_tma else "ragged" if is_c_ragged and not has_scatter_tma else "dense" + c_tensor_or_tma = c_storage.make_tma(c_tma_block_size, c_tma_mode) if c_has_tma else c_storage.data # create tma descriptor for w - w_has_tma = opt_flags.is_persistent - w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data + b_has_tma = opt_flags.is_persistent + b_tensor_or_tma = b_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if b_has_tma else b_storage.data # create tma descriptor for w_scale - w_scale_has_tma = opt_flags.is_persistent and w_scale is not None - # When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed - # (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose - # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs. - # w_transpose = w_storage.data.stride()[-1] != 1 - w_transpose = w_storage.data.stride()[-2] == 1 - if w_scale_has_tma: - w_scale_storage = w_scale.storage + b_scale_has_tma = opt_flags.is_persistent and b_scale is not None + b_transpose = b_storage.data.stride()[-2] == 1 + if b_scale_has_tma: scale_block_k = opt_flags.block_k // int(MXFP_BLOCK_SIZE) - # cancel out the transpose done inside make_tma since - # BlackwellMXScaleLayout.swizzle_block_shape expects block_shape[1] is - # the reduction dimension. - w_scale_tma_block_size = [opt_flags.block_n, scale_block_k] if w_transpose and w_scale.storage.layout.name == "BLACKWELL_SCALE" else [scale_block_k, opt_flags.block_n] - if isinstance(w_scale.storage.layout, StridedLayout): - assert w_scale_storage.data.stride()[-1] == 1, "w_scale should be contiguous with StridedLayout" - w_scale_storage = _canonicalize_storage(w_scale.storage, 3, None) - w_scale_tma_block_size = [1] + w_scale_tma_block_size - w_scale_tensor_or_tma = w_scale_storage.make_tma(w_scale_tma_block_size, "dense", is_scale=True) + b_scale_storage = b_scale.storage + b_scale_tma_block_size = [opt_flags.block_n, scale_block_k] if b_transpose else [scale_block_k, opt_flags.block_n] + if isinstance(b_scale.storage.layout, StridedLayout): + b_scale_storage = _canonicalize_storage(b_scale.storage, 3, None) + b_scale_tma_block_size = [1] + b_scale_tma_block_size + b_scale_tensor_or_tma = b_scale_storage.make_tma(b_scale_tma_block_size, "dense", is_scale=True) else: - w_scale_tensor_or_tma = w_scale + b_scale_tensor_or_tma = b_scale # canonicalize strides - x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride()) - x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None) - x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides - w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None) - w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides + a_strides = [0]*(3 - a_storage.data.ndim) + list(a_storage.data.stride()) + a_scale_strides = a_scale.stride() if a_has_mx else (None, None, None) + a_scale_strides = (0, ) * (3 - len(a_scale_strides)) + a_scale_strides + b_scale_strides = b_scale.stride() if b_has_mx and not b_scale_has_tma else (None, None, None) + b_scale_strides = (0, ) * (3 - len(b_scale_strides)) + b_scale_strides out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None) out_matmul_scale_strides = (0, ) * (4 - len(out_matmul_scale_strides)) + out_matmul_scale_strides # launch kernel kernels = specializations.get(epilogue=epilogue.specs, activation=matmul_fused_activation.specs) - if gather_indx is not None: - gather_src_indx = torch.div(gather_indx.src_indx, routing_data.n_expts_act, rounding_mode='trunc') + # When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed + # (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose + # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs. + # w_transpose = w_storage.data.stride()[-1] != 1 fused_comm_kwargs = { "pYPtrs": fused_comm.out_handles, "ScatterShardIndx": fused_comm.scatter_shard_indx, "reduce_rank": fused_comm.reduce_rank, "n_reduce_shards": fused_comm.n_reduce_shards, } if fused_comm is not None else {} - # if routing_data.n_expts_act > 1: - # y_storage.data.view(torch.uint8).zero_() - (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)]( - y_tensor_or_tma, y_storage.data, *out_matmul.stride(), + n_valid_slices = b_tensor_or_tma.shape[0] if ragged_dimension == "M" else n_slices + (kernels._p_matmul if opt_flags.is_persistent else kernels._matmul)[(grid,)]( + c_tensor_or_tma, c_storage.data, *out_matmul.stride(), *((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex), *out_matmul_scale_strides[-4:], - x_tensor_or_tma, x_storage.data, *x_strides, x_transpose, + a_tensor_or_tma, a_storage.data, *a_strides, a_transpose, flex.lhs_data.scale, - None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides, - w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose, + None if a_scale is None else a_scale.data.view(torch.uint8), *a_scale_strides, + b_tensor_or_tma, b_storage.data, *b_storage.data.stride(), b_transpose, flex.rhs_data.scale, - w_scale_tensor_or_tma, *w_scale_strides, - flex.acc_data.reinterpret(y_acc_in), *y_acc_strides, - flex.acc_data.scale, y_acc_is_y, + b_scale_tensor_or_tma, *b_scale_strides, + flex.acc_data.reinterpret(c_acc_in), *c_acc_strides, + flex.acc_data.scale, c_acc_is_c, bias, bias_stride, - x.shape[-2] if routing_data.expt_hist is None else None, + None if ragged_dimension == "M" else a.shape[-2], N, K, K_W, betas, gammas, - None if gather_indx is None else gather_src_indx, - None if gather_indx is None else gather_indx.dst_indx, # Only for launch_metadata - None if scatter_indx is None else scatter_indx.src_indx, - num_indx, - None if scatter_indx is None else scatter_indx.dst_indx, - None if scatter_indx is None else scatter_indx.dst_indx.shape[0], - *expt_data_args, + gather_indx, + scatter_indx, + None if scatter_indx is None else scatter_indx.shape[0], + ragged_dimension, + *expt_data_x, + *expt_data_w, batch_size, grid_m, grid_n, out_alpha, *matmul_fused_activation.fn_args, matmul_fused_activation.specs.reduction_n, *epilogue.fn_arg_values_matmul, - routing_data.n_expts_tot, + n_valid_slices, precision_config.max_num_imprecise_acc, precision_config.allow_tf32, precision_config.flexpoint_saturate_inf, @@ -609,21 +486,20 @@ def matmul_ogs(x, w, bias, opt_flags.block_n, opt_flags.block_k, opt_flags.group_m, - INIT_OUTPUT_TO_ZERO=routing_data.n_expts_act == 1, XCD_SWIZZLE=opt_flags.xcd_swizzle, - SWIZZLE_MX_VALUE=w.storage.layout.name, - SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name, + SWIZZLE_MX_VALUE=b.storage.layout.name, + SWIZZLE_MX_SCALE=None if b_scale is None else b_scale.storage.layout.name, EPILOGUE_SUBTILE=opt_flags.epilogue_subtile, SPLIT_K=opt_flags.split_k, EVEN_K=even_K, W_CACHE_MODIFIER=opt_flags.w_cache_modifier, - TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt, + TOKENS_PER_EXPT_FOR_ANNOTATION=None if a_ragged_metadata is None else a_ragged_metadata.expected_slice_size, num_warps=opt_flags.num_warps, num_stages=opt_flags.num_stages, arch=opt_flags.arch, - UPCAST_INDICES=should_upcast_indices(x, w, out_matmul), - X_TMA_MODE=x_tma_mode, - Y_TMA_MODE=y_tma_mode, + UPCAST_INDICES=should_upcast_indices(a, b, out_matmul), + X_TMA_MODE=a_tma_mode, + Y_TMA_MODE=c_tma_mode, SWAP_XW=get_swap_xw(precision_config, opt_flags), IS_EPILOGUE_QUANT_MXFP8=epilogue.specs.name == FnName.QUANTIZE_MXFP8.name, NUM_SMS = grid if opt_flags.is_persistent else 0, @@ -635,8 +511,8 @@ def matmul_ogs(x, w, bias, if opt_flags.split_k > 1: assert not out_matmul_has_mx postprocess_fn1 = ReducePostprocessFn(specs=reduce_fused_activation.specs, fn_args=reduce_fused_activation.fn_args) - postprocess_fn2 = None if has_scatter else ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize) - y, y_mx_scale = reduce( + postprocess_fn2 = ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize) + c, y_mx_scale = reduce( x = out_matmul.view(out_matmul.shape[0], -1, out_matmul.shape[-1]), dim = 0, # output data/metadata @@ -644,114 +520,156 @@ def matmul_ogs(x, w, bias, y_dtype = memory["output"].dtype, y_flex = precision_config.flex_ctx.out_data, y_flex_saturate_inf = precision_config.flexpoint_saturate_inf, - y_has_mx = precision_config.out_scale is not None, + y_has_mx = precision_config.c_mx_scale is not None, # fused functions postprocess_fn1 = postprocess_fn1, postprocess_fn2 = postprocess_fn2, ) y_shape = out_matmul.shape[1:-1] + (out_matmul.shape[-1] // reduce_fused_activation.specs.reduction_n,) - out_matmul = y.view(*y_shape).unsqueeze(0) + out_final = c.view(*y_shape) if y_mx_scale is not None: out_final_mx_scale = y_mx_scale.view(out_matmul.shape[-2], triton.cdiv(out_matmul.shape[-1], 32)) - # TODO: change `matmul_ogs` semantics and move this to another op! - if scatter_indx is not None and (not is_cuda() or routing_data.n_expts_act > 1): # Matmul ogs kernel fuses scatter already, so only need for n_exps_act > 1. - mask = (scatter_indx.src_indx != -1).view(out_matmul.shape[-2]//routing_data.n_expts_act, routing_data.n_expts_act, 1) - out_matmul = out_matmul.view(out_matmul.shape[-2]//routing_data.n_expts_act, routing_data.n_expts_act, -1) - mask = mask.expand_as(out_matmul) - out_matmul_scale_shape = out_matmul.shape[:-1] + (triton.cdiv(out_matmul.shape[-1], 32),) - postprocess_fn = ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize) - x_flex = InFlexData(dtype=out_matmul_flex.dtype, scale=out_matmul_flex.expected_scale) - out_final, out_final_mx_scale = reduce(out_matmul, dim=1, postprocess_fn2=postprocess_fn, x_flex=x_flex, # - mask=mask, - y=memory["output"].squeeze(0).squeeze(0), - x_mxscale=out_matmul_scale.view(*out_matmul_scale_shape) if out_matmul_has_mx else None, - y_has_mx=precision_config.out_scale is not None, - y_flex=precision_config.flex_ctx.out_data, - y_flex_saturate_inf=precision_config.flexpoint_saturate_inf, - ) - out_final = out_final.unsqueeze(0) else: out_final = out_matmul.squeeze(0) + out_final_mx_scale = out_matmul_scale - if not (is_input_batched or inner_routing_data is not None): + if not (is_input_batched or b_ragged_metadata is not None): out_final = out_final.squeeze(0) if out_final_mx_scale is not None: - precision_config.out_scale = out_final_mx_scale + precision_config.c_mx_scale = out_final_mx_scale return out_final # ----------------------------------------------------------------------------- # Reference Implementation # ----------------------------------------------------------------------------- -def matmul_ogs_torch(x, w, bias, - routing_data: RoutingData = None, - gather_indx: GatherIndx = None, - scatter_indx: ScatterIndx = None, +def apply_precision(x_tri, w_tri, precision_config): + from .tensor import convert_layout + from .tensor_details import layout + from .numerics_details.mxfp import upcast_from_mxfp + + flex_ctx = precision_config.flex_ctx + + def apply(x, scale): + if scale is None: + return x.clone() + return x.float() * scale + + if precision_config.a_mx_scale is not None: + mx_axis = x_tri.storage.data.ndim -1 + x_tri = convert_layout(x_tri, layout.StridedLayout) + x_tri_scale = convert_layout(precision_config.a_mx_scale, layout.StridedLayout) + x_ref = upcast_from_mxfp(x_tri.storage.data, x_tri_scale.storage.data, torch.bfloat16, axis=mx_axis) + else: + x_ref = apply(x_tri, flex_ctx.lhs_data.scale) + + if precision_config.b_mx_scale is not None: + mx_axis = w_tri.storage.data.ndim - 2 + w_tri = convert_layout(w_tri, layout.StridedLayout) + w_tri_scale = convert_layout(precision_config.b_mx_scale, layout.StridedLayout) + w_ref = upcast_from_mxfp(w_tri.storage.data, w_tri_scale.storage.data, torch.bfloat16, axis=mx_axis) + else: + w_ref = apply(w_tri, flex_ctx.rhs_data.scale) + + return ( + x_ref, w_ref, + ) + + +def scale(val, scal): + if scal is None: + return val + elif scal.numel() == 1: + return val / scal + else: + assert val.ndim == 3 + return val / scal[:, None, None] + +def compute_actual_scale(x, dtype, per_batch_scale=False): + from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 + max_finite = { + torch.float8_e5m2: MAX_FINITE_FLOAT8E5, + torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV, + torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8, + }[dtype] + maxvals = x.abs().amax(dim=tuple(range(1, x.ndim))) if per_batch_scale else x.abs().max() + return maxvals / max_finite + + +def matmul_torch(a, b, bias, + a_ragged_metadata: RaggedTensorMetadata | None = None, + b_ragged_metadata: RaggedTensorMetadata | None = None, + gather_indx: torch.Tensor = None, + scatter_indx: torch.Tensor = None, precision_config: PrecisionConfig = None, betas = None, gammas = None, - inner_routing_data: InnerRoutingData | None = None, round_x = None, round_y = None, ): - if inner_routing_data is not None: - assert bias is None, "Not supported yet" - m, n = x.shape[-2], w.shape[-1] - block_k = inner_routing_data.block_k - n_expts_tot = inner_routing_data.base.n_expts_tot - out = torch.zeros((n_expts_tot, m, n), dtype=torch.float32, device=x.device) - start_x = start_w = 0 + a, b = apply_precision(a, b, precision_config) + + if b_ragged_metadata is not None: + n_expts_tot = b_ragged_metadata.slice_sizes.shape[0] + m, n = a.shape[-2], b.shape[-1] + out = torch.zeros((n_expts_tot, m, n), dtype=torch.float32, device=a.device) + x_slice_offs = a_ragged_metadata.slice_offs + w_slice_offs = b_ragged_metadata.slice_offs for expt in range(n_expts_tot): - k = inner_routing_data.base.expt_hist[expt].item() - if k > 0: - out[expt] = matmul_ogs_torch( - x[:, start_x:start_x+k], w[start_w:start_w+k, :], None, - None, None, None, None, betas, gammas, None, round_x, round_y - ) - padded_k = triton.cdiv(k, block_k) * block_k - start_x += padded_k if inner_routing_data.x_is_padded else k - start_w += padded_k if inner_routing_data.w_is_padded else k - return out - - is_input_batched = x.ndim == 3 - assert x.dtype.itemsize > 1 - assert w.dtype.itemsize > 1 + k = int(b_ragged_metadata.slice_sizes[expt].item()) + if k == 0: + continue + x_start = int(x_slice_offs[expt].item()) + w_start = int(w_slice_offs[expt].item()) + x_slice = a[:, x_start:x_start + k] + w_slice = b[w_start:w_start + k, :] + out_expt = matmul_torch( + x_slice, w_slice, None, None, + None, None, None, PrecisionConfig(), + betas, gammas, + round_x, round_y, + ) + out[expt] = out_expt.to(out.dtype) + actual_scale = precision_config.flex_ctx.out_data.actual_scale + if actual_scale is not None: + actual_scale.copy_(compute_actual_scale(out, precision_config.out_dtype)) + return scale(out, precision_config.flex_ctx.out_data.expected_scale) + + is_input_batched = a.ndim == 3 + assert a.dtype.itemsize > 1 + assert b.dtype.itemsize > 1 if is_input_batched: assert gather_indx is None, "gather not supported in batched mode" assert scatter_indx is None, "scatter not supported in batched mode" - assert routing_data is None, "routing not supported in batched mode" - assert w.ndim == 3 and w.shape[0] == x.shape[0] + assert b.ndim == 3 and b.shape[0] == a.shape[0] if round_x is None: round_x = lambda x, idx: x if round_y is None: round_y = lambda x: x if bias is not None and bias.ndim == 1: bias = bias.view(1, *bias.shape) - if w.ndim == 2: - w = w.view(1, *w.shape) - if x.ndim == 2: - x = x.view(1, *x.shape) - if routing_data is None: - routing_data = RoutingData(None, None, w.shape[0], 1) - n_expts_act = routing_data.n_expts_act + if b.ndim == 2: + b = b.view(1, *b.shape) + if a.ndim == 2: + a = a.view(1, *a.shape) # memory offsets - if routing_data.n_expts_tot > 1 and not is_input_batched: - sizes = routing_data.expt_hist + if a_ragged_metadata is not None and not is_input_batched: + sizes = a_ragged_metadata.slice_sizes off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32) off[1:] = torch.cumsum(sizes, 0) offs = list(itertools.pairwise(off)) else: - offs = [[0, x.shape[1]] for _ in range(w.shape[0])] + offs = [[0, a.shape[1]] for _ in range(b.shape[0])] # compute - n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0] - y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype) + n_rows = a.shape[1] if gather_indx is None else gather_indx.shape[0] + y = torch.zeros((a.shape[0], n_rows, b.shape[-1]), device=a.device, dtype=a.dtype) for i, (lo, hi) in enumerate(offs): if gather_indx is None: - idx = torch.arange(lo, hi, device=x.device) + idx = torch.arange(lo, hi, device=a.device) else: - idx = gather_indx.src_indx[lo:hi] // n_expts_act + idx = gather_indx[lo:hi] batch = i if is_input_batched else 0 - out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(), - w[i].float()) + out = torch.matmul(round_x(a[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(), + b[i].float()) if bias is not None: out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None] if gammas is not None: @@ -760,15 +678,15 @@ def matmul_ogs_torch(x, w, bias, if not is_input_batched: y = y.view(y.shape[1], y.shape[2]) if scatter_indx is None: - return y - # accumulate output from all experts - n_rows = y.shape[0] // n_expts_act - out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device) - for i, (lo, hi) in enumerate(offs): - dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act - msk = dst_idx != -1 - out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float() - return out + out = y + else: + out = torch.zeros((scatter_indx.shape[0], y.shape[-1]), dtype=y.dtype, device=a.device) + msk = scatter_indx != -1 + out[scatter_indx[msk], :] = y[msk, :] + actual_scale = precision_config.flex_ctx.out_data.actual_scale + if actual_scale is not None: + actual_scale.copy_(compute_actual_scale(out, precision_config.out_dtype)) + return scale(out, precision_config.flex_ctx.out_data.expected_scale) def post_matmul_comm_torch(y: torch.Tensor, rank: int, n_reduce_shards: int, diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py b/python/triton_kernels/triton_kernels/matmul_details/_common.py similarity index 63% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py rename to python/triton_kernels/triton_kernels/matmul_details/_common.py index 951be4c053..4c04b1aafc 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py +++ b/python/triton_kernels/triton_kernels/matmul_details/_common.py @@ -53,100 +53,80 @@ def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr): @triton.jit -def _load_tile_attrs( - tile_id, - num_tiles, - unpadded_m, - grid_n, - M, - K, - ExptData, - ExptHist, - ExptOffs, - ExptTileOffs, - EXPT_IS_INNER: tl.constexpr, - X_IS_PADDED: tl.constexpr, - W_IS_PADDED: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, - PACKED_BLOCK_K_W: tl.constexpr, - SPLIT_K: tl.constexpr, - GROUP_M: tl.constexpr, - XCD_SWIZZLE: tl.constexpr, - SWIZZLE_MX_VALUE: tl.constexpr, -): - # unpack and swizzle program ids - pid_emnk = tile_id +def compute_pids(block_id, grid_m, grid_n, num_blocks, XCD_SWIZZLE: tl.constexpr, GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr): + pid_zmnk = block_id if XCD_SWIZZLE != 1: - pid_emnk = xcd_swizzle(pid_emnk, num_tiles, XCD_SWIZZLE) - pid_e = pid_emnk // (unpadded_m * grid_n * SPLIT_K) - pid_mnk = pid_emnk % (unpadded_m * grid_n * SPLIT_K) + pid_zmnk = xcd_swizzle(pid_zmnk, num_blocks, XCD_SWIZZLE) + pid_z = pid_zmnk // (grid_m * grid_n * SPLIT_K) + pid_mnk = pid_zmnk % (grid_m * grid_n * SPLIT_K) if SPLIT_K > 1: pid_k = pid_mnk % SPLIT_K pid_mn = pid_mnk // SPLIT_K else: pid_k: tl.constexpr = 0 pid_mn = pid_mnk - pid_m, pid_n = swizzle2d(pid_mn, unpadded_m, grid_n, GROUP_M) + pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M) + return pid_z, pid_m, pid_n, pid_k - # unpack expert data - if EXPT_IS_INNER: - # pid_e indicates expert ID: experts are laid sequentially along the K dimension + +@triton.jit +def compute_offsets( + pid_z, + pid_m, + pid_k, + XBlockSchedule, + XSliceOffs, + X_SLICE_SIZE_DIVISIBILITY: tl.constexpr, + WBlockSchedule, + WSliceOffs, + W_SLICE_SIZE_DIVISIBILITY: tl.constexpr, + RAGGED_DIMENSION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K_X: tl.constexpr, + PACKED_BLOCK_K_W: tl.constexpr, + SPLIT_K: tl.constexpr, +): + if RAGGED_DIMENSION == "K": + # pid_z indicates slice ID: experts are laid sequentially along the K dimension # (i.e., we have columns for expert 0, and then expert 1, and then so on). # pid_k is meaningless (always zero). - tl.static_assert(X_IS_PADDED or W_IS_PADDED, "At least one input must be padded!") + tl.static_assert(X_SLICE_SIZE_DIVISIBILITY is not None or \ + W_SLICE_SIZE_DIVISIBILITY is not None, + "At least one input must be padded!") tl.static_assert(SPLIT_K == 1, "Not supported yet") - tl.static_assert(M is not None) - expt_id, pid_z, pid_z_out, start_m, block_id, eM = 0, 0, pid_e, 0, pid_m, M - k_tiles = tl.cdiv(tl.load(ExptHist + pid_e), BLOCK_K) - padded_start_off_raw = tl.load(ExptTileOffs + pid_e) - padded_start_off = padded_start_off_raw * BLOCK_K - unpadded_start_off = tl.load(ExptOffs + pid_e) - off_k_x = padded_start_off if X_IS_PADDED else unpadded_start_off - # K_W is only used for non-TMA kernel (W bound is handled by TMA on TMA kernel). - if W_IS_PADDED: - off_k_w = padded_start_off_raw * PACKED_BLOCK_K_W - K_W = tl.load(ExptTileOffs + pid_e + 1) * PACKED_BLOCK_K_W + off_x_k = tl.load(XSliceOffs + pid_z) + off_w_k = tl.load(WSliceOffs + pid_z) + if PACKED_BLOCK_K_W >= BLOCK_K_X: + off_w_k = off_w_k * (PACKED_BLOCK_K_W // BLOCK_K_X) else: - off_k_w = unpadded_start_off - K_W = tl.load(ExptOffs + pid_e + 1) + off_w_k = off_w_k // (BLOCK_K_X // PACKED_BLOCK_K_W) + off_x_m = BLOCK_M * pid_m + off_w_z, off_x_z, off_x_slice = 0, 0, 0 + off_y_z = pid_z + elif RAGGED_DIMENSION == "M": + off_x_k = pid_k * BLOCK_K_X + off_w_k = pid_k * PACKED_BLOCK_K_W + block_schedule = tl.load(XBlockSchedule + pid_m) + off_w_z = block_schedule & 0x0000FFFF + block_id = block_schedule >> 16 + off_x_slice = tl.load(XSliceOffs + off_w_z) + off_x_z, off_y_z = 0, 0 + off_x_m = BLOCK_M * block_id else: - off_k_x = pid_k * BLOCK_K - off_k_w = pid_k * PACKED_BLOCK_K_W - if PACKED_BLOCK_K_W >= BLOCK_K: - K_W = K * (PACKED_BLOCK_K_W // BLOCK_K) - else: - K_W = K // (BLOCK_K // PACKED_BLOCK_K_W) - if SWIZZLE_MX_VALUE == "HOPPER_VALUE": - K_W = tl.cdiv(K_W, 128) * 128 - k_tiles = tl.cdiv(K - off_k_x, BLOCK_K * SPLIT_K) - if ExptData is None: - tl.static_assert(M is not None) - expt_id, pid_z, pid_z_out, start_m, block_id, eM = pid_e, pid_e, pid_e, 0, pid_m, M - else: - tl.static_assert(M is None) - expt_data = tl.load(ExptData + pid_m) - expt_id = expt_data & 0x0000FFFF - block_id = expt_data >> 16 - eM = tl.load(ExptHist + expt_id) - start_m = tl.load(ExptOffs + expt_id) - pid_z, pid_z_out = 0, 0 - - off_m = BLOCK_M * block_id - + tl.static_assert(RAGGED_DIMENSION is None) + off_x_k = pid_k * BLOCK_K_X + off_w_k = pid_k * PACKED_BLOCK_K_W + off_w_z, off_x_z, off_y_z, off_x_slice = pid_z, pid_z, pid_z, 0 + off_x_m = BLOCK_M * pid_m return ( - expt_id, - pid_z, - pid_z_out, - start_m, - eM, - off_m, - pid_n, - k_tiles, - pid_k, - off_k_x, - off_k_w, - K_W, + off_w_z, + off_x_z, + off_y_z, + off_x_slice, + off_x_m, + off_x_k, + off_w_k, ) @@ -192,40 +172,36 @@ def matmul_launch_metadata(grid, kernel, args): ret = dict() M, N, K = args["M"], args["N"], args["K"] Y, X, W = args["YPtr"], args["XPtr"], args["WPtr"] - tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION") - hist = args["ExptHist"] + expected_slice_sizes = args.get("X_EXPECTED_SLICE_SIZE") + slice_sizes = args["XSliceSizes"] batch_size = args.get("batch_size", 1) - expt_is_inner = args["EXPT_IS_INNER"] - if hist is not None: + if slice_sizes is not None: # If annotation is given, use that to generate name for profiling. - if tokens_per_expt is not None: - n_rows = f"{tokens_per_expt}*" + if expected_slice_sizes is not None: + n_rows = f"{expected_slice_sizes}*" elif launch_metadata_allow_sync(): - n_rows = int(hist.float().mean()) + n_rows = int(slice_sizes.float().mean()) else: n_rows = "unknown" - if launch_metadata_allow_sync(): - n_tokens = float(hist.sum()) - n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum() - elif tokens_per_expt is not None: - n_tokens = tokens_per_expt * args["N_EXPTS_TOT"] + n_tokens = float(slice_sizes.sum()) + n_w_bytes = (W.numel() * W.element_size() // slice_sizes.numel()) * (slice_sizes > 0).sum() + elif expected_slice_sizes is not None: + n_tokens = expected_slice_sizes * args["N_SLICES"] # This may not be totally correct (e.g., we might not be using all experts) # but it's better than nothing. n_w_bytes = W.numel() * W.element_size() else: n_tokens = None n_w_bytes = 0 - # If annotation is given, use that to generate name for profiling. - tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION") - n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows + n_rows = f"{expected_slice_sizes}*" if expected_slice_sizes is not None else n_rows else: n_tokens = None n_w_bytes = W.numel() * W.element_size() - if expt_is_inner: + if args["RAGGED_DIMENSION"] == "K": K = None if n_tokens is None else int(n_tokens) - repr = lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}" + repr = lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(slice_sizes)}({s}) = {n_rows}" nbits = X.dtype.itemsize * 8 batch_repr = "" if batch_size > 1: @@ -235,20 +211,21 @@ def matmul_launch_metadata(grid, kernel, args): if ep_subtile is not None and ep_subtile > 1: ret["name"] += f" ep/{ep_subtile}" - if hist is not None and n_tokens is None: + if slice_sizes is not None and n_tokens is None: return ret # Don't fill metadata because we can't compute them properly. fM = M if M is not None else n_tokens - ret[f"flops{nbits}"] = 2.0 * fM * N * K * (1 if expt_is_inner else batch_size) + Z = 1 if args["RAGGED_DIMENSION"] == "K" else batch_size + ret[f"flops{nbits}"] = 2.0 * fM * N * K * Z # sindx = args.get("WriteBackIndx", None) n_x_bytes = X.numel() * X.element_size() n_y_bytes = Y.numel() * Y.element_size() - if hist is not None: + if slice_sizes is not None: assert n_tokens is not None n_read_rows = n_tokens - if expt_is_inner: + if args["RAGGED_DIMENSION"] == "K": n_x_bytes = n_read_rows * X.shape[-2] * X.element_size() # Here, we're computing dW = X.T@dY, so "W" is actually dY and "Y" is actually dW. n_y_bytes = Y.numel() * Y.element_size() * (2 if args["OutAcc"] is not None else 1) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_details/_matmul.py similarity index 85% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py rename to python/triton_kernels/triton_kernels/matmul_details/_matmul.py index d3f26b9299..b165534afe 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_details/_matmul.py @@ -9,36 +9,19 @@ from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE from ._common import ( - _load_tile_attrs, + compute_offsets, get_scaled_dot_format_string, make_matmul_repr, matmul_launch_metadata, - swizzle2d, - xcd_swizzle, threadfence_system, + compute_pids, ) -@triton.jit -def _zero_masked_rows( - pid_m, pid_n, - Y, stride_y_m, stride_y_n, - N, - ScatterSrcIndx, num_idxs, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M) - offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N) - src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0) - YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n - mask_n = offs_n < N - mask = (src_idx == -1)[:, None] & mask_n[None, :] - tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask) - - -_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2]) +_matmul_repr = make_matmul_repr("_matmul", [0, 1, 2]) @triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"], - repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata) -def _matmul_ogs( + repr=_matmul_repr, launch_metadata=matmul_launch_metadata) +def _matmul( Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n, YExpectedScale, YActualScale, YChecksumScale, stride_y_mx_k, stride_y_mx_z, stride_y_mx_m, stride_y_mx_n, @@ -54,14 +37,11 @@ def _matmul_ogs( M, N, K, K_W, # shapes # expt data Betas, Gammas, - GatherIndx, GatherDstIndx, # GatherDstIndx is only used for launch metadata. - ScatterSrcIndx, num_idxs, + GatherIndx, WriteBackIndx, writeback_size, - ExptHist, ExptOffs, ExptTileOffs, ExptData, - EXPT_IS_INNER: tl.constexpr, - X_IS_PADDED: tl.constexpr, - W_IS_PADDED: tl.constexpr, - ExptHistMax, + RAGGED_DIMENSION: tl.constexpr, + XSliceSizes, XSliceOffs, XBlockOffs, XBlockSchedule, X_EXPECTED_SLICE_SIZE: tl.constexpr, X_SLICE_SIZES_DIVISIBILITY: tl.constexpr, + WSliceSizes, WSliceOffs, WBlockOffs, WBlockSchedule, W_EXPECTED_SLICE_SIZE: tl.constexpr, _W_SLICE_SIZES_DIVISIBILITY: tl.constexpr, # true grid size batch_size, grid_m, grid_n, # Out scale @@ -81,7 +61,6 @@ def _matmul_ogs( # optimization config BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr, - INIT_OUTPUT_TO_ZERO: tl.constexpr, # One of ["HOPPER", "BLACKWELL", None] SWIZZLE_MX_VALUE: tl.constexpr, # One of ["HOPPER", "BLACKWELL", None] @@ -123,12 +102,14 @@ def _matmul_ogs( tl.assume(grid_m >= 0) tl.assume(grid_n >= 0) + + w_type: tl.constexpr = W.dtype.element_ty is_x_microscaled: tl.constexpr = XMxScale is not None is_w_microscaled: tl.constexpr = WMxScale is not None + is_w_mxfp4: tl.constexpr = w_type == tl.uint8 and is_w_microscaled + MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE if is_w_microscaled: - w_type: tl.constexpr = W.dtype.element_ty - is_mxfp4: tl.constexpr = w_type == tl.uint8 tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5), "mx_weight_ptr must be uint8 or fp8") tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") @@ -137,7 +118,7 @@ def _matmul_ogs( # TODO: refactor if/else when triton front end improves if SWIZZLE_MX_VALUE == "HOPPER_VALUE": - tl.static_assert(is_mxfp4, "Only mxfp4 is supported for HOPPER swizzling") + tl.static_assert(is_w_mxfp4, "Only mxfp4 is supported for HOPPER swizzling") tl.static_assert(not is_x_microscaled) # We have pack 2 fp4 values in a byte but we divide the dimension by 2 # when swizzling @@ -146,7 +127,7 @@ def _matmul_ogs( W_N_DIVISOR: tl.constexpr = 4 else: # We have pack 2 fp4 values in a byte - W_K_DIVISOR: tl.constexpr = 2 if is_mxfp4 else 1 + W_K_DIVISOR: tl.constexpr = 2 if is_w_mxfp4 else 1 W_K_MULTIPLIER: tl.constexpr = 1 W_N_DIVISOR: tl.constexpr = 1 @@ -177,52 +158,71 @@ def _matmul_ogs( tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") is_out_microscaled: tl.constexpr = stride_y_mx_z is not None + if _W_SLICE_SIZES_DIVISIBILITY is None: + W_SLICE_SIZES_DIVISIBILITY: tl.constexpr = 1 + else: + if PACKED_BLOCK_K_W > BLOCK_K: + W_SLICE_SIZES_DIVISIBILITY: tl.constexpr = _W_SLICE_SIZES_DIVISIBILITY * (PACKED_BLOCK_K_W // BLOCK_K) + else: + W_SLICE_SIZES_DIVISIBILITY: tl.constexpr = _W_SLICE_SIZES_DIVISIBILITY // (BLOCK_K // PACKED_BLOCK_K_W) + OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N yN = N // ACTIVATION_REDUCTION_N pid = tl.program_id(0) - if ExptTileOffs is not None and (not EXPT_IS_INNER): - # Determine how much padding there is on the expert data. This allows us to - # know the true grid size and avoid processing padding tiles. - padding_m = grid_m - tl.load(ExptTileOffs + N_EXPTS_TOT) + if RAGGED_DIMENSION == "M": + padding_m = grid_m - tl.load(XBlockOffs + N_EXPTS_TOT) else: padding_m: tl.constexpr = 0 - HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32 unpadded_m = grid_m - padding_m tl.assume(unpadded_m >= 0) total_actual_tiles = batch_size * unpadded_m * grid_n * SPLIT_K - # set masked out rows to 0 - # We are tiling Y here, so the tiling is independent of matmul (where we - # tile X & W and scatter to different rows of Y). - # TODO: refactor (same code in _p_matmul_ogs) - if HAS_FUSED_SCATTER and INIT_OUTPUT_TO_ZERO: - tl.device_assert(batch_size == 1) - pid_mnk = pid - if XCD_SWIZZLE != 1: - pid_mnk = xcd_swizzle(pid_mnk, grid_m * grid_n * SPLIT_K, XCD_SWIZZLE) - pid_k = pid_mnk % SPLIT_K - pid_mn = pid_mnk // SPLIT_K - pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M) - _zero_masked_rows(pid_m, pid_n, - Y + pid_k.to(index_type) * stride_y_k, stride_y_m, stride_y_n, - yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N) - if padding_m > 0 and pid >= total_actual_tiles: return + pid_s, pid_m, pid_n, pid_k = compute_pids(pid, unpadded_m, grid_n, total_actual_tiles, XCD_SWIZZLE, GROUP_M, SPLIT_K) + loop_k = tl.multiple_of(tl.load(XSliceSizes + pid_s), X_SLICE_SIZES_DIVISIBILITY) if RAGGED_DIMENSION == "K" else K + ( expt_id, start_z, start_z_out, - start_m, eM, off_m, pid_n, - k_tiles, pid_k, off_k_x, off_k_w, K_W, - ) = _load_tile_attrs(pid, total_actual_tiles, unpadded_m, grid_n, - M, K, ExptData, ExptHist, ExptOffs, ExptTileOffs, - EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED, - BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K, - GROUP_M, XCD_SWIZZLE, SWIZZLE_MX_VALUE) + start_m, off_m, + off_k_x, off_k_w + ) = compute_offsets( + pid_s, pid_m, pid_k, + XBlockSchedule, XSliceOffs, X_SLICE_SIZES_DIVISIBILITY, + WBlockSchedule, WSliceOffs, W_SLICE_SIZES_DIVISIBILITY, + RAGGED_DIMENSION, + BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K + ) + if X_SLICE_SIZES_DIVISIBILITY is not None: + off_k_x = off_k_x // X_SLICE_SIZES_DIVISIBILITY * X_SLICE_SIZES_DIVISIBILITY + if W_SLICE_SIZES_DIVISIBILITY is not None: + off_k_w = off_k_w // W_SLICE_SIZES_DIVISIBILITY * W_SLICE_SIZES_DIVISIBILITY + + + if RAGGED_DIMENSION == "M": + eM = tl.multiple_of(tl.load(XSliceSizes + expt_id), X_SLICE_SIZES_DIVISIBILITY) + else: + eM = M + + + if RAGGED_DIMENSION == "K": + K_W = tl.multiple_of(tl.load(WSliceOffs + pid_s + 1), W_SLICE_SIZES_DIVISIBILITY) + if PACKED_BLOCK_K_W > BLOCK_K: + K_W = K_W * (PACKED_BLOCK_K_W // BLOCK_K) + else: + K_W = K_W // (BLOCK_K // PACKED_BLOCK_K_W) + K_X = tl.multiple_of(tl.load(XSliceOffs + pid_s + 1), X_SLICE_SIZES_DIVISIBILITY) + else: + K_W = K * (PACKED_BLOCK_K_W // BLOCK_K) if PACKED_BLOCK_K_W >= BLOCK_K else K // (BLOCK_K // PACKED_BLOCK_K_W) + K_X = K + + loop_k = tl.multiple_of(tl.load(XSliceSizes + pid_s), X_SLICE_SIZES_DIVISIBILITY) if RAGGED_DIMENSION == "K" else K - off_k_x + k_tiles = tl.cdiv(loop_k, BLOCK_K * SPLIT_K) # For split-k, advance to the output k slice if SPLIT_K > 1: @@ -246,6 +246,7 @@ def _matmul_ogs( offs_k = off_k_x + tl.arange(0, BLOCK_K) XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k + # TODO: refactor if/else when triton front end improves if is_w_microscaled: WMxScale += expt_id * stride_w_mx_e @@ -262,6 +263,7 @@ def _matmul_ogs( # TODO: support non W_TRANSPOSE with Hopper swizzling tl.static_assert(W_TRANSPOSE) n_warps: tl.constexpr = tl.extra.cuda.num_warps() + tl.static_assert(n_warps == 8) tl.static_assert(BLOCK_N % (2 * n_warps * 2 * 8) == 0) tl.static_assert(MX_SCALE_BLOCK_K % 2 == 0) PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32 @@ -281,7 +283,6 @@ def _matmul_ogs( offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N) # K dimension must be the last dimension for the scales - tl.static_assert(not EXPT_IS_INNER or W_IS_PADDED) offs_k_scale = off_k_w // PACKED_BLOCK_K_W * PACKED_MX_BLOCK + tl.arange(0, PACKED_MX_BLOCK) WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n else: @@ -309,27 +310,28 @@ def _matmul_ogs( WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n) # compute output acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - x_k_limit = K + BLOCK_K * SPLIT_K + x_k_limit = K_X + BLOCK_K * SPLIT_K w_k_limit = K_W + PACKED_BLOCK_K_W * SPLIT_K + for ki in range(k_tiles): x_k_limit -= BLOCK_K * SPLIT_K w_k_limit -= PACKED_BLOCK_K_W * SPLIT_K if EVEN_K: - mask_k = tl.full([BLOCK_K], True, dtype=tl.int1) + mask_k_x = tl.full([BLOCK_K], True, dtype=tl.int1) mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1) if is_w_microscaled and SWIZZLE_MX_SCALE is None: mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1) if is_x_microscaled: mask_x_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1) else: - mask_k = offs_k < x_k_limit + mask_k_x = offs_k < x_k_limit mask_k_w = offs_w_k < w_k_limit if is_w_microscaled and SWIZZLE_MX_SCALE is None: - mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < x_k_limit + mask_k_scale = offs_k_scale * MX_PACK_DIVISOR // 2 < w_k_limit if is_x_microscaled: mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < x_k_limit - x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0) + x = tl.load(XPtrs, mask=mask_k_x[None, :], other=0.0) w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER) if is_w_microscaled: x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype) @@ -348,6 +350,7 @@ def _matmul_ogs( elif SWIZZLE_MX_SCALE == "HOPPER_SCALE": # Handshake with the swizzling code num_warps: tl.constexpr = tl.extra.cuda.num_warps() + tl.static_assert(num_warps == 8) w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps) elif SWIZZLE_MX_SCALE == "CDNA4_SCALE": w_scales = unswizzle_mx_scale_cdna4(tl.load(WMxScalePtrs), BLOCK_N, MX_SCALE_BLOCK_K) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py similarity index 80% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py rename to python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py index 3a38b32cc1..ccb7a4a1a5 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py @@ -14,12 +14,12 @@ ) from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE from ._common import ( - _load_tile_attrs, + compute_offsets, get_scaled_dot_format_string, make_matmul_repr, matmul_launch_metadata, - swizzle2d, threadfence_system, + compute_pids, ) @@ -44,10 +44,10 @@ def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask): return (offs, mask) -_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2]) +_matmul_repr = make_matmul_repr("_p_matmul", [0, 1, 2]) @triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"], - repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata) -def _p_matmul_ogs( + repr=_matmul_repr, launch_metadata=matmul_launch_metadata) +def _p_matmul( Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n, YExpectedScale, YActualScale, YChecksumScale, stride_y_mx_k, stride_y_mx_z, stride_y_mx_m, stride_y_mx_n, @@ -63,14 +63,11 @@ def _p_matmul_ogs( M, N, K, K_W, # shapes # expt data Betas, Gammas, - GatherIndx, GatherDstIndx, # GatherDstIndx is only used for launch metadata. - ScatterSrcIndx, num_idxs, + GatherIndx, WriteBackIndx, writeback_size, - ExptHist, ExptOffs, ExptTileOffs, ExptData, - EXPT_IS_INNER: tl.constexpr, - X_IS_PADDED: tl.constexpr, - W_IS_PADDED: tl.constexpr, - ExptHistMax, + RAGGED_DIMENSION: tl.constexpr, + XSliceSizes, XSliceOffs, XBlockOffs, XBlockSchedule, X_EXPECTED_SLICE_SIZE: tl.constexpr, X_SLICE_SIZES_DIVISIBILITY: tl.constexpr, + WSliceSizes, WSliceOffs, WBlockOffs, WBlockSchedule, W_EXPECTED_SLICE_SIZE: tl.constexpr, W_SLICE_SIZES_DIVISIBILITY: tl.constexpr, # true grid size batch_size, grid_m, grid_n, # Out scale @@ -80,7 +77,7 @@ def _p_matmul_ogs( # epilogue transform EPILOGUE_FN: tl.constexpr, epilogue_fn_args, # MoE config - N_EXPTS_TOT: tl.constexpr, + N_SLICES: tl.constexpr, # precision config MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr, FLEXPOINT_SATURATE_INF: tl.constexpr, @@ -90,7 +87,6 @@ def _p_matmul_ogs( # optimization config BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr, - INIT_OUTPUT_TO_ZERO: tl.constexpr, # NYI: Must be None SWIZZLE_MX_VALUE: tl.constexpr, # One of ["BLACKWELL", None] @@ -142,12 +138,10 @@ def _p_matmul_ogs( tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") is_out_microscaled: tl.constexpr = stride_y_mx_z is not None - if ExptTileOffs is not None and (not EXPT_IS_INNER): - # Determine how much padding there is on the expert data. This allows us to - # know the true grid size and avoid processing padding tiles. - padding_m = grid_m - tl.load(ExptTileOffs + N_EXPTS_TOT) + if RAGGED_DIMENSION == "M": + useful_grid_m = tl.load(XBlockOffs + N_SLICES) else: - padding_m: tl.constexpr = 0 + useful_grid_m = grid_m index_type: tl.constexpr = tl.int64 @@ -157,12 +151,9 @@ def _p_matmul_ogs( USE_GATHER_TMA: tl.constexpr = HAS_GATHER and X_TMA_MODE == "dense" USE_SCATTER_TMA: tl.constexpr = HAS_SCATTER and Y_TMA_MODE == "dense" - if EXPT_IS_INNER: + if RAGGED_DIMENSION == "K": tl.static_assert((OutAcc is None) or Y_ACC_IS_Y, "Using differernt y_acc is not supported with TMA kernel.") - tl.static_assert( - not (HAS_SCATTER or USE_GATHER_TMA or USE_SCATTER_TMA), - "Cannot be used with EXPT_IS_INNER" - ) + tl.static_assert(not (HAS_SCATTER or USE_GATHER_TMA or USE_SCATTER_TMA), "Cannot be used with RAGGED_DIMENSION == 'K'") if EPILOGUE_SUBTILE is None: SUBTILE_FACTOR: tl.constexpr = 1 @@ -172,28 +163,7 @@ def _p_matmul_ogs( OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N yN = N // ACTIVATION_REDUCTION_N - # set masked out rows to 0 - if HAS_SCATTER and INIT_OUTPUT_TO_ZERO: - # Iterate with reversed pids so that later pids will get more tiles if the number of - # tiles isn't evenly divisible by the number of SMs. - # The main loop after this iterates in the forward direction such that earlier - # pids get more tiles if the number of tiles isn't evenly divisible. - # This helps balance the work across the SMs. - for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS): - pid_k = pid_mnk % SPLIT_K - pid_mn = pid_mnk // SPLIT_K - pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M) - - z = tl.zeros([BLOCK_M, BLOCK_N // ACTIVATION_REDUCTION_N], dtype=tl.float32) - offs_m = z.shape[0] * pid_m + tl.arange(0, z.shape[0]) - offs_n = z.shape[1] * pid_n + tl.arange(0, z.shape[1]) - src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0) - YPtrs = YPtr + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n - mask_n = offs_n < yN - mask = (src_idx == -1)[:, None] & mask_n[None, :] - tl.store(YPtrs + pid_k * stride_y_k, z, mask=mask) - - num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K + num_blocks = batch_size * useful_grid_m * grid_n * SPLIT_K # If true, do not share loop-carried variables between the prologue and the # epilogue to enable better pipelining with mmav5 @@ -211,55 +181,71 @@ def _p_matmul_ogs( DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_w_microscaled and BLOCK_M * BLOCK_N >= 128 * 256 - for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=True): - expt_id, start_z, start_z_out, start_m, eM, off_m, pid_n, k_tiles, pid_k, off_k_x0, off_k_w0, _ = _load_tile_attrs( - tile_id, num_tiles, grid_m - padding_m, grid_n, - M, K, ExptData, ExptHist, ExptOffs, ExptTileOffs, - EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED, - BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K, - GROUP_M, XCD_SWIZZLE, SWIZZLE_MX_VALUE) - off_n = BLOCK_N * pid_n + for block_id in tl.range(tl.program_id(0), num_blocks, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=True): - # Base pointers and offsets. - if X_TMA_MODE is None: - XBase = X + start_z.to(index_type) * stride_x_z - offs_x_k = (off_k_x0.to(index_type) + tl.arange(0, BLOCK_K))[None, :] * stride_x_k + pid_z, pid_m, pid_n, pid_k = compute_pids(block_id, useful_grid_m, grid_n, num_blocks, XCD_SWIZZLE, GROUP_M, SPLIT_K) + + # ------------------------------------------------------------ + # prologue + # ------------------------------------------------------------ + off_w_z, off_x_z, off_y_z, slice_off_m, off_m, off_k_x0, off_k_w0 = compute_offsets( + pid_z, pid_m, pid_k, + XBlockSchedule, XSliceOffs, X_SLICE_SIZES_DIVISIBILITY, + WBlockSchedule, WSliceOffs, W_SLICE_SIZES_DIVISIBILITY, + RAGGED_DIMENSION, + BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K + ) + # TODO: if RAGGED_DIMENSION == "M" + if RAGGED_DIMENSION == "M": + shape_m = tl.load(XSliceSizes + off_w_z) + else: + shape_m = M + + off_n = BLOCK_N * pid_n + + # ---- offset x ------ if USE_GATHER_TMA: offs_m = off_m + tl.arange(0, BLOCK_M) - mask_m = offs_m < eM - if ExptData is None: - offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m) + mask_m = offs_m < shape_m + if XBlockSchedule is None: + offs_x_m = tl.load(GatherIndx + slice_off_m.to(index_type) + offs_m, mask=mask_m) # Bump rows to account for the Z offset. - offs_x_m += start_z * (stride_x_z // stride_x_m) + offs_x_m += off_x_z * (stride_x_z // stride_x_m) offs_x_m = tl.where(mask_m, offs_x_m, -1) else: - offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m, other=-1) - elif X_TMA_MODE is None or is_x_microscaled: + offs_x_m = tl.load(GatherIndx + slice_off_m.to(index_type) + offs_m, mask=mask_m, other=-1) + if X_TMA_MODE is None: + XBase = X + off_x_z.to(index_type) * stride_x_z offs_m = off_m + tl.arange(0, BLOCK_M) - offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M) + offs_m = tl.max_contiguous(tl.multiple_of(offs_m % shape_m, BLOCK_M), BLOCK_M) # no needs to bounds-check here because `offs_m` wraps around M dim if GatherIndx is not None: tl.static_assert(HAS_GATHER) - offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) + offs_m = tl.load(GatherIndx + slice_off_m.to(index_type) + offs_m) offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m + offs_x_k = (off_k_x0.to(index_type) + tl.arange(0, BLOCK_K))[None, :] * stride_x_k if is_x_microscaled: - XMxScalePtrs = XMxScale + start_z.to(index_type) * stride_x_mx_z + offs_m = off_m + tl.arange(0, BLOCK_M) + XMxScalePtrs = XMxScale + off_x_z.to(index_type) * stride_x_mx_z if GatherIndx is None: - XMxScalePtrs += start_m * stride_x_mx_m + XMxScalePtrs += slice_off_m * stride_x_mx_m offs_k_scale = off_k_x0 // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K) XMxScalePtrs += (offs_x_m if USE_GATHER_TMA else offs_m).to(index_type)[:, None] * stride_x_mx_m XMxScalePtrs += offs_k_scale.to(index_type)[None, :] * stride_x_mx_k - else: - XMxScalePtrs = None acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32) + # ------------------------------------------------------------ + # inner loop + # ------------------------------------------------------------ + loop_k = tl.load(XSliceSizes + pid_z) if RAGGED_DIMENSION == "K" else K - off_k_x0 + k_tiles = tl.cdiv(loop_k, BLOCK_K * SPLIT_K) loop_bound = tl.maximum(k_tiles, 1) tl.assume(loop_bound > 0) # Currently necessary for the compiler to flatten the loop properly. for ki in tl.range(loop_bound, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER): - if EXPT_IS_INNER and ki >= k_tiles: + if RAGGED_DIMENSION == "K" and ki >= k_tiles: # Tile #ki does not exist: use out-of-bound indices to mask all loads. off_k_x = K off_k_w = K_W @@ -272,13 +258,13 @@ def _p_matmul_ogs( x = X.gather(offs_x_m, off_k_x) elif X_TMA_MODE == "dense": if X_TRANSPOSE: - x = X.load([start_z, off_k_x, start_m + off_m]) + x = X.load([off_x_z, off_k_x, slice_off_m + off_m]) x = x.reshape(BLOCK_K, BLOCK_M).T else: - x = X.load([start_z, start_m + off_m, off_k_x]) + x = X.load([off_x_z, slice_off_m + off_m, off_k_x]) x = x.reshape(BLOCK_M, BLOCK_K) elif X_TMA_MODE == "ragged": - x = load_ragged(X, start_m, eM, [start_z, off_m, off_k_x], ragged_dim=1) + x = load_ragged(X, slice_off_m, shape_m, [off_x_z, off_m, off_k_x], ragged_dim=1) x = x.reshape(BLOCK_M, BLOCK_K) else: tl.static_assert(X_TMA_MODE is None) @@ -293,37 +279,39 @@ def _p_matmul_ogs( else: x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0) + # --- load x_scale --- + x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype) + if is_x_microscaled: + off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_PACK_DIVISOR) + if EVEN_K: + mask_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1) + else: + mask_k_scale = off_k_mx + tl.arange(0, MX_SCALE_BLOCK_K) < tl.cdiv(K, MX_PACK_DIVISOR) + mask_m = off_m + tl.arange(0, BLOCK_M) < shape_m + x_scales = tl.load(XMxScalePtrs, mask=mask_k_scale[None, :] & mask_m[:, None], other=0.0) + elif x_format == "fp16" or x_format == "bf16": + x_scales: tl.constexpr = None + else: + x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8) + # --- load w --- if W_TRANSPOSE: - w = tl.reshape(W.load([expt_id, off_n, off_k_w]), W.block_shape[1:]).T + w = tl.reshape(W.load([off_w_z, off_n, off_k_w]), W.block_shape[1:]).T else: - w = tl.reshape(W.load([expt_id, off_k_w, off_n]), W.block_shape[1:]) + w = tl.reshape(W.load([off_w_z, off_k_w, off_n]), W.block_shape[1:]) # --- load w_scale --- + w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype) if is_w_microscaled: - x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype) - w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype) off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_PACK_DIVISOR) - - if is_x_microscaled: - if EVEN_K: - mask_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1) - else: - mask_k_scale = off_k_mx + tl.arange(0, MX_SCALE_BLOCK_K) < tl.cdiv(K, MX_PACK_DIVISOR) - mask_m = off_m + tl.arange(0, BLOCK_M) < eM - x_scales = tl.load(XMxScalePtrs, mask=mask_k_scale[None, :] & mask_m[:, None], other=0.0) - elif x_format == "fp16" or x_format == "bf16": - x_scales: tl.constexpr = None - else: - x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8) tl.static_assert(MX_PACK_DIVISOR % W_PACK_DIVISOR == 0) if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE": - flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128) + flattened_expt_n_idx = off_w_z * ((N + 127) // 128) + (off_n // 128) w_scales = WMxScale.load([0, flattened_expt_n_idx, off_k_mx // 4, 0, 0]) w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1])) w_scales = unswizzle_mx_scale_bw(w_scales) else: - w_scales = WMxScale.load([expt_id, off_k_mx, off_n]) + w_scales = WMxScale.load([off_w_z, off_k_mx, off_n]) w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T # --- update accumulator --- @@ -332,25 +320,35 @@ def _p_matmul_ogs( acc = tl.dot_scaled(w.T, w_scales, w_format, x.T, x_scales, x_format, acc=acc, fast_math=True) else: acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True) - if is_x_microscaled: - XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k else: if SWAP_XW: acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) else: acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) + if is_x_microscaled: + XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k + + # ------------------------------------------------------------ + # epilogue + # ------------------------------------------------------------ if INDEPENDENT_EPILOGUE: tile_id1 += NUM_SMS - expt_id1, _, start_z1, start_m1, eM1, off_m1, pid_n1, _, pid_k1, _, _, _ = _load_tile_attrs( - tile_id1, num_tiles, grid_m - padding_m, grid_n, - M, K, ExptData, ExptHist, ExptOffs, ExptTileOffs, - EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED, - BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K, - GROUP_M, XCD_SWIZZLE, SWIZZLE_MX_VALUE) + pid_s1, pid_m1, pid_n1, pid_k1 = compute_pids(tile_id1, useful_grid_m, grid_n, num_blocks, XCD_SWIZZLE, GROUP_M, SPLIT_K) + expt_id1, _, start_z1, start_m1, off_m1, _, _ = compute_offsets( + pid_z, pid_m, pid_k, + XBlockSchedule, XSliceOffs, X_SLICE_SIZES_DIVISIBILITY, + WBlockSchedule, WSliceOffs, W_SLICE_SIZES_DIVISIBILITY, + RAGGED_DIMENSION, + BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K + ) off_n1 = pid_n1 * BLOCK_N + if RAGGED_DIMENSION == "M": + eM1 = tl.load(XSliceSizes + expt_id1) + else: + eM1 = M else: - tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z_out, start_m, eM + tile_id1, expt_id1, start_z1, start_m1, eM1 = block_id, off_w_z, off_y_z, slice_off_m, shape_m off_m1, off_n1, pid_k1 = off_m, off_n, pid_k offs_m = off_m1 + tl.arange(0, BLOCK_M) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py similarity index 88% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py rename to python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index 8aa185b0ae..de2c3f2bd0 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -54,7 +54,7 @@ def make_default_opt_flags_amd( m, n, k, - routing_data, + ragged_metadata, can_use_persistent_tma, can_use_split_k, enforce_bitwise_invariance, @@ -65,13 +65,13 @@ def make_default_opt_flags_amd( ): constraints_supported = ["block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() - # tokens per expert - if routing_data is None: - tokens_per_expt = m - elif routing_data.expected_tokens_per_expt is None: - tokens_per_expt = max(1, m // routing_data.n_expts_tot) + # tokens per slice + if ragged_metadata is None: + slice_size = m + elif ragged_metadata.expected_slice_size is None: + slice_size = max(1, m // ragged_metadata.n_slices) else: - tokens_per_expt = routing_data.expected_tokens_per_expt + slice_size = ragged_metadata.expected_slice_size is_cdna4 = get_cdna_version() == 4 # block_m @@ -79,15 +79,15 @@ def make_default_opt_flags_amd( block_m = constraints["block_m"] elif enforce_bitwise_invariance: block_m = 256 if is_cdna4 else 128 - elif tokens_per_expt >= 512 and n >= 2048: + elif slice_size >= 512 and n >= 2048: block_m = 256 if is_cdna4 else 128 elif is_cdna4 and m >= 512: block_m = 128 else: - block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64)) + block_m = max(32, min(triton.next_power_of_2(slice_size), 64)) - if routing_data is not None: - grid_m = routing_data.n_blocks(m, block_m) + if ragged_metadata is not None: + grid_m = ragged_metadata.n_blocks(ragged_metadata.n_slices, m, block_m) else: grid_m = triton.cdiv(m, block_m) # group_m: @@ -122,9 +122,13 @@ def make_default_opt_flags_amd( if epilogue_subtile is None: epilogue_subtile = 1 + # prevents OutOfSharedMemoryError for mxfp8 on CDNA3 + if get_cdna_version() == 3 and bitwidth(rhs_dtype) == 8 and precision_config.b_mx_scale is not None: + num_stages = 1 + # specific configs for F16 x MXFP4 on CDNA4 # Note that these configs will exceed LDS usage with async copy enabled - if is_cdna4 and bitwidth(lhs_dtype) == 16 and bitwidth(rhs_dtype) == 4 and precision_config.weight_scale is not None: + if is_cdna4 and bitwidth(lhs_dtype) == 16 and bitwidth(rhs_dtype) == 4 and precision_config.b_mx_scale is not None: split_k = 1 if m <= 1024: target_kernel_kwargs["waves_per_eu"] = 3 @@ -186,11 +190,11 @@ def make_default_opt_flags_nvidia( assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None or batch_size > 1: - tokens_per_expt = m - elif routing_data.expected_tokens_per_expt is None: - tokens_per_expt = max(1, m // routing_data.n_expts_tot) + slice_size = m + elif routing_data.expected_slice_size is None: + slice_size = max(1, m // routing_data.n_slices) else: - tokens_per_expt = routing_data.expected_tokens_per_expt + slice_size = routing_data.expected_slice_size # pid swizzling group_m = 8 xcd_swizzle = 1 @@ -200,14 +204,14 @@ def make_default_opt_flags_nvidia( elif enforce_bitwise_invariance: block_m = 128 else: - if tokens_per_expt <= 64 and routing_data is not None and routing_data.expt_hist is not None: + if slice_size <= 64 and routing_data is not None and routing_data.slice_sizes is not None: # Ragged and likely memory bound; set the block size higher to minimize loading weights more than once. - if lhs_dtype == torch.bfloat16 and rhs_dtype == FP4 and tokens_per_expt >= 16 and torch.cuda.get_device_capability()[0] >= 10: - block_m = max(16, min(triton.next_power_of_2(8 * tokens_per_expt), 128)) + if lhs_dtype == torch.bfloat16 and rhs_dtype == FP4 and slice_size >= 16 and torch.cuda.get_device_capability()[0] >= 10: + block_m = max(16, min(triton.next_power_of_2(8 * slice_size), 128)) else: - block_m = max(16, min(triton.next_power_of_2(2 * tokens_per_expt), 64)) + block_m = max(16, min(triton.next_power_of_2(2 * slice_size), 64)) else: - block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + block_m = max(16, min(triton.next_power_of_2(slice_size), 128)) # block n arch = None block_n, block_n_tma = opt_flags_nvidia.compute_block_n(n, arch, precision_config) @@ -216,7 +220,7 @@ def make_default_opt_flags_nvidia( n_sms = torch.cuda.get_device_properties(0).multi_processor_count tiles_per_sm = grid_size_tma / n_sms supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9) - requires_persistent = (get_layout(precision_config.act_scale) is not None or get_layout(precision_config.weight_scale) is not None) and target_info.has_native_mxfp() + requires_persistent = (get_layout(precision_config.a_mx_scale) is not None or get_layout(precision_config.b_mx_scale) is not None) and target_info.has_native_mxfp() if constraints.get("is_persistent", None) is not None: is_persistent = constraints["is_persistent"] elif requires_persistent: @@ -231,7 +235,7 @@ def make_default_opt_flags_nvidia( block_n = block_n_tma if is_persistent else block_n # block k block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in) - if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and tokens_per_expt > 1 and lhs_dtype != torch.bfloat16: + if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and slice_size > 1 and lhs_dtype != torch.bfloat16: # Swap block_n and block_k for mxfp4 weights so that block_k is a full cacheline, so long as K is sufficiently large. # TODO: swizzle the HBM layout of the weights instead block_n, block_k = block_k, block_n @@ -332,7 +336,7 @@ def make_opt_flags( m, n, k, - routing_data, + ragged_metadata, can_use_persistent_tma, can_use_split_k, epilogue_effective_itemsize, @@ -357,7 +361,7 @@ def make_opt_flags( opt_flags_constraints = opt_flags_constraints.copy() opt_flags_constraints.update(block_k=block_k, split_k=1) args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, batch_size, m, n, k, - routing_data, can_use_persistent_tma, can_use_split_k, + ragged_metadata, can_use_persistent_tma, can_use_split_k, enforce_bitwise_invariance, epilogue_effective_itemsize, x_transpose, has_y_acc_in, opt_flags_constraints] backend = triton.runtime.driver.active.get_current_target().backend diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py similarity index 92% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py rename to python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py index ffe06c333f..593688a74d 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py @@ -28,6 +28,6 @@ def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precisi # TODO: block_k = 128 seems to work better for now. # perhaps due to increased number of k loops to pipeline - if precision_config.weight_scale is not None and get_cdna_version() != 4: + if precision_config.b_mx_scale is not None and get_cdna_version() != 4: block_k = 128 return block_n, block_k diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py similarity index 92% rename from python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py rename to python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py index 2cc0f3d41d..7998b56089 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py @@ -8,7 +8,7 @@ def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n): if routing_data is not None and batch_size == 1: - grid_m = routing_data.n_blocks(m, block_m) + grid_m = routing_data.n_blocks(routing_data.n_slices, m, block_m) else: grid_m = triton.cdiv(m, block_m) grid_n = (n + block_n - 1) // block_n @@ -17,10 +17,10 @@ def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n): def compute_block_n(n: int, arch, precision_config): # block_n: - layout = get_layout(precision_config.weight_scale) + layout = get_layout(precision_config.b_mx_scale) if isinstance(layout, HopperMXScaleLayout): if layout.num_warps in [4, 8]: - # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py#L265 + # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/matmul_details/_matmul.py#L265 block_n = 2 * layout.num_warps * 2 * 8 return block_n, block_n elif precision_config.max_num_imprecise_acc is None and n > 128: @@ -41,7 +41,7 @@ def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_d elif k is not None: min_block_k = 32 if is_persistent or lhs_width != 16 or rhs_width != 16 else 16 block_k = max(min_block_k, min(triton.next_power_of_2(k), block_k)) - has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None + has_mx_weight_scale = precision_config is not None and precision_config.b_mx_scale is not None if has_native_mxfp and is_persistent and has_mx_weight_scale: block_k = min(block_k, 128) if has_y_acc_in and lhs_width == rhs_width == 16 and not target_info.cuda_capability_geq(10, 0): @@ -62,7 +62,7 @@ def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int: def compute_num_warps(block_m, block_n, is_persistent: bool, precision_config): - layout = get_layout(precision_config.weight_scale) + layout = get_layout(precision_config.b_mx_scale) if isinstance(layout, HopperMXScaleLayout): return layout.num_warps return max(block_m * block_n // 4096, 4 if is_persistent else 1) @@ -90,7 +90,7 @@ def compute_num_stages( device_props = torch.cuda.get_device_properties(0) smem_capacity = device_props.shared_memory_per_block_optin has_native_mxfp = target_info.cuda_capability_geq(10, 0) - if has_native_mxfp and getattr(precision_config, "weight_scale", None) is not None: + if has_native_mxfp and getattr(precision_config, "b_mx_scale", None) is not None: if rhs_dtype == FP4: # 4-bit e2m1 weights are padded 2x # https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory @@ -114,7 +114,7 @@ def compute_num_stages( smem_capacity -= int((block_m + 4) * acc_block_n * acc_size) if x_transpose: smem_capacity -= block_m * block_k * lhs_dtype.itemsize - if precision_config.weight_scale is not None: + if precision_config.b_mx_scale is not None: # mx scales stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE)) elif has_native_mxfp: diff --git a/python/triton_kernels/triton_kernels/reduce.py b/python/triton_kernels/triton_kernels/reduce.py index 4d64173a9d..f9bf58be02 100644 --- a/python/triton_kernels/triton_kernels/reduce.py +++ b/python/triton_kernels/triton_kernels/reduce.py @@ -16,29 +16,30 @@ class PostprocessFn: @triton.jit -def _reduce(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x tensor (input) - XMx, stride_xmxr, stride_xmx0, stride_xmx1, # x mx scale - Y, stride_y0: tl.int64, stride_y1, # y tensor (output) - YMx, stride_ymx0, stride_ymx1, # y mx scale - Mask, stride_mr, stride_m0, stride_m1, # mask tensor - Scale, stride_sr, stride_s0, stride_s1, # scale tensor - K, S0, X_S1, Y_S1, # shape (K = reduction dim; S0, IN_S1 = input dims, OUT_S1 = output dims) - POSTPROCESS_FN1: tl.constexpr, postprocess_fn1_args, # - POSTPROCESS_FN2: tl.constexpr, postprocess_fn2_args, # - XFlex, # x flex (global) scale - YFlexExpected, YFlexActual, YFlexChecksum, Y_FLEX_SATURATE_INF: tl.constexpr, # y flex (global) scale - IS_MASK_NONE: tl.constexpr, # - BROADCAST_R: tl.constexpr, # - BROADCAST_S0: tl.constexpr, # - BROADCAST_S1: tl.constexpr, # - IS_SCALE_NONE: tl.constexpr, # - SCALE_BROADCAST_R: tl.constexpr, # - SCALE_BROADCAST_S0: tl.constexpr, # - SCALE_BROADCAST_S1: tl.constexpr, # - BLOCK_S0: tl.constexpr, # - BLOCK_X_S1: tl.constexpr, # - BLOCK_Y_S1: tl.constexpr, # - ): +def _reduce_forward(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x tensor (input) + XMx, stride_xmxr, stride_xmx0, stride_xmx1, # x mx scale + Y, stride_y0: tl.int64, stride_y1, # y tensor (output) + YMx, stride_ymx0, stride_ymx1, # y mx scale + Mask, stride_mr, stride_m0, stride_m1, # mask tensor + Scale, stride_sr, stride_s0, stride_s1, # scale tensor + K, S0, X_S1, Y_S1, # shape (K = reduction dim; S0, IN_S1 = input dims, OUT_S1 = output dims) + POSTPROCESS_FN1: tl.constexpr, postprocess_fn1_args, # + POSTPROCESS_FN2: tl.constexpr, postprocess_fn2_args, # + XFlex, # x flex (global) scale + YFlexExpected, YFlexActual, YFlexChecksum, + Y_FLEX_SATURATE_INF: tl.constexpr, # y flex (global) scale + IS_MASK_NONE: tl.constexpr, # + BROADCAST_R: tl.constexpr, # + BROADCAST_S0: tl.constexpr, # + BROADCAST_S1: tl.constexpr, # + IS_SCALE_NONE: tl.constexpr, # + SCALE_BROADCAST_R: tl.constexpr, # + SCALE_BROADCAST_S0: tl.constexpr, # + SCALE_BROADCAST_S1: tl.constexpr, # + BLOCK_S0: tl.constexpr, # + BLOCK_X_S1: tl.constexpr, # + BLOCK_Y_S1: tl.constexpr, # + ): pid_s0 = tl.program_id(0) pid_s1 = tl.program_id(1) tl.static_assert(BLOCK_X_S1 % 32 == 0) @@ -95,9 +96,9 @@ def _reduce(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x tensor tl.store(y_ptrs, y, mask=valid_s0[:, None] & valid_y_s1[None, :]) -specializations = SpecializationModule( - "reduce", - kernels=[("_reduce", _reduce)], +forward_specializations = SpecializationModule( + "reduce_forward", + kernels=[("_reduce_forward", _reduce_forward)], closure_args={ "postprocess_fn1": ClosureArg("POSTPROCESS_FN1", "postprocess_fn1_args"), "postprocess_fn2": ClosureArg("POSTPROCESS_FN2", "postprocess_fn2_args"), @@ -105,7 +106,7 @@ def _reduce(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x tensor ) -def reduce( +def reduce_forward( x: torch.Tensor, dim: int, mask: Optional[torch.Tensor] = None, @@ -216,8 +217,8 @@ def reduce( grid = (triton.cdiv(S0, BLOCK_S0), triton.cdiv(Y_S1, BLOCK_Y_S1)) mask_arg = mask if mask is not None else None scale_arg = scale if scale is not None else None - reduce_kernel = specializations.get(postprocess_fn1=postprocess_fn1.specs, - postprocess_fn2=postprocess_fn2.specs)._reduce + reduce_kernel = forward_specializations.get(postprocess_fn1=postprocess_fn1.specs, + postprocess_fn2=postprocess_fn2.specs)._reduce_forward reduce_kernel[grid]( x_flex.reinterpret(x), stride_xr, stride_x0, stride_x1, # x_mxscale, stride_xmxr, stride_xmx0, stride_xmx1, # @@ -245,6 +246,311 @@ def reduce( return y, y_mxscale +# ------------------------------------------------------------ + + +@triton.jit +def _reduce_backward( + dY, + stride_y0: tl.int64, + stride_y1, # upstream grad (S0, Y_S1) + dX, + stride_xr: tl.int64, + stride_x0: tl.int64, + stride_x1, # grad wrt X (K, S0, X_S1) in the chosen layout + XMx, + stride_xmxr, + stride_xmx0, + stride_xmx1, # input micro-scales (optional) + Mask, + stride_mr, + stride_m0, + stride_m1, # mask (optional) + Scale, + stride_sr, + stride_s0, + stride_s1, # scale (optional) + K, + S0, + X_S1, + Y_S1, # shapes + XFlex, # global input flex scale (scalar device buffer) + IS_MASK_NONE: tl.constexpr, + BROADCAST_R: tl.constexpr, + BROADCAST_S0: tl.constexpr, + BROADCAST_S1: tl.constexpr, + IS_SCALE_NONE: tl.constexpr, + SCALE_BROADCAST_R: tl.constexpr, + SCALE_BROADCAST_S0: tl.constexpr, + SCALE_BROADCAST_S1: tl.constexpr, + REDUCTION_N: tl.constexpr, # maps X_S1 -> Y_S1 (grouped sum in fwd) + BLOCK_S0: tl.constexpr, + BLOCK_X_S1: tl.constexpr, +): + # Tile over (S0, X_S1). We loop over the reduction K dimension. + pid_s0 = tl.program_id(0) + pid_s1 = tl.program_id(1) + + tl.static_assert(BLOCK_X_S1 % 32 == 0) + BLOCK_X_SMX1: tl.constexpr = BLOCK_X_S1 // 32 + + offs_s0 = pid_s0 * BLOCK_S0 + tl.arange(0, BLOCK_S0) + offs_x_s1 = pid_s1 * BLOCK_X_S1 + tl.arange(0, BLOCK_X_S1) + offs_x_smx1 = pid_s1 * BLOCK_X_SMX1 + tl.arange(0, BLOCK_X_SMX1) + + valid_s0 = offs_s0 < S0 + valid_x_s1 = offs_x_s1 < X_S1 + valid_in_smx1 = offs_x_smx1 < tl.cdiv(X_S1, 32) + + # Map X_S1 positions to their Y_S1 group index (grouped-sum fwd) + offs_y_from_x = offs_x_s1 // REDUCTION_N + valid_y_from_x = offs_y_from_x < Y_S1 + + # Load upstream grad; broadcasting over the REDUCTION_N group happens via indexing. + dy_ptrs = dY + offs_s0[:, None] * stride_y0 + offs_y_from_x[None, :] * stride_y1 + dy = tl.load(dy_ptrs, mask=valid_s0[:, None] & valid_y_from_x[None, :], other=0.0).to(tl.float32) + + # Global flex scale (scalar) + x_flex_scale = load_scale(XFlex) + + # Loop over the reduced dimension + for k in tl.range(0, K, num_stages=2): + g = dy + # Multiply by input micro-scale per group of 32 lanes if present + if XMx is not None: + xmx_ptrs = XMx + k * stride_xmxr + offs_s0[:, None] * stride_xmx0 + offs_x_smx1[None, :] * stride_xmx1 + xmx = tl.load(xmx_ptrs, mask=valid_s0[:, None] & valid_in_smx1[None, :], other=0) + xmx = (xmx.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + g = (g.reshape([BLOCK_S0, BLOCK_X_S1 // 32, 32]) * xmx[:, :, None]).reshape([BLOCK_S0, BLOCK_X_S1]) + # Multiply by global input flex scale + g = g * x_flex_scale + # Multiply by per-element Scale if provided + if not IS_SCALE_NONE: + k_term_s = 0 if SCALE_BROADCAST_R else (k * stride_sr) + s0_term_s = 0 if SCALE_BROADCAST_S0 else (offs_s0[:, None] * stride_s0) + s1_term_s = 0 if SCALE_BROADCAST_S1 else (offs_x_s1[None, :] * stride_s1) + s_ptrs = Scale + k_term_s + s0_term_s + s1_term_s + s = tl.load(s_ptrs, mask=valid_s0[:, None] & valid_x_s1[None, :], other=1) + g = g * s + # Apply mask if provided + if not IS_MASK_NONE: + k_term = 0 if BROADCAST_R else (k * stride_mr) + s0_term = 0 if BROADCAST_S0 else (offs_s0[:, None] * stride_m0) + s1_term = 0 if BROADCAST_S1 else (offs_x_s1[None, :] * stride_m1) + m_ptrs = Mask + k_term + s0_term + s1_term + m = tl.load(m_ptrs, mask=valid_s0[:, None] & valid_x_s1[None, :], other=1) + g = tl.where(m != 0, g, 0.0) + # + dx_ptrs = dX + k * stride_xr + offs_s0[:, None] * stride_x0 + offs_x_s1[None, :] * stride_x1 + tl.store(dx_ptrs, g, mask=valid_s0[:, None] & valid_x_s1[None, :]) + + +def reduce_backward( + dy: torch.Tensor, + x_shape: tuple[int, int, int], + dim: int, + *, + mask: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + x_mxscale: Optional[torch.Tensor], + x_flex: Optional[InFlexData], + postprocess_fn1: Optional[PostprocessFn], + x_strides: tuple[int, int, int], + x_mx_strides: Optional[tuple[int, int, int]], + mask_strides: Optional[tuple[int, int, int]], + scale_strides: Optional[tuple[int, int, int]], + dx: torch.Tensor, +): + # Shapes/axes handling mirrors `reduce(...)` + if dim < 0: + dim += 3 + dims = (0, 1, 2) + nonred = tuple(d for d in dims if d != dim) + + S0, X_S1 = x_shape[nonred[0]], x_shape[nonred[1]] + K = x_shape[dim] + + # Postprocess grouping (grouped sum). Default is identity (1). + reduction_n = (postprocess_fn1.specs.reduction_n if postprocess_fn1 is not None else FnSpecs.default().reduction_n) + Y_S1 = X_S1 // reduction_n + assert dy.shape == (S0, Y_S1), f"dY shape {dy.shape} mismatch with (S0={S0}, Y_S1={Y_S1})" + + # Strides for dX must match the element size of the tensor passed to the kernel. + # If we reinterpret the dtype (e.g., flex/float8), use the reinterpreted view's strides. + dx_view = x_flex.reinterpret(dx) + dx_str0, dx_str1, dx_str2 = dx_view.stride() + stride_xr = (dx_str0 if dim == 0 else (dx_str1 if dim == 1 else dx_str2)) + stride_x0 = (dx_str0 if nonred[0] == 0 else (dx_str1 if nonred[0] == 1 else dx_str2)) + stride_x1 = (dx_str0 if nonred[1] == 0 else (dx_str1 if nonred[1] == 1 else dx_str2)) + stride_xmxr = stride_xmx0 = stride_xmx1 = 0 + if x_mxscale is not None: + stride_xmxr, stride_xmx0, stride_xmx1 = x_mx_strides + + if mask is not None: + mstr0, mstr1, mstr2 = mask_strides + stride_mr = (mstr0 if dim == 0 else (mstr1 if dim == 1 else mstr2)) + stride_m0 = (mstr0 if nonred[0] == 0 else (mstr1 if nonred[0] == 1 else mstr2)) + stride_m1 = (mstr0 if nonred[1] == 0 else (mstr1 if nonred[1] == 1 else mstr2)) + else: + stride_mr = stride_m0 = stride_m1 = 0 + + if scale is not None: + sstr0, sstr1, sstr2 = scale_strides + stride_sr = (sstr0 if dim == 0 else (sstr1 if dim == 1 else sstr2)) + stride_s0 = (sstr0 if nonred[0] == 0 else (sstr1 if nonred[0] == 1 else sstr2)) + stride_s1 = (sstr0 if nonred[1] == 0 else (sstr1 if nonred[1] == 1 else sstr2)) + else: + stride_sr = stride_s0 = stride_s1 = 0 + + # Launch configuration mirrors forward (but we tile over X_S1, not Y_S1) + BLOCK_S0 = 64 + BLOCK_X_S1 = 128 + grid = (triton.cdiv(S0, BLOCK_S0), triton.cdiv(X_S1, BLOCK_X_S1)) + + _reduce_backward[grid]( + dy, + dy.stride(0), + dy.stride(1), + dx_view, + stride_xr, + stride_x0, + stride_x1, + x_mxscale, + stride_xmxr, + stride_xmx0, + stride_xmx1, + mask, + stride_mr, + stride_m0, + stride_m1, + scale, + stride_sr, + stride_s0, + stride_s1, + K, + S0, + X_S1, + Y_S1, + x_flex.scale, + IS_MASK_NONE=(mask is None), + BROADCAST_R=(stride_mr == 0), + BROADCAST_S0=(stride_m0 == 0), + BROADCAST_S1=(stride_m1 == 0), + IS_SCALE_NONE=(scale is None), + SCALE_BROADCAST_R=(stride_sr == 0), + SCALE_BROADCAST_S0=(stride_s0 == 0), + SCALE_BROADCAST_S1=(stride_s1 == 0), + REDUCTION_N=reduction_n, + BLOCK_S0=BLOCK_S0, + BLOCK_X_S1=BLOCK_X_S1, + num_warps=4, + ) + + +# ------------------------------------------------------------ + +backward_specializations = SpecializationModule( + "reduce_backward", + kernels=[("_reduce_backward", _reduce_backward)], + closure_args={ + "postprocess_fn1": ClosureArg("POSTPROCESS_FN1", "postprocess_fn1_args"), + "postprocess_fn2": ClosureArg("POSTPROCESS_FN2", "postprocess_fn2_args"), + }, +) + + +class _ReduceAutograd(torch.autograd.Function): + + @staticmethod + def forward(ctx, x: torch.Tensor, dim: int, mask: Optional[torch.Tensor], scale: Optional[torch.Tensor], + x_mxscale: Optional[torch.Tensor], x_flex: Optional[InFlexData], y_dtype: Optional[torch.dtype], + y_flex: Optional[OutFlexData], y_flex_saturate_inf: bool, y_has_mx: Optional[bool], + y: Optional[torch.Tensor], postprocess_fn1: Optional[PostprocessFn], + postprocess_fn2: Optional[PostprocessFn]): + # Run your existing Triton forward + y, y_mx = reduce_forward( + x=x, + dim=dim, + mask=mask, + scale=scale, + x_mxscale=x_mxscale, + x_flex=x_flex, + y_dtype=y_dtype, + y_flex=y_flex, + y_flex_saturate_inf=y_flex_saturate_inf, + y_has_mx=y_has_mx, + y=y, + postprocess_fn1=postprocess_fn1, + postprocess_fn2=postprocess_fn2, + ) + + # Save everything needed for backward (no tensors are modified) + ctx.dim = dim + ctx.x_shape = tuple(x.shape) + ctx.x_dtype = x.dtype + ctx.device = x.device + ctx.mask = mask + ctx.scale = scale + ctx.x_mxscale = x_mxscale + ctx.x_flex = x_flex if x_flex is not None else InFlexData() + ctx.postprocess_fn1 = postprocess_fn1 if postprocess_fn1 is not None else PostprocessFn() + ctx.x_strides = tuple(x.stride()) + ctx.x_mx_strides = tuple(x_mxscale.stride()) if x_mxscale is not None else None + ctx.mask_strides = tuple(mask.stride()) if mask is not None else None + ctx.scale_strides = tuple(scale.stride()) if scale is not None else None + ctx.y_has_mx = bool(y_mx is not None) + + return y, y_mx + + @staticmethod + def backward(ctx, grad_y: torch.Tensor, grad_y_mxscale: Optional[torch.Tensor] = None): + # We do not support grads through MX-quantized outputs (no torch compute in bwd) + if ctx.y_has_mx: + raise NotImplementedError("Backward with y_mxscale (MX-quantized outputs) is not supported.") + + # Allocate grad for x; (no torch compute) + dx = torch.empty(ctx.x_shape, dtype=ctx.x_dtype, device=grad_y.device) + + reduce_backward( + dy=grad_y, + x_shape=ctx.x_shape, + dim=ctx.dim, + mask=ctx.mask, + scale=ctx.scale, + x_mxscale=ctx.x_mxscale, + x_flex=ctx.x_flex, + postprocess_fn1=ctx.postprocess_fn1, + x_strides=ctx.x_strides, + x_mx_strides=ctx.x_mx_strides, + mask_strides=ctx.mask_strides, + scale_strides=ctx.scale_strides, + dx=dx, + ) + return dx, None, None, None, None, None, None, None, None, None, None, None, None + + +def reduce( + x: torch.Tensor, + dim: int, + mask: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + x_mxscale: Optional[torch.Tensor] = None, + x_flex: Optional[InFlexData] = InFlexData(), + y: Optional[torch.Tensor] = None, + y_dtype: Optional[torch.dtype] = None, + y_flex: Optional[OutFlexData] = OutFlexData(), + y_flex_saturate_inf: bool = False, + y_has_mx: Optional[bool] = None, + postprocess_fn1: Optional[PostprocessFn] = None, + postprocess_fn2: Optional[PostprocessFn] = None, +): + return _ReduceAutograd.apply(x, dim, mask, scale, x_mxscale, x_flex, y_dtype, y_flex, # + y_flex_saturate_inf, y_has_mx, y, postprocess_fn1, postprocess_fn2) + + +# ------------------------------------------------------------ + + def compute_actual_scale(x, dtype, per_batch_scale=False): max_finite = { torch.float8_e5m2: MAX_FINITE_FLOAT8E5, diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py index f3e9582a4c..9656fd0f08 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py @@ -13,11 +13,13 @@ class HopperMXScaleLayout(Layout): name: str = "HOPPER_SCALE" def __init__(self, shape, mx_axis, num_warps=8) -> None: + if mx_axis is not None and mx_axis < 0: + mx_axis += len(shape) assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2" super().__init__(shape) self.mx_axis = mx_axis self.num_warps = num_warps - *self.leading_shape, _, _ = shape + *self.leading_shape, self.M, self.K = shape def _maybe_mT(self, data): if self.mx_axis == len(self.leading_shape): @@ -25,6 +27,7 @@ def _maybe_mT(self, data): return data def swizzle_data(self, data): + assert data.shape == (*self.leading_shape, self.M, self.K) data = self._maybe_mT(data).contiguous() *batch, M, K = data.shape SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8 @@ -59,7 +62,7 @@ def unswizzle_data(self, data): data = data.permute(*perm) data = data.reshape(*batch, M * 32, K // 32) data = self._maybe_mT(data) - return data + return data[..., :self.M, :self.K] def swizzle_block_shape(self, block_shape): return block_shape @@ -70,6 +73,8 @@ def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constex """ Triton inverse of swizzle_mxfp4_scale_hopper """ + if mx_axis is not None and mx_axis < 0: + mx_axis += len(x.shape) tl.static_assert(len(x.shape) == 2, "NYI") # implementation assumes mxfp data is packed along the last dimension x = x.trans() if mx_axis == 0 else x diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py index 118388c275..9128037bf7 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py @@ -92,7 +92,8 @@ class HopperMXValueLayout(Layout): def __init__(self, shape, mx_axis, mma_version=3): super().__init__(shape) - assert mx_axis in range(len(shape)) + if mx_axis < 0: + mx_axis += len(shape) self.mx_axis = mx_axis self.mma_version = mma_version *self.leading_shape, self.K, self.N, = shape diff --git a/python/triton_kernels/triton_kernels/tensor_details/ragged_tensor.py b/python/triton_kernels/triton_kernels/tensor_details/ragged_tensor.py index 277141d95a..9456fb6bf6 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/ragged_tensor.py +++ b/python/triton_kernels/triton_kernels/tensor_details/ragged_tensor.py @@ -45,6 +45,10 @@ class RaggedTensorMetadata: # NOTE 2: because the size of `block_schedule_data[k]` is data-dependent, we pad it with -1s # up to an user-provided upper bound block_schedule_data: torch.Tensor + # expected slice size (for heuristics) + expected_slice_size: int | None = None + # divisibility hint for values in `slice_sizes` + slice_sizes_divisibility: int = None def __post_init__(self): assert self.block_offs_data.shape[0] == len(RaggedTensorMetadata.block_sizes()) @@ -56,6 +60,10 @@ def __post_init__(self): if self.slice_offs is not None: assert self.slice_offs.dtype == torch.int32 + @property + def n_slices(self): + return self.slice_sizes.shape[0] + def block_offs(self, block_size): return self.block_offs_data[RaggedTensorMetadata.block_sizes().index(block_size)] @@ -63,10 +71,14 @@ def block_schedule(self, block_size): return self.block_schedule_data[RaggedTensorMetadata.block_sizes().index(block_size)] @staticmethod - def max_n_tiles(n_slices, n_total_rows): + def n_blocks(n_slices, n_total_rows, block_size): if n_total_rows <= n_slices: return n_total_rows - return n_slices - 1 - ((n_slices - n_total_rows - 1) // min(RaggedTensorMetadata.block_sizes())) + return n_slices - 1 - ((n_slices - n_total_rows - 1) // block_size) + + @staticmethod + def max_n_blocks(n_slices, n_total_rows): + return RaggedTensorMetadata.n_blocks(n_slices, n_total_rows, min(RaggedTensorMetadata.block_sizes())) @staticmethod def block_sizes_log2(): @@ -77,6 +89,11 @@ def block_sizes(): return [2**x for x in RaggedTensorMetadata.block_sizes_log2()] +def ragged_metadata_fields(metadata, block_size): + return (metadata.slice_sizes, metadata.slice_offs, metadata.block_offs(block_size), + metadata.block_schedule(block_size), metadata.expected_slice_size, metadata.slice_sizes_divisibility or 1) + + # utilities # --------------------------------------------------------- # @@ -171,7 +188,7 @@ def make_ragged_tensor_metadata(slice_sizes, n_total_rows): MEMSET_BLOCK = 512 dtype = torch.int32 device = slice_sizes.device - max_n_blocks = RaggedTensorMetadata.max_n_tiles(n_slices, n_total_rows) + max_n_blocks = RaggedTensorMetadata.max_n_blocks(n_slices, n_total_rows) slice_offs_combined, _ = empty_aligned((block_size_num + 1, n_slices + 1), dtype, device, MEMSET_BLOCK) block_schedule_data, n_memset_elts = empty_aligned((block_size_num, max_n_blocks), dtype, device, MEMSET_BLOCK) slice_offs, block_offs_data = slice_offs_combined[0], slice_offs_combined[1:] @@ -200,7 +217,7 @@ def make_ragged_tensor_metadata(slice_sizes, n_total_rows): def make_ragged_tensor_metadata_torch(slice_sizes, n_total_rows): assert slice_sizes.ndim == 1 n_slices = slice_sizes.shape[0] - max_n_blocks = RaggedTensorMetadata.max_n_tiles(n_slices, n_total_rows) + max_n_blocks = RaggedTensorMetadata.max_n_blocks(n_slices, n_total_rows) # offset for each experts device = slice_sizes.device slice_offs = torch.cumsum(slice_sizes, dim=0) @@ -226,11 +243,11 @@ def _build_schedule(block_off, n_blocks): block_offs = dict() block_pid_map = dict() for block_size in RaggedTensorMetadata.block_sizes(): - n_tiles = (slice_sizes + block_size - 1) // block_size - block = torch.cumsum(n_tiles, dim=0) + n_blocks = (slice_sizes + block_size - 1) // block_size + block = torch.cumsum(n_blocks, dim=0) block = torch.cat((torch.zeros(1, device=device), block)).int() block_offs[block_size] = block - block_pid_map[block_size] = _build_schedule(block, n_tiles) + block_pid_map[block_size] = _build_schedule(block, n_blocks) block_offs = torch.stack(list(block_offs.values())) block_pid_map = torch.stack(list(block_pid_map.values())) return RaggedTensorMetadata(slice_sizes, slice_offs, block_offs, block_pid_map) diff --git a/python/triton_kernels/triton_kernels/testing.py b/python/triton_kernels/triton_kernels/testing.py index fd7169a656..0b49a79dc3 100644 --- a/python/triton_kernels/triton_kernels/testing.py +++ b/python/triton_kernels/triton_kernels/testing.py @@ -5,6 +5,11 @@ import sys import torch from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 +from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4, make_ragged_tensor_metadata +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, MXFP_BLOCK_SIZE +from triton_kernels.tensor_details import layout +import itertools +from dataclasses import replace def assert_equal(ref, tri): @@ -194,3 +199,131 @@ def compute_actual_scale(x, dtype, per_batch_scale=False): }[dtype] maxvals = x.abs().amax(dim=tuple(range(1, x.ndim))) if per_batch_scale else x.abs().max() return maxvals / max_finite + + +# --- create tensor --- + + +def normalize_blocks(x, BLOCK_SIZE=None): + if BLOCK_SIZE is None: + BLOCK_SIZE = int(MXFP_BLOCK_SIZE) + x_ndim = x.ndim + if x_ndim == 2: + x = x.unsqueeze(0) + for e, i, j in itertools.product(range(x.shape[0]), range(0, x.shape[1], BLOCK_SIZE), + range(0, x.shape[2], BLOCK_SIZE)): + i_end = min(i + BLOCK_SIZE, x.shape[1]) + j_end = min(j + BLOCK_SIZE, x.shape[2]) + block = x[e, i:i_end, j:j_end] + m_abs = block.abs().max() + i_len = i_end - i + j_len = j_end - j + min_len = min(i_len, j_len) + signs = torch.randint(0, 2, (max(i_len, j_len), ), device=x.device) * 2 - 1 + block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs + if j_len > i_len: + block[i_len - 1, i_len:] = signs[min_len:] * m_abs + elif i_len > j_len: + block[j_len:, j_len - 1] = signs[min_len:] * m_abs + if x_ndim == 2: + x = x.squeeze(0) + return x + + +def alloc_rand(shape, device, dtype, requires_grad=False): + if dtype.itemsize == 1: + tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) + return tmp.to(dtype).requires_grad_(requires_grad) + ret = torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) + ret = normalize_blocks(ret) + return ret + + +def make_slice_sizes(n_slices, total_size, device="cuda"): + torch.manual_seed(0) + dtype = torch.int32 + if total_size < 0: + raise ValueError("total_size must be non-negative") + if n_slices <= 0: + return torch.zeros((0, ), dtype=dtype, device=device) + if total_size == 0: + return torch.zeros((n_slices, ), dtype=dtype, device=device) + # always set one slice size to zero + probs = torch.ones(n_slices, device=device) / n_slices + if n_slices > 1: + probs[2] += probs[1] + probs[1] = 0. + assignments = torch.multinomial(probs, total_size, replacement=True) + counts = torch.bincount(assignments, minlength=n_slices).to(dtype) + assert counts.sum().item() == total_size + assert len(counts) == n_slices + return counts + + +def pad_rows_to_multiples(A, indices, multiple=128, pad_value=float('nan')): + """ + Insert padding so that each row A[i] (for i in indices) + appears at an output row index that is a multiple of `multiple`. + """ + D = A.size(1) + out = [] + for i_cur, i_next in zip(indices[:-1], indices[1:]): + size = (i_next - i_cur) + size_padded = ((size + multiple - 1) // multiple) * multiple + cur = torch.full((size_padded, D), pad_value, dtype=A.dtype, device=A.device) + cur[:size, :] = A[i_cur:i_next, :] + out.append(cur) + return torch.vstack(out) + + +def pad_ragged_tensor(x, x_ragged_metadata, hbm_swizzling, transpose): + multiple = 128 if hbm_swizzling else 64 + if transpose: + y = pad_rows_to_multiples(x.T, x_ragged_metadata.slice_offs, multiple=multiple, pad_value=0).T.contiguous() + else: + y = pad_rows_to_multiples(x, x_ragged_metadata.slice_offs, multiple=multiple, pad_value=0).contiguous() + + y_ragged_metadata = replace(x_ragged_metadata, slice_offs=x_ragged_metadata.block_offs(multiple) * multiple, + slice_sizes_divisibility=multiple) + return y, y_ragged_metadata + + +def make_random_tensor(shape, n_slices, ragged_dim, ragged_padding, device, dtype, mxfp_dim, transpose, + squeeze_batch_dim, hbm_swizzling=False, is_mx_rowmajor=False): + # allocate buffer + buffer_shape = ((n_slices, ) if ragged_dim is None else tuple()) + shape + buffer_dtype = torch.bfloat16 if dtype.has_mx_scale else dtype.torch_dtype + buffer = alloc_rand(buffer_shape, device=device, dtype=buffer_dtype) + if squeeze_batch_dim: + buffer = buffer.squeeze(0) + # handle raggedness + ragged_metadata = None + if ragged_dim is not None: + slice_sizes = make_slice_sizes(n_slices, shape[ragged_dim], device=device) + ragged_metadata = make_ragged_tensor_metadata(slice_sizes, shape[ragged_dim]) + if ragged_padding: + buffer, ragged_metadata = pad_ragged_tensor(buffer, ragged_metadata, hbm_swizzling, ragged_dim == 1) + # handle transpose + if transpose: + buffer = buffer.mT.contiguous().mT + # handle mxfp + scales = None + if mxfp_dim is not None: + assert dtype.has_mx_scale + buffer_dtype = dtype.torch_dtype + if is_mx_rowmajor: + scales = downcast_to_mxfp(buffer, buffer_dtype, axis=mxfp_dim)[1] + buffer = downcast_to_mxfp(buffer.mT.contiguous(), buffer_dtype, axis=mxfp_dim)[0].mT + else: + buffer, scales = downcast_to_mxfp(buffer, buffer_dtype, axis=mxfp_dim) + buffer = wrap_torch_tensor(buffer, FP4 if dtype.is_mxfloat4 else buffer_dtype) + scales = wrap_torch_tensor(scales) + if dtype.is_mxfloat4 and hbm_swizzling and not is_mx_rowmajor: + # convert buffer to swizzled hbm layout + buffer_layout, buffer_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mxfp_dim) + buffer = convert_layout(buffer, buffer_layout, **buffer_layout_opts) + # convert scales to swizzled hbm layout + scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=mxfp_dim, num_warps=8) + scales = convert_layout(scales, scale_layout, **scale_layout_opts) + return buffer, scales, ragged_metadata From 597b17ecc27887c3c60390439cd6824b2d13ac45 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sun, 23 Nov 2025 13:57:20 -0800 Subject: [PATCH 2/4] [CMake] Fix build from empty CMAKE_LIBRARY_OUTPUT_DIRECTORY (#8810) If building from root `CMakeLists.txt`, we will have an empty `CMAKE_LIBRARY_OUTPUT_DIRECTORY`, then this will try to create `/plugins` and see permission failures. --- lib/Plugins/CMakeLists.txt | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/Plugins/CMakeLists.txt b/lib/Plugins/CMakeLists.txt index 593335d222..1987910922 100644 --- a/lib/Plugins/CMakeLists.txt +++ b/lib/Plugins/CMakeLists.txt @@ -35,9 +35,16 @@ foreach( plugin ${TRITON_PLUGIN_PASSES} ) "$<$:-undefined dynamic_lookup>" ) - set_target_properties(${plugin} PROPERTIES + # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python + # build. It is empty if building directly from the root + # CMakeLists.txt file. Therefore if not building from Python just + # use the default CMake shared lib path otherwise this causes a hard + # build error + if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + set_target_properties(${plugin} PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../plugins") + endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) target_compile_options(${plugin} PRIVATE -fvisibility=hidden) endforeach() From dcad27098ebea3e1ef30aca0a97ec774a60fbfdf Mon Sep 17 00:00:00 2001 From: Saeid Rostami <123997133+saeid-rostami@users.noreply.github.com> Date: Sun, 23 Nov 2025 17:52:07 -0500 Subject: [PATCH 3/4] [AMD] Fix test_dot_multidim floating-point comparision (#8780) This PR fixes test failures in `test_dot_multidim` that occur on AMD RDNA3 GPUs due to overly strict floating-point comparisons. The test currently uses `torch.equal()` which requires exact bit-for-bit equality, but different GPU architectures and drivers can produce slightly different results at the machine epsilon level while still being mathematically correct. This change replaces the exact equality check with `torch.allclose()` using appropriate tolerances. Triton --> tensor([[[[ -3.0000, -44.0000, 8.0000, ..., 1.0000, -28.0000, -12.0000], [ -6.0000, 46.0000, 10.0000, ..., -13.0000, 35.0000, 50.0000], [ 83.0000, 5.0000, 59.0000, ..., -32.0000, -19.0000, 96.0000], ..., Torch ---> tensor([[[[ -3., -44., 8., ..., 1., -28., -12.], [ -6., 46., 10., ..., -13., 35., 50.], [ 83., 5., 59., ..., -32., -19., 96.], ..., The printed values appear identical, but `torch.equal()` detects differences at the least significant bits level. --- python/test/unit/language/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d9446fc5ca..b40f98712b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6693,4 +6693,4 @@ def kernel(X, Y, Z, RANK: tl.constexpr, TRANS_A: tl.constexpr, TRANS_B: tl.const d = a.to(torch.float32) @ b.to(torch.float32) - assert torch.equal(c, d) + assert torch.allclose(c, d, rtol=1e-3, atol=1e-2) From 83eb05c24d757d6134ea37d3886c6093b1d1cd91 Mon Sep 17 00:00:00 2001 From: Witold Dziurdz Date: Thu, 27 Nov 2025 11:28:22 +0000 Subject: [PATCH 4/4] Align behavior with CUDA/HIP: skip test_matmul when swiglu_opts is not None and do_gamma is set Signed-off-by: Witold Dziurdz (cherry picked from commit 1479afdd64a69345c171ef4f5c504d68771b562b) Signed-off-by: Anatoly Myachev --- python/triton_kernels/tests/test_matmul.py | 10 ++-- python/triton_kernels/tests/test_reduce.py | 2 +- .../triton_kernels/triton_kernels/matmul.py | 2 +- .../matmul_details/opt_flags.py | 16 +++--- .../opt_flags_details/opt_flags_intel.py | 4 +- scripts/skiplist/default/triton_kernels.txt | 17 ------- scripts/skiplist/xe2/triton_kernels.txt | 49 ------------------- 7 files changed, 19 insertions(+), 81 deletions(-) diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 31093f567b..0a2b70141c 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -16,7 +16,7 @@ # testing utilities from triton_kernels.testing import assert_close, make_random_tensor # target-specific utilities -from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4 +from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4, is_xpu from triton_kernels.swiglu import swiglu, swiglu_fn from triton_kernels.swiglu import PrecisionConfig as SwiGLUPrecisionConfig @@ -243,6 +243,10 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, if split_k is not None and split_k > 1: pytest.skip("splitK hasn't been fully tested on AMD GPU.") + elif is_xpu(): + if swiglu_opts is not None and do_gamma: + pytest.xfail("NYI: swiglu and gamma not supported together") + if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3(): pytest.xfail("float8_e4m3fnuz only tested on AMD CDNA3 Platform") @@ -276,12 +280,12 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, if hbm_swizzling: pytest.skip("NYI: nner_expt_opt and HBM swizzling") if not colmajor_mxfp_weight: - if torch.cuda.get_device_capability()[0] < 10: + if is_cuda() and torch.cuda.get_device_capability()[0] < 10: pytest.skip("transposed mxfp weight not supported with cuda capability < 10") if block_m == 16: pytest.skip("PassManager::run failed from Triton compiler") # TODO: should construct the test case differently rather than overriding here - if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 10: + if "float8" in weight_dtype_str and is_cuda() and torch.cuda.get_device_capability()[0] < 10: b_transpose = True torch.manual_seed(0) diff --git a/python/triton_kernels/tests/test_reduce.py b/python/triton_kernels/tests/test_reduce.py index 77b87c808d..6ca8a09e22 100644 --- a/python/triton_kernels/tests/test_reduce.py +++ b/python/triton_kernels/tests/test_reduce.py @@ -60,7 +60,7 @@ def plus_a_reduce(x, a): @pytest.mark.parametrize("dim", [0, 1, 2]) def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn, device): is_hip = triton.runtime.driver.active.get_current_target().backend == "hip" - is_pre_h100 = torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0) + is_pre_h100 = device == "cuda" and torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0) if (is_hip or is_pre_h100) and "float8" in dtype_str: pytest.skip("float8 not supported on CUDA < 9.0") torch.manual_seed(0) diff --git a/python/triton_kernels/triton_kernels/matmul.py b/python/triton_kernels/triton_kernels/matmul.py index ee1804fac9..e6d6bb283c 100644 --- a/python/triton_kernels/triton_kernels/matmul.py +++ b/python/triton_kernels/triton_kernels/matmul.py @@ -318,7 +318,7 @@ def matmul(a, b, bias, ) # there seems to be a bug on A100 # pytest -vs test_matmul.py::test_op[False-False-False-False-pad_b-16-768-512-1024-ragged-float16-float16-10-1-False-None-False-False-False-True-None] - if ragged_dimension == "K" and torch.cuda.get_device_capability()[0] < 9: + if ragged_dimension == "K" and is_cuda() and torch.cuda.get_device_capability()[0] < 9: opt_flags.num_stages = 1 if ragged_dimension == "K": a_has_tma = opt_flags.is_persistent and (a.stride(-1) != 1 or (a_ragged_metadata.slice_sizes_divisibility is not None)) diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index 16fc572ca2..5c08d238b5 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -54,7 +54,7 @@ def make_default_opt_flags_intel( m, n, k, - routing_data, + ragged_metadata, can_use_persistent_tma, can_use_split_k, enforce_bitwise_invariance, @@ -65,13 +65,13 @@ def make_default_opt_flags_intel( ): constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() - # tokens per expert - if routing_data is None: - tokens_per_expt = m - elif routing_data.expected_tokens_per_expt is None: - tokens_per_expt = max(1, m // routing_data.n_expts_tot) + # tokens per slice + if ragged_metadata is None: + slice_size = m + elif ragged_metadata.expected_slice_size is None: + slice_size = max(1, m // ragged_metadata.n_slices) else: - tokens_per_expt = routing_data.expected_tokens_per_expt + slice_size = ragged_metadata.expected_slice_size # pid swizzling group_m = 8 xcd_swizzle = 1 @@ -81,7 +81,7 @@ def make_default_opt_flags_intel( elif enforce_bitwise_invariance: block_m = 128 else: - block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + block_m = max(16, min(triton.next_power_of_2(slice_size), 128)) # block n block_n = opt_flags_intel.compute_block_n(n) # is_persistent diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_intel.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_intel.py index 6c3a6ebc3f..586cc01095 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_intel.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_intel.py @@ -4,7 +4,7 @@ def compute_grid_size(routing_data, m, n, block_m, block_n): if routing_data is not None: - grid_m = routing_data.n_blocks(m, block_m) + grid_m = routing_data.n_blocks(routing_data.n_slices, m, block_m) else: grid_m = triton.cdiv(m, block_m) grid_n = (n + block_n - 1) // block_n @@ -19,7 +19,7 @@ def compute_block_n(n: int): def compute_block_k(k: int | None, is_persistent: bool, precision_config): if k is not None: block_k = max(32, min(128, triton.next_power_of_2(k))) - has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None + has_mx_weight_scale = precision_config is not None and precision_config.b_mx_scale is not None if is_persistent and has_mx_weight_scale: block_k = min(block_k, 128) return block_k diff --git a/scripts/skiplist/default/triton_kernels.txt b/scripts/skiplist/default/triton_kernels.txt index 7870a886fd..e69de29bb2 100644 --- a/scripts/skiplist/default/triton_kernels.txt +++ b/scripts/skiplist/default/triton_kernels.txt @@ -1,17 +0,0 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/5074 -tests/test_matmul.py::test_op[False-False-False-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] diff --git a/scripts/skiplist/xe2/triton_kernels.txt b/scripts/skiplist/xe2/triton_kernels.txt index 108cf912ba..e69de29bb2 100644 --- a/scripts/skiplist/xe2/triton_kernels.txt +++ b/scripts/skiplist/xe2/triton_kernels.txt @@ -1,49 +0,0 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/5074 -tests/test_matmul.py::test_op[False-False-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-False-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-False-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-True-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-True-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] -tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] -tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]