From 2b29c3dbf82c663ccdbb0a72bd97b40b1473bfa7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 29 Oct 2025 00:18:17 -0700 Subject: [PATCH 1/9] [triton_kernels] decouple split-k reduction from inter-expert reductions in matmul (#8483) --- python/triton_kernels/bench/distributed.py | 3 +- python/triton_kernels/tests/test_matmul.py | 8 +- python/triton_kernels/tests/test_reduce.py | 15 +- .../triton_kernels/matmul_ogs.py | 282 ++++++------------ .../matmul_ogs_details/_common.py | 8 +- .../matmul_ogs_details/_matmul_ogs.py | 9 +- .../matmul_ogs_details/_p_matmul_ogs.py | 10 +- .../matmul_ogs_details/_reduce_grouped.py | 102 ------- .../matmul_ogs_details/opt_flags.py | 4 - .../triton_kernels/triton_kernels/reduce.py | 168 ++++++----- .../triton_kernels/specialize.py | 51 ++++ 11 files changed, 262 insertions(+), 398 deletions(-) delete mode 100644 python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index 09f89b8723..e675051846 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -277,7 +277,8 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac # precision configs pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), weight_scale=wg_scale) - act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.0, 1.0), 2) + act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2), + (1.0, 1.0)) pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), weight_scale=w1_scale) pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), weight_scale=w2_scale) if rank == 0: diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index a260724061..3db95e91d2 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -130,8 +130,8 @@ def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, mode ) if weight_use_flexpoint else InFlexData(), out_data=OutFlexData( dtype=out_dtype, - expected_scale=make(4.00, 5.00, mode == "batched" or expt_is_inner), - actual_scale=make(0, 0, mode == "batched" or expt_is_inner), + expected_scale=make_scalar(4.00), + actual_scale=make_scalar(0), checksum_scale=None, ) if act_use_flexpoint else OutFlexData(), ) @@ -776,8 +776,8 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, precision_config=SwiGLUPrecisionConfig(swiglu_limit)) b = matmul_ogs( x, w, bias, rdata, gindx, sindx, precision_opt, - fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), - (swiglu_alpha, swiglu_limit), 2)) + fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2), + (swiglu_alpha, swiglu_limit))) except opt_flags.InapplicableConstraint: pytest.skip("inapplicable constraint") diff --git a/python/triton_kernels/tests/test_reduce.py b/python/triton_kernels/tests/test_reduce.py index f257f00f4d..9f64ffb20f 100644 --- a/python/triton_kernels/tests/test_reduce.py +++ b/python/triton_kernels/tests/test_reduce.py @@ -5,6 +5,7 @@ from triton_kernels.numerics_details.mxfp import upcast_from_mxfp_torch, downcast_to_mxfp_torch from triton_kernels.numerics import InFlexData, OutFlexData import triton +import triton.language as tl def init_mask(mask_mode, B, M, N, device): @@ -30,8 +31,9 @@ def dtype_str_to_torch(dtype_str: str) -> torch.dtype: @triton.jit -def plus_a(x, a): - return x + a +def plus_a_reduce(x, a): + y = x + a + return tl.sum(y.reshape([x.shape[0], x.shape[1] // 2, 2]), axis=2) @pytest.mark.parametrize("B, M, N, postprocess_fn", [ @@ -84,14 +86,15 @@ def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn): reduce(x, dim=dim, mask=mask, x_mxscale=x_mscale) return if postprocess_fn == "plus_ten": - postprocess_fn_tri = PostprocessFn(specs=FnSpecs("plus_a", plus_a, ("a", )), fn_args=(10, )) - postprocess_fn_ref = lambda x: x + 10 + postprocess_fn_tri = PostprocessFn(specs=FnSpecs("plus_a", plus_a_reduce, ("a", ), reduction_n=2), + fn_args=(10, )) + postprocess_fn_ref = lambda x: (x + 10).reshape([x.shape[0], x.shape[1] // 2, 2]).sum(dim=2) else: postprocess_fn_tri = postprocess_fn_ref = None y_tri, y_tri_mxscale = reduce(x, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_tri, - postprocess_fn=postprocess_fn_tri) + postprocess_fn1=postprocess_fn_tri) y_ref, y_ref_mxscale = reduce_torch(x, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_ref, - postprocess_fn=postprocess_fn_ref) + postprocess_fn1=postprocess_fn_ref) if is_mx: y_ref = upcast_from_mxfp_torch(y_ref, y_ref_mxscale, torch.float16, axis=-1) y_tri = upcast_from_mxfp_torch(y_tri, y_tri_mxscale, torch.float16, axis=-1) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index 5f5a3bf433..85291c667a 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -2,7 +2,6 @@ # fmt: off from dataclasses import dataclass, field import itertools -import sys import torch import triton from enum import Enum, auto @@ -14,12 +13,13 @@ # details from .matmul_ogs_details._matmul_ogs import _matmul_ogs from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn -from .matmul_ogs_details._reduce_grouped import _reduce_grouped from .numerics_details.mxfp import MXFP_BLOCK_SIZE from .tensor_details.layout_details.strided import StridedLayout -from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint -from .specialize import specialize +from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints +from .specialize import FnSpecs, SpecializationModule, ClosureArg from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor, RaggedTensorMetadata +from .reduce import reduce +from .reduce import PostprocessFn as ReducePostprocessFn @dataclass @@ -62,23 +62,10 @@ def n_blocks(self, n_rows, block_m): else: return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1 -@dataclass(frozen=True) -class FnSpecs: - name: str - fn: "triton.runtime.jit.JITFunction" - fn_arg_names: tuple[str] - fn_arg_do_not_specialize: tuple[str] = tuple() - - @staticmethod - def default(): - return FnSpecs("dflt", None, tuple()) - - @dataclass(frozen=True) class FusedActivation: specs: FnSpecs = FnSpecs.default() fn_args: tuple[object] = tuple() - reduction_n: int = 1 @dataclass(frozen=True) @@ -99,39 +86,13 @@ class FusedComm: reduce_rank: int = 0 n_reduce_shards: int = 1 -EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated - -_kernels = dict() - - -def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()): - global _kernels - key = (fused_activation.name, epilogue.name) - if key in _kernels: - return _kernels[key] - spec_constants = { - "ACTIVATION_FN": fused_activation.fn, - "EPILOGUE_FN": epilogue.fn, - } - spec_tuples = { - "activation_fn_args": fused_activation.fn_arg_names, - "epilogue_fn_args": epilogue.fn_arg_names, - } - do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize - import types - - module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}") - sys.modules[module.__name__] = module - module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples, - do_not_specialize=do_not_specialize) - module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples, - do_not_specialize=do_not_specialize) - module._reduce_grouped = specialize(_reduce_grouped, module, spec_constants, spec_tuples, - do_not_specialize=do_not_specialize) - _kernels[key] = module - return module - - +specializations = SpecializationModule("matmul_ogs", + kernels=[("_matmul_ogs", _matmul_ogs), ("_p_matmul_ogs", _p_matmul_ogs)], + closure_args={ + "epilogue": ClosureArg("EPILOGUE_FN", "epilogue_fn_args"), # + "activation": ClosureArg("ACTIVATION_FN", "activation_fn_args"), # + }, +) # ----------------------------------------------------------------------------- # Matrix Multiplication + Outer Gather/Scatter # ----------------------------------------------------------------------------- @@ -275,28 +236,29 @@ def init_allocation(x, w, precision_config, fused_activation, # if the activations are gathered, then M is number of gather indices if gather_indx is not None: M = gather_indx.src_indx.shape[0] - # final output - if routing_data.n_expts_act == 1 or scatter_indx is None: + if scatter_indx is not None: + M = scatter_indx.src_indx.shape[0] + if scatter_indx is None: y_rows = M else: - Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows - y_rows = Mc + y_rows = M // routing_data.n_expts_act y_rows *= n_reduce_shards if inner_routing_data is not None: batch_dim = inner_routing_data.base.n_expts_tot else: batch_dim = x.shape[0] if x.ndim == 3 else 1 - out_shape = (batch_dim, y_rows, N // fused_activation.reduction_n) + out_shape = (batch_dim, y_rows, N // fused_activation.specs.reduction_n) out_dtype = precision_config.out_dtype or x.dtype output = (out_shape, out_dtype) # ---- scratchpad -----# scratchpad = dict() - if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter): + N_scratch = N // fused_activation.specs.reduction_n if opt_flags.split_k == 1 else N + if opt_flags.split_k > 1 or scatter_indx is not None: scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype - scratchpad["matmul"] = ((opt_flags.split_k, batch_dim, M, N), scratch_out_dtype) + scratchpad["matmul"] = ((opt_flags.split_k, batch_dim, M, N_scratch), scratch_out_dtype) if "matmul" in scratchpad and precision_config.out_scale is not None: assert batch_dim == 1, "batch_dim > 1 not supported yet" - scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8) + scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N_scratch, MXFP_BLOCK_SIZE)), torch.uint8) return MatmulAllocation(x.device, output, scratchpad) def apply_allocation(allocation: MatmulAllocation, output): @@ -337,87 +299,6 @@ def _canonicalize_storage(storage, out_ndim, flex_data): new_storage_data = flex_data.reinterpret(new_storage_data) return Storage(new_storage_data, storage.layout) -# - -def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_mx_scale: torch.Tensor, - fused_activation, epilogue, - x_flex: InFlexData | None = None, - out_flex: OutFlexData | None = None, x_mx_scale: torch.Tensor | None = None, - out_dtype: bool = None, flexpoint_saturate_inf: bool = False): - """ - In-place grouped row reduction. - - Arguments - - x: Tensor[AnyFloat] of shape [(num_groups * K), N] - - indx: Tensor[Int] of shape [num_groups, K] - - Description - For each group g in [0, num_groups), this routine sums the K rows of `x` - specified by `indx[g, :]` and overwrites the row corresponding to the first - valid (non-negative) index with the per-group sum. Accumulation is performed - in float32 for numerical stability, and the result is written back in the - dtype of `x`. - - Behavior and edge cases - - Invalid (-1) entries are skipped during accumulation and do not generate - memory traffic. If a group has no valid entries, nothing is written for - that group. - - Reduction is performed tile-by-tile along the N dimension within a single - kernel launch (persistent along N) to minimize launch overhead. - - Performance notes - - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x), - plus index reads. With no invalid entries, this becomes (K + 1) reads/writes - of length N per group. - - Returns - - The input tensor `x` (modified in place). - """ - M = x.shape[2] # Only used for per-batch flex scale. - if indx is None and x.shape[0] == 1: - return x.squeeze(0), None - if indx is not None: - num_groups = indx.shape[0] - else: - # Handle batched matmul (K, B, M, N) by pretending it to be (K, 1, B*M, N). - x = x.view(x.shape[0], 1, x.shape[1] * x.shape[2], x.shape[3]) - num_groups = x.shape[-2] - if x_flex is None: - x_flex = InFlexData() - if out_flex is None: - out_flex = OutFlexData() - K = 1 if indx is None else indx.shape[1] - out_dtype = x.dtype if out_dtype is None else out_dtype - assert x.shape[-1] % fused_activation.reduction_n == 0 - BLOCK_N = 512 - # Resolve scalar flex scales (may be None) - x_expected_scale = None if x_flex is None else x_flex.scale - out_expected_scale = None if out_flex is None else out_flex.expected_scale - out_actual_scale = None if out_flex is None else out_flex.actual_scale - out_checksum_scale = None if out_flex is None else out_flex.checksum_scale - # Resolve MXFP output scale row stride - stride_mxb = 0 if x_mx_scale is None else x_mx_scale.stride(0) - stride_mxs = 0 if x_mx_scale is None else x_mx_scale.stride(1) - stride_omxs = 0 if out_mx_scale is None else out_mx_scale.stride(0) - kernels = get_kernels(epilogue.specs, fused_activation.specs) - kernels._reduce_grouped[(num_groups, )]( - x_flex.reinterpret(x), x.stride(0), x.stride(2), x.stride(3), # - x_expected_scale, # scalar input scale - out_flex.reinterpret(out), out.stride(1), out.stride(2), # - out_expected_scale, out_actual_scale, out_checksum_scale, - out_flex is not None and out_flex.is_per_batch, - indx, - x.shape[0], M, x.shape[-1], # - x_mx_scale, stride_mxb, stride_mxs, # - out_mx_scale, stride_omxs, # - *fused_activation.fn_args, fused_activation.reduction_n, - *epilogue.fn_arg_values_finalize, - HAS_IN_MX_SCALE=x_mx_scale is not None, HAS_OUT_MX_SCALE=out_mx_scale is not None, - FLEXPOINT_SATURATE_INF=flexpoint_saturate_inf, # - BLOCK_N=BLOCK_N, K=K, # - num_warps=1, # - ) - return out, out_mx_scale # ----------------------------------------------------------------------------- # Triton Implementation @@ -430,20 +311,20 @@ def matmul_ogs_set_idle_sms(num_idle_sms): update_opt_flags_constraints({"idle_sms": num_idle_sms}) def matmul_ogs(x, w, bias, - routing_data: RoutingData | None = None, - gather_indx: GatherIndx | None = None, - scatter_indx: ScatterIndx | None = None, - precision_config: PrecisionConfig | None = None, - betas: torch.Tensor | None = None, - gammas: torch.Tensor | None = None, - out_alpha: float | None = None, - y: torch.Tensor | None = None, - fused_comm: FusedComm | None = None, - fused_activation: FusedActivation | None = None, - epilogue: Epilogue | None = None, - y_acc_in: torch.Tensor | None = None, - inner_routing_data: InnerRoutingData | None = None, - ): + routing_data: RoutingData | None = None, + gather_indx: GatherIndx | None = None, + scatter_indx: ScatterIndx | None = None, + precision_config: PrecisionConfig | None = None, + betas: torch.Tensor | None = None, + gammas: torch.Tensor | None = None, + out_alpha: float | None = None, + y: torch.Tensor | None = None, + fused_comm: FusedComm | None = None, + fused_activation: FusedActivation | None = None, + epilogue: Epilogue | None = None, + y_acc_in: torch.Tensor | None = None, + inner_routing_data: InnerRoutingData | None = None, +): """ Y[:, :] = 0. for e in num_experts: @@ -475,7 +356,7 @@ def matmul_ogs(x, w, bias, if precision_config is None: precision_config = PrecisionConfig() if fused_activation is None: - fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1) + fused_activation = FusedActivation(FnSpecs.default(), tuple()) if epilogue is None: epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False) if routing_data is None: @@ -538,15 +419,12 @@ def matmul_ogs(x, w, bias, has_gather_tma = has_gather and target_info.has_tma_gather() # hopper w/ mxfp4 doesn't support TMA can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4) - can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1) opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config, batch_size, M, N, w.shape[-2], routing_data, - can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize, + can_use_tma, scatter_indx is not None, epilogue.effective_itemsize, x_transpose, y_acc_in is not None, inner_routing_data.block_k if inner_routing_data is not None else None, ) - if not can_use_fused_scatter and opt_flags.fused_scatter: - raise InapplicableConstraint("Fused scatter is not supported") if inner_routing_data is not None: assert opt_flags.block_k == inner_routing_data.block_k assert opt_flags.split_k == 1 @@ -572,7 +450,7 @@ def matmul_ogs(x, w, bias, # fused activation matmul_fused_activation = fused_activation reduce_fused_activation = FusedActivation() - if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter): + if opt_flags.split_k > 1: matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation # allocate output/scratchpad memory allocation = init_allocation(x, w, precision_config, fused_activation, @@ -614,8 +492,8 @@ def matmul_ogs(x, w, bias, max_grid = batch_size * grid_m * grid_n * opt_flags.split_k grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid # canonicalize storage - has_scatter_tma = opt_flags.fused_scatter and target_info.has_tma_gather() - y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if opt_flags.fused_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:])) + has_scatter_tma = scatter_indx is not None and target_info.has_tma_gather() + y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if has_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:])) x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data) w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data) y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data) @@ -635,10 +513,10 @@ def matmul_ogs(x, w, bias, x_tensor_or_tma = x_storage.make_tma(x_tma_block_size, x_tma_mode) if x_has_tma else x_storage.data # create tma descriptor for y y_has_tma = ( - opt_flags.is_persistent and (has_scatter_tma or not opt_flags.fused_scatter) - and (y_acc_in is None or y_acc_is_y) and fused_comm is None + opt_flags.is_persistent and (scatter_indx is None or has_scatter_tma) + and (y_acc_in is None or y_acc_is_y) ) - block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.reduction_n + block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.specs.reduction_n y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n] y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense" y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data @@ -666,17 +544,21 @@ def matmul_ogs(x, w, bias, out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None) out_matmul_scale_strides = (0, ) * (4 - len(out_matmul_scale_strides)) + out_matmul_scale_strides # launch kernel - kernels = get_kernels(epilogue.specs, matmul_fused_activation.specs) + kernels = specializations.get(epilogue=epilogue.specs, activation=matmul_fused_activation.specs) # When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed # (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs. # w_transpose = w_storage.data.stride()[-1] != 1 + if gather_indx is not None: + gather_src_indx = torch.div(gather_indx.src_indx, routing_data.n_expts_act, rounding_mode='trunc') fused_comm_kwargs = { "pYPtrs": fused_comm.out_handles, "ScatterShardIndx": fused_comm.scatter_shard_indx, "reduce_rank": fused_comm.reduce_rank, "n_reduce_shards": fused_comm.n_reduce_shards, } if fused_comm is not None else {} + # if routing_data.n_expts_act > 1: + # y_storage.data.view(torch.uint8).zero_() (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)]( y_tensor_or_tma, y_storage.data, *out_matmul.stride(), *((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex), @@ -693,18 +575,18 @@ def matmul_ogs(x, w, bias, x.shape[-2] if routing_data.expt_hist is None else None, N, K, K_W, betas, gammas, - None if gather_indx is None else gather_indx.src_indx, + None if gather_indx is None else gather_src_indx, None if gather_indx is None else gather_indx.dst_indx, # Only for launch_metadata None if scatter_indx is None else scatter_indx.src_indx, num_indx, - None if not opt_flags.fused_scatter else scatter_indx.dst_indx, - None if not opt_flags.fused_scatter else scatter_indx.dst_indx.shape[0], + None if scatter_indx is None else scatter_indx.dst_indx, + None if scatter_indx is None else scatter_indx.dst_indx.shape[0], *expt_data_args, batch_size, grid_m, grid_n, out_alpha, - *matmul_fused_activation.fn_args, matmul_fused_activation.reduction_n, + *matmul_fused_activation.fn_args, matmul_fused_activation.specs.reduction_n, *epilogue.fn_arg_values_matmul, - routing_data.n_expts_tot, routing_data.n_expts_act, + routing_data.n_expts_tot, precision_config.max_num_imprecise_acc, precision_config.allow_tf32, precision_config.flexpoint_saturate_inf, @@ -715,6 +597,7 @@ def matmul_ogs(x, w, bias, opt_flags.block_n, opt_flags.block_k, opt_flags.group_m, + INIT_OUTPUT_TO_ZERO=routing_data.n_expts_act == 1, XCD_SWIZZLE=opt_flags.xcd_swizzle, SWIZZLE_MX_VALUE=w.storage.layout.name, SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name, @@ -734,21 +617,50 @@ def matmul_ogs(x, w, bias, NUM_SMS = grid if opt_flags.is_persistent else 0, **fused_comm_kwargs, **opt_flags.target_kernel_kwargs) - # Build grouped reduction inputs in a uniform way - group_indx = None if scatter_indx is None or opt_flags.fused_scatter else scatter_indx.src_indx.view(-1, routing_data.n_expts_act) - out_final, out_final_mx_scale = reduce_grouped( - out_matmul, - group_indx, - memory["output"].squeeze(0), - precision_config.out_scale, - reduce_fused_activation, - epilogue, - x_flex=InFlexData(dtype=out_matmul_flex.dtype, scale=out_matmul_flex.expected_scale), - out_flex=precision_config.flex_ctx.out_data, - x_mx_scale=out_matmul_scale.squeeze(1) if out_matmul_has_mx else None, - out_dtype=memory["output"].dtype, - flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf, - ) + + out_final_mx_scale = None + if opt_flags.split_k > 1: + assert not out_matmul_has_mx + has_scatter = scatter_indx is not None + postprocess_fn1 = ReducePostprocessFn(specs=reduce_fused_activation.specs, fn_args=reduce_fused_activation.fn_args) + postprocess_fn2 = None if has_scatter else ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize) + y, y_mx_scale = reduce( + x = out_matmul.view(out_matmul.shape[0], -1, out_matmul.shape[-1]), + dim = 0, + # output data/metadata + y = None if has_scatter else memory["output"].view(-1, memory["output"].shape[-1]), + y_dtype = out_matmul.dtype if has_scatter else memory["output"].dtype, + y_flex = OutFlexData() if has_scatter else precision_config.flex_ctx.out_data, + y_flex_saturate_inf = None if has_scatter else precision_config.flexpoint_saturate_inf, + y_has_mx = scatter_indx is None and precision_config.out_scale is not None, + # fused functions + postprocess_fn1 = postprocess_fn1, + postprocess_fn2 = postprocess_fn2, + ) + y_shape = out_matmul.shape[1:-1] + (out_matmul.shape[-1] // reduce_fused_activation.specs.reduction_n,) + out_matmul = y.view(*y_shape).unsqueeze(0) + if y_mx_scale is not None: + out_final_mx_scale = y_mx_scale.view(out_matmul.shape[-2], triton.cdiv(out_matmul.shape[-1], 32)) + # TODO: change `matmul_ogs` semantics and move this to another op! + if scatter_indx is not None: + mask = (scatter_indx.src_indx != -1).view(out_matmul.shape[-2]//routing_data.n_expts_act, routing_data.n_expts_act, 1) + out_matmul = out_matmul.view(out_matmul.shape[-2]//routing_data.n_expts_act, routing_data.n_expts_act, -1) + mask = mask.expand_as(out_matmul) + out_matmul_scale_shape = out_matmul.shape[:-1] + (triton.cdiv(out_matmul.shape[-1], 32),) + postprocess_fn = ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize) + x_flex = InFlexData(dtype=out_matmul_flex.dtype, scale=out_matmul_flex.expected_scale) + out_final, out_final_mx_scale = reduce(out_matmul, dim=1, postprocess_fn2=postprocess_fn, x_flex=x_flex, # + mask=mask, + y=memory["output"].squeeze(0).squeeze(0), + x_mxscale=out_matmul_scale.view(*out_matmul_scale_shape) if out_matmul_has_mx else None, + y_has_mx=precision_config.out_scale is not None, + y_flex=precision_config.flex_ctx.out_data, + y_flex_saturate_inf=precision_config.flexpoint_saturate_inf, + ) + out_final = out_final.unsqueeze(0) + else: + out_final = out_matmul.squeeze(0) + if not (is_input_batched or inner_routing_data is not None): out_final = out_final.squeeze(0) if out_final_mx_scale is not None: diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py index 1d71b9bc27..2e30113bf7 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py @@ -238,18 +238,12 @@ def matmul_launch_metadata(grid, kernel, args): fM = M if M is not None else n_tokens ret[f"flops{nbits}"] = 2.0 * fM * N * K * (1 if expt_is_inner else batch_size) - dst = args.get("GatherDstIndx", None) # sindx = args.get("WriteBackIndx", None) n_x_bytes = X.numel() * X.element_size() n_y_bytes = Y.numel() * Y.element_size() if hist is not None: assert n_tokens is not None - n_expts_act = args["N_EXPTS_ACT"] - - if (dst is not None) and launch_metadata_allow_sync(): - n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum() - else: - n_read_rows = n_tokens + n_read_rows = n_tokens if expt_is_inner: n_x_bytes = n_read_rows * X.shape[-2] * X.element_size() diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py index 5b671cb11b..8c6f3ca1f5 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py @@ -71,7 +71,7 @@ def _matmul_ogs( # epilogue transform EPILOGUE_FN: tl.constexpr, epilogue_fn_args, # MoE config - N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr, + N_EXPTS_TOT: tl.constexpr, # precision config MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr, FLEXPOINT_SATURATE_INF: tl.constexpr, @@ -81,6 +81,7 @@ def _matmul_ogs( # optimization config BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr, + INIT_OUTPUT_TO_ZERO: tl.constexpr, # One of ["HOPPER", "BLACKWELL", None] SWIZZLE_MX_VALUE: tl.constexpr, # One of ["HOPPER", "BLACKWELL", None] @@ -198,7 +199,7 @@ def _matmul_ogs( # We are tiling Y here, so the tiling is independent of matmul (where we # tile X & W and scatter to different rows of Y). # TODO: refactor (same code in _p_matmul_ogs) - if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1: + if HAS_FUSED_SCATTER and INIT_OUTPUT_TO_ZERO: tl.device_assert(batch_size == 1) pid_mnk = pid if XCD_SWIZZLE != 1: @@ -241,7 +242,7 @@ def _matmul_ogs( else: GatherIndx += start_m # no needs to bounds-check here because `offs_x_m` wraps around M dim - offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT + offs_x_m = tl.load(GatherIndx + offs_x_m) offs_k = off_k_x + tl.arange(0, BLOCK_K) XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k @@ -455,7 +456,7 @@ def _matmul_ogs( YActualScale += start_m * stride_y_mx_m YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n else: - YActualScalePtrs = YActualScale + (offs_y_m - num_idxs // N_EXPTS_ACT).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n + YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :]) else: if PER_BATCH_OUT_SCALE: diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py index 3c9f05f499..6d24f73329 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py @@ -80,7 +80,7 @@ def _p_matmul_ogs( # epilogue transform EPILOGUE_FN: tl.constexpr, epilogue_fn_args, # MoE config - N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr, + N_EXPTS_TOT: tl.constexpr, # precision config MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr, FLEXPOINT_SATURATE_INF: tl.constexpr, @@ -90,6 +90,7 @@ def _p_matmul_ogs( # optimization config BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr, + INIT_OUTPUT_TO_ZERO: tl.constexpr, # NYI: Must be None SWIZZLE_MX_VALUE: tl.constexpr, # One of ["BLACKWELL", None] @@ -172,7 +173,7 @@ def _p_matmul_ogs( yN = N // ACTIVATION_REDUCTION_N # set masked out rows to 0 - if HAS_SCATTER and N_EXPTS_ACT == 1: + if HAS_SCATTER and INIT_OUTPUT_TO_ZERO: # Iterate with reversed pids so that later pids will get more tiles if the number of # tiles isn't evenly divisible by the number of SMs. # The main loop after this iterates in the forward direction such that earlier @@ -233,15 +234,14 @@ def _p_matmul_ogs( offs_x_m += start_z * (stride_x_z // stride_x_m) offs_x_m = tl.where(mask_m, offs_x_m, -1) else: - offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, - mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT + offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m, other=-1) elif X_TMA_MODE is None or is_x_microscaled: offs_m = off_m + tl.arange(0, BLOCK_M) offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M) # no needs to bounds-check here because `offs_m` wraps around M dim if GatherIndx is not None: tl.static_assert(HAS_GATHER) - offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT + offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m if is_x_microscaled: diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py deleted file mode 100644 index b6eac14b2f..0000000000 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py +++ /dev/null @@ -1,102 +0,0 @@ -from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale -from triton_kernels.numerics_details.mxfp import quantize_mxfp8_fn -import triton -import triton.language as tl - - -@triton.jit -def _reduce_grouped(X, stride_xb: tl.uint64, stride_xm: tl.uint64, stride_xn, # - XScale, # input scalar flex scale - Out, stride_om: tl.uint64, stride_on, # output tensor - OutExpectedScale, OutActualScale, OutChecksumScale, # output scalar flex scales - PER_BATCH_OUT_SCALE: tl.constexpr, InIndx, B, M, N, # - XMxScale, stride_mxb: tl.uint64, - stride_mxs: tl.uint64, # optional per-32-col output MXFP scales (uint8) - OutMxScale, stride_omxs: tl.uint64, # optional per-32-col output MXFP scales (uint8) - # fused activation function - ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr, - # epilogue transform - EPILOGUE_FN: tl.constexpr, epilogue_fn_args, - # - HAS_IN_MX_SCALE: tl.constexpr, HAS_OUT_MX_SCALE: tl.constexpr, FLEXPOINT_SATURATE_INF: tl.constexpr, - K: tl.constexpr, BLOCK_N: tl.constexpr): - pid_t = tl.program_id(0) - BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N - # persistent along N: single program on N, iterate tiles of size BLOCK_N - start = pid_t * K - # load indices into a tuple - if InIndx is None: - indxs = (pid_t, ) - else: - indxs = () - for i in tl.static_range(0, K): - indxs = indxs + (tl.load(InIndx + start + i), ) - # determine first valid topk row - fi = indxs[(K - 1)] - for i in tl.static_range(K - 2, -1, -1): - fi = tl.where(indxs[i] != -1, indxs[i], fi) - # record overwritten row index (may be -1 if none) - XPtrs = X + tl.arange(0, BLOCK_N) * stride_xn - OutPtrs = Out + tl.arange(0, BLOCK_N_OUT) * stride_on - if HAS_IN_MX_SCALE: - XScalePtrs = XMxScale + tl.arange(0, BLOCK_N // 32) * stride_xn - if HAS_OUT_MX_SCALE: - OutScalePtrs = OutMxScale + tl.arange(0, BLOCK_N_OUT // 32) * stride_on - if PER_BATCH_OUT_SCALE: - out_batch_idx = pid_t // M - OutExpectedScale += out_batch_idx - OutActualScale += out_batch_idx - if OutChecksumScale is not None: - OutChecksumScale += out_batch_idx - x_scale = load_scale(XScale) - for n_curr in tl.range(0, N, BLOCK_N, num_stages=4): - acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32) - x_n_mask = tl.arange(0, BLOCK_N) < N - n_curr - x_n_mask_scale = tl.arange(0, BLOCK_N // 32) < tl.cdiv(N - n_curr, 32) - # accumulate contributions for this tile - for i in tl.static_range(0, K): - curr = tl.zeros([BLOCK_N], dtype=tl.float32) - # iterate over split_k partial values - for b in tl.range(0, B): - is_valid = indxs[i] != -1 - x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb - vals = tl.load(x_row_ptr, mask=x_n_mask & is_valid, other=0.0) - vals = vals.to(tl.float32) - if HAS_IN_MX_SCALE: - scale_row_ptr = XScalePtrs + indxs[i] * stride_mxs + b * stride_mxb - scale = tl.load(scale_row_ptr, mask=x_n_mask_scale & is_valid, other=0.) - scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) - vals = vals.reshape([BLOCK_N // 32, 32]) - vals = (scale[:, None] * vals).reshape([BLOCK_N]) - curr += vals - # apply nonlinearity to split-k output - if ACTIVATION_FN is not None: - curr = ACTIVATION_FN(curr[None, :], *activation_fn_args) - curr = tl.reshape(curr, [curr.shape[-1]]) - # update final accumulator - acc += curr - acc *= x_scale - # Compute per-32-col MXFP scales for this tile if requested - Nrem = (N - n_curr) // ACTIVATION_REDUCTION_N - out_n_mask = tl.arange(0, BLOCK_N_OUT) < Nrem - out_n_mask_scale = tl.arange(0, BLOCK_N_OUT // 32) < tl.cdiv(Nrem, 32) - if HAS_OUT_MX_SCALE: - acc, acc_scale = quantize_mxfp8_fn(acc[None, :], out_n_mask[None, :]) - acc = tl.reshape(acc, [acc.shape[-1]]) - acc_scale = tl.reshape(acc_scale, [acc_scale.shape[-1]]) - # Convert to flexpoint output if configured (scalar scales) - acc = float_to_flex(acc, OutExpectedScale, OutActualScale, OutChecksumScale, None, Out, FLEXPOINT_SATURATE_INF) - if not HAS_OUT_MX_SCALE and EPILOGUE_FN is not None: - acc = EPILOGUE_FN(acc, *epilogue_fn_args, target_dtype=Out.dtype.element_ty) - # write-back for this tile - out_ptr = OutPtrs + pid_t * stride_om - tl.store(out_ptr, acc, mask=out_n_mask) - if HAS_OUT_MX_SCALE: - out_scale_ptr = OutScalePtrs + pid_t * stride_omxs - tl.store(out_scale_ptr, acc_scale, mask=out_n_mask_scale) - XPtrs += BLOCK_N * stride_xn - OutPtrs += BLOCK_N_OUT * stride_on - if HAS_IN_MX_SCALE: - XScalePtrs += BLOCK_N // 32 * stride_xn - if HAS_OUT_MX_SCALE: - OutScalePtrs += BLOCK_N_OUT // 32 * stride_xn diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index e6c7b9857c..31f37ba444 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -28,10 +28,6 @@ class OptFlags: arch: str target_kernel_kwargs: dict - def __post_init__(self): - if self.fused_scatter and self.split_k != 1: - raise ValueError("Not supported") - def max_allowable_mn( max_mn: int, diff --git a/python/triton_kernels/triton_kernels/reduce.py b/python/triton_kernels/triton_kernels/reduce.py index c712a13536..4d64173a9d 100644 --- a/python/triton_kernels/triton_kernels/reduce.py +++ b/python/triton_kernels/triton_kernels/reduce.py @@ -6,23 +6,7 @@ from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale from triton_kernels.numerics import InFlexData, OutFlexData, MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 from typing import Optional -import types -import sys -from .specialize import specialize - -_kernels = dict() - - -@dataclass(frozen=True) -class FnSpecs: - name: str - fn: "triton.runtime.jit.JITFunction" - fn_arg_names: tuple[str] - fn_arg_do_not_specialize: tuple[str] = tuple() - - @staticmethod - def default(): - return FnSpecs("dflt", None, tuple()) +from .specialize import SpecializationModule, ClosureArg, FnSpecs @dataclass(frozen=True) @@ -31,30 +15,17 @@ class PostprocessFn: fn_args: tuple[object] = tuple() -def get_kernels(fn_specs: FnSpecs = FnSpecs.default()): - global _kernels - key = (fn_specs.name, ) - if key in _kernels: - return _kernels[key] - spec_constants = {"POSTPROCESS_FN": fn_specs.fn} - spec_tuples = {"postprocess_fn_args": fn_specs.fn_arg_names} - do_not_specialize = fn_specs.fn_arg_do_not_specialize - module = types.ModuleType(f"reduce{'_'.join(key)}") - sys.modules[module.__name__] = module - module._reduce = specialize(_reduce, module, spec_constants, spec_tuples, do_not_specialize=do_not_specialize) - _kernels[key] = module - return module - - @triton.jit -def _reduce(X, stride_xr, stride_x0, stride_x1, # x tensor (input) +def _reduce(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x tensor (input) XMx, stride_xmxr, stride_xmx0, stride_xmx1, # x mx scale - Y, stride_y0, stride_y1, # y tensor (output) + Y, stride_y0: tl.int64, stride_y1, # y tensor (output) YMx, stride_ymx0, stride_ymx1, # y mx scale Mask, stride_mr, stride_m0, stride_m1, # mask tensor Scale, stride_sr, stride_s0, stride_s1, # scale tensor - K, S0, S1, # shape (K = reduction dim; S0, S1 = output dims) - POSTPROCESS_FN: tl.constexpr, postprocess_fn_args, XFlex, # x flex (global) scale + K, S0, X_S1, Y_S1, # shape (K = reduction dim; S0, IN_S1 = input dims, OUT_S1 = output dims) + POSTPROCESS_FN1: tl.constexpr, postprocess_fn1_args, # + POSTPROCESS_FN2: tl.constexpr, postprocess_fn2_args, # + XFlex, # x flex (global) scale YFlexExpected, YFlexActual, YFlexChecksum, Y_FLEX_SATURATE_INF: tl.constexpr, # y flex (global) scale IS_MASK_NONE: tl.constexpr, # BROADCAST_R: tl.constexpr, # @@ -65,54 +36,73 @@ def _reduce(X, stride_xr, stride_x0, stride_x1, # x tensor (input) SCALE_BROADCAST_S0: tl.constexpr, # SCALE_BROADCAST_S1: tl.constexpr, # BLOCK_S0: tl.constexpr, # - BLOCK_S1: tl.constexpr, # + BLOCK_X_S1: tl.constexpr, # + BLOCK_Y_S1: tl.constexpr, # ): pid_s0 = tl.program_id(0) pid_s1 = tl.program_id(1) - tl.static_assert(BLOCK_S1 % 32 == 0) - BLOCK_SMX1: tl.constexpr = BLOCK_S1 // 32 + tl.static_assert(BLOCK_X_S1 % 32 == 0) + BLOCK_X_SMX1: tl.constexpr = BLOCK_X_S1 // 32 + BLOCK_Y_SMX1: tl.constexpr = BLOCK_Y_S1 // 32 offs_s0 = pid_s0 * BLOCK_S0 + tl.arange(0, BLOCK_S0) - offs_s1 = pid_s1 * BLOCK_S1 + tl.arange(0, BLOCK_S1) - offs_smx1 = pid_s1 * BLOCK_SMX1 + tl.arange(0, BLOCK_SMX1) + offs_x_s1 = pid_s1 * BLOCK_X_S1 + tl.arange(0, BLOCK_X_S1) + offs_x_smx1 = pid_s1 * BLOCK_X_SMX1 + tl.arange(0, BLOCK_X_SMX1) valid_s0 = offs_s0 < S0 - valid_s1 = offs_s1 < S1 - valid_smx1 = offs_smx1 < tl.cdiv(S1, 32) - y = tl.zeros((BLOCK_S0, BLOCK_S1), dtype=tl.float32) + valid_x_s1 = offs_x_s1 < X_S1 + valid_in_smx1 = offs_x_smx1 < tl.cdiv(X_S1, 32) + y = tl.zeros((BLOCK_S0, BLOCK_X_S1), dtype=tl.float32) x_flex_scale = load_scale(XFlex) for k in tl.range(0, K, num_stages=2): - x_ptrs = X + k * stride_xr + offs_s0[:, None] * stride_x0 + offs_s1[None, :] * stride_x1 - x = tl.load(x_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=0.0) + x_ptrs = X + k * stride_xr + offs_s0[:, None] * stride_x0 + offs_x_s1[None, :] * stride_x1 + x = tl.load(x_ptrs, mask=valid_s0[:, None] & valid_x_s1[None, :], other=0.0) x = x.to(tl.float32) if XMx is not None: - xmx_ptrs = XMx + k * stride_xmxr + offs_s0[:, None] * stride_xmx0 + offs_smx1[None, :] * stride_xmx1 - xmx = tl.load(xmx_ptrs, mask=valid_s0[:, None] & valid_smx1[None, :], other=0.0) + xmx_ptrs = XMx + k * stride_xmxr + offs_s0[:, None] * stride_xmx0 + offs_x_smx1[None, :] * stride_xmx1 + xmx = tl.load(xmx_ptrs, mask=valid_s0[:, None] & valid_in_smx1[None, :], other=0.0) xmx = (xmx.to(tl.uint32) << 23).to(tl.float32, bitcast=True) - x = (xmx[:, :, None] * x.reshape([BLOCK_S0, BLOCK_S1 // 32, 32])).reshape([BLOCK_S0, BLOCK_S1]) + x = (xmx[:, :, None] * x.reshape([BLOCK_S0, BLOCK_X_S1 // 32, 32])).reshape([BLOCK_S0, BLOCK_X_S1]) x = x * x_flex_scale if not IS_SCALE_NONE: k_term_s = 0 if SCALE_BROADCAST_R else (k * stride_sr) s0_term_s = 0 if SCALE_BROADCAST_S0 else (offs_s0[:, None] * stride_s0) - s1_term_s = 0 if SCALE_BROADCAST_S1 else (offs_s1[None, :] * stride_s1) + s1_term_s = 0 if SCALE_BROADCAST_S1 else (offs_x_s1[None, :] * stride_s1) s_ptrs = Scale + k_term_s + s0_term_s + s1_term_s - s = tl.load(s_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=1) + s = tl.load(s_ptrs, mask=valid_s0[:, None] & valid_x_s1[None, :], other=1) x = x * s if not IS_MASK_NONE: k_term = 0 if BROADCAST_R else (k * stride_mr) s0_term = 0 if BROADCAST_S0 else (offs_s0[:, None] * stride_m0) - s1_term = 0 if BROADCAST_S1 else (offs_s1[None, :] * stride_m1) + s1_term = 0 if BROADCAST_S1 else (offs_x_s1[None, :] * stride_m1) m_ptrs = Mask + k_term + s0_term + s1_term - m = tl.load(m_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=1) + m = tl.load(m_ptrs, mask=valid_s0[:, None] & valid_x_s1[None, :], other=1) x = tl.where(m != 0, x, 0.0) y += x - if POSTPROCESS_FN is not None: - y = POSTPROCESS_FN(y, *postprocess_fn_args) + if POSTPROCESS_FN1 is not None: + y = POSTPROCESS_FN1(y, *postprocess_fn1_args) + offs_y_s1 = pid_s1 * BLOCK_Y_S1 + tl.arange(0, BLOCK_Y_S1) + offs_y_smx1 = pid_s1 * BLOCK_Y_SMX1 + tl.arange(0, BLOCK_Y_SMX1) + valid_y_s1 = offs_y_s1 < Y_S1 + valid_y_smx1 = offs_y_smx1 < tl.cdiv(Y_S1, 32) y = float_to_flex(y, YFlexExpected, YFlexActual, YFlexChecksum, None, Y, Y_FLEX_SATURATE_INF) - y_ptrs = Y + offs_s0[:, None] * stride_y0 + offs_s1[None, :] * stride_y1 + # TODO (phil): keeping for backward compatibility, but will remove ! + if YMx is None and POSTPROCESS_FN2 is not None: + y = POSTPROCESS_FN2(y, *postprocess_fn2_args, target_dtype=Y.dtype.element_ty) + y_ptrs = Y + offs_s0[:, None] * stride_y0 + offs_y_s1[None, :] * stride_y1 if YMx is not None: - y, y_scale = quantize_mxfp8_fn(y, valid_s1[None, :]) - y_mx_ptrs = YMx + offs_s0[:, None] * stride_ymx0 + offs_smx1[None, :] * stride_ymx1 - tl.store(y_mx_ptrs, y_scale, mask=valid_s0[:, None] & valid_smx1[None, :]) - tl.store(y_ptrs, y, mask=valid_s0[:, None] & valid_s1[None, :]) + y, y_scale = quantize_mxfp8_fn(y, valid_y_s1[None, :]) + y_mx_ptrs = YMx + offs_s0[:, None] * stride_ymx0 + offs_y_smx1[None, :] * stride_ymx1 + tl.store(y_mx_ptrs, y_scale, mask=valid_s0[:, None] & valid_y_smx1[None, :]) + tl.store(y_ptrs, y, mask=valid_s0[:, None] & valid_y_s1[None, :]) + + +specializations = SpecializationModule( + "reduce", + kernels=[("_reduce", _reduce)], + closure_args={ + "postprocess_fn1": ClosureArg("POSTPROCESS_FN1", "postprocess_fn1_args"), + "postprocess_fn2": ClosureArg("POSTPROCESS_FN2", "postprocess_fn2_args"), + }, +) def reduce( @@ -122,9 +112,14 @@ def reduce( scale: Optional[torch.Tensor] = None, x_mxscale: Optional[torch.Tensor] = None, x_flex: Optional[InFlexData] = InFlexData(), + y_dtype: Optional[torch.dtype] = None, y_flex: Optional[OutFlexData] = OutFlexData(), y_flex_saturate_inf: bool = False, - postprocess_fn: Optional[PostprocessFn] = None, + y_has_mx: Optional[bool] = None, + y: Optional[torch.Tensor] = None, + postprocess_fn1: Optional[PostprocessFn] = None, + # TODO: keeping for backward compatibility, but will remove ! + postprocess_fn2: Optional[PostprocessFn] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Performs a reduction over the specified dimension of the input tensor, @@ -163,20 +158,29 @@ def reduce( assert triton.cdiv(x.shape[-1], 32) * 32 == x_mxscale.shape[-1] * 32 assert dim != -1 # assert not y_flex.is_per_batch - if postprocess_fn is None: - postprocess_fn = PostprocessFn() + if postprocess_fn1 is None: + postprocess_fn1 = PostprocessFn() + if postprocess_fn2 is None: + postprocess_fn2 = PostprocessFn() + if y_dtype is None: + y_dtype = x.dtype if y_flex is None: y_flex = OutFlexData() if x_flex is None: x_flex = InFlexData() + if y_has_mx is None: + y_has_mx = x_mxscale is not None # input shapes dims = (0, 1, 2) nonred = tuple(d for d in dims if d != dim) - S0, S1 = x.shape[nonred[0]], x.shape[nonred[1]] - y = torch.empty((S0, S1), device=x.device, dtype=x.dtype) + S0, X_S1 = x.shape[nonred[0]], x.shape[nonred[1]] + Y_S1 = X_S1 // postprocess_fn1.specs.reduction_n + if y is None: + y = torch.empty((S0, Y_S1), device=x.device, dtype=y_dtype) + assert y.shape == (S0, Y_S1), f"y.shape: {y.shape} != ({S0}, {Y_S1})" y_mxscale = None - if x_mxscale is not None: - y_mxscale = torch.empty((S0, triton.cdiv(S1, 32)), device=x.device, dtype=x_mxscale.dtype) + if y_has_mx: + y_mxscale = torch.empty((S0, triton.cdiv(Y_S1, 32)), device=x.device, dtype=torch.uint8) # Strides for X along reduced and non-reduced dims stride_xr = x.stride(dim) stride_x0 = x.stride(nonred[0]) @@ -207,20 +211,23 @@ def reduce( K = x.shape[dim] # Always use the 2D tiled kernel with constexpr metaprogramming for mask broadcasting BLOCK_S0 = 64 - BLOCK_S1 = 128 - grid = (triton.cdiv(S0, BLOCK_S0), triton.cdiv(S1, BLOCK_S1)) - mask_arg = mask if mask is not None else x - scale_arg = scale if scale is not None else x - reduce_kernel = get_kernels(postprocess_fn.specs)._reduce + BLOCK_X_S1 = 128 + BLOCK_Y_S1 = 128 // postprocess_fn1.specs.reduction_n + grid = (triton.cdiv(S0, BLOCK_S0), triton.cdiv(Y_S1, BLOCK_Y_S1)) + mask_arg = mask if mask is not None else None + scale_arg = scale if scale is not None else None + reduce_kernel = specializations.get(postprocess_fn1=postprocess_fn1.specs, + postprocess_fn2=postprocess_fn2.specs)._reduce reduce_kernel[grid]( - x, stride_xr, stride_x0, stride_x1, # + x_flex.reinterpret(x), stride_xr, stride_x0, stride_x1, # x_mxscale, stride_xmxr, stride_xmx0, stride_xmx1, # - y, y.stride(0), y.stride(1), # + y_flex.reinterpret(y), y.stride(0), y.stride(1), # y_mxscale, stride_ymx0, stride_ymx1, # mask_arg, stride_mr, stride_m0, stride_m1, # scale_arg, stride_sr, stride_s0, stride_s1, # - K, S0, S1, # - *postprocess_fn.fn_args, x_flex.scale, y_flex.expected_scale, y_flex.actual_scale, y_flex.checksum_scale, + K, S0, X_S1, Y_S1, # + *postprocess_fn1.fn_args, *postprocess_fn2.fn_args, # + x_flex.scale, y_flex.expected_scale, y_flex.actual_scale, y_flex.checksum_scale, # y_flex_saturate_inf, # IS_MASK_NONE=(mask is None), # BROADCAST_R=(stride_mr == 0), # @@ -231,7 +238,8 @@ def reduce( SCALE_BROADCAST_S0=(stride_s0 == 0), # SCALE_BROADCAST_S1=(stride_s1 == 0), # BLOCK_S0=BLOCK_S0, # - BLOCK_S1=BLOCK_S1, # + BLOCK_X_S1=BLOCK_X_S1, # + BLOCK_Y_S1=BLOCK_Y_S1, # num_warps=4 # ) return y, y_mxscale @@ -251,7 +259,7 @@ def reduce_torch(x: torch.Tensor, dim: int, mask: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, # x_mxscale: Optional[torch.Tensor] = None, # x_flex: Optional[InFlexData] = InFlexData(), y_flex: Optional[OutFlexData] = OutFlexData(), - y_flex_saturate_inf: bool = False, postprocess_fn: Optional[callable] = None): + y_flex_saturate_inf: bool = False, postprocess_fn1: Optional[callable] = None): from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch, upcast_from_mxfp_torch x_dtype = x.dtype # upcast input @@ -269,8 +277,8 @@ def reduce_torch(x: torch.Tensor, dim: int, mask: Optional[torch.Tensor] = None, mask = torch.ones(1, dtype=torch.bool, device=x.device) mask = mask.to(torch.bool) ret = torch.where(mask, x * scale, 0).sum(dim=dim) - if postprocess_fn is not None: - ret = postprocess_fn(ret) + if postprocess_fn1 is not None: + ret = postprocess_fn1(ret) if y_flex is not None: y_flex.actual_scale.copy_(compute_actual_scale(ret, x_dtype, y_flex.is_per_batch)) ret = (ret / y_flex.expected_scale).to(x_dtype) diff --git a/python/triton_kernels/triton_kernels/specialize.py b/python/triton_kernels/triton_kernels/specialize.py index 4d116123a0..bf0cf00a33 100644 --- a/python/triton_kernels/triton_kernels/specialize.py +++ b/python/triton_kernels/triton_kernels/specialize.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import inspect import re import textwrap @@ -62,6 +63,19 @@ def _empty_fn(): return f +@dataclass(frozen=True) +class FnSpecs: + name: str + fn: "triton.runtime.jit.JITFunction" + fn_arg_names: tuple[str] + fn_arg_do_not_specialize: tuple[str] = tuple() + reduction_n: int = 1 + + @staticmethod + def default(): + return FnSpecs("dflt", None, tuple()) + + def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()): assert isinstance(fn, triton.runtime.jit.JITFunction) if name is None: @@ -149,3 +163,40 @@ def new_repr(specialization): co_firstlineno=max(1, orig_code.co_firstlineno - line_delta), ) return ret + + +@dataclass(frozen=True) +class ClosureArg: + fn_name: str + fn_params_name: str + + +class SpecializationModule: + + def __init__(self, module_name: str, kernels: list[tuple[str, object]], closure_args: dict[str, ClosureArg]): + self.module_name = module_name + self.kernels = kernels + self.closure_args = closure_args + self._modules = dict() + + def get(self, **kwargs): + import types + import sys + specs = [FnSpecs.default()] * len(self.closure_args) + for key, value in kwargs.items(): + specs[list(self.closure_args.keys()).index(key)] = value + key = tuple(spec.name for spec in specs) + if key in self._modules: + return self._modules[key] + spec_constants = {arg.fn_name: spec.fn for arg, spec in zip(self.closure_args.values(), specs)} + spec_tuples = {arg.fn_params_name: spec.fn_arg_names for arg, spec in zip(self.closure_args.values(), specs)} + do_not_specialize = [] + for spec in specs: + do_not_specialize.extend(spec.fn_arg_do_not_specialize) + module = types.ModuleType(self.module_name + '_'.join(key)) + sys.modules[module.__name__] = module + for kernel_name, kernel_fn in self.kernels: + setattr(module, kernel_name, + specialize(kernel_fn, module, spec_constants, spec_tuples, do_not_specialize=do_not_specialize)) + self._modules[key] = module + return module From b6201360e9acb5ae2d4e423c2607cb21bf104069 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Wed, 29 Oct 2025 14:40:39 +0000 Subject: [PATCH 2/9] =?UTF-8?q?[RELAND][LAYOUTS]=C2=A0Generate=20distribut?= =?UTF-8?q?ed=20layouts=20for=20`tcgen05.ld/st`=20generically=20(#8421)=20?= =?UTF-8?q?(#8495)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR relands https://github.com/triton-lang/triton/pull/8386. It depends on https://github.com/triton-lang/triton/pull/8492 to avoid regressing in some workloads. --- .../TritonGPU/IR/LinearLayoutConversions.h | 23 +- .../Dialect/TritonNvidiaGPU/IR/Dialect.h | 59 ++- .../TritonNvidiaGPU/IR/TensorMemoryUtils.h | 37 ++ include/triton/Tools/LinearLayout.h | 19 + .../TritonToTritonGPU/RelayoutTritonGPU.cpp | 12 +- .../TritonGPU/IR/LinearLayoutConversions.cpp | 227 +-------- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 41 +- lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt | 1 + lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp | 407 +++++++++++----- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 42 +- .../TritonNvidiaGPU/IR/TensorMemoryUtils.cpp | 317 +++++++++++++ .../Transforms/OptimizeTMemLayouts.cpp | 90 ++-- .../Transforms/PromoteLHSToTMem.cpp | 18 +- lib/Tools/LinearLayout.cpp | 34 ++ python/examples/gluon/01-attention-forward.py | 66 ++- python/src/gluon_ir.cc | 74 +++ python/src/ir.cc | 2 + python/test/gluon/test_core.py | 52 ++- .../experimental/gluon/language/_semantic.py | 80 +++- .../language/nvidia/blackwell/__init__.py | 137 ++---- python/triton/language/core.py | 2 +- .../translator_helpers.py | 17 +- python/tutorials/gluon/06-tcgen05.py | 38 +- python/tutorials/gluon/07-persistence.py | 4 +- .../tutorials/gluon/08-warp-specialization.py | 10 +- test/Conversion/relayout_tritongpu.mlir | 12 +- .../tritongpu_to_llvm_blackwell.mlir | 27 +- test/TritonGPU/accelerate-matmul.mlir | 20 +- test/TritonGPU/canonicalize.mlir | 2 +- test/TritonNvidiaGPU/interleave_tmem.mlir | 47 +- .../test_tensor_memory_allocation.mlir | 4 +- test/TritonNvidiaGPU/tmem_layouts.mlir | 90 ++-- .../WarpSpecialization/WSDataPartition.cpp | 8 +- .../TensorMemoryToLLVM.cpp | 435 +++--------------- 34 files changed, 1365 insertions(+), 1089 deletions(-) create mode 100644 include/triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h create mode 100644 lib/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.cpp diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index fed4ded91f..5745a88243 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -117,19 +117,6 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef shape, int32_t elemBitWidth, unsigned instBitWidth, unsigned numLanesInShuffleGroup); -LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType, - int numWarps); - -std::optional -getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType, - int numWarps); - -// Return a layout valid for TMemLoad op for a tmem layout of block MxN that -// distribute the data long M for the warp groups. This doesn't affect the TMem -// layout it just returns a distributed layout compatible for tmem_load. -LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType, - int numWarps); - // Create LinearLayout for scale in scaled mfma. LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, ArrayRef dotOperandShape, @@ -161,5 +148,15 @@ std::optional chooseMfmaLikeStoreLayout(RankedTensorType valType); LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared, bool disableSwizzle); +// Make a LinearLayout that maps a block-id to an N-dimensional index. +// +// The tensor is split up into CTAsPerCGA pieces, which are distributed among +// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups). +// +// See the nomenclature note at the top of the LinearLayoutConversions.cpp file +// for an explanation of why this is called makeCgaLayout when it accepts a +// CTALayoutAttr. +LinearLayout makeCgaLayout(CTALayoutAttr layout); + } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h index 6b934ea8a6..8da92235d0 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -29,6 +29,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" // TritonNvidiaGPU depends on Triton #include "triton/Dialect/Triton/IR/Dialect.h" @@ -61,14 +62,52 @@ struct TMemAllocation { int numCols; }; +// Used to describe the layout of the TMEM load/store instructions +enum class TMemAccessAtom { I32x32b, I16x64b, I16x128b, I16x256b, I16x32bx2 }; + +inline int getElementsPerThread(TMemAccessAtom atom) { + switch (atom) { + case TMemAccessAtom::I32x32b: + case TMemAccessAtom::I16x64b: + case TMemAccessAtom::I16x32bx2: + return 1; + case TMemAccessAtom::I16x128b: + return 2; + case TMemAccessAtom::I16x256b: + return 4; + } + llvm_unreachable("Unknown TMemAccessAtom"); +} + +inline const char *getOpShape(TMemAccessAtom atom) { + switch (atom) { + case TMemAccessAtom::I32x32b: + return "32x32b"; + case TMemAccessAtom::I16x64b: + return "16x64b"; + case TMemAccessAtom::I16x128b: + return "16x128b"; + case TMemAccessAtom::I16x256b: + return "16x256b"; + case TMemAccessAtom::I16x32bx2: + return "16x32bx2"; + } + llvm_unreachable("Unknown TMemAccessAtom"); +} + +LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom, + bool unpacked); + TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType); -gpu::DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N, - RankedTensorType oltType, - unsigned numWarps); -gpu::DistributedEncodingTrait +SmallVector +getTmemCompatibleLayouts(gpu::MemDescType memType, unsigned numWarps, + ArrayRef ctaSplit = {1, 1}); + +std::optional getTmemLoadLayoutSplitLongM(RankedTensorType tensorType, gpu::MemDescType memType, int numWarps); + SmallVector getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType, gpu::MemDescType memType); @@ -76,9 +115,15 @@ getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType, bool isDistributedLayoutTMemCompatible(Operation *op, RankedTensorType tensorType, gpu::MemDescType memType); -bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType, - gpu::MemDescType memType, - int numWarps); + +gpu::DistributedEncodingTrait +getDefaultLayoutForTmemLdSt(gpu::MemDescType memType, unsigned numWarps, + gpu::CTALayoutAttr ctaLayout); + +std::optional +getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom, + unsigned numWarps, + gpu::CTALayoutAttr ctaLayout); } // namespace mlir::triton::nvidia_gpu diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h b/include/triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h new file mode 100644 index 0000000000..3ae002a597 --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h @@ -0,0 +1,37 @@ +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" + +#include +#include +#include + +namespace mlir::triton::nvidia_gpu { + +// Get the maximum number of registers per thread based on the context. This is +// by default 256, but it can be overridden by `ttg.maxnreg` set on the module +// or a contextual register limit set by the compiler on partitions. +int getContextualMaxNReg(Operation *op); +struct TMemLdStEncodingInfo { + TMemAccessAtom atom; + LinearLayout reps; + ColumnAction perm; + int numRegsPerMessage; + std::optional secondHalfOffset; + std::optional broadcast = std::nullopt; + bool unpacked = false; + unsigned vec = 1; + bool padding = false; +}; + +FailureOr +computeTMemLdStEncodingInfo(RankedTensorType regTy, gpu::MemDescType memTy, + int maxnreg, + std::function emitError = {}); + +} // namespace mlir::triton::nvidia_gpu + +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_ diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 36001e212f..175aefcaea 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -558,6 +558,25 @@ class LinearLayout { return reshapeOuts({{*getOutDimNames().begin(), getTotalOutDimSize()}}); } + // Resizes the dimension to one that is smallre or equal to the given size. + // These operations are similar to `sublayout` but at a dimension level. + [[nodiscard]] LinearLayout resizeInDim(StringAttr inDim, + int32_t newSize) const; + [[nodiscard]] LinearLayout resizeOutDim(StringAttr outDim, + int32_t newSize) const; + + [[nodiscard]] LinearLayout renameInDim(StringAttr oldDim, + StringAttr newDim) const { + auto bases = getBases(); + auto it = bases.find(oldDim); + assert(it != bases.end()); + auto value = std::move(it->second); + bases.erase(it); + bases.insert({newDim, std::move(value)}); + return LinearLayout(bases, getOutDims(), + /*requireSurjective=*/isSurjective()); + } + // Concatenates two layouts by their in (resp. out) dimensions. The layouts // must have the same output (resp. input) dimensions and sizes and different // input (resp. output) dimensions. The input dimensions of this layout are diff --git a/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp index 193f6f5052..33da83e4d6 100644 --- a/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp @@ -21,16 +21,10 @@ namespace ttng = triton::nvidia_gpu; RankedTensorType getTMEMTensorLayout(const TypeConverter *tc, RankedTensorType type, MemDescType memdesc, unsigned numWarps) { - Attribute encoding; type = cast(tc->convertType(type)); - if (isa(memdesc.getEncoding())) { - encoding = LinearEncodingAttr::get( - type.getContext(), getScaleTMEMStoreLinearLayout(type, numWarps)); - } else { - auto tmemEnc = cast(memdesc.getEncoding()); - encoding = ttng::getTmemCompatibleLayout( - tmemEnc.getBlockM(), tmemEnc.getBlockN(), type, numWarps); - } + auto ctaLayout = getCTALayout(type.getEncoding()); + auto encoding = + ttng::getDefaultLayoutForTmemLdSt(memdesc, numWarps, ctaLayout); return type.cloneWithEncoding(encoding); } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 4b7ee487c6..8908a0d98a 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -53,35 +53,6 @@ SmallVector permuteDimNames(const SmallVector &names, return ret; } -// Make a LinearLayout that maps a block-id to an N-dimensional index. -// -// The tensor is split up into CTAsPerCGA pieces, which are distributed among -// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups). -// -// See the nomenclature note at the top of the file for an explanation of why -// this is called makeCgaLayout when it accepts a CTALayoutAttr. -LinearLayout makeCgaLayout(CTALayoutAttr layout) { - MLIRContext *ctx = layout.getContext(); - StringAttr kBlock = S("block"); - - int rank = layout.getCTAOrder().size(); - SmallVector outDimNames = standardOutDimNames(ctx, rank); - - LinearLayout ret = LinearLayout::empty(); - for (int i = 0; i < rank; i++) { - // Start with the most minor dimension, which is order[0]. - int dim = layout.getCTAOrder()[i]; - int split = layout.getCTASplitNum()[dim]; - int ctas = layout.getCTAsPerCGA()[dim]; - assert(ctas % split == 0); - ret *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) * - LinearLayout::zeros1D(ctas / split, kBlock, outDimNames[dim]); - } - - // Transpose to standard order (dim0, dim1, ...). - return ret.transposeOuts(outDimNames); -} - LinearLayout swizzledSharedToLinearLayout(ArrayRef shape, SwizzledSharedEncodingAttr shared) { MLIRContext *ctx = shared.getContext(); @@ -185,6 +156,28 @@ sharedToLinearLayoutAMDRotating(ArrayRef shape, } // namespace +LinearLayout makeCgaLayout(CTALayoutAttr layout) { + MLIRContext *ctx = layout.getContext(); + StringAttr kBlock = S("block"); + + int rank = layout.getCTAOrder().size(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < rank; i++) { + // Start with the most minor dimension, which is order[0]. + int dim = layout.getCTAOrder()[i]; + int split = layout.getCTASplitNum()[dim]; + int ctas = layout.getCTAsPerCGA()[dim]; + assert(ctas % split == 0); + ret *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) * + LinearLayout::zeros1D(ctas / split, kBlock, outDimNames[dim]); + } + + // Transpose to standard order (dim0, dim1, ...). + return ret.transposeOuts(outDimNames); +} + // Returns the layout of a single core matrix which tiles the nvmma layout LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared, bool disableSwizzle) { @@ -1757,180 +1750,4 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) { return mfmaLL.compose(swapLL); } -LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType, - int numWarps) { - assert(numWarps == 4 || numWarps == 8); - MLIRContext *ctx = scaleType.getContext(); - - using basisT = std::vector>; - StringAttr kRegister = StringAttr::get(ctx, "register"); - StringAttr kLane = StringAttr::get(ctx, "lane"); - StringAttr kWarp = StringAttr::get(ctx, "warp"); - - int64_t M = scaleType.getDimSize(0); - int64_t N = scaleType.getDimSize(1); - auto CTALayout = getCTALayout(scaleType.getEncoding()); - basisT regBase; - - // Pick a layout that will be trivial to store into the following TMEM layout: - // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x - // Pack 4 scales together, if there are less than 4 we replicate the data. - for (int i = 1; i < 4; i = i << 1) { - if (i >= N) - regBase.push_back({0, 0}); - else - regBase.push_back({0, i}); - } - // Distribute 32 elements of M along a warp. - basisT laneBase = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}; - // The data are replicated across all the warps of each warpgroups. - basisT warpBase = {{0, 0}, {0, 0}}; - for (int i = 32; i < M; i = i << 1) { - regBase.push_back({i, 0}); - } - for (int i = 4; i < N; i = i << 1) { - regBase.push_back({0, i}); - } - // If we have 8 warps distribute the last dimension on the second warp group. - if (numWarps == 8) { - warpBase.push_back(regBase.back()); - regBase.pop_back(); - } - - SmallVector outDimNames = standardOutDimNames(ctx, 2); - auto regLanes = - LinearLayout({{kRegister, regBase}, {kLane, laneBase}, {kWarp, warpBase}}, - {outDimNames[0], outDimNames[1]}); - - return combineCtaCgaWithShape(regLanes, CTALayout, scaleType.getShape()); -} - -std::optional -getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType, - int numWarps) { - // Too small to distribute on two warp groups while using 16x256 message. - if (numWarps == 8 && M == 64 && N <= 16 && - oldType.getElementTypeBitWidth() < 32) { - return {}; - } - assert(numWarps == 4 || numWarps == 8); - auto ctaLayout = getCTALayout(oldType.getEncoding()); - SmallVector shape = getShapePerCTA(oldType); - MLIRContext *ctx = ctaLayout.getContext(); - - using basisT = std::vector>; - StringAttr kRegister = StringAttr::get(ctx, "register"); - StringAttr kLane = StringAttr::get(ctx, "lane"); - StringAttr kWarp = StringAttr::get(ctx, "warp"); - SmallVector outDimNames = standardOutDimNames(ctx, 2); - - unsigned numElementsPerThread = 256 / oldType.getElementTypeBitWidth(); - int kWidth = 64 / oldType.getElementTypeBitWidth(); - // Follow the layout given by a tmem load using this layout for the inner - // shape: - // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b - LinearLayout innerTile = - nvidiaMmaTile(ctx, {8, numElementsPerThread}, kWidth, {1, 0}, {0, 1}); - innerTile = - innerTile * LinearLayout::identity1D(2, kRegister, outDimNames[0]); - // Then distribute the rest along warpgroups and registers. - // Then the last warp distribute along M or N following the same order as - // in getTmemLoadStoreLayout32x32b. This allows us to use the same lowering to - // tmem for load and store. This part could be generalized by making the - // lowering of tmem load and store rely more on linear layout. - bool distributeMAlongWarps = false; - bool distributeNAlongWarps = false; - // Figure out how to distribute acorss warpgroups. - if (numWarps == 8) { - if (shape[0] > 128) { - distributeMAlongWarps = true; - } else { - distributeNAlongWarps = true; - } - } - int nBase = numElementsPerThread; - int maxRegN = - std::min(N, distributeNAlongWarps ? (int)shape[1] / 2 : (int)shape[1]); - if (maxRegN / nBase > 1) { - innerTile = innerTile * LinearLayout::identity1D(maxRegN / nBase, kRegister, - outDimNames[1]); - } - if (M != 64) { - innerTile = - innerTile * LinearLayout::identity1D(2, kRegister, outDimNames[0]); - } - // Distribute M along 4 warps to satisfy TMEM requirements. - innerTile = innerTile * LinearLayout::identity1D(4, kWarp, outDimNames[0]); - - // Fill out the rest of the shape with M first then N. - int numMRegDim = std::min(128, (int)shape[0]) / M; - if (numMRegDim > 1) { - innerTile = innerTile * - LinearLayout::identity1D(numMRegDim, kRegister, outDimNames[0]); - } - // Dim M=128 should be distributed on the second warp group. - int nextDim = 128; - if (distributeMAlongWarps) { - innerTile = innerTile * LinearLayout::identity1D(2, kWarp, outDimNames[0]); - nextDim <<= 1; - } - numMRegDim = shape[0] / nextDim; - if (numMRegDim > 1) { - innerTile = innerTile * - LinearLayout::identity1D(numMRegDim, kRegister, outDimNames[0]); - } - int maxN = distributeNAlongWarps ? shape[1] / 2 : shape[1]; - int numNRegDim = maxN / maxRegN; - if (numNRegDim > 1) { - innerTile = innerTile * - LinearLayout::identity1D(numNRegDim, kRegister, outDimNames[1]); - } - if (distributeNAlongWarps) { - innerTile = innerTile * LinearLayout::identity1D(2, kWarp, outDimNames[1]); - } - return combineCtaCgaWithShape(innerTile, ctaLayout, oldType.getShape()); -} - -LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType, - int numWarps) { - assert(numWarps == 8); - auto ctaLayout = getCTALayout(oldType.getEncoding()); - SmallVector shape = getShapePerCTA(oldType); - MLIRContext *ctx = ctaLayout.getContext(); - - using basisT = std::vector>; - StringAttr kRegister = StringAttr::get(ctx, "register"); - StringAttr kLane = StringAttr::get(ctx, "lane"); - StringAttr kWarp = StringAttr::get(ctx, "warp"); - - // Follow the layout given by a tmem load using this layout: - // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-1632b2 - basisT laneBase; - assert(M == 128); - for (int i = 1; i < 16; i = i << 1) { - laneBase.push_back({i, 0}); - } - basisT regBase; - for (int i = 1; i < N / 2; i = i << 1) { - regBase.push_back({0, i}); - } - laneBase.push_back({0, N / 2}); - // then replicate the pattern. - for (int i = N; i < shape[1]; i = i << 1) { - regBase.push_back({0, i}); - } - for (int i = M; i < shape[0]; i = i << 1) { - regBase.push_back({i, 0}); - } - // warp 0 and 4 can only access M[0:32], therefore we need to interleave the - // data. - basisT warpBase = {{32, 0}, {64, 0}, {16, 0}}; - SmallVector outDimNames = standardOutDimNames(ctx, 2); - auto regLanes = - LinearLayout({{kRegister, regBase}, {kLane, laneBase}, {kWarp, warpBase}}, - {outDimNames[0], outDimNames[1]}); - - return combineCtaCgaWithShape(regLanes, ctaLayout, oldType.getShape()); -} - } // namespace mlir::triton::gpu diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 6bec8e1e32..25dee6c054 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -444,13 +444,6 @@ class BlockedToMMA : public mlir::OpRewritePattern { } }; -// Pick the layout to match MXFP scales layout in register so that it can be -// copied directly using tmem st. -static Attribute getTmemScales(RankedTensorType type, unsigned numWarps) { - return triton::gpu::LinearEncodingAttr::get( - type.getContext(), getScaleTMEMStoreLinearLayout(type, numWarps)); -} - static bool canUseTwoCTAs(triton::DotOp dotOp) { RankedTensorType retType = dotOp.getType(); auto retShapePerCTA = getShapePerCTA(retType); @@ -575,12 +568,12 @@ class BlockedToMMAv5 : public mlir::OpRewritePattern { CTASplitNum[1]); Attribute tensorMemorySpace = triton::nvidia_gpu::TensorMemorySpaceAttr::get(context); - Type accMemDescType = triton::gpu::MemDescType::get( - oldRetType.getShape(), oldRetType.getElementType(), accEncoding, - tensorMemorySpace, - /*mutableMemory=*/true); - Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout( - instrShape[0], instrShape[1], oldRetType, numWarps); + MemDescType accMemDescType = + MemDescType::get(oldRetType.getShape(), oldRetType.getElementType(), + accEncoding, tensorMemorySpace, + /*mutableMemory=*/true); + auto newDistributedEncoding = nvidia_gpu::getDefaultLayoutForTmemLdSt( + accMemDescType, numWarps, CTALayout); auto newAccType = oldRetType.cloneWithEncoding(newDistributedEncoding); Value cvtAcc = ConvertLayoutOp::create(rewriter, loc, newAccType, dotOp.getOperand(2)); @@ -856,12 +849,12 @@ class ScaledBlockedToMMAv5 context, m, n, colStride, CTASplitNum[0], CTASplitNum[1]); Attribute tensorMemorySpace = triton::nvidia_gpu::TensorMemorySpaceAttr::get(context); - Type accMemDescType = triton::gpu::MemDescType::get( - oldRetType.getShape(), oldRetType.getElementType(), accEncoding, - tensorMemorySpace, - /*mutableMemory=*/true); - Attribute newDistributedEncoding = - nvidia_gpu::getTmemCompatibleLayout(m, n, oldRetType, numWarps); + MemDescType accMemDescType = + MemDescType::get(oldRetType.getShape(), oldRetType.getElementType(), + accEncoding, tensorMemorySpace, + /*mutableMemory=*/true); + auto newDistributedEncoding = nvidia_gpu::getDefaultLayoutForTmemLdSt( + accMemDescType, numWarps, CTALayout); auto newAccType = oldRetType.cloneWithEncoding(newDistributedEncoding); Value cvtAcc = ConvertLayoutOp::create(rewriter, loc, newAccType, dotOp.getOperand(2)); @@ -875,16 +868,18 @@ class ScaledBlockedToMMAv5 Attribute scaleEncoding = triton::nvidia_gpu::TensorMemoryScalesEncodingAttr::get( context, CTASplitNum[0], CTASplitNum[1]); - Type scaleAType = triton::gpu::MemDescType::get( + MemDescType scaleAType = triton::gpu::MemDescType::get( oldScaleAType.getShape(), oldScaleAType.getElementType(), scaleEncoding, tensorMemorySpace, /*mutableMemory=*/false); - Type scaleBType = triton::gpu::MemDescType::get( + MemDescType scaleBType = triton::gpu::MemDescType::get( oldScaleBType.getShape(), oldScaleBType.getElementType(), scaleEncoding, tensorMemorySpace, /*mutableMemory=*/false); - Attribute scaleALayout = getTmemScales(oldScaleAType, numWarps); - Attribute scaleBLayout = getTmemScales(oldScaleBType, numWarps); + Attribute scaleALayout = nvidia_gpu::getDefaultLayoutForTmemLdSt( + scaleAType, numWarps, getCTALayout(oldScaleAType.getEncoding())); + Attribute scaleBLayout = nvidia_gpu::getDefaultLayoutForTmemLdSt( + scaleBType, numWarps, getCTALayout(oldScaleBType.getEncoding())); RankedTensorType newScaleAType = oldScaleAType.cloneWithEncoding(scaleALayout); RankedTensorType newScaleBType = diff --git a/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt index 47414c697e..51b023370c 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonNvidiaGPUIR Dialect.cpp + TensorMemoryUtils.cpp Ops.cpp DEPENDS diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp index bd1070a059..138736751b 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -35,6 +35,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -85,101 +86,304 @@ TMemAllocation getTmemAllocSizes(MemDescType memDescType) { return {nRow, nCol}; } -DistributedEncodingTrait getTmemLoadStoreLayout32x32b(unsigned M, unsigned N, - RankedTensorType oldType, - unsigned numWarps) { - assert(numWarps == 4 || numWarps == 8); - auto shape = getShapePerCTA(oldType); - assert(shape.size() == 2); - SmallVector sizePerThread; - SmallVector threadsPerWarp; - SmallVector warpsPerCTA; - SmallVector order; - SmallVector blocksPerTile = {(unsigned)shape[0] / M, - (unsigned)shape[1] / N}; - int numBlocks = blocksPerTile[0] * blocksPerTile[1]; - if (M == 64) { - unsigned numWarpGroups = numWarps / 4; - if (numBlocks == 1) { - // Split along the N dimension - sizePerThread = {1, ceil(N, numWarpGroups * 2)}; - threadsPerWarp = {16, 2}; - warpsPerCTA = {4, numWarpGroups}; +LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom, + bool unpacked) { + auto str_attr = [&](StringRef str) { return StringAttr::get(ctx, str); }; + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kRow = str_attr("row"); + auto kCol = str_attr("col"); + // Set the output order to be kRow, kCol and the input order to be kReg first + LinearLayout tile = LinearLayout::identity1D(1, kReg, kRow) * + LinearLayout::identity1D(1, kReg, kCol) * + LinearLayout::identity1D(1, kLane, kRow) * + LinearLayout::identity1D(1, kLane, kCol); + // Each register moves 32/bitwidth (= 2) columns when unpacked + if (unpacked) { + tile *= LinearLayout::zeros1D(1, kReg, kCol, 2); + } + if (atom == TMemAccessAtom::I32x32b) { + tile *= LinearLayout::identity1D(32, kLane, kRow); + } else if (atom == TMemAccessAtom::I16x32bx2) { + tile *= LinearLayout::identity1D(16, kLane, kRow); + } else if (atom == TMemAccessAtom::I16x64b) { + LinearLayout::BasesT bases; + bases[kLane] = std::vector>{ + {8, 0}, {0, 1}, {1, 0}, {2, 0}, {4, 0}}; + tile *= LinearLayout(bases, {kRow, kCol}); + } else if (atom == TMemAccessAtom::I16x128b) { + tile *= LinearLayout::identity1D(4, kLane, kCol) * + LinearLayout::identity1D(8, kLane, kRow) * + LinearLayout::identity1D(2, kReg, kRow); + } else if (atom == TMemAccessAtom::I16x256b) { + tile *= LinearLayout::identity1D(2, kReg, kCol) * + LinearLayout::identity1D(4, kLane, kCol) * + LinearLayout::identity1D(8, kLane, kRow) * + LinearLayout::identity1D(2, kReg, kRow); + } else { + llvm_unreachable("Unsupported TMEM access atom"); + } + auto nCol = tile.getOutDimSize(kCol); + auto bases = tile.getBases(); + bases[kWarp].push_back({32, 0}); + bases[kWarp].push_back({64, 0}); + auto ret = LinearLayout(bases, {{kRow, 128}, {kCol, nCol}}, false); + return ret; +} + +static std::optional getDistributedLayoutForTmemLdSt( + const LinearLayout &ll, TMemAccessAtom atom, unsigned numWarps, + int bitwidth, std::optional ctaLayout = std::nullopt) { + auto dims = to_vector(ll.getOutDimNames()); + assert(dims.size() == 2); + auto rowColDims = to_vector(ll.getInDimNames()); + auto *ctx = dims[0].getContext(); + // Add block dimension + if (ctaLayout) { + // Get CTALayout without broadcasting to divide the ll + // as the TMEM layout does not reflect CTA broadcasting + auto splitNum = ctaLayout->getCTASplitNum(); + auto ctaBlockSplit = + CTALayoutAttr::get(ctx, splitNum, splitNum, ctaLayout->getCTAOrder()); + auto ctaBlockSplitLL = gpu::makeCgaLayout(ctaBlockSplit); + assert(ctaBlockSplitLL.getNumOutDims() == ll.getNumOutDims()); + // rename block into col + auto kBlock = StringAttr::get(ctx, "block"); + auto ctaCol = ctaBlockSplitLL.renameInDim(kBlock, rowColDims[1]); + auto quot = divideRight(ll, ctaCol); + assert(quot.has_value()); + auto maybeRet = + getDistributedLayoutForTmemLdSt(*quot, atom, numWarps, bitwidth); + if (!maybeRet) + return maybeRet; + // Add the full ctaBlock layout (with broadcasting) + auto ctaBlock = gpu::makeCgaLayout(*ctaLayout); + return *maybeRet * ctaBlock; + } + // This code is dual to the one in lowerTMemLdSt + if (bitwidth != 32) { + // TODO move this to a helper function + auto kReg = StringAttr::get(ctx, "register"); + LinearLayout quot; + int bestContig = 1; + for (int contig = 1; bitwidth * contig <= 32; contig *= 2) { + auto maybeQuot = divideLeft( + ll, LinearLayout::identity1D(contig, rowColDims[1], dims[1])); + if (!maybeQuot) + break; + quot = *maybeQuot; + bestContig = contig; + } + + // Pack contiguous elements + // This works to pack b8 or b16 into b32 but also b8 into b16 and recurse + if (bestContig > 1) { + auto ret = getDistributedLayoutForTmemLdSt(quot, atom, numWarps, + bitwidth * bestContig); + if (!ret) + return ret; + auto castbbitwidth = LinearLayout::identity1D(bestContig, kReg, dims[1]); + return castbbitwidth * ret.value(); + } + if (auto maybeQuot = divideLeft( + ll, LinearLayout::zeros1D(32 / bitwidth, rowColDims[1], dims[1]) * + LinearLayout::identity1D(2, rowColDims[1], dims[1])); + bitwidth == 16 && maybeQuot) { + // Unpacked case + auto ret = + getDistributedLayoutForTmemLdSt(*maybeQuot, atom, numWarps, 32); + if (!ret) + return ret; + auto castbbitwidth = LinearLayout::identity1D(2, kReg, dims[1]); + return castbbitwidth * ret.value(); + } else if (auto maybeQuot = + divideLeft(ll, LinearLayout::zeros1D( + 32 / bitwidth, rowColDims[1], dims[1]))) { + // Software padding + assert(maybeQuot); + return getDistributedLayoutForTmemLdSt(*maybeQuot, atom, numWarps, 32); + } else if (ll.getInDimSize(rowColDims[1]) == 1) { + // Software padding with just one column + return getDistributedLayoutForTmemLdSt(ll, atom, numWarps, 32); } else { - sizePerThread = {1, ceil(N, 2)}; - threadsPerWarp = {16, 2}; - warpsPerCTA = {0, 0}; - // Distribute at most as many warp groups as there is blocks - // along M dimension. - warpsPerCTA[0] = 4 * std::min(blocksPerTile[0], numWarpGroups); - // Distribute rest of the warp groups along N dimension. - warpsPerCTA[1] = ceil(numWarpGroups, warpsPerCTA[0] / 4); + assert(false && "Should not happen"); } - } else { - unsigned numWarpGroups = numWarps / 4; - if (shape[0] > 128) { - // Split along M dimension - sizePerThread = {1, N}; - threadsPerWarp = {32, 1}; - warpsPerCTA = {4 * numWarpGroups, 1}; + } + // getTileLayout returns the layout for a bitwidth of 32 + assert(bitwidth == 32); + auto tile = getTileLayout(ctx, atom, false); + // Plan: + // tile: register, lane, warp -> row, cols + // ll: row, cols -> dim0, dim1 + // We extend the tile to have the right vectorisation and the result is given + // by ll o tile : register, lane warp -> dim0, dim1 If the tile is too + // large, we cannot use the tile + + auto nColsTile = tile.getOutDimSize(rowColDims[1]); + auto nColsLL = ll.getInDimSize(rowColDims[1]); + auto nColsMissing = nColsLL / nColsTile; + if (nColsMissing == 0) { + return std::nullopt; + } + auto kReg = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + auto kWarp = StringAttr::get(ctx, "warp"); + bool instr32Rows = atom == TMemAccessAtom::I32x32b; + bool layout16Rows = + ll.getBasis(rowColDims[0], llvm::Log2_32(16)) == ArrayRef{0, 0}; + + // We are choosing the distributed layout (ll o tile). In the lowering + // we will do ll^{-1} o (ll o tile) and we expect to get tile back. + // For this to be possible, ll should accept a left-inverse, that is, it + // should be injective + // In less fancy words, we look for the `comp` layout not to have any zero + // basis as that would disallow the resulting layout to be left-divisible by + // the tile + auto comp = + tile.compose(ll).sublayout({kReg, kLane}, to_vector(ll.getOutDimNames())); + if (instr32Rows) { + // We will use 16x32bx2 instruction for lane=16 so we remove the last lane + // basis + comp = comp.resizeInDim(kLane, comp.getInDimSize(kLane) / 2); + } + if (!comp.isInjective()) + return std::nullopt; + + // Fit the warp bases either tiling on the RHS or in row=16 + StringAttr row16; + // If we need to fit something (the instruction does not cover it + // and the layout has 32 rows) we first try to fit a warp, and if we + // can't we fit a register + if (!instr32Rows && !layout16Rows) { + if (numWarps > 4) { + row16 = kWarp; + } else { + row16 = kReg; + } + } + + // We reserve enough columns to fit in the warps + int warpsToTile = numWarps / ((row16 == kWarp) ? 8 : 4); + // Cap warps to tile above by nColsMissing. The rest go to broadcasting + int warpBroadcast = warpsToTile / std::min(nColsMissing, warpsToTile); + warpsToTile /= warpBroadcast; + auto nColsOrig = nColsLL; + nColsMissing /= warpsToTile; + nColsLL /= warpsToTile; + + if (nColsMissing > 1) { + if (instr32Rows && layout16Rows) { + // If the lane 16 would load repeated data, instead we make it load half + // of the data via the 16x32bx2 instruction + tile *= LinearLayout::identity1D(nColsMissing / 2, kReg, rowColDims[1]); + auto bases = tile.getBases(); + bases[kLane].back() = {0, nColsLL / 2}; + tile = LinearLayout( + bases, {{rowColDims[0], 128}, {rowColDims[1], nColsLL}}, false); + } else { - // Split along N dimension - sizePerThread = {1, ceil(N, numWarpGroups)}; - threadsPerWarp = {32, 1}; - warpsPerCTA = {4, numWarpGroups}; + tile *= LinearLayout::identity1D(nColsMissing, kReg, rowColDims[1]); } } - order = {0, 1}; - auto ctaLayout = getCTALayout(oldType.getEncoding()); - return triton::gpu::BlockedEncodingAttr::get(ctaLayout.getContext(), - sizePerThread, threadsPerWarp, - warpsPerCTA, order, ctaLayout); + auto bases = tile.getBases(); + auto &warpBases = bases[kWarp]; + if (row16) { + bases[row16].push_back({16, 0}); + } + + // Add the bases we had reserved for the warps to tile + assert(nColsOrig / nColsLL == warpsToTile); + for (int i = nColsLL; i < nColsOrig; i *= 2) { + bases[kWarp].push_back({0, i}); + } + // Broadcast in the rest of the warps if we need more bases + for (int i = 1; i < warpBroadcast; i *= 2) { + bases[kWarp].push_back({0, 0}); + } + tile = LinearLayout(bases, {{rowColDims[0], 128}, {rowColDims[1], nColsOrig}}, + false); + auto ret = tile.compose(ll); + return ret; +} + +std::optional +getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom, + unsigned numWarps, + gpu::CTALayoutAttr ctaLayout) { + assert(memType.getMemorySpace() == + TensorMemorySpaceAttr::get(memType.getContext())); + assert(numWarps >= 4 && llvm::isPowerOf2_32(numWarps) && + "numWarps must be a power of 2 and >= 4"); + assert(atom != TMemAccessAtom::I16x32bx2 && + "This layout is inferred sometimes for the 32x32b atom"); + auto ll = toLinearLayout(memType.getShape(), memType.getEncoding()); + auto bitwidth = memType.getElementTypeBitWidth(); + return getDistributedLayoutForTmemLdSt(ll, atom, numWarps, bitwidth, + ctaLayout); } -DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N, - RankedTensorType oldType, - unsigned numWarps) { +DistributedEncodingTrait +getDefaultLayoutForTmemLdSt(gpu::MemDescType memType, unsigned numWarps, + gpu::CTALayoutAttr ctaLayout) { + auto *ctx = memType.getContext(); bool prefer16x256 = triton::tools::getBoolEnv("TRITON_PREFER_TMEM_16x256_LAYOUT"); if (prefer16x256) { - std::optional ll = - getTmemLoadStoreLayout16x256(M, N, oldType, numWarps); - if (ll) { - return LinearEncodingAttr::get(oldType.getContext(), *ll); + auto layout = getDistributedLayoutForTmemLdSt( + memType, TMemAccessAtom::I16x256b, numWarps, ctaLayout); + if (layout) { + return LinearEncodingAttr::get(ctx, *layout); } } - return getTmemLoadStoreLayout32x32b(M, N, oldType, numWarps); + auto layout = getDistributedLayoutForTmemLdSt( + memType, TMemAccessAtom::I32x32b, numWarps, ctaLayout); + assert(layout); + return LinearEncodingAttr::get(ctx, *layout); } -DistributedEncodingTrait +std::optional getTmemLoadLayoutSplitLongM(RankedTensorType tensorType, MemDescType memType, int numWarps) { - auto tmemEnc = dyn_cast( - memType.getEncoding()); - if (!tmemEnc || tmemEnc.getBlockM() != 128) - return {}; - int M = tmemEnc.getBlockM(); - int N = tmemEnc.getBlockN(); - auto llEncoding = dyn_cast(tensorType.getEncoding()); - if (!llEncoding) - return {}; - auto CTALayout = getCTALayout(tensorType.getEncoding()); - auto shapePerCTA = mlir::triton::gpu::getShapePerCTA(tensorType); if (numWarps != 8) - return {}; - LinearLayout llLayout = - gpu::getTmemLoadLayoutSplitLongM(M, N, tensorType, numWarps); - return LinearEncodingAttr::get(tensorType.getContext(), llLayout); -} + return std::nullopt; -bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType, - MemDescType memType, int numWarps) { - auto layout = getTmemLoadLayoutSplitLongM(tensorType, memType, numWarps); + auto ctaLayout = getCTALayout(tensorType.getEncoding()); + std::optional layout = getDistributedLayoutForTmemLdSt( + memType, TMemAccessAtom::I32x32b, numWarps, ctaLayout); if (!layout) - return false; - return areLayoutsEquivalent( - tensorType.getShape(), cast(layout), - cast(tensorType.getEncoding())); + return std::nullopt; + auto ret = *layout; + + // Optimisation for reductions: + // We can map lane=16 to any dimension, and it will be lowered to 32x16bx2. + // As such, if we have 8 warps and the basis warp=4 is mapped to a different + // dimension than warp=1, warp=2, and lane=16 is mapped to the same dimension + // as the first two warp bases, we can swap warp=4 and lane=16. + // Generally, we don't want warp=4 to have data on a different dimension to + // dim=1 and dim=2 + auto *ctx = tensorType.getContext(); + auto kLane = StringAttr::get(ctx, "lane"); + auto kWarp = StringAttr::get(ctx, "warp"); + auto dims = to_vector(ret.getOutDimNames()); + + // In most cases this is going to be dim=0, but the optimization + // also applies for scales where we may be able to have the layout + // replicated across warps + for (int dim : {0, 1}) { + auto w1dim = ret.getBasis(kWarp, 0, dims[dim]) == 0; + auto w2dim = ret.getBasis(kWarp, 1, dims[dim]) == 0; + auto w4dim = ret.getBasis(kWarp, 2, dims[dim]) == 0; + auto l16dim = ret.getBasis(kLane, 4, dims[dim]) == 0; + if (l16dim != w4dim && w1dim == w2dim && w1dim == l16dim) { + auto bases = ret.getBases(); + std::swap(bases[kWarp][2], bases[kLane][4]); + return LinearEncodingAttr::get( + tensorType.getContext(), + LinearLayout(bases, ret.getOutDims(), ret.isSurjective())); + } + } + return std::nullopt; } SmallVector @@ -187,36 +391,22 @@ getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType, MemDescType memType) { int numWarps = lookupNumWarps(op); assert(numWarps % 4 == 0); - - if (isa( - memType.getEncoding())) { - return {triton::gpu::LinearEncodingAttr::get( - tensorType.getContext(), - getScaleTMEMStoreLinearLayout(tensorType, numWarps))}; - } - + auto ctaLayout = getCTALayout(tensorType.getEncoding()); SmallVector layouts; - auto attr = - cast(memType.getEncoding()); - int blockM = attr.getBlockM(); - int blockN = attr.getBlockN(); - - if (DistributedEncodingTrait splitMLayout = - getTmemLoadLayoutSplitLongM(tensorType, memType, numWarps)) - layouts.push_back(splitMLayout); - - if (auto ll16x256 = - getTmemLoadStoreLayout16x256(blockM, blockN, tensorType, numWarps)) { - layouts.push_back( - LinearEncodingAttr::get(tensorType.getContext(), ll16x256.value())); + for (auto atom : {TMemAccessAtom::I32x32b, TMemAccessAtom::I16x256b, + TMemAccessAtom::I16x128b, TMemAccessAtom::I16x64b}) { + auto ll = + getDistributedLayoutForTmemLdSt(memType, atom, numWarps, ctaLayout); + if (ll) { + layouts.push_back( + LinearEncodingAttr::get(tensorType.getContext(), ll.value())); + } + } + // Small hack until we generalise isDistributedLayoutTMemCompatible + auto ll = getTmemLoadLayoutSplitLongM(tensorType, memType, numWarps); + if (ll) { + layouts.push_back(ll.value()); } - - layouts.push_back(nvidia_gpu::getTmemLoadStoreLayout32x32b( - blockM, blockN, tensorType, numWarps)); - - // TODO: Add support for more layout compatible with tmem load/store. There - // will only be a discret set of layout possible due to the limiations of - // tmem_load/store. return layouts; } @@ -224,13 +414,8 @@ getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType, bool isDistributedLayoutTMemCompatible(Operation *op, RankedTensorType tensorType, gpu::MemDescType memType) { - SmallVector layouts = - getTmemCompatibleLayouts(op, tensorType, memType); - auto enc = cast(tensorType.getEncoding()); - return llvm::any_of(layouts, [&](DistributedEncodingTrait layout) { - return areLayoutsEquivalent(tensorType.getShape(), - cast(layout), enc); - }); + auto maxnreg = getContextualMaxNReg(op); + return succeeded(computeTMemLdStEncodingInfo(tensorType, memType, maxnreg)); } LogicalResult diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 1576445459..923ab0abc1 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -29,6 +29,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h" #include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.cpp.inc" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" #include "llvm/Support/ErrorHandling.h" @@ -593,31 +594,22 @@ static LogicalResult verifyTMEMOperand(Operation *op, RankedTensorType type, MemDescType memdesc, StringRef regName) { if (type.getRank() != 2) return op->emitOpError(regName) << " must be a 2D tensor"; - if (type.getEncoding()) { - auto enc = dyn_cast(type.getEncoding()); - if (!enc) { - return op->emitOpError(regName) - << " does not have an distributed encoding"; - } - SmallVector layouts = - getTmemCompatibleLayouts(op, type, memdesc); - if (layouts.empty()) { - return op->emitOpError(regName) - << " does not have any TMEM compatible layouts"; - } - if (llvm::none_of(layouts, [&](DistributedEncodingTrait layout) { - return areLayoutsEquivalent(type.getShape(), - cast(layout), - cast(enc)); - })) { - InFlightDiagnostic diag = op->emitOpError(regName) - << " layout is not TMEM compatible"; - for (Attribute layout : layouts) - diag.attachNote() << "potential TMEM layout: " << layout; - return diag; - } - } - return success(); + if (!type.getEncoding()) + return success(); + + auto maxnreg = getContextualMaxNReg(op); + if (isDistributedLayoutTMemCompatible(op, type, memdesc)) + return success(); + + // If it failed, give the user a hint + SmallVector layouts = + getTmemCompatibleLayouts(op, type, memdesc); + + InFlightDiagnostic diag = op->emitOpError(regName); + diag.attachNote() << "Got: " << type.getEncoding(); + for (Attribute layout : layouts) + diag.attachNote() << "potential TMEM layout: " << layout; + return diag; } LogicalResult TMEMStoreOp::verify() { diff --git a/lib/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.cpp b/lib/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.cpp new file mode 100644 index 0000000000..1a25e99599 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.cpp @@ -0,0 +1,317 @@ +#include "triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h" + +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" + +#include +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace mlir::triton::nvidia_gpu { + +namespace { + +constexpr int maxRegisters = 256; +constexpr int largestTmemLoadStore = 128; + +// Similar to largestVectorisation in TritonGPUToLLVM/Utility.cpp +std::optional> +getVec(const LinearLayout &cvt, const LinearLayout &tile, int maxnreg) { + auto *ctx = cvt.getInDimNames().begin()->getContext(); + auto kReg = StringAttr::get(ctx, "register"); + auto kCol = StringAttr::get(ctx, "col"); + LinearLayout reps, vec; + ColumnAction perm; + // Heuristic: + // Do not use more than half the registers as otherwise it's prone to spilling + assert(maxnreg / 2 <= largestTmemLoadStore); + auto maxReg = maxnreg / 2; + // Heuristic: + // If maxnreg is 256 and we need more than one message, we don't use max + // vectorisation as ptxas' scheduler breaks... + if (maxnreg == 256 && cvt.getInDimSize(kReg) > maxReg) { + maxReg /= 2; + } + auto maxVec = maxReg / tile.getInDimSize(kReg); + int i = 1; + for (; i <= maxVec; i *= 2) { + vec = LinearLayout::identity1D(i, kReg, kCol); + auto vecTile = tile * vec; + auto maybePerm = regPermForDivide(cvt, vecTile, /*left=*/true); + if (!maybePerm) { + break; + } + // nb. We could remove this part once we are confident the algo works + perm = *maybePerm; + auto newCvt = maybePerm->apply(cvt); + auto maybeReps = getReps(newCvt, vecTile); + if (!maybeReps.has_value()) { + break; + } + reps = *maybeReps; + } + if (i == 1) { + // Couldn't lower the tile + return std::nullopt; + } + // i is the smallest power of 2 that *cannot* be used to lower the tile + // so we return i / 2. + assert(i > 1); + return std::make_tuple(std::move(reps), std::move(perm), + (i / 2) * tile.getInDimSize(kReg)); +} +} // namespace + +// Get the maximum number of registers per thread based on the context. This is +// by default 256, but it can be overridden by `ttg.maxnreg` set on the module +// or a contextual register limit set by the compiler on partitions. +int getContextualMaxNReg(Operation *op) { + // Check the immediate parent op to see if it places a register constraint. + auto getFromParent = [](Operation *op) -> std::optional { + Operation *parent = op->getParentOp(); + if (auto mod = dyn_cast(parent)) { + if (auto attr = mod->getAttrOfType(AttrMaxRegistersName)) + return attr.getInt(); + return {}; + } + + if (auto partitions = dyn_cast(parent)) { + // Check if the partition has reduced registers. + unsigned idx = op->getParentRegion()->getRegionNumber(); + if (auto actRegisters = partitions.getParentOp().getActualRegisters()) + return (*actRegisters)[1 + idx]; + return {}; + } + + if (auto wsOp = dyn_cast(op->getParentOp())) { + // Check the register usage of the default warpgroup. + if (auto actRegisters = wsOp.getActualRegisters()) + return actRegisters->front(); + return {}; + } + + return {}; + }; + + // PTXAS validates the register usage of `tcgen05.ld` and `tcgen05.st` + // instructions based on the static number of registers set on the module, not + // the dynamic allocation. This just means the register limit used for the + // purpose of subtiling TMEM messages cannot be higher than the module's. + auto mod = op->getParentOfType(); + int maxnreg = maxRegisters; + + for (; op != mod; op = op->getParentOp()) { + if (std::optional limit = getFromParent(op)) { + maxnreg = std::min(maxnreg, *limit); + break; + } + } + + if (auto maxnregAttr = mod->getAttrOfType(AttrMaxRegistersName)) + maxnreg = std::min(maxnreg, maxnregAttr.getInt()); + + return maxnreg; +} + +FailureOr +lowerTMemLdSt(const LinearLayout &cvt, int maxnreg, int bitwidth, bool isScales, + std::function emitError, + bool unpacked = false) { + // We will fill in the returned value recursively (if it exists) + + // Remove broadcasting in the registers + auto removeBroadcastSrc = actionRemoveBroadcastedRegs(cvt); + if (!removeBroadcastSrc.isIdentity()) { + auto prmtCvt = removeBroadcastSrc.apply(cvt); + auto info = lowerTMemLdSt(prmtCvt, maxnreg, bitwidth, isScales, emitError, + unpacked); + if (failed(info)) + return failure(); + info->broadcast = std::move(removeBroadcastSrc); + return info; + } + auto *ctx = cvt.getInDimNames().begin()->getContext(); + auto S = [ctx](StringRef str) { return StringAttr::get(ctx, str); }; + auto kReg = S("register"); + auto kLane = S("lane"); + auto kRow = S("row"); + auto kCol = S("col"); + if (bitwidth < 32) { + LinearLayout quot; + int bestContig = 1; + for (int contig = 1; bitwidth * contig <= 32; contig *= 2) { + auto maybeQuot = + divideLeft(cvt, LinearLayout::identity1D(contig, kReg, kCol)); + if (!maybeQuot) + break; + quot = *maybeQuot; + bestContig = contig; + } + bool padding = false; + int newBitwidth = bitwidth; + if (bestContig > 1) { + // There are contiguous elements along kCol, so we can pack them into a + // larger dtype + unpacked = false; + newBitwidth = bitwidth * bestContig; + } else if (auto maybeQuot = divideLeft( + cvt, LinearLayout::zeros1D(1, kReg, kCol, 32 / bitwidth) * + LinearLayout::identity1D(2, kReg, kCol)); + bitwidth == 16 && maybeQuot) { + // Unpacked just supported for bitwidth 16 + unpacked = true; + quot = *maybeQuot; + newBitwidth = 32; + } else if (auto maybeQuot = divideLeft( + cvt, LinearLayout::zeros1D(1, kReg, kCol, 32 / bitwidth))) { + // We software-pad the elements when we either do not have enough elements + // to fill a full 32b register, e.g., colN = 1 and colStride != 1 or when + // bitwidth == 8 (this happens with scales with K=1). + // These two cases are mostly supported for testing purposes. + unpacked = bitwidth == 16; + quot = *maybeQuot; + padding = true; + newBitwidth = 32; + } else { + if (emitError) { + emitError() << "Failed to lower TMEM load/store: TMEM layout is not " + "packed or unpacked"; + } + return failure(); + } + // When unpacked each register moves 32/bitwidth (= 2) columns + if (unpacked) { + quot = LinearLayout::zeros1D(1, kReg, kCol, 32 / bitwidth) * quot; + } + auto info = lowerTMemLdSt(quot, maxnreg, newBitwidth, isScales, emitError, + unpacked); + if (failed(info)) + return failure(); + if (bestContig > 1) { + info->vec = bestContig; + } + if (unpacked) { + info->unpacked = true; + } + if (padding) { + info->padding = true; + } + return info; + } + + assert(bitwidth == 32); + + // The algorithm goes as: + // - Try to match the tile with one of the standard messages + // - If it doesn't match, we use the 16x32bx2 message + // Note that it can match one and only one of the layouts, even after register + // reordering, as the layouts yield predetermined positions for the lanes + // We store the instruction, the resulting reps layout, the permutation and + // the number of registers per message + std::optional msgInfo; + for (auto atom : {TMemAccessAtom::I32x32b, TMemAccessAtom::I16x256b, + TMemAccessAtom::I16x64b, TMemAccessAtom::I16x128b}) { + auto tile = getTileLayout(ctx, atom, unpacked); + auto maybeReps = getVec(cvt, tile, maxnreg); + if (maybeReps) { + // Cannot match more than one + msgInfo = {atom, std::get<0>(*maybeReps), std::get<1>(*maybeReps), + std::get<2>(*maybeReps)}; + break; + } + } + std::optional secondHalfOffset = std::nullopt; + if (!msgInfo) { + // Quotient by the smaller tile and then, if possible, we set the + // secondHalfOffset to the last kLane basis + auto tile = getTileLayout(ctx, TMemAccessAtom::I16x32bx2, unpacked); + auto maybeReps = getVec(cvt, tile, maxnreg); + if (maybeReps) { + auto [reps, perm, numRegsPerMessage] = std::move(*maybeReps); + // Find the last kLane basis and use it as secondHalfOffset + auto row = reps.getBasis(kLane, 4, kRow); + auto col = reps.getBasis(kLane, 4, kCol); + secondHalfOffset = (row << 16) | col; + if (*secondHalfOffset == 0) { + // Workaround for ptxas bug, we cannot use secondHalfOffset = 0 to write + // only 16 elements. We use secondHalfOffset = 1 instead and we pad the + // allocation. + if (!isScales) { + if (emitError) { + emitError() + << "Only supported for scales as we pad the allocation."; + } + return failure(); + } + secondHalfOffset = 1; + } + // We "quotient it out", meaning we remove the last basis from reps + auto basis = reps.getBases(); + basis[kLane][4] = {0, 0}; + reps = LinearLayout(basis, reps.getOutDims(), /*isSurjective=*/false); + msgInfo = {TMemAccessAtom::I16x32bx2, reps, perm, numRegsPerMessage}; + } + } + + if (!msgInfo) { + if (emitError) { + emitError() + << "Failed to lower TMEM load/store: unsupported dst layout\n" + + cvt.toString(); + } + return failure(); + } + auto info = std::move(*msgInfo); + info.secondHalfOffset = secondHalfOffset; + return info; +} + +FailureOr +computeTMemLdStEncodingInfo(RankedTensorType regTy, MemDescType memTy, + int maxnreg, + std::function emitError) { + auto memLayout = toLinearLayout(memTy); + auto regLayout = toLinearLayout(regTy); + auto cvt = regLayout.invertAndCompose(memLayout); + auto *ctx = regTy.getContext(); + auto S = [ctx](StringRef str) { return StringAttr::get(ctx, str); }; + auto kWarp = S("warp"); + auto kRow = S("row"); + // Warps 0-3 must map to row=32 and row=64 whether with broadcasting or not + if (!(regLayout.getBasis(kWarp, 0) == memLayout.getBasis(kRow, 5) && + regLayout.getBasis(kWarp, 1) == memLayout.getBasis(kRow, 6))) { + if (emitError) { + emitError() << "warps=1,2 must map to rows=32,64. Got:\n" + << regLayout.toString() << "\n" + << memLayout.toString(); + } + return failure(); + } + // Map warp bases to row=32 and row=64 in the cvt. This would be done + // automatically in `invertAndCompose` if we had a different dimension name + // for these rows. We can do this in the future if needed. + auto bases = cvt.getBases(); + bases[kWarp][0] = {32, 0}; + bases[kWarp][1] = {64, 0}; + cvt = LinearLayout(bases, cvt.getOutDims(), + /*isSurjective=*/cvt.isSurjective()); + + // tmemBase already encodes CTA/block offsets so we just remove them from the + // cvt + auto kBlock = StringAttr::get(ctx, "block"); + auto kCol = StringAttr::get(ctx, "col"); + auto nCTAs = cvt.getInDimSize(kBlock); + auto maybeQuot = + divideRight(cvt, LinearLayout::identity1D(nCTAs, kBlock, kCol)); + assert(maybeQuot.has_value()); + auto quot = maybeQuot->unsqueezeIn(kBlock); + + bool isScales = isa(memTy.getEncoding()); + int bitwidth = memTy.getElementTypeBitWidth(); + return lowerTMemLdSt(quot, maxnreg, bitwidth, isScales, emitError); +} + +} // namespace mlir::triton::nvidia_gpu diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemLayouts.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemLayouts.cpp index b821affba7..c9472bd128 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemLayouts.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemLayouts.cpp @@ -117,8 +117,12 @@ class TMemSplitLoadPattern : public OpRewritePattern { nOffset, splitNSize); // Choose a layout compatible with the slice size. - Attribute distLayout = getTmemCompatibleLayout( - mDim, splitNSize, splitOp.getOutLHS().getType(), numWarps); + gpu::MemDescType subSliceType = + cast(subSlice.getType()); + auto ctaLayout = + ttg::getCTALayout(splitOp.getOutLHS().getType().getEncoding()); + auto distLayout = nvidia_gpu::getDefaultLayoutForTmemLdSt( + subSliceType, numWarps, ctaLayout); RankedTensorType newLoadType = splitOp.getOutLHS().getType().cloneWithEncoding(distLayout); @@ -180,23 +184,22 @@ class TMemStoreJoinPattern : public OpRewritePattern { int numWarps = ttg::lookupNumWarps(storeOp); Value truePred = arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); - Attribute distLayout = getTmemCompatibleLayout( - mDim, splitNSize, joinOp.getLhs().getType(), numWarps); - auto newStoreType = joinOp.getLhs().getType().cloneWithEncoding(distLayout); - - // First slice. - auto subSlice0 = TMEMSubSliceOp::create(b, loc, tmem, 0, splitNSize); - auto cvt0 = - ttg::ConvertLayoutOp::create(b, loc, newStoreType, joinOp.getLhs()); - auto store0 = - TMEMStoreOp::create(b, loc, subSlice0, cvt0.getResult(), truePred); - // Second slice. - auto subSlice1 = - TMEMSubSliceOp::create(b, loc, tmem, splitNSize, splitNSize); - auto cvt1 = - ttg::ConvertLayoutOp::create(b, loc, newStoreType, joinOp.getRhs()); - auto store1 = - TMEMStoreOp::create(b, loc, subSlice1, cvt1.getResult(), truePred); + auto ctaLayout = ttg::getCTALayout(joinOp.getLhs().getType().getEncoding()); + auto *ctx = joinOp.getContext(); + + auto createSlice = [&](TypedValue input, int offset) { + auto subSlice = TMEMSubSliceOp::create(b, loc, tmem, offset, splitNSize); + auto distLayout = nvidia_gpu::getDefaultLayoutForTmemLdSt( + subSlice.getType(), numWarps, ctaLayout); + auto newType = input.getType().cloneWithEncoding(distLayout); + auto cvt = ttg::ConvertLayoutOp::create(b, loc, newType, input); + auto store = + TMEMStoreOp::create(b, loc, subSlice, cvt.getResult(), truePred); + return store; + }; + + auto store0 = createSlice(joinOp.getLhs(), 0); + auto store1 = createSlice(joinOp.getRhs(), splitNSize); b.eraseOp(storeOp); return success(); } @@ -218,14 +221,6 @@ class TMemLoadReducePattern : public OpRewritePattern { // is already reduction friendly. if (numWarps != 8) return failure(); - auto tmemEnc = dyn_cast( - tmemLoadOp.getSrc().getType().getEncoding()); - if (!tmemEnc) - return failure(); - int M = tmemEnc.getBlockM(); - int N = tmemEnc.getBlockN(); - if (M != 128) - return failure(); bool foundReductionAlongN = false; auto filter = [&](Operation *op) { if (isa(op) || op->hasTrait()) @@ -246,13 +241,15 @@ class TMemLoadReducePattern : public OpRewritePattern { // M = 96 warp 4 gets M = 16, warp 5 gets M = 48, warp 6 gets M = 80, // warp 7 gets M = 112 RankedTensorType oldType = tmemLoadOp.getType(); - Attribute newLayout = ttg::LinearEncodingAttr::get( - tmemLoadOp.getContext(), - ttg::getTmemLoadLayoutSplitLongM(M, N, oldType, numWarps)); - if (newLayout == oldType.getEncoding()) + std::optional newLayout = + getTmemLoadLayoutSplitLongM(oldType, tmemLoadOp.getSrc().getType(), + numWarps); + if (!newLayout) + return failure(); + if (newLayout.value() == oldType.getEncoding()) return failure(); - auto newType = oldType.cloneWithEncoding(newLayout); + auto newType = oldType.cloneWithEncoding(newLayout.value()); tmemLoadOp.getResult().setType(newType); OpBuilder builder(tmemLoadOp); builder.setInsertionPointAfter(tmemLoadOp); @@ -279,16 +276,19 @@ class TMemFromSharedMemPattern : public OpRewritePattern { int N = tmemEnc.getBlockN(); int numWarps = ttg::lookupNumWarps(tmemStoreOp); // Compute the alternative layout. - std::optional ll = gpu::getTmemLoadStoreLayout16x256( - M, N, tmemStoreOp.getSrc().getType(), numWarps); + auto ctaLayout = + ttg::getCTALayout(tmemStoreOp.getSrc().getType().getEncoding()); + std::optional ll = + nvidia_gpu::getDistributedLayoutForTmemLdSt( + tmemStoreOp.getDst().getType(), TMemAccessAtom::I16x256b, numWarps, + ctaLayout); if (!ll) return failure(); Attribute newEncoding = gpu::LinearEncodingAttr::get(tmemStoreOp.getContext(), *ll); - auto newType = RankedTensorType::get( - tmemStoreOp.getSrc().getType().getShape(), - tmemStoreOp.getSrc().getType().getElementType(), newEncoding); - if (newType == tmemStoreOp.getSrc().getType()) + auto oldType = tmemStoreOp.getSrc().getType(); + auto newType = oldType.cloneWithEncoding(newEncoding); + if (newType == oldType) return failure(); SetVector slice; @@ -345,17 +345,18 @@ class TMemToSharedMemPattern : public OpRewritePattern { int M = tmemEnc.getBlockM(); int N = tmemEnc.getBlockN(); int numWarps = ttg::lookupNumWarps(tmemLoadOp); + auto oldType = tmemLoadOp.getType(); + auto ctaLayout = ttg::getCTALayout(oldType.getEncoding()); + auto memType = cast(tmemLoadOp.getSrc().getType()); // Compute the alternative layout. - std::optional ll = - gpu::getTmemLoadStoreLayout16x256(M, N, tmemLoadOp.getType(), numWarps); + auto ll = nvidia_gpu::getDistributedLayoutForTmemLdSt( + memType, TMemAccessAtom::I16x256b, numWarps, ctaLayout); if (!ll) return failure(); Attribute newEncoding = gpu::LinearEncodingAttr::get(tmemLoadOp.getContext(), *ll); - auto newType = RankedTensorType::get(tmemLoadOp.getType().getShape(), - tmemLoadOp.getType().getElementType(), - newEncoding); - if (newType == tmemLoadOp.getType()) + auto newType = oldType.cloneWithEncoding(newEncoding); + if (newType == oldType) return failure(); SetVector slice; @@ -409,7 +410,6 @@ class TMemToSharedMemPattern : public OpRewritePattern { return failure(); // Use the new layout and rely on RemoveLayoutConversions pass to propagate // the convert_layout. - Type oldType = tmemLoadOp.getType(); rewriter.modifyOpInPlace( tmemLoadOp, [&]() { tmemLoadOp.getResult().setType(newType); }); rewriter.setInsertionPointAfter(tmemLoadOp); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp index 34b76c67a9..35a79a7145 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp @@ -19,18 +19,11 @@ namespace nvidia_gpu { namespace { template -Attribute getLHSTMemLayout(MMAOpTy tcGen5MMAOp, RankedTensorType srcType) { +Attribute getLHSTMemLayout(MMAOpTy tcGen5MMAOp, gpu::MemDescType lhsTMEMType, + ttg::CTALayoutAttr ctaLayout) { int numWarps = ttg::lookupNumWarps(tcGen5MMAOp); - auto accTmemEncoding = dyn_cast( - tcGen5MMAOp.getD().getType().getEncoding()); - auto lhs = tcGen5MMAOp.getA(); - auto lhsShape = lhs.getType().getShape(); - // M has to follow the MMA size, as it is related to the message we are using. - // N has to follow the number of columns in the LHS. - int M = accTmemEncoding.getBlockM(); - int N = lhsShape[1]; - Attribute resLayout = getTmemCompatibleLayout(M, N, srcType, numWarps); - return resLayout; + return nvidia_gpu::getDefaultLayoutForTmemLdSt(lhsTMEMType, numWarps, + ctaLayout); } template class LHSToTMem : public OpRewritePattern { @@ -79,7 +72,8 @@ template class LHSToTMem : public OpRewritePattern { if (!layoutTmemCompatible) { if (!comesFromLoadOrBlockArg(src) || triton::tools::getBoolEnv("ALLOW_LHS_TMEM_LAYOUT_CONVERSION")) { - newLayout = getLHSTMemLayout(tcGen5MMAOp, srcType); + newLayout = getLHSTMemLayout(tcGen5MMAOp, lhsMemDescType, + ttg::getCTALayout(srcType.getEncoding())); } else { return failure(); } diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 8c723ca3dc..11b4367072 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -462,6 +462,40 @@ LinearLayout LinearLayout::reshapeOuts( return LinearLayout(std::move(newBases), newOutDims, isSurjective()); } +LinearLayout LinearLayout::resizeInDim(StringAttr inDim, + int32_t newSize) const { + assert(llvm::isPowerOf2_32(newSize)); + assert(newSize <= getInDimSize(inDim)); + auto newBases = bases; + newBases[inDim].resize(llvm::Log2_32(newSize)); + return LinearLayout(std::move(newBases), getOutDims(), + /*requiresSurjective=*/false); +} + +LinearLayout LinearLayout::resizeOutDim(StringAttr outDim, + int32_t newSize) const { + assert(llvm::isPowerOf2_32(newSize)); + assert(newSize <= getOutDimSize(outDim)); + auto newBases = bases; + // Zero-out the basis vectors that are greater than or equal to the new size + for (auto &[inDim, inDimBases] : newBases) { + for (auto &basis : inDimBases) { + auto &b = basis[getOutDimIndex(outDim)]; + if (b >= newSize) { + b = 0; + } + } + } + auto outDims = getOutDims(); + for (auto &[outDim, outDimSize] : outDims) { + if (outDim == outDim) { + outDimSize = newSize; + } + } + return LinearLayout(std::move(newBases), outDims, + /*requiresSurjective=*/false); +} + LinearLayout LinearLayout::concatIns(const LinearLayout &other) const { assert(llvm::to_vector(getOutDimNames()) == llvm::to_vector(other.getOutDimNames()) && diff --git a/python/examples/gluon/01-attention-forward.py b/python/examples/gluon/01-attention-forward.py index 16d549b272..86b2bf4ce2 100644 --- a/python/examples/gluon/01-attention-forward.py +++ b/python/examples/gluon/01-attention-forward.py @@ -1,4 +1,5 @@ import copy +import math import torch import triton import pytest @@ -13,7 +14,7 @@ from triton.experimental.gluon.language.nvidia.blackwell import ( TensorMemoryLayout, allocate_tensor_memory, - get_tmem_32x32b_reg_layout, + get_tmem_reg_layout, tensor_memory_descriptor, tma, mbarrier, @@ -21,6 +22,7 @@ tcgen05_commit, float2, ) +from triton.experimental.gluon.language.nvidia.blackwell.float2 import Float2Tensor # ===-----------------------------------------------------------------------===# # Layout Utilities @@ -35,12 +37,6 @@ def get_mma_instr_shape(shape, element_ty): return (m, n, k) -@gluon.constexpr_function -def get_mma_reg_layout(shape, num_warps, dtype=gl.float32): - instr_shape = get_mma_instr_shape(shape, dtype) - return get_tmem_32x32b_reg_layout(*instr_shape[:2], shape, num_warps) - - # ===-----------------------------------------------------------------------===# # Data Abstractions # ===-----------------------------------------------------------------------===# @@ -227,7 +223,6 @@ class AttentionConfig: p_tmem_layout: gl.constexpr qk_layout: gl.constexpr - o_layout: gl.constexpr o_splitn_layout: gl.constexpr alpha_2d_layout: gl.constexpr @@ -266,14 +261,15 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE self.qk_tmem_layout = gl.constexpr(TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1)) self.o_tmem_layout = gl.constexpr(TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), col_stride=1)) self.p_tmem_layout = gl.constexpr(TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1)) + o_splitn_tmem_layout: gl.constexpr = TensorMemoryLayout( + (o_instr_shape[0], o_instr_shape[1] // self.SPLIT_D_FACTOR), col_stride=1) self.qk_layout = gl.constexpr( - get_tmem_32x32b_reg_layout(qk_instr_shape[0], qk_instr_shape[0], self.qk_shape, self.num_warps)) - self.o_layout = gl.constexpr( - get_tmem_32x32b_reg_layout(o_instr_shape[0], o_instr_shape[1], self.o_shape, self.num_warps)) + get_tmem_reg_layout(gl.float32, self.qk_shape, self.qk_tmem_layout, self.num_warps, + instr_variant="32x32b_splitn")) self.o_splitn_layout = gl.constexpr( - get_tmem_32x32b_reg_layout(o_instr_shape[0], o_instr_shape[1] // self.SPLIT_D_FACTOR, - (self.o_shape[0], self.o_shape[1] // self.SPLIT_D_FACTOR), self.num_warps)) + get_tmem_reg_layout(gl.float32, (self.o_shape[0], self.o_shape[1] // self.SPLIT_D_FACTOR), + o_splitn_tmem_layout, self.num_warps)) self.alpha_2d_layout = gl.constexpr(gl.BlockedLayout([1, 1], [32, 1], [self.num_warps, 1], [0, 1])) is_fp16 = self.dtype.value in [gl.float16, gl.bfloat16] @@ -399,10 +395,28 @@ def _borrow_s_for_epilogue(config, s_tmem): @gluon.constexpr_function -def _get_split_n_layout(layout, SPLIT_FACTOR: gl.constexpr = 2): - layout = copy.deepcopy(layout) - layout.size_per_thread[1] //= SPLIT_FACTOR - return layout +def _get_split_n_layout(layout: gl.constexpr, SPLIT_FACTOR: gl.constexpr = 2): + assert isinstance(layout, gl.DistributedLinearLayout), "split_n requires a distributed layout" + assert SPLIT_FACTOR == 1 or SPLIT_FACTOR == 2, "split_n requires a split factor of 1 or 2" + if SPLIT_FACTOR == 1: + return layout + else: + target = [0, layout.shape[1] // 2] # [0, 2^{m-1}] + last_reg_idx = len(layout.reg_bases) - 1 + reg_last = layout.reg_bases[last_reg_idx] + + if reg_last == target: + return layout + + ret = copy.deepcopy(layout) + + # Find [0, 2^{m-1}] across lists and swap it with last reg + for L in (ret.reg_bases, ret.lane_bases, ret.warp_bases, ret.block_bases): + for i, b in enumerate(L): + if b == target: + L[i], ret.reg_bases[last_reg_idx] = reg_last, target + return ret + assert False, f"split_n requires having a basis {target}. Got\n{layout}" @gluon.jit @@ -419,9 +433,17 @@ def _split_n(x, SPLIT_FACTOR: gl.constexpr = 2): @gluon.constexpr_function def _get_join_n_layout(layout, SPLIT_FACTOR: gl.constexpr = 2): - layout = copy.deepcopy(layout) - layout.size_per_thread[1] *= SPLIT_FACTOR - return layout + assert isinstance(layout, gl.DistributedLinearLayout), "join_n requires a Linear layout" + shape = list(layout.shape) + regs = [[0, shape[1] * (1 << i)] for i in range(int(math.log2(SPLIT_FACTOR)))] + shape[1] *= SPLIT_FACTOR + return gl.DistributedLinearLayout( + layout.reg_bases + regs, + layout.lane_bases, + layout.warp_bases, + layout.block_bases, + shape, + ) @gluon.jit @@ -572,7 +594,8 @@ def _compute_and_store_exp2(config, qk, p_tmem): @gluon.jit def _subtiled_qk_load(config, s_tmem): SIZE: gl.constexpr = s_tmem.shape[1] // config.SPLIT_QK_LOAD_FACTOR - layout: gl.constexpr = _get_split_n_layout(config.qk_layout, config.SPLIT_QK_LOAD_FACTOR) + s = s_tmem.slice(0, SIZE) + layout: gl.constexpr = get_tmem_reg_layout(gl.float32, s.shape, s.layout, config.num_warps) qks = () for i in gl.static_range(config.SPLIT_QK_LOAD_FACTOR): qks = qks + (s_tmem.slice(i * SIZE, SIZE).load(layout), ) @@ -623,6 +646,7 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, # mbarrier.arrive(exp_bar, count=1) l_ij = float2.pack2(*_split_n(p)).sum(axis=1) + l_ij = Float2Tensor(gl.convert_layout(l_ij.value, l_i.value.type.layout, assert_trivial=True)) alpha = gl.convert_layout(alpha, l_i.value.type.layout, assert_trivial=True) l_i = float2.fma(l_i, float2.pack2(alpha, alpha), l_ij) m_i = m_ij diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index e9e6effb4b..ce6dddbd6f 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -2,11 +2,16 @@ #include "pybind11/pybind11.h" #include +#include +#include + #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/Types.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Gluon/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" @@ -15,6 +20,8 @@ #include "triton/Tools/GenericSwizzling.h" #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/MathExtras.h" using namespace mlir; namespace py = pybind11; @@ -43,6 +50,7 @@ static void printDiagStr(llvm::raw_ostream &os, const Diagnostic &diag) { } struct GluonOpBuilder : public TritonOpBuilder { + using TritonOpBuilder::TritonOpBuilder; // Construct an attribute or type while calling its verifier. Error messages // are intercepted and sent back to Python via a C++ exception. template @@ -287,6 +295,14 @@ void init_gluon_ir(py::module &&m) { /*mutableMemory=*/true, /*allocShape=*/allocShape); }) + .def("get_cta_layout", + [](GluonOpBuilder &self, std::vector &ctasPerCga, + std::vector &ctaSplitNum, + std::vector &ctaOrder) -> Attribute { + auto ctx = self.getContext(); + return self.getChecked(ctx, ctasPerCga, + ctaSplitNum, ctaOrder); + }) .def("get_blocked_layout", [](GluonOpBuilder &self, std::vector &sizePerThread, std::vector &threadsPerWarp, @@ -819,6 +835,64 @@ void init_gluon_ir(py::module &&m) { self.create(tokens, num); }); + m.def( + "compute_tmem_reg_layout", + [](py::object elementTyObj, std::vector shape, + py::object layoutObj, unsigned numWarps, const std::string &atomName, + std::vector ctasPerCga, std::vector ctaSplitNum, + std::vector ctaOrder) -> py::object { + DialectRegistry registry; + registry.insert(); + MLIRContext context(MLIRContext::Threading::DISABLED); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + + GluonOpBuilder builder(&context); + auto builderObj = + py::cast(&builder, py::return_value_policy::reference); + + auto elementType = elementTyObj.attr("to_ir")(builderObj).cast(); + auto layoutAttr = + layoutObj.attr("_to_ir")(builderObj).cast(); + auto allocShape = shape; + + auto ctx = builder.getContext(); + auto memDescTy = builder.getChecked( + shape, elementType, layoutAttr, + ttng::TensorMemorySpaceAttr::get(ctx), + /*mutableMemory=*/true, allocShape); + auto ctaLayoutAttr = builder.getChecked( + ctx, ctasPerCga, ctaSplitNum, ctaOrder); + + auto maybeAtom = + llvm::StringSwitch>(atomName) + .Case("32x32b", ttng::TMemAccessAtom::I32x32b) + .Case("16x64b", ttng::TMemAccessAtom::I16x64b) + .Case("16x128b", ttng::TMemAccessAtom::I16x128b) + .Case("16x256b", ttng::TMemAccessAtom::I16x256b) + .Case("16x32bx2", ttng::TMemAccessAtom::I16x32bx2) + .Default(std::nullopt); + if (!maybeAtom) + throw std::invalid_argument("unknown TMEM access atom: " + atomName); + auto atom = *maybeAtom; + if (atom == ttng::TMemAccessAtom::I16x32bx2) + throw std::invalid_argument( + "Atom 16x32bx2 is inferred implicitly and cannot be requested " + "explicitly"); + if (numWarps < 4 || !llvm::isPowerOf2_32(numWarps)) + throw std::invalid_argument( + "numWarps must be a power of two and >= 4"); + + auto layout = ttng::getDistributedLayoutForTmemLdSt( + memDescTy, atom, numWarps, ctaLayoutAttr); + if (!layout) + return py::none(); + + auto attr = ttg::LinearEncodingAttr::get(ctx, *layout); + return layoutToGluon(attr); + }); + py::class_(m, "WarpSpecializeOp", py::module_local()) .def("get_default_region", &ttg::WarpSpecializeOp::getDefaultRegion, diff --git a/python/src/ir.cc b/python/src/ir.cc index 93be162289..20bff8a104 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -34,6 +34,7 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/FileSystem.h" @@ -340,6 +341,7 @@ void init_triton_ir(py::module &&m) { DialectRegistry registry; registry.insert 1) or (not transpose_b and swizzling_b == 0 and shape_n > 1)): fresh_knobs.nvidia.disable_ptxas_opt = True @@ -846,7 +846,7 @@ def kernel(s_ptr, out_ptr): s_tmem = allocate_tensor_memory(ttgl.float32, (BLOCK_M, N), layout=tmem_layout) o_tmem = allocate_tensor_memory(ttgl.float32, (BLOCK_M, N), layout=tmem_layout) - layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, (BLOCK_M, N), num_warps=4) + layout: ttgl.constexpr = get_tmem_reg_layout(ttgl.float32, (BLOCK_M, N), tmem_layout, num_warps=4) offsets = ttgl.arange(0, BLOCK_M)[:, None] * N + ttgl.arange(0, N)[None, :] offsets = ttgl.set_auto_layout(offsets, layout) @@ -860,7 +860,7 @@ def kernel(s_ptr, out_ptr): p_tmem.store(ttgl.full((BLOCK_M, N), 0.0, dtype=ttgl.float16, layout=layout)) d1_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, 2), col_stride=1) - d1_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, 2, (BLOCK_M, 2), num_warps=4) + d1_layout: ttgl.constexpr = get_tmem_reg_layout(ttgl.float32, (BLOCK_M, 2), d1_tmem_layout, num_warps=4) m_tmem = s_tmem.slice(N // 4, 2)._reinterpret(ttgl.float32, [BLOCK_M, 2], d1_tmem_layout) m_tmem.store(ttgl.full((BLOCK_M, 2), 2.0, dtype=ttgl.float32, layout=d1_layout)) @@ -925,7 +925,10 @@ def kernel(a_ptr, b_ptr, c_ptr, d_ptr): a_offsets = ttgl.arange(0, BLOCK_M)[:, None] * N + ttgl.arange(0, N)[None, :] b_offsets = ttgl.arange(0, N)[:, None] * N + ttgl.arange(0, N)[None, :] - a_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, (BLOCK_M, N), num_warps=4) + a_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), col_stride=1) + acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), col_stride=1) + a_layout: ttgl.constexpr = get_tmem_reg_layout(ttgl.float16, (BLOCK_M, N), a_tmem_layout, num_warps=4, + instr_variant="32x32b_splitn") b_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0]) a_offsets = ttgl.set_auto_layout(a_offsets, a_layout) b_offsets = ttgl.set_auto_layout(b_offsets, b_layout) @@ -934,8 +937,6 @@ def kernel(a_ptr, b_ptr, c_ptr, d_ptr): b = ttgl.load(b_ptr + b_offsets) c = ttgl.load(c_ptr + a_offsets) - a_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), col_stride=1) - acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), col_stride=1) al_tmem = allocate_tensor_memory(ttgl.float16, (BLOCK_M, N), layout=a_tmem_layout) ar_tmem = allocate_tensor_memory(ttgl.float16, (BLOCK_M, N), layout=a_tmem_layout) acc_tmem = allocate_tensor_memory(ttgl.float32, (BLOCK_M, N), layout=acc_tmem_layout) @@ -1078,10 +1079,14 @@ def test_tmem_copy_no_scales(M, N, BLOCK_N, num_warps, swizzle): @gluon.jit def tmem_copy_no_scales(in_ptr, out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, BLOCK_N: ttgl.constexpr, swizzle: ttgl.constexpr, num_warps: ttgl.constexpr): - tmem_reg_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout( - M=128, - N=BLOCK_N, - shape=[M, N], + tmem_layout: ttgl.constexpr = TensorMemoryLayout( + block=(128, BLOCK_N), + col_stride=32 // in_ptr.dtype.element_ty.primitive_bitwidth, + ) + tmem_reg_layout: ttgl.constexpr = get_tmem_reg_layout( + in_ptr.dtype.element_ty, + (M, N), + tmem_layout, num_warps=num_warps, ) offs_m = ttgl.arange(0, M, ttgl.SliceLayout(1, tmem_reg_layout)) @@ -1089,10 +1094,6 @@ def tmem_copy_no_scales(in_ptr, out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, B offs = offs_m[:, None] * N + offs_n[None, :] input = ttgl.load(in_ptr + offs) - tmem_layout: ttgl.constexpr = TensorMemoryLayout( - block=(128, BLOCK_N), - col_stride=32 // in_ptr.dtype.element_ty.primitive_bitwidth, - ) smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=swizzle, element_bitwidth=32, rank=2) smem = ttgl.allocate_shared_memory(in_ptr.dtype.element_ty, [M, N], layout=smem_layout) @@ -1433,16 +1434,19 @@ def kernel(out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, a, # Accumulator in TMEM initialized to ones acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([M, N], col_stride=1) - tmem_reg_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(M, N, [M, N], ttgl.num_warps()) + tmem_reg_layout: ttgl.constexpr = get_tmem_reg_layout(ttgl.float32, (M, N), acc_tmem_layout, ttgl.num_warps()) acc_init = ttgl.zeros([M, N], ttgl.float32, layout=tmem_reg_layout) acc_tmem = allocate_tensor_memory(ttgl.float32, [M, N], acc_tmem_layout, acc_init) # Zero scales in TMEM scale_layout: ttgl.constexpr = TensorMemoryScalesLayout() - scale_reg_layout: ttgl.constexpr = get_tmem_scales_reg_layout(M, N, [M, N], ttgl.num_warps()) - scale_offs_k = ttgl.arange(0, (K // 32), layout=ttgl.SliceLayout(0, scale_reg_layout))[None, :] - scale_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, scale_reg_layout))[:, None] - scale_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(1, scale_reg_layout))[:, None] + scale_reg_layout_m: ttgl.constexpr = get_tmem_reg_layout(ttgl.int8, (M, K // 32), scale_layout, + ttgl.num_warps()) + scale_reg_layout_n: ttgl.constexpr = get_tmem_reg_layout(ttgl.int8, (N, K // 32), scale_layout, + ttgl.num_warps()) + scale_offs_k = ttgl.arange(0, (K // 32), layout=ttgl.SliceLayout(0, scale_reg_layout_m))[None, :] + scale_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, scale_reg_layout_m))[:, None] + scale_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(1, scale_reg_layout_n))[:, None] a_scale_init = ttgl.load(a_scale + scale_offs_m * (K // 32) + scale_offs_k) b_scale_init = ttgl.load(b_scale + scale_offs_n * (K // 32) + scale_offs_k) a_scale_tmem = allocate_tensor_memory(ttgl.int8, [M, K // 32], scale_layout, a_scale_init) diff --git a/python/triton/experimental/gluon/language/_semantic.py b/python/triton/experimental/gluon/language/_semantic.py index 7e3579dc28..00d5701ad7 100644 --- a/python/triton/experimental/gluon/language/_semantic.py +++ b/python/triton/experimental/gluon/language/_semantic.py @@ -2,8 +2,8 @@ import math from triton.language.semantic import TritonSemantic from . import _core as ttgl -from ._layouts import AutoLayout, DistributedLayout, SliceLayout, SharedLayout -from triton._C.libtriton.gluon_ir import GluonOpBuilder +from ._layouts import AutoLayout, DistributedLayout, DistributedLinearLayout, SliceLayout, SharedLayout +from triton._C.libtriton.gluon_ir import GluonOpBuilder, compute_tmem_reg_layout from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values TensorTy = TypeVar("TensorTy") @@ -18,6 +18,71 @@ def _is_int_list(value): return isinstance(value, Sequence) and all(isinstance(i, int) for i in value) +def _compute_tmem_reg_layout(element_ty, shape, layout, num_warps, instr_variant, ctas_per_cga, cta_split_num, + cta_order): + _check(isinstance(instr_variant, str), lambda: "instr_variant must be a string") + _check(instr_variant in ("32x32b", "16x64b", "16x128b", "16x256b", "16x32bx2", "32x32b_splitn"), + lambda: f"unknown instr_variant: {instr_variant}") + _check(isinstance(num_warps, int), lambda: f"num_warps must be an int but got {type(num_warps)!r}") + _check(num_warps >= 4 and (num_warps & (num_warps - 1)) == 0, lambda: "num_warps must be a power of two and >= 4") + + shape = list(shape) + _check(all(isinstance(dim, int) for dim in shape), lambda: f"shape entries must be ints but got {shape}") + rank = len(shape) + _check(rank == 2, lambda: "expected a 2D tensor") + + ctas_per_cga = list(ctas_per_cga) + cta_split_num = list(cta_split_num) + cta_order = list(cta_order) + splitn = instr_variant == "32x32b_splitn" + atom_variant = "32x32b" if splitn else instr_variant + + _check(len(ctas_per_cga) == rank, lambda: "ctas_per_cga rank mismatch") + _check(len(cta_split_num) == rank, lambda: "cta_split_num rank mismatch") + _check(len(cta_order) == rank, lambda: "cta_order rank mismatch") + + layout_obj = compute_tmem_reg_layout( + element_ty, + shape, + layout, + num_warps, + atom_variant, + ctas_per_cga, + cta_split_num, + cta_order, + ) + _check(layout_obj is not None, + lambda: f"TMEM layout '{atom_variant}' unsupported for shape {shape} and num_warps {num_warps}") + + if splitn: + N = shape[1] + if not layout_obj.reg_bases: + # We cannot use this layout in a load or a store ATM due to a PTX bug! + # You can work around this by loading to 32x32b and follow by a convert_layout to this layout. + _check(layout_obj.lane_bases[-1] == [0, N // 2], + lambda: f"splitn with 1 register requires the last lane basis to be [0, N / 2]. Got {layout_obj}") + layout_obj.reg_bases.append([0, N // 2]) + layout_obj.lane_bases[-1] = [0, 0] + elif layout_obj.reg_bases[-1] != [0, N // 2]: + bitwidth = element_ty.primitive_bitwidth + _check( + len(layout_obj.reg_bases) * bitwidth > 32, + lambda: "splitn requires register bases of more than 2 32 bit registers") + + reg_bases = layout_obj.reg_bases + for bases_str in ("lane_bases", "warp_bases"): + bases = getattr(layout_obj, bases_str) + for i, basis in enumerate(bases): + if basis == [0, N // 2]: + reg_bases[-1], bases[i] = bases[i], reg_bases[-1] + return layout_obj + assert False, f"splitn requires at least one basis of [0, N / 2]. Got {layout}" + return layout_obj + + +_compute_tmem_reg_layout.__triton_builtin__ = True + + class GluonCallerContext: def __init__(self, num_warps: int): @@ -177,7 +242,9 @@ def convert_layout(self, value, layout, assert_trivial=False): ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout) ret_ty_ir = ret_ty.to_ir(self.builder) if assert_trivial and not self.builder.is_convert_layout_trivial(ret_ty_ir, value.handle): - raise TypeError(f"layout conversion from {ty.layout} to {layout} is not trivial") + raise TypeError(f"layout conversion from {ty.layout} to {layout} is not trivial.\n" + f"The linear layouts are:\n{self.to_linear_layout(ty.layout, ty.shape)}\n" + f"{self.to_linear_layout(layout, ty.shape)}") handle = self.builder.create_convert_layout(ret_ty_ir, value.handle) return ttgl.tensor(handle, ret_ty) @@ -241,7 +308,12 @@ def to_linear_layout(self, layout, shape): if not isinstance(shape, list): shape = list(shape) - return self.builder.to_linear_layout(layout._to_ir(self.builder), shape) + layout = ttgl._unwrap_if_constexpr(layout) + + if isinstance(layout, (AutoLayout, DistributedLinearLayout)): + return ttgl.constexpr(layout) + + return ttgl.constexpr(self.builder.to_linear_layout(layout._to_ir(self.builder), shape)) def shared_dealloc(self, mem_desc): self.builder.create_local_dealloc(mem_desc.handle) diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index 2636d1d72e..3523042295 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -5,8 +5,7 @@ from triton.runtime.jit import constexpr_function from triton.experimental.gluon.language import _core as ttgl from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr -from triton.experimental.gluon.language._layouts import BlockedLayout, _get_shape_per_cta, DistributedLinearLayout -from triton.experimental.gluon.language._semantic import _check +from triton.experimental.gluon.language._semantic import _check, _compute_tmem_reg_layout from . import tma from ..hopper import fence_async_shared, mbarrier @@ -21,8 +20,7 @@ "allocate_tensor_memory", "async_copy", "fence_async_shared", - "get_tmem_32x32b_reg_layout", - "get_tmem_scales_reg_layout", + "get_tmem_reg_layout", "mbarrier", "mma_v2", "tensor_memory_descriptor", @@ -91,101 +89,48 @@ def mangle(self) -> str: @constexpr_function -def _cdiv(x, div): - return (x + div - 1) // div - - -@constexpr_function -def get_tmem_32x32b_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_split_num=None, cta_order=None): - """Returns a BlockedLayout compatible with load/store on tensor memory with the 32x32b instruction variant. +def get_tmem_reg_layout( + element_ty, + shape, + layout, + num_warps, + instr_variant="32x32b", + ctas_per_cga=(1, 1), + cta_split_num=(1, 1), + cta_order=(1, 0), +): """ - assert len(shape) == 2, "expected a 2D tensor" - assert num_warps in [4, 8], "expected 4 or 8 warps" - - shape_per_cta = _get_shape_per_cta(shape, cta_split_num) - blocks_per_tile = [shape_per_cta[0] // M, shape_per_cta[1] // N] - num_blocks = blocks_per_tile[0] * blocks_per_tile[1] - - num_warp_groups = num_warps // 4 - if M == 64: - threads_per_warp = [16, 2] - if num_blocks == 1: - size_per_thread = [1, _cdiv(N, num_warp_groups * 2)] - warps_per_cta = [4, num_warp_groups] - else: - size_per_thread = [1, _cdiv(N, 2)] - warps_per_cta = [4 * min(blocks_per_tile[0], num_warp_groups)] - warps_per_cta.append(_cdiv(num_warp_groups, warps_per_cta[0] // 4)) - else: - if shape[0] > 128: - size_per_thread = [1, N] - threads_per_warp = [32, 1] - warps_per_cta = [4 * num_warp_groups, 1] - else: - size_per_thread = [1, _cdiv(N, num_warp_groups)] - threads_per_warp = [32, 1] - warps_per_cta = [4, num_warp_groups] - return BlockedLayout( - size_per_thread=size_per_thread, - threads_per_warp=threads_per_warp, - warps_per_cta=warps_per_cta, - order=[0, 1], - ctas_per_cga=ctas_per_cga, - cta_split_num=cta_split_num, - cta_order=cta_order, - ) - + Returns a DistributedLinearLayout compatible with TMEM load/store instructions. -@constexpr_function -def get_tmem_scales_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_split_num=None, cta_order=None): - """Return a linear layout that is compatible with tmem scaled layout. + Args: + element_ty (dtype): Element type stored in tensor memory. + shape (Sequence[int]): Global tensor shape addressed by the TMEM descriptor. + layout (TensorMemoryLayout): Tensor memory layout descriptor. + num_warps (int): Number of warps participating in the operation. + instr_variant (str): TMEM instruction variant (e.g. ``\"32x32b\"``). + ctas_per_cga (tuple[int, int]): CTA grouping along each dimension. + cta_split_num (tuple[int, int]): CTA split factors along each dimension. + cta_order (tuple[int, int]): CTA order. """ - assert len(shape) == 2, "expected a 2D tensor" - assert num_warps in [4, 8], "expected 4 or 8 warps" - - # Use per-CTA shape to build the linear layout bases - shape_per_cta = _get_shape_per_cta(shape, cta_split_num) - M_cta, N_cta = shape_per_cta[0], shape_per_cta[1] - - # Register bases: pack 4 scales together along N; if fewer than 4, replicate. - reg_bases = [] - i = 1 - while i < 4: - if i >= N_cta: - reg_bases.append([0, 0]) - else: - reg_bases.append([0, i]) - i <<= 1 - - # Lane bases: distribute 32 rows of M along a warp. - lane_bases = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]] - - # Warp bases: replicate across warps within a warpgroup by default. - warp_bases = [[0, 0], [0, 0]] - - # Extend register bases for larger M and N beyond the initial pack. - i = 32 - while i < M_cta: - reg_bases.append([i, 0]) - i <<= 1 - - i = 4 - while i < N_cta: - reg_bases.append([0, i]) - i <<= 1 - - # For 8 warps, distribute the last dimension on the second warpgoup. - if num_warps == 8: - warp_bases.append(reg_bases[-1]) - reg_bases.pop() - - # No explicit CTA mapping here; the register layout is per-CTA. - return DistributedLinearLayout( - reg_bases=reg_bases, - lane_bases=lane_bases, - warp_bases=warp_bases, - block_bases=[], - shape=shape_per_cta, + + def _unwrap(x): + if isinstance(x, ttgl.constexpr): + return _unwrap(x.value) + if isinstance(x, list): + return [_unwrap(i) for i in x] + if isinstance(x, tuple): + return tuple(_unwrap(i) for i in x) + return x + + return _compute_tmem_reg_layout( + _unwrap(element_ty), + _unwrap(shape), + _unwrap(layout), + _unwrap(num_warps), + _unwrap(instr_variant), + _unwrap(ctas_per_cga), + _unwrap(cta_split_num), + _unwrap(cta_order), ) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3faefc88b6..d239bb6a7a 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -570,7 +570,7 @@ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: def to_ir(self, builder: ir.builder) -> ir.type: if self.name.startswith("fp8"): - if self.name not in builder.options.supported_fp8_dtypes: + if hasattr(builder, "options") and self.name not in builder.options.supported_fp8_dtypes: raise ValueError(f'type {self} not supported in this architecture. ' f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') diff --git a/python/triton/tools/triton_to_gluon_translater/translator_helpers.py b/python/triton/tools/triton_to_gluon_translater/translator_helpers.py index e4c07fb1c1..15905094d1 100644 --- a/python/triton/tools/triton_to_gluon_translater/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translater/translator_helpers.py @@ -5,8 +5,7 @@ TensorMemoryLayout, TensorMemoryScalesLayout, allocate_tensor_memory, - get_tmem_32x32b_reg_layout, - get_tmem_scales_reg_layout, + get_tmem_reg_layout, tcgen05_mma, tcgen05_mma_scaled, tcgen05_commit, @@ -39,7 +38,8 @@ def tl_dot(a, b, acc=None, input_precision=None, allow_tf32=None, max_num_imprec a_smem = ttgl.allocate_shared_memory(a.dtype, [M, K], nvmma_layout_a, a) b_smem = ttgl.allocate_shared_memory(b.dtype, [K, N], nvmma_layout_b, b) acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([M, N], col_stride=1) - tmem_reg_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(M, N, [M, N], ttgl.num_warps()) + acc_dtype: ttgl.constexpr = acc.dtype if acc is not None else out_dtype + tmem_reg_layout: ttgl.constexpr = get_tmem_reg_layout(acc_dtype, (M, N), acc_tmem_layout, ttgl.num_warps()) if acc is not None: acc_temp = ttgl.convert_layout(acc, tmem_reg_layout) else: @@ -79,7 +79,8 @@ def tl_dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=No a_smem = ttgl.allocate_shared_memory(lhs.dtype, lhs.shape, nvmma_layout_a, lhs) b_smem = ttgl.allocate_shared_memory(rhs.dtype, rhs.shape, nvmma_layout_b, rhs) acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([M, N], col_stride=1) - tmem_reg_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(M, N, [M, N], ttgl.num_warps()) + acc_dtype: ttgl.constexpr = acc.dtype if acc is not None else out_dtype + tmem_reg_layout: ttgl.constexpr = get_tmem_reg_layout(acc_dtype, (M, N), acc_tmem_layout, ttgl.num_warps()) if acc is not None: acc_temp = ttgl.convert_layout(acc, tmem_reg_layout) else: @@ -88,10 +89,10 @@ def tl_dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=No bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) mbarrier.init(bar, count=1) scale_layout: ttgl.constexpr = TensorMemoryScalesLayout() - scale_layout_reg_lhs: ttgl.constexpr = get_tmem_scales_reg_layout(lhs_scale.type.shape[0], lhs_scale.type.shape[1], - lhs_scale.type.shape, ttgl.num_warps()) - scale_layout_reg_rhs: ttgl.constexpr = get_tmem_scales_reg_layout(rhs_scale.type.shape[1], rhs_scale.type.shape[0], - rhs_scale.type.shape, ttgl.num_warps()) + scale_layout_reg_lhs: ttgl.constexpr = get_tmem_reg_layout(lhs_scale.type.element_ty, lhs_scale.type.shape, + scale_layout, ttgl.num_warps()) + scale_layout_reg_rhs: ttgl.constexpr = get_tmem_reg_layout(rhs_scale.type.element_ty, rhs_scale.type.shape, + scale_layout, ttgl.num_warps()) lhs_scale = ttgl.convert_layout(lhs_scale, scale_layout_reg_lhs) rhs_scale = ttgl.convert_layout(rhs_scale, scale_layout_reg_rhs) a_scale_tmem = allocate_tensor_memory(lhs_scale.dtype, lhs_scale.shape, scale_layout, lhs_scale) diff --git a/python/tutorials/gluon/06-tcgen05.py b/python/tutorials/gluon/06-tcgen05.py index 614a5926be..ac11f3c622 100644 --- a/python/tutorials/gluon/06-tcgen05.py +++ b/python/tutorials/gluon/06-tcgen05.py @@ -23,7 +23,7 @@ from triton.experimental.gluon.language.nvidia.blackwell import ( TensorMemoryLayout, allocate_tensor_memory, - get_tmem_32x32b_reg_layout, + get_tmem_reg_layout, tma, mbarrier, tcgen05_mma, @@ -117,10 +117,10 @@ def tmem_example_kernel(in_ptr, out_ptr, M: gl.constexpr, N: gl.constexpr, num_w ) # Get the register layout needed to access the tensor memory using a helper. - tmem_reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout( - M=64, - N=64, - shape=[M, N], + tmem_reg_layout: gl.constexpr = get_tmem_reg_layout( + in_ptr.dtype.element_ty, + (M, N), + tmem_layout, num_warps=num_warps, ) @@ -188,7 +188,12 @@ def small_mma_kernel(a_desc, b_desc, c_desc, d_desc, tmem_block: gl.constexpr, col_stride=32 // d_desc.dtype.primitive_bitwidth, ) acc_tmem = allocate_tensor_memory(d_desc.dtype, [M, N], acc_tmem_layout) - acc_reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout(tmem_block[0], tmem_block[1], [M, N], num_warps) + acc_reg_layout: gl.constexpr = get_tmem_reg_layout( + d_desc.dtype, + (M, N), + acc_tmem_layout, + num_warps, + ) acc = c_smem.load(acc_reg_layout) acc_tmem.store(acc) @@ -200,7 +205,12 @@ def small_mma_kernel(a_desc, b_desc, c_desc, d_desc, tmem_block: gl.constexpr, ) lhs_tmem = allocate_tensor_memory(a_desc.dtype, [M, K], lhs_tmem_layout) - lhs_reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout(M, K, [M, K], num_warps) + lhs_reg_layout: gl.constexpr = get_tmem_reg_layout( + a_desc.dtype, + (M, K), + lhs_tmem_layout, + num_warps, + ) lhs = a_smem.load(lhs_reg_layout) lhs_tmem.store(lhs) a = lhs_tmem @@ -352,7 +362,12 @@ def blocked_matmul_kernel(a_desc, b_desc, c_desc, TRANSPOSE_B: gl.constexpr, num mbarrier.invalidate(tma_bar) mbarrier.invalidate(mma_bar) - acc_reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, [BLOCK_M, BLOCK_N], num_warps) + acc_reg_layout: gl.constexpr = get_tmem_reg_layout( + gl.float32, + (BLOCK_M, BLOCK_N), + tmem_layout, + num_warps, + ) acc = acc_tmem.load(acc_reg_layout) # Downcast accumulator and store tile of C. @@ -570,7 +585,12 @@ def blocked_matmul_pipelined_kernel(a_desc, b_desc, c_desc, num_warps: gl.conste tma.async_copy_global_to_shared(a_desc, [off_m + BLOCK_M, k], load_v_bar, v_bufs.index(load_index)) k += BLOCK_K - acc_reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, [BLOCK_M, BLOCK_N], num_warps) + acc_reg_layout: gl.constexpr = get_tmem_reg_layout( + gl.float32, + (BLOCK_M, BLOCK_N), + tmem_layout, + num_warps, + ) mma_index, mma_phase, mma_counter = get_and_increment(mma_counter) ub_bar = mma_ub_bars.index(mma_index) diff --git a/python/tutorials/gluon/07-persistence.py b/python/tutorials/gluon/07-persistence.py index bf86cfc65c..5e578b278e 100644 --- a/python/tutorials/gluon/07-persistence.py +++ b/python/tutorials/gluon/07-persistence.py @@ -48,7 +48,7 @@ TensorMemoryLayout, tensor_memory_descriptor, allocate_tensor_memory, - get_tmem_32x32b_reg_layout, + get_tmem_reg_layout, tcgen05_mma, tcgen05_commit, ) @@ -149,7 +149,7 @@ def initialize(dtype: gl.constexpr, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr acc_tmem = allocate_tensor_memory(gl.float32, [BLOCK_M, BLOCK_N], layout) bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) mbarrier.init(bar, count=1) - reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, [BLOCK_M, BLOCK_N], num_warps) + reg_layout: gl.constexpr = get_tmem_reg_layout(gl.float32, (BLOCK_M, BLOCK_N), layout, num_warps) return MMAv5(gl.to_tensor(False), acc_tmem, bar, gl.to_tensor(0), reg_layout) @gluon.jit diff --git a/python/tutorials/gluon/08-warp-specialization.py b/python/tutorials/gluon/08-warp-specialization.py index ed0a9940cc..521460b30e 100644 --- a/python/tutorials/gluon/08-warp-specialization.py +++ b/python/tutorials/gluon/08-warp-specialization.py @@ -35,7 +35,7 @@ TensorMemoryLayout, tensor_memory_descriptor, allocate_tensor_memory, - get_tmem_32x32b_reg_layout, + get_tmem_reg_layout, tcgen05_mma, tcgen05_commit, ) @@ -529,7 +529,13 @@ def matmul_epilogue_partition(p, SchedulerImpl: gl.constexpr): acc_empty_bars = p.acc_empty_bars acc_ready_bars = p.acc_ready_bars acc_state = Counter.create(0, p.acc_empty_bars.shape[0]) - acc_layout: gl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, [BLOCK_M, BLOCK_N], p.num_warps) + acc_tmem_layout: gl.constexpr = TensorMemoryLayout([BLOCK_M, BLOCK_N], col_stride=1) + acc_layout: gl.constexpr = get_tmem_reg_layout( + dtype, + (BLOCK_M, BLOCK_N), + acc_tmem_layout, + p.num_warps, + ) SPLIT_N: gl.constexpr = BLOCK_N // p.SUBTILE_FACTOR acc_smem = gl.allocate_shared_memory(dtype, [BLOCK_M, SPLIT_N], p.c_desc.layout) diff --git a/test/Conversion/relayout_tritongpu.mlir b/test/Conversion/relayout_tritongpu.mlir index 341a19271a..34bc3188b8 100644 --- a/test/Conversion/relayout_tritongpu.mlir +++ b/test/Conversion/relayout_tritongpu.mlir @@ -5,22 +5,22 @@ #tmem2 = #ttng.tensor_memory_encoding #tmem_scales = #ttng.tensor_memory_scales_encoding<> -// CHECK-DAG: [[BLOCKN64:#.*]] = #ttg.blocked<{sizePerThread = [1, 64] -// CHECK-DAG: [[BLOCKN128:#.*]] = #ttg.blocked<{sizePerThread = [1, 128] +// CHECK-DAG: [[LINEAR64:#.*]] = #ttg.linear<{register = {{\[\[}}0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [16, 0{{]]}}, warp = {{\[\[}}32, 0], [64, 0{{]]}}, block = []}> +// CHECK-DAG: [[LINEAR128:#.*]] = #ttg.linear<{register = {{\[\[}}0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [16, 0{{]]}}, warp = {{\[\[}}32, 0], [64, 0{{]]}}, block = []}> // CHECK-DAG: [[SCALES:#.*]] = #ttg.linear<{register = {{\[\[}}0, 1], [0, 2], [32, 0], [64, 0], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [16, 0{{]]}}, warp = {{\[\[}}0, 0], [0, 0{{]]}}, block = []}> -// CHECK-DAG: [[BLOCK64_SPLIT:#.*]] = #ttg.blocked<{sizePerThread = [1, 32] +// CHECK-DAG: [[LINEAR_STORE:#.*]] = #ttg.linear<{register = {{\[\[}}0, 1], [0, 2], [0, 4], [0, 8], [0, 16{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [0, 32{{]]}}, warp = {{\[\[}}16, 0], [32, 0{{]]}}, block = []}> // CHECK: @tmem_alloc tt.func @tmem_alloc() { %cst = arith.constant dense<1.0> : tensor<128x128xf32> - // CHECK: ttng.tmem_alloc {{.*}} (tensor<128x128xf32, [[BLOCKN128]]>) -> + // CHECK: ttng.tmem_alloc {{.*}} (tensor<128x128xf32, [[LINEAR128]]>) -> %result = ttng.tmem_alloc %cst : (tensor<128x128xf32>) -> !ttg.memdesc<128x128xf32, #tmem0, #ttng.tensor_memory> tt.return } // CHECK: @tmem_load tt.func @tmem_load(%desc: !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory>) { - // CHECK: ttng.tmem_load {{.*}} -> tensor<128x64xf32, [[BLOCKN64]]> + // CHECK: ttng.tmem_load {{.*}} -> tensor<128x64xf32, [[LINEAR64]]> %result = ttng.tmem_load %desc : !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory> -> tensor<128x64xf32> tt.return } @@ -29,7 +29,7 @@ tt.func @tmem_load(%desc: !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory>) tt.func @tmem_store(%desc: !ttg.memdesc<64x64xf32, #tmem2, #ttng.tensor_memory, mutable>) { %cst = arith.constant dense<1.0> : tensor<64x64xf32> %true = arith.constant true - // CHECK: ttng.tmem_store {{.*}} tensor<64x64xf32, [[BLOCK64_SPLIT]]> -> + // CHECK: ttng.tmem_store {{.*}} tensor<64x64xf32, [[LINEAR_STORE]]> -> ttng.tmem_store %cst, %desc, %true : tensor<64x64xf32> -> !ttg.memdesc<64x64xf32, #tmem2, #ttng.tensor_memory, mutable> tt.return } diff --git a/test/Conversion/tritongpu_to_llvm_blackwell.mlir b/test/Conversion/tritongpu_to_llvm_blackwell.mlir index febda9c140..3a1579b292 100644 --- a/test/Conversion/tritongpu_to_llvm_blackwell.mlir +++ b/test/Conversion/tritongpu_to_llvm_blackwell.mlir @@ -150,22 +150,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [64, 0]], warp = [[16, 0], [32, 0]], block = []}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @tensor_memory_ld_m64 // CHECK: nvgpu.tensor_memory_base - // CHECK: tcgen05.st.sync.aligned.16x32bx2.x64.b32 - // CHECK: tcgen05.st.sync.aligned.16x32bx2.x64.b32 + // CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32 // CHECK: nvvm.tcgen05.wait - // CHECK: tcgen05.ld.sync.aligned.16x32bx2.x64.b32 - // CHECK: tcgen05.ld.sync.aligned.16x32bx2.x64.b32 + // CHECK: tcgen05.ld.sync.aligned.32x32b.x128.b32 // CHECK: nvvm.tcgen05.wait tt.func public @tensor_memory_ld_m64(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> - %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> - %20 = ttng.tmem_load %0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #linear> + %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #linear>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %20 = ttng.tmem_load %0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #linear> tt.return } } @@ -449,13 +446,13 @@ module attributes {"ttg.num-warps" = 8 : i32} { // ----- -#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[32, 0], [64, 0], [16, 0]], block = []}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0], [128, 0]], block = []}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @tensor_memory_ld_256x64_8_warps_splitM tt.func public @tensor_memory_ld_256x64_8_warps_splitM(%tmem: !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable>) { - // CHECK-COUNT-2: tcgen05.ld.sync.aligned.16x32bx2.x32.b32 + // CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32 // CHECK-NOT: tcgen05.ld %result = ttng.tmem_load %tmem : !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<256x64xf32, #linear> tt.return @@ -464,13 +461,13 @@ module attributes {"ttg.num-warps" = 8 : i32} { // ----- -#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 64]], warp = [[32, 0], [64, 0], [16, 0]], block = []}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0], [0, 64]], block = []}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @tensor_memory_ld_128x128_8_warps_splitM tt.func public @tensor_memory_ld_128x128_8_warps_splitM(%tmem: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>) { - // CHECK-COUNT-1: tcgen05.ld.sync.aligned.16x32bx2.x64.b32 + // CHECK-COUNT-1: tcgen05.ld.sync.aligned.32x32b.x64.b32 // CHECK-NOT: tcgen05.ld %result = ttng.tmem_load %tmem : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #linear> tt.return @@ -479,13 +476,13 @@ module attributes {"ttg.num-warps" = 8 : i32} { // ----- -#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[32, 0], [64, 0], [16, 0]], block = []}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0], [0, 32]], block = []}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @tensor_memory_ld_128x64_8_warps_splitM tt.func public @tensor_memory_ld_128x64_8_warps_splitM(%tmem: !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>) { - // CHECK-COUNT-1: tcgen05.ld.sync.aligned.16x32bx2.x32.b32 + // CHECK-COUNT-1: tcgen05.ld.sync.aligned.32x32b.x32.b32 // CHECK-NOT: tcgen05.ld %result = ttng.tmem_load %tmem : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #linear> tt.return diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 7b42b91e2e..e5b4e9f9f6 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -229,7 +229,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // LAYOUT_16x256{LITERAL}: #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[32, 0], [64, 0]], block = []}> // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> - // CHECK-DAG: #[[$T:.+]] = #ttg.blocked<{sizePerThread = [1, 256], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> + // CHECK-DAG: #[[$L:.+]] = #ttg.linear<{register = {{\[\[0, 1\], \[0, 2\], \[0, 4\], \[0, 8\], \[0, 16\], \[0, 32\], \[0, 64\], \[0, 128\]\]}}, lane = {{\[\[1, 0\], \[2, 0\], \[4, 0\], \[8, 0\], \[16, 0\]\]}}, warp = {{\[\[32, 0\], \[64, 0\]\]}}, block = []}> // CHECK-LABEL: mmav5 // CHECK-DAG: %[[TRUE:.+]] = arith.constant true // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !ttg.memdesc<128x64xf16, #{{.*}}, #smem @@ -237,7 +237,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-DAG: %[[ACC:.+]], %[[ACC_TOK:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x256xf32, #{{.*}}>) -> (!ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable>, !ttg.async.token) // CHECK: %[[MMA_TOK:.+]] = ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]][%[[ACC_TOK]]], %[[TRUE]], %[[TRUE]] : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x256xf16, #shared, #smem>, !ttg.memdesc<128x256xf32, #[[$TMEM]], #ttng.tensor_memory, mutable> // CHECK: %[[R:.+]], %{{.*}} = ttng.tmem_load %[[ACC]][%[[MMA_TOK]]] : !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> -> tensor<128x256xf32 - // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$T]]> -> tensor<128x256xf32, #[[$B]]> + // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$L]]> -> tensor<128x256xf32, #[[$B]]> // CHECK: tt.return %[[CVT]] : tensor<128x256xf32 tt.func public @mmav5(%a: tensor<128x64xf16, #blocked2>, %b: tensor<64x256xf16, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { %ad = ttg.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> @@ -292,7 +292,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // LAYOUT_16x256{LITERAL}: #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = [[64, 0]]}> // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> - // CHECK-DAG: #[[$T:.+]] = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> + // CHECK-DAG: #[[$L:.+]] = #ttg.linear<{register = {{\[\[0, 1\], \[0, 2\], \[0, 4\], \[0, 8\], \[0, 16\], \[0, 32\], \[0, 64\]\]}}, lane = {{\[\[1, 0\], \[2, 0\], \[4, 0\], \[8, 0\], \[0, 128\]\]}}, warp = {{\[\[16, 0\], \[32, 0\]\]}}, block = {{\[\[64, 0\]\]}}}> // CHECK-LABEL: mmav5 // CHECK-DAG: %[[TRUE:.+]] = arith.constant true // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !ttg.memdesc<128x64xf16, #{{.*}}, #smem @@ -300,7 +300,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-DAG: %[[ACC:.+]], %[[ACC_TOK:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x256xf32, #{{.*}}>) -> (!ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable>, !ttg.async.token) // CHECK: %[[MMA_TOK:.+]] = ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]][%[[ACC_TOK]]], %[[TRUE]], %[[TRUE]] : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x256xf16, #shared, #smem>, !ttg.memdesc<128x256xf32, #[[$TMEM]], #ttng.tensor_memory, mutable> // CHECK: %[[R:.+]], %{{.*}} = ttng.tmem_load %[[ACC]][%[[MMA_TOK]]] : !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> -> tensor<128x256xf32 - // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$T]]> -> tensor<128x256xf32, #[[$B]]> + // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$L]]> -> tensor<128x256xf32, #[[$B]]> // CHECK: tt.return %[[CVT]] : tensor<128x256xf32 tt.func public @mmav5_multi_ctas(%a: tensor<128x64xf16, #blocked2>, %b: tensor<64x256xf16, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { %ad = ttg.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> @@ -319,7 +319,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> - // CHECK-DAG: #[[$T:.+]] = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> + // CHECK-DAG: #[[$L:.+]] = #ttg.linear<{register = {{\[\[0, 1\], \[0, 2\], \[0, 4\], \[0, 8\], \[0, 16\], \[0, 32\], \[0, 64\]\]}}, lane = {{\[\[1, 0\], \[2, 0\], \[4, 0\], \[8, 0\], \[0, 128\]\]}}, warp = {{\[\[16, 0\], \[32, 0\]\]}}, block = {{\[\[64, 0\]\]}}}> // CHECK-DAG: #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> // CHECK-DAG: #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> // CHECK-LABEL: mmav5 @@ -329,7 +329,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-DAG: %[[ACC:.+]], %[[ACC_TOK:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x256xf32, #{{.*}}>) -> (!ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable>, !ttg.async.token) // CHECK: %[[MMA_TOK:.+]] = ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]][%[[ACC_TOK]]], %[[TRUE]], %[[TRUE]] {two_ctas} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x256xf16, #shared1, #smem>, !ttg.memdesc<128x256xf32, #[[$TMEM]], #ttng.tensor_memory, mutable> // CHECK: %[[R:.+]], %{{.*}} = ttng.tmem_load %[[ACC]][%[[MMA_TOK]]] : !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> -> tensor<128x256xf32 - // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$T]]> -> tensor<128x256xf32, #[[$B]]> + // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$L]]> -> tensor<128x256xf32, #[[$B]]> // CHECK: tt.return %[[CVT]] : tensor<128x256xf32 tt.func public @mmav5_2ctas(%a: tensor<128x64xf16, #blocked2>, %b_ptr: tensor<64x256x!tt.ptr, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { %ad = ttg.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> @@ -546,12 +546,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { - // LAYOUT_16x256{LITERAL}: #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [16, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[32, 0], [64, 0], [0, 64]], block = []}> + // LAYOUT_16x256{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [8, 0]], lane = [[64, 0], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[0, 0], [0, 0], [16, 0]], block = []}> // CHECK-DAG: #[[$TMEM1:.+]] = #ttng.tensor_memory_scales_encoding - // CHECK{LITERAL}-DAG: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> + // CHECK{LITERAL}-DAG: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0], [0, 4]], block = []}> // CHECK-LABEL: mmav5_block_scaled_8_warps - // CHECK: ttng.tmem_alloc %{{.*}} : (tensor<128x8xi8, #linear>) -> !ttg.memdesc<128x8xi8, #[[$TMEM1]], #ttng.tensor_memory> - // CHECK: ttng.tmem_alloc %{{.*}} : (tensor<128x8xi8, #linear>) -> !ttg.memdesc<128x8xi8, #[[$TMEM1]], #ttng.tensor_memory> + // CHECK: ttng.tmem_alloc %{{.*}} : (tensor<128x8xi8, #linear1>) -> !ttg.memdesc<128x8xi8, #[[$TMEM1]], #ttng.tensor_memory> + // CHECK: ttng.tmem_alloc %{{.*}} : (tensor<128x8xi8, #linear1>) -> !ttg.memdesc<128x8xi8, #[[$TMEM1]], #ttng.tensor_memory> // CHECK: ttng.tc_gen5_mma_scaled tt.func public @mmav5_block_scaled_8_warps(%a: tensor<128x256xi8, #blocked2>, %scale_a: tensor<128x8xi8, #blocked1>, %b: tensor<256x128xi8, #blocked>, %scale_b: tensor<128x8xi8, #blocked1>) -> tensor<128x128xf32, #blocked> { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index 61aa2d326a..3f8ba99f68 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -125,7 +125,7 @@ tt.func @test_canonicalize_convert_local_load(%arg0: !ttg.async.token) -> tensor // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> -#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[32, 0], [64, 0], [16, 0]], block = []}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0], [0, 32]], block = []}> #tmem = #ttng.tensor_memory_encoding // CHECK-LABEL: test_canonicalize_convert_tmem_store module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { diff --git a/test/TritonNvidiaGPU/interleave_tmem.mlir b/test/TritonNvidiaGPU/interleave_tmem.mlir index 1935c626b9..32526789ad 100644 --- a/test/TritonNvidiaGPU/interleave_tmem.mlir +++ b/test/TritonNvidiaGPU/interleave_tmem.mlir @@ -1,7 +1,8 @@ // RUN: triton-opt %s --triton-nvidia-interleave-tmem --allow-unregistered-dialect | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#linear64 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0], [0, 32]], block = []}> +#linear128 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0], [0, 64]], block = []}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #smem = #ttg.shared_memory @@ -19,11 +20,11 @@ tt.func public @sink_load(%arg0: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_m // CHECK: ttg.convert_layout // CHECK: arith.truncf %subslice0 = ttng.tmem_subslice %arg0 {N = 0 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> - %subtile0 = ttng.tmem_load %subslice0 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1> - %outLHS = ttg.convert_layout %subtile0 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #blocked> + %subtile0 = ttng.tmem_load %subslice0 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #linear64> + %outLHS = ttg.convert_layout %subtile0 : tensor<128x64xf32, #linear64> -> tensor<128x64xf32, #blocked> %subslice1 = ttng.tmem_subslice %arg0 {N = 64 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> - %subtile1 = ttng.tmem_load %subslice1 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1> - %outRHS = ttg.convert_layout %subtile1 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #blocked> + %subtile1 = ttng.tmem_load %subslice1 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #linear64> + %outRHS = ttg.convert_layout %subtile1 : tensor<128x64xf32, #linear64> -> tensor<128x64xf32, #blocked> // CHECK: ttng.tmem_load // CHECK: ttg.convert_layout @@ -33,16 +34,16 @@ tt.func public @sink_load(%arg0: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_m %5 = arith.truncf %outLHS : tensor<128x64xf32, #blocked> to tensor<128x64xf16, #blocked> %true = arith.constant true - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> - ttng.tmem_store %cst, %arg2, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #linear128> + ttng.tmem_store %cst, %arg2, %true : tensor<128x128xf32, #linear128> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %6 = arith.truncf %outRHS : tensor<128x64xf32, #blocked> to tensor<128x64xf16, #blocked> // CHECK: ttng.tmem_load // CHECK: ttg.convert_layout // CHECK: "unknow_may_side_effect"() : () -> () // CHECK: arith.truncf - %7 = ttng.tmem_load %arg2 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> - %8 = ttg.convert_layout %7 : tensor<128x128xf32, #blocked1> -> tensor<128x128xf32, #blocked> + %7 = ttng.tmem_load %arg2 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #linear128> + %8 = ttg.convert_layout %7 : tensor<128x128xf32, #linear128> -> tensor<128x128xf32, #blocked> "unknow_may_side_effect"() : () -> () %9 = arith.truncf %8 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> @@ -62,7 +63,7 @@ tt.func @interleave_load_store_ws() { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %c32 = arith.constant 32 : i32 - %alpha = arith.constant dense<0.5> : tensor<128x64xf32, #blocked1> + %alpha = arith.constant dense<0.5> : tensor<128x64xf32, #linear64> %true = arith.constant true // CHECK: scf.for @@ -77,18 +78,18 @@ tt.func @interleave_load_store_ws() { // CHECK-NEXT: [[M0:%.+]] = arith.mulf [[L0]] // CHECK-NEXT: ttng.tmem_store [[M0]], [[S0]] %slice0 = ttng.tmem_subslice %cur_acc {N = 0 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> - %val0 = ttng.tmem_load %slice0 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1> - %mul0 = arith.mulf %val0, %alpha : tensor<128x64xf32, #blocked1> + %val0 = ttng.tmem_load %slice0 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #linear64> + %mul0 = arith.mulf %val0, %alpha : tensor<128x64xf32, #linear64> // CHECK-NEXT: [[L1:%.+]] = ttng.tmem_load [[S1]] // CHECK-NEXT: [[M1:%.+]] = arith.mulf [[L1]] // CHECK-NEXT: ttng.tmem_store [[M1]], [[S1]] %slice1 = ttng.tmem_subslice %cur_acc {N = 64 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> - %val1 = ttng.tmem_load %slice1 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1> - %mul1 = arith.mulf %val1, %alpha : tensor<128x64xf32, #blocked1> + %val1 = ttng.tmem_load %slice1 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #linear64> + %mul1 = arith.mulf %val1, %alpha : tensor<128x64xf32, #linear64> - ttng.tmem_store %mul0, %slice0, %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> - ttng.tmem_store %mul1, %slice1, %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %mul0, %slice0, %true : tensor<128x64xf32, #linear64> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %mul1, %slice1, %true : tensor<128x64xf32, #linear64> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> } ttg.warp_return @@ -99,23 +100,23 @@ tt.func @interleave_load_store_ws() { // CHECK-LABEL: @arrive_barrier tt.func @arrive_barrier(%arg0: !ttg.memdesc<1xi64, #shared, #smem, mutable>) { %true = arith.constant true - %cst = arith.constant dense<0.0> : tensor<128x128xf32, #blocked1> + %cst = arith.constant dense<0.0> : tensor<128x128xf32, #linear128> // CHECK-COUNT-2: ttng.tmem_alloc %alloc = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %noalias_alloc = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK-NEXT: tmem_store // CHECK-NEXT: tmem_load - %0 = ttng.tmem_load %alloc : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> - ttng.tmem_store %cst, %noalias_alloc, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %0 = ttng.tmem_load %alloc : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #linear128> + ttng.tmem_store %cst, %noalias_alloc, %true : tensor<128x128xf32, #linear128> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK-NEXT: arrive_barrier ttng.arrive_barrier %arg0, 1 : !ttg.memdesc<1xi64, #shared, #smem, mutable> - "user"(%0) : (tensor<128x128xf32, #blocked1>) -> () + "user"(%0) : (tensor<128x128xf32, #linear128>) -> () tt.return } // CHECK-LABEL: @sink_alloc_op -tt.func @sink_alloc_op(%arg0: tensor<128x128xf32, #blocked1>) { +tt.func @sink_alloc_op(%arg0: tensor<128x128xf32, #linear128>) { %c0 = arith.constant 0 : i32 %true = arith.constant true @@ -126,11 +127,11 @@ tt.func @sink_alloc_op(%arg0: tensor<128x128xf32, #blocked1>) { // CHECK: [[SUBVIEW1:%.+]] = ttg.memdesc_index [[ALLOC1]] %subview1 = ttg.memdesc_index %alloc1[%c0] : !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK-NEXT: tmem_store %arg0, [[SUBVIEW1]] - ttng.tmem_store %arg0, %subview1, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %arg0, %subview1, %true : tensor<128x128xf32, #linear128> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK-NEXT: [[ALLOC0:%.+]] = ttng.tmem_alloc // CHECK: [[SUBVIEW0:%.+]] = ttg.memdesc_index [[ALLOC0]] // CHECK-NEXT: tmem_store %arg0, [[SUBVIEW0]] - ttng.tmem_store %arg0, %subview0, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %arg0, %subview0, %true : tensor<128x128xf32, #linear128> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> tt.return } diff --git a/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir b/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir index 52132242b8..e92523080d 100644 --- a/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir +++ b/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir @@ -162,8 +162,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> #tmem = #ttng.tensor_memory_encoding #tmem1 = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32, ttg.shared = 65536 : i32} { diff --git a/test/TritonNvidiaGPU/tmem_layouts.mlir b/test/TritonNvidiaGPU/tmem_layouts.mlir index 8c2c436041..f3c506e8f5 100644 --- a/test/TritonNvidiaGPU/tmem_layouts.mlir +++ b/test/TritonNvidiaGPU/tmem_layouts.mlir @@ -1,10 +1,10 @@ // RUN: triton-opt %s -split-input-file --triton-nvidia-optimize-tmem-layouts --allow-unregistered-dialect | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 64], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 2, 1], order = [0, 2, 1]}> -#blocked3 = #ttg.blocked<{sizePerThread = [1, 64, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 2], order = [0, 1, 2]}> -#blocked4 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 2, 1], order = [2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 2, 1], order = [2, 1, 0]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 64]], warp = [[32, 0], [64, 0], [16, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [0, 0, 8], [0, 0, 16], [0, 0, 32]], lane = [[1, 0, 0], [2, 0, 0], [4, 0, 0], [8, 0, 0], [0, 1, 0]], warp = [[32, 0, 0], [64, 0, 0], [16, 0, 0]], block = []}> +#linear2 = #ttg.linear<{register = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 8, 0], [0, 16, 0], [0, 32, 0]], lane = [[1, 0, 0], [2, 0, 0], [4, 0, 0], [8, 0, 0], [0, 0, 1]], warp = [[32, 0, 0], [64, 0, 0], [16, 0, 0]], block = []}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100"} { @@ -17,31 +17,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK: %[[L1:.+]] = ttng.tmem_load %[[S1]] : !ttg.memdesc<128x64xf32 // CHECK: %[[C1:.+]] = ttg.convert_layout %[[L1]] // CHECK: tt.return %[[C0]], %[[C1]] - %0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> - %1 = tt.reshape %0 : tensor<128x128xf32, #blocked1> -> tensor<128x2x64xf32, #blocked2> - %2 = tt.trans %1 {order = array} : tensor<128x2x64xf32, #blocked2> -> tensor<128x64x2xf32, #blocked3> - %3 = ttg.convert_layout %2 : tensor<128x64x2xf32, #blocked3> -> tensor<128x64x2xf32, #blocked4> - %outLHS, %outRHS = tt.split %3 : tensor<128x64x2xf32, #blocked4> -> tensor<128x64xf32, #blocked> + %0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #linear> + %1 = tt.reshape %0 : tensor<128x128xf32, #linear> -> tensor<128x2x64xf32, #linear1> + %2 = tt.trans %1 {order = array} : tensor<128x2x64xf32, #linear1> -> tensor<128x64x2xf32, #linear2> + %3 = ttg.convert_layout %2 : tensor<128x64x2xf32, #linear2> -> tensor<128x64x2xf32, #blocked1> + %outLHS, %outRHS = tt.split %3 : tensor<128x64x2xf32, #blocked1> -> tensor<128x64xf32, #blocked> tt.return %outLHS, %outRHS : tensor<128x64xf32, #blocked>, tensor<128x64xf32, #blocked> } } // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 2, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [4, 1, 2], order = [1, 2, 0]}> -#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 2, 1], order = [2, 1, 0]}> -#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> -#blocked7 = #ttg.blocked<{sizePerThread = [1, 1, 128], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 2, 1], order = [0, 2, 1]}> -#blocked8 = #ttg.blocked<{sizePerThread = [1, 128, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 2], order = [0, 1, 2]}> -#linear = #ttg.linear<{register = [[0, 0, 1], [0, 64, 0], [4, 0, 0], [8, 0, 0], [16, 0, 0], [32, 0, 0], [64, 0, 0]], lane = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 8, 0], [0, 16, 0]], warp = [ -[0, 32, 0], [1, 0, 0], [2, 0, 0]], block = []}> -#linear1 = #ttg.linear<{register = [[0, 64], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [1, 0], [2, 0]], block = []}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 2, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [4, 1, 2], order = [1, 2, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 2, 1], order = [2, 1, 0]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128]], warp = [[32, 0], [64, 0], [16, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [0, 0, 8], [0, 0, 16], [0, 0, 32], [0, 0, 64]], lane = [[1, 0, 0], [2, 0, 0], [4, 0, 0], [8, 0, 0], [0, 1, 0]], warp = [[32, 0, 0], [64, 0, 0], [16, 0, 0]], block = []}> +#linear2 = #ttg.linear<{register = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 8, 0], [0, 16, 0], [0, 32, 0], [0, 64, 0]], lane = [[1, 0, 0], [2, 0, 0], [4, 0, 0], [8, 0, 0], [0, 0, 1]], warp = [[32, 0, 0], [64, 0, 0], [16, 0, 0]], block = []}> +#linear3 = #ttg.linear<{register = [[0, 0, 1], [0, 64, 0], [4, 0, 0], [8, 0, 0], [16, 0, 0], [32, 0, 0], [64, 0, 0]], lane = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 8, 0], [0, 16, 0]], warp = [[0, 32, 0], [1, 0, 0], [2, 0, 0]], block = []}> +#linear4 = #ttg.linear<{register = [[0, 64], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [1, 0], [2, 0]], block = []}> #tmem = #ttng.tensor_memory_encoding - module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @subtile4_tmem_load - tt.func public @subtile4_tmem_load(%arg0: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>) -> (tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>) { + tt.func public @subtile4_tmem_load(%arg0: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>) -> (tensor<128x64xf32, #blocked>, tensor<128x64xf32, #blocked>, tensor<128x64xf32, #blocked>, tensor<128x64xf32, #blocked>) { // CHECK: %[[S0:.+]] = ttng.tmem_subslice %{{.+}} {N = 0 : i32} // CHECK: %[[S1:.+]] = ttng.tmem_subslice %[[S0]] {N = 0 : i32} // CHECK: %[[L1:.+]] = ttng.tmem_load %[[S1]] : !ttg.memdesc<128x64xf32 @@ -57,18 +55,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK: %[[L5:.+]] = ttng.tmem_load %[[S5]] : !ttg.memdesc<128x64xf32 // CHECK: %[[C5:.+]] = ttg.convert_layout %[[L5]] // CHECK: tt.return %[[C1]], %[[C2]], %[[C4]], %[[C5]] - %0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked> - %1 = tt.reshape %0 : tensor<128x256xf32, #blocked> -> tensor<128x2x128xf32, #blocked7> - %2 = tt.trans %1 {order = array} : tensor<128x2x128xf32, #blocked7> -> tensor<128x128x2xf32, #blocked8> - %3 = ttg.convert_layout %2 : tensor<128x128x2xf32, #blocked8> -> tensor<128x128x2xf32, #linear> - %outLHS, %outRHS = tt.split %3 : tensor<128x128x2xf32, #linear> -> tensor<128x128xf32, #linear1> - %4 = tt.reshape %outLHS : tensor<128x128xf32, #linear1> -> tensor<128x2x64xf32, #blocked2> - %5 = tt.trans %4 {order = array} : tensor<128x2x64xf32, #blocked2> -> tensor<128x64x2xf32, #blocked3> - %outLHS_1, %outRHS_1 = tt.split %5 : tensor<128x64x2xf32, #blocked3> -> tensor<128x64xf32, #blocked4> - %6 = tt.reshape %outRHS : tensor<128x128xf32, #linear1> -> tensor<128x2x64xf32, #blocked2> - %7 = tt.trans %6 {order = array} : tensor<128x2x64xf32, #blocked2> -> tensor<128x64x2xf32, #blocked3> - %outLHS_2, %outRHS_2 = tt.split %7 : tensor<128x64x2xf32, #blocked3> -> tensor<128x64xf32, #blocked4> - tt.return %outLHS_1, %outRHS_1, %outLHS_2, %outRHS_2 : tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4> + %result = ttng.tmem_load %arg0 : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #linear> + %0 = tt.reshape %result : tensor<128x256xf32, #linear> -> tensor<128x2x128xf32, #linear1> + %1 = tt.trans %0 {order = array} : tensor<128x2x128xf32, #linear1> -> tensor<128x128x2xf32, #linear2> + %2 = ttg.convert_layout %1 : tensor<128x128x2xf32, #linear2> -> tensor<128x128x2xf32, #linear3> + %outLHS, %outRHS = tt.split %2 : tensor<128x128x2xf32, #linear3> -> tensor<128x128xf32, #linear4> + %3 = tt.reshape %outLHS : tensor<128x128xf32, #linear4> -> tensor<128x2x64xf32, #blocked1> + %4 = tt.trans %3 {order = array} : tensor<128x2x64xf32, #blocked1> -> tensor<128x64x2xf32, #blocked2> + %outLHS_0, %outRHS_1 = tt.split %4 : tensor<128x64x2xf32, #blocked2> -> tensor<128x64xf32, #blocked> + %5 = tt.reshape %outRHS : tensor<128x128xf32, #linear4> -> tensor<128x2x64xf32, #blocked1> + %6 = tt.trans %5 {order = array} : tensor<128x2x64xf32, #blocked1> -> tensor<128x64x2xf32, #blocked2> + %outLHS_2, %outRHS_3 = tt.split %6 : tensor<128x64x2xf32, #blocked2> -> tensor<128x64xf32, #blocked> + tt.return %outLHS_0, %outRHS_1, %outLHS_2, %outRHS_3 : tensor<128x64xf32, #blocked>, tensor<128x64xf32, #blocked>, tensor<128x64xf32, #blocked>, tensor<128x64xf32, #blocked> } } @@ -109,7 +107,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0], [128, 0]], block = []}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 2, 64], threadsPerWarp = [32, 1, 1], warpsPerCTA = [8, 1, 1], order = [0, 2, 1]}> #blocked3 = #ttg.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [8, 1, 1], order = [0, 1, 2]}> #blocked4 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 2, 1], order = [2, 1, 0]}> @@ -120,8 +118,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK-NOT: ttng.tmem_subslice // CHECK: tt.return tt.func public @subtile_tmem_load_256(%arg0: !ttg.memdesc<256x128xf32, #tmem, #ttng.tensor_memory, mutable>) -> (tensor<256x64xf32, #blocked>, tensor<256x64xf32, #blocked>) { - %0 = ttng.tmem_load %arg0 : !ttg.memdesc<256x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<256x128xf32, #blocked1> - %1 = tt.reshape %0 : tensor<256x128xf32, #blocked1> -> tensor<256x2x64xf32, #blocked2> + %0 = ttng.tmem_load %arg0 : !ttg.memdesc<256x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<256x128xf32, #linear> + %1 = tt.reshape %0 : tensor<256x128xf32, #linear> -> tensor<256x2x64xf32, #blocked2> %2 = tt.trans %1 {order = array} : tensor<256x2x64xf32, #blocked2> -> tensor<256x64x2xf32, #blocked3> %3 = ttg.convert_layout %2 : tensor<256x64x2xf32, #blocked3> -> tensor<256x64x2xf32, #blocked4> %outLHS, %outRHS = tt.split %3 : tensor<256x64x2xf32, #blocked4> -> tensor<256x64xf32, #blocked> @@ -131,27 +129,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0], [0, 32]], block = []}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { -// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[32, 0], [64, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[32, 0], [64, 0], [16, 0]], block = []}> // CHECK-LABEL: tmem_load_reduce -tt.func public @tmem_load_reduce(%arg0: !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> { - %0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory> -> tensor<128x64xf32, #blocked> - // CHECK: ttng.tmem_load %{{.*}} : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory> -> tensor<128x64xf32, #linear> +tt.func public @tmem_load_reduce(%arg0: !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #linear}>> { + %0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory> -> tensor<128x64xf32, #linear> + // CHECK: ttng.tmem_load %{{.*}} : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory> -> tensor<128x64xf32, #linear1> %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ ^bb0(%arg2: f32, %arg3: f32): %2 = arith.addf %arg2, %arg3 : f32 tt.reduce.return %2 : f32 - }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> - tt.return %1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<128x64xf32, #linear>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #linear}>> + tt.return %1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #linear}>> } } - // ----- #blocked = #ttg.blocked<{sizePerThread = [64, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> @@ -203,14 +200,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> #smem = #ttg.shared_memory +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> #tmem = #ttng.tensor_memory_encoding -// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [16, 0], [128, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[32, 0], [64, 0]], block = []}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [128, 0], [16, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[32, 0], [64, 0]], block = []}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { tt.func @reshape_memedesc_negative(%arg0: !ttg.memdesc<256x16xf32, #tmem, #ttng.tensor_memory>, %arg1: !ttg.memdesc<16x256xf8E4M3FN, #shared, #smem, mutable>) { // CHECK: %[[L:.+]] = ttng.tmem_load %{{.+}} : !ttg.memdesc<256x16xf32, #tmem, #ttng.tensor_memory> -> tensor<256x16xf32, #linear> // CHECK: ttg.convert_layout %[[L:.+]] - %result = ttng.tmem_load %arg0 : !ttg.memdesc<256x16xf32, #tmem, #ttng.tensor_memory> -> tensor<256x16xf32, #blocked> - %0 = tt.trans %result {order = array} : tensor<256x16xf32, #blocked> -> tensor<16x256xf32, #blocked1> + %result = ttng.tmem_load %arg0 : !ttg.memdesc<256x16xf32, #tmem, #ttng.tensor_memory> -> tensor<256x16xf32, #linear> + %0 = tt.trans %result {order = array} : tensor<256x16xf32, #linear> -> tensor<16x256xf32, #blocked1> %1 = tt.fp_to_fp %0, rounding = rtne : tensor<16x256xf32, #blocked1> -> tensor<16x256xf8E4M3FN, #blocked1> ttg.local_store %1, %arg1 : tensor<16x256xf8E4M3FN, #blocked1> -> !ttg.memdesc<16x256xf8E4M3FN, #shared, #smem, mutable> tt.return diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp index 5ada24afd4..a38822ea63 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp @@ -864,8 +864,8 @@ static Operation *sliceOp(Operation *op, int offset, IRMapping &mappings, // The source op is already sliced at this point, so srcTy, type, tmem is // sliced. We use getTmemCompatibleLayout to get a block layout that is for // the sliced tmem here. - Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout( - tmem.getBlockM(), tmem.getBlockN(), oldRetType, numWarps); + auto newDistributedEncoding = + nvidia_gpu::getDefaultLayoutForTmemLdSt(type, numWarps, CTALayout); // oldRetType is the desired output, we slice it and convert from the // compatible layout to the sliced desired output. @@ -919,8 +919,8 @@ static Operation *sliceOp(Operation *op, int offset, IRMapping &mappings, accEncoding, retType.getMemorySpace(), retType.getMutableMemory()); - Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout( - accEncoding.getBlockM(), accEncoding.getBlockN(), srcTy, numWarps); + auto newDistributedEncoding = + nvidia_gpu::getDefaultLayoutForTmemLdSt(retType, numWarps, CTALayout); auto newAccType = RankedTensorType::get( srcTy.getShape(), srcTy.getElementType(), newDistributedEncoding); auto cvtOp = builder.createWithAsyncTaskIds( diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp index 86014a6bbf..3c030f9575 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp @@ -10,6 +10,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Types.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h" #include "triton/Tools/LayoutUtils.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" @@ -29,21 +30,6 @@ static constexpr int maxRegisters = 256; namespace { -struct TMemAccessAtom { - int colsPerThread; - int rowsPerThread; - const char *opShape; -}; - -constexpr TMemAccessAtom TMemAccess32x32b{ - 1 /*colsPerThread*/, 1 /*rowsPerThread*/, "32x32b" /*opShape*/}; - -constexpr TMemAccessAtom TMemAccess16x256b{ - 2 /*colsPerThread*/, 2 /*rowsPerThread*/, "16x256b" /*opShape*/}; - -constexpr TMemAccessAtom TMemAccess16x32bx2{ - 1 /*colsPerThread*/, 1 /*rowsPerThread*/, "16x32bx2" /*opShape*/}; - struct TMemCopyAtom { int nRow; int bCol; @@ -101,91 +87,6 @@ TMemCopyAtom getTMemCopyAtom(const LinearLayout &cvt, int bitwidth) { } } -// Similar to largestVectorisation in TritonGPUToLLVM/Utility.cpp -std::optional> -getVec(const LinearLayout &cvt, const LinearLayout &tile, int maxnreg) { - auto *ctx = cvt.getInDimNames().begin()->getContext(); - auto kReg = StringAttr::get(ctx, "register"); - auto kCol = StringAttr::get(ctx, "col"); - LinearLayout reps, vec; - ColumnAction perm; - // Heuristic: - // Do not use more than half the registers as otherwise it's prone to spilling - assert(maxnreg / 2 <= largestTmemLoadStore); - auto maxReg = maxnreg / 2; - // Heuristic: - // If maxnreg is 256 and we need more than one message, we don't use max - // vectorisation as ptxas' scheduler breaks... - if (maxnreg == 256 && cvt.getInDimSize(kReg) > maxReg) { - maxReg /= 2; - } - auto maxVec = maxReg / tile.getInDimSize(kReg); - int i = 1; - for (; i <= maxVec; i *= 2) { - vec = LinearLayout::identity1D(i, kReg, kCol); - auto vecTile = tile * vec; - auto maybePerm = regPermForDivide(cvt, vecTile, /*left=*/true); - if (!maybePerm) { - if (i == 1) { - // Couldn't lower the tile - return std::nullopt; - } - break; - } - // nb. We could remove this part once we are confident the algo works - perm = *maybePerm; - auto newCvt = maybePerm->apply(cvt); - auto maybeReps = getReps(newCvt, vecTile); - if (!maybeReps.has_value()) { - if (i == 1) { - // Couldn't lower the tile - return std::nullopt; - } - break; - } - reps = *maybeReps; - } - // i is the smallest power of 2 that *cannot* be used to lower the tile - // so we return i / 2. - assert(i > 1); - return std::make_tuple(std::move(reps), std::move(perm), - (i / 2) * tile.getInDimSize(kReg)); -} - -LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom, - bool unpacked) { - auto kReg = str_attr("register"); - auto kLane = str_attr("lane"); - auto kWarp = str_attr("warp"); - auto kRow = str_attr("row"); - auto kCol = str_attr("col"); - // Set the output order to be kRow, kCol and the input order to be kReg first - LinearLayout tile = LinearLayout::identity1D(1, kReg, kRow) * - LinearLayout::identity1D(1, kReg, kCol); - if (atom.opShape == std::string("32x32b")) { - tile *= LinearLayout::identity1D(32, kLane, kRow); - } else if (atom.opShape == std::string("16x32bx2")) { - tile *= LinearLayout::identity1D(16, kLane, kRow); - } else if (atom.opShape == std::string("16x256b")) { - tile *= LinearLayout::identity1D(2, kReg, kCol) * - LinearLayout::identity1D(4, kLane, kCol) * - LinearLayout::identity1D(8, kLane, kRow) * - LinearLayout::identity1D(2, kReg, kRow); - } else { - llvm_unreachable("Unsupported TMEM access atom"); - } - // Each register moves 32/bitwidth (= 2) columns when unpacked - if (unpacked) { - tile = LinearLayout::zeros1D(1, kReg, kCol, 2) * tile; - } - auto nCol = tile.getOutDimSize(kCol); - auto bases = tile.getBases(); - bases[kWarp].push_back({32, 0}); - bases[kWarp].push_back({64, 0}); - auto ret = LinearLayout(bases, {{kRow, 128}, {kCol, nCol}}, false); - return ret; -} - SmallVector pack(ArrayRef values, Type outType, Location loc, ConversionPatternRewriter &rewriter, bool pad = false) { auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -258,14 +159,14 @@ SmallVector unpack(ArrayRef packedValues, Type outType, void createTensorMemoryStore(Location loc, Value address, int colOffset, SmallVector &srcs, std::optional secondHalfOffset, Value pred, - bool unpacked, const TMemAccessAtom &atom, + bool unpacked, TMemAccessAtom atom, ConversionPatternRewriter &rewriter) { PTXBuilder ptxBuilder; std::string packedStr = unpacked ? ".unpack::16b" : ""; - unsigned numRepeats = srcs.size() / (atom.rowsPerThread * atom.colsPerThread); - std::string opcode = "@$0 tcgen05.st.sync.aligned." + - std::string(atom.opShape) + ".x" + - std::to_string(numRepeats) + packedStr; + unsigned numRepeats = srcs.size() / getElementsPerThread(atom); + std::string opcode = "@$0 tcgen05.st.sync.aligned."; + opcode += getOpShape(atom); + opcode += ".x" + std::to_string(numRepeats) + packedStr; opcode += ".b32 [$1 + " + std::to_string(colOffset) + "], "; if (secondHalfOffset) opcode += std::to_string(*secondHalfOffset) + ", {"; @@ -290,69 +191,18 @@ void createTensorMemoryStore(Location loc, Value address, int colOffset, ptxBuilder.launch(rewriter, loc, voidTy); } -// Get the maximum number of registers per thread based on the context. This is -// by default 256, but it can be overridden by `ttg.maxnreg` set on the module -// or a contextual register limit set by the compiler on partitions. -int getContextualMaxNReg(Operation *op) { - // Check the immediate parent op to see if it places a register constraint. - auto getFromParent = [](Operation *op) -> std::optional { - Operation *parent = op->getParentOp(); - if (auto mod = dyn_cast(parent)) { - if (auto attr = mod->getAttrOfType(AttrMaxRegistersName)) - return attr.getInt(); - return {}; - } - - if (auto partitions = dyn_cast(parent)) { - // Check if the partition has reduced registers. - unsigned idx = op->getParentRegion()->getRegionNumber(); - if (auto actRegisters = partitions.getParentOp().getActualRegisters()) - return (*actRegisters)[1 + idx]; - return {}; - } - - if (auto wsOp = dyn_cast(op->getParentOp())) { - // Check the register usage of the default warpgroup. - if (auto actRegisters = wsOp.getActualRegisters()) - return actRegisters->front(); - return {}; - } - - return {}; - }; - - // PTXAS validates the register usage of `tcgen05.ld` and `tcgen05.st` - // instructions based on the static number of registers set on the module, not - // the dynamic allocation. This just means the register limit used for the - // purpose of subtiling TMEM messages cannot be higher than the module's. - auto mod = op->getParentOfType(); - int maxnreg = maxRegisters; - - for (; op != mod; op = op->getParentOp()) { - if (std::optional limit = getFromParent(op)) { - maxnreg = std::min(maxnreg, *limit); - break; - } - } - - if (auto maxnregAttr = mod->getAttrOfType(AttrMaxRegistersName)) - maxnreg = std::min(maxnreg, maxnregAttr.getInt()); - - return maxnreg; -} - Value createTensorMemoryLoad(Location loc, MLIRContext *ctx, Value address, int colOffset, std::optional secondHalfOffset, bool unpacked, int numRegPerMessage, - const TMemAccessAtom &atom, + TMemAccessAtom atom, ConversionPatternRewriter &rewriter) { PTXBuilder ptxBuilder; // If the memory is unpacked we need to pack on the fly when loading. std::string packedStr = unpacked ? ".pack::16b" : ""; - unsigned numRepeats = - numRegPerMessage / (atom.rowsPerThread * atom.colsPerThread); - std::string opcode = "tcgen05.ld.sync.aligned." + std::string(atom.opShape) + - ".x" + std::to_string(numRepeats) + packedStr + ".b32 {"; + unsigned numRepeats = numRegPerMessage / getElementsPerThread(atom); + std::string opcode = "tcgen05.ld.sync.aligned."; + opcode += getOpShape(atom); + opcode += ".x" + std::to_string(numRepeats) + packedStr + ".b32 {"; SmallVector operands; for (int i = 0; i < numRegPerMessage; i++) { @@ -416,12 +266,14 @@ static SmallVector unpackResults(Value packedValues, Type elemTy, return resultVals; } -FailureOr> -lowerTMemLdSt(Location loc, MLIRContext *ctx, - ConversionPatternRewriter &rewriter, const LinearLayout &reps, - ArrayRef vals, TMemAccessAtom atom, Type llvmElemTy, - Value tmemBase, Value pred, int valsPerMessage, bool unpacked, - std::optional secondHalfOffset) { +SmallVector lowerTMemLdSt(Location loc, + ConversionPatternRewriter &rewriter, + const LinearLayout &reps, ArrayRef vals, + TMemAccessAtom atom, Type llvmElemTy, + Value tmemBase, Value pred, int valsPerMessage, + bool unpacked, + std::optional secondHalfOffset) { + auto *ctx = rewriter.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); auto kReg = str_attr("register"); auto kLane = str_attr("lane"); @@ -481,216 +333,77 @@ lowerTMemLdSt(Location loc, MLIRContext *ctx, return resultVals; } -FailureOr> -lowerTMemLdSt(Location loc, MLIRContext *ctx, - ConversionPatternRewriter &rewriter, const LinearLayout &cvt, - ArrayRef vals, Type llvmElemTy, Value tmemBase, - int maxnreg, Value pred, bool isScales = false, - bool unpacked = false) { - assert(cvt.getNumOutDims() == 2); +static SmallVector +lowerTMemLdStFromInfo(Location loc, ConversionPatternRewriter &rewriter, + TMemLdStEncodingInfo &info, Value pred, Type llvmElemTy, + ArrayRef vals, Value tmemBase) { bool isStore = !vals.empty(); - // Remove broadcasting in the registers - auto removeBroadcastSrc = actionRemoveBroadcastedRegs(cvt); - if (!removeBroadcastSrc.isIdentity()) { - auto prmtCvt = removeBroadcastSrc.apply(cvt); + if (info.broadcast) { + auto removeBroadcast = std::move(info.broadcast.value()); + info.broadcast = std::nullopt; + auto inVals = to_vector(vals); if (isStore) { - inVals = removeBroadcastSrc.apply(inVals); + inVals = removeBroadcast.apply(inVals); } - auto outValsOr = - lowerTMemLdSt(loc, ctx, rewriter, prmtCvt, inVals, llvmElemTy, tmemBase, - maxnreg, pred, isScales, unpacked); - if (failed(outValsOr)) - return failure(); - auto outVals = std::move(*outValsOr); + auto outVals = lowerTMemLdStFromInfo(loc, rewriter, info, pred, llvmElemTy, + inVals, tmemBase); if (!isStore) { - outVals = broadcastAs(outVals, cvt); + outVals = broadcastAs(outVals, info.reps); } return outVals; } - auto kReg = str_attr("register"); - auto kLane = str_attr("lane"); - auto kRow = str_attr("row"); - auto kCol = str_attr("col"); - - // Default to unpacked=false for bitwidth == 32 if (llvmElemTy.getIntOrFloatBitWidth() < 32) { - auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); - LinearLayout quot; - Type packedElemTy; - int bestContig = 1; - for (int contig = 1; bitwidth * contig <= 32; contig *= 2) { - auto maybeQuot = - divideLeft(cvt, LinearLayout::identity1D(contig, kReg, kCol)); - if (!maybeQuot) - break; - quot = *maybeQuot; - bestContig = contig; - } + unsigned bitwidth = llvmElemTy.getIntOrFloatBitWidth(); bool padding = false; - if (bestContig > 1) { + Type packedElemTy; + if (info.vec > 1) { // There are contiguous elements along kCol, so we can pack them into a // larger dtype - unpacked = false; - packedElemTy = int_ty(bitwidth * bestContig); - } else if (auto maybeQuot = divideLeft( - cvt, LinearLayout::zeros1D(1, kReg, kCol, 32 / bitwidth) * - LinearLayout::identity1D(2, kReg, kCol)); - bitwidth == 16 && maybeQuot) { - // Unpacked just supported for bitwidth 16 - unpacked = true; - quot = *maybeQuot; - packedElemTy = i32_ty; - } else if (auto maybeQuot = divideLeft( - cvt, LinearLayout::zeros1D(1, kReg, kCol, 32 / bitwidth))) { - // We software-pad the elements when we either do not have enough elements - // to fill a full 32b register, e.g., colN = 1 and colStride != 1 or when - // bitwidth == 8 (this happens with scales with K=1). - // These two cases are mostly supported for testing purposes. - unpacked = bitwidth == 16; - quot = *maybeQuot; - packedElemTy = i32_ty; - padding = true; + packedElemTy = int_ty(bitwidth * info.vec); + info.vec = 1; } else { - emitError(loc, "Failed to lower TMEM load/store: TMEM layout is not " - "packed or unpacked"); - return failure(); - } - // When unpacked each register moves 32/bitwidth (= 2) columns - if (unpacked) { - quot = LinearLayout::zeros1D(1, kReg, kCol, 32 / bitwidth) * quot; + padding = info.padding; + assert(info.unpacked || info.padding); + packedElemTy = i32_ty; } - SmallVector inVals; + SmallVector inVals = to_vector(vals); if (isStore) { - inVals = pack(vals, packedElemTy, loc, rewriter, padding); + inVals = pack(inVals, packedElemTy, loc, rewriter, padding); } - auto outValsOr = - lowerTMemLdSt(loc, ctx, rewriter, quot, inVals, packedElemTy, tmemBase, - maxnreg, pred, isScales, unpacked); - if (failed(outValsOr)) - return failure(); - auto outVals = std::move(*outValsOr); + auto outVals = lowerTMemLdStFromInfo(loc, rewriter, info, pred, + packedElemTy, inVals, tmemBase); if (!isStore) { outVals = unpack(outVals, llvmElemTy, loc, rewriter, padding); } return outVals; } - assert(!isStore || cvt.getInDimSize(kReg) == vals.size()); - assert(llvmElemTy.getIntOrFloatBitWidth() == 32); - - // The algorithm goes as: - // - Try to match the tile with one of the standard messages - // - If it doesn't match, we use the 16x32bx2 message - // Note that it can match one and only one of the layouts, even after register - // reordering, as the layouts yield predetermined positions for the lanes - // We store the instruction, the resulting reps layout, the permutation and - // the number of registers per message - std::optional> - msgInfo; - for (auto atom : {TMemAccess32x32b, TMemAccess16x256b}) { - auto tile = getTileLayout(ctx, atom, unpacked); - auto maybeReps = getVec(cvt, tile, maxnreg); - if (maybeReps) { - // Cannot match more than one - msgInfo = {atom, std::get<0>(*maybeReps), std::get<1>(*maybeReps), - std::get<2>(*maybeReps)}; - break; - } - } - std::optional secondHalfOffset = std::nullopt; - if (!msgInfo) { - // Quotient by the smaller tile and then, if possible, we set the - // secondHalfOffset to the last kLane basis - auto tile = getTileLayout(ctx, TMemAccess16x32bx2, unpacked); - auto maybeReps = getVec(cvt, tile, maxnreg); - if (maybeReps) { - auto [reps, perm, numRegsPerMessage] = std::move(*maybeReps); - // Find the last kLane basis and use it as secondHalfOffset - auto row = reps.getBasis(kLane, 4, kRow); - auto col = reps.getBasis(kLane, 4, kCol); - secondHalfOffset = (row << 16) | col; - if (*secondHalfOffset == 0) { - // Workaround for ptxas bug, we cannot use secondHalfOffset = 0 to write - // only 16 elements. We use secondHalfOffset = 1 instead and we pad the - // allocation. - assert(isScales && - "Only supported for scales as we pad the allocation."); - secondHalfOffset = 1; - } - // We "quotient it out", meaning we remove the last basis from reps - auto basis = reps.getBases(); - basis[kLane][4] = {0, 0}; - reps = LinearLayout(basis, reps.getOutDims(), /*isSurjective=*/false); - msgInfo = {TMemAccess16x32bx2, reps, perm, numRegsPerMessage}; - } - } - - if (!msgInfo) { - emitError(loc, "Failed to lower TMEM load/store: unsupported dst layout\n" + - cvt.toString()); - return failure(); - } - auto [atom, reps, perm, numRegsPerMessage] = std::move(msgInfo.value()); - - SmallVector inVals; + SmallVector inVals = to_vector(vals); if (isStore) { - inVals = to_vector(vals); - inVals = perm.apply(inVals); + inVals = info.perm.apply(inVals); } - auto outValsOr = lowerTMemLdSt(loc, ctx, rewriter, reps, inVals, atom, - llvmElemTy, tmemBase, pred, numRegsPerMessage, - unpacked, secondHalfOffset); - if (failed(outValsOr)) - return failure(); - auto outVals = std::move(*outValsOr); - assert(isStore || outVals.size() == cvt.getInDimSize(kReg)); + auto outVals = lowerTMemLdSt( + loc, rewriter, info.reps, inVals, info.atom, llvmElemTy, tmemBase, pred, + info.numRegsPerMessage, info.unpacked, info.secondHalfOffset); if (!isStore) { - outVals = perm.inverse().apply(outVals); + outVals = info.perm.inverse().apply(outVals); } return outVals; } -static FailureOr> lowerTMemLdStFromTypes( - Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, - RankedTensorType regTy, MemDescType memTy, Value tmemBase, int maxnreg, - Value pred, Type llvmElemTy, ArrayRef vals) { - auto memLayout = toLinearLayout(memTy); - auto regLayout = toLinearLayout(regTy); - auto cvt = regLayout.invertAndCompose(memLayout); - auto kWarp = str_attr("warp"); - auto kRow = str_attr("row"); - // Warps 0-3 must map to row=32 and row=64 whether with broadcasting or not - if (!(regLayout.getBasis(kWarp, 0) == memLayout.getBasis(kRow, 5) && - regLayout.getBasis(kWarp, 1) == memLayout.getBasis(kRow, 6))) { - emitError( - loc, - "Failed to lower TMEM load/store: unsupported src/dst combination\n" + - regLayout.toString() + "\n" + memLayout.toString()); - return failure(); - } - // Map warp bases to row=32 and row=64 in the cvt. This would be done - // automatically in `invertAndCompose` if we had a different dimension name - // for these rows. We can do this in the future if needed. - auto bases = cvt.getBases(); - bases[kWarp][0] = {32, 0}; - bases[kWarp][1] = {64, 0}; - cvt = LinearLayout(bases, cvt.getOutDims(), - /*isSurjective=*/cvt.isSurjective()); - - // tmemBase already encodes CTA/block offsets so we just remove them from the - // cvt - auto kBlock = str_attr("block"); - auto kCol = str_attr("col"); - auto nCTAs = cvt.getInDimSize(kBlock); - auto maybeQuot = - divideRight(cvt, LinearLayout::identity1D(nCTAs, kBlock, kCol)); - assert(maybeQuot.has_value()); - auto quot = maybeQuot->unsqueezeIn(kBlock); - - bool isScales = isa(memTy.getEncoding()); - return lowerTMemLdSt(loc, ctx, rewriter, quot, vals, llvmElemTy, tmemBase, - maxnreg, pred, isScales); +static SmallVector +lowerTMemLdStFromTypes(Location loc, ConversionPatternRewriter &rewriter, + RankedTensorType regTy, MemDescType memTy, + Value tmemBase, int maxnreg, Value pred, Type llvmElemTy, + ArrayRef vals) { + auto diag = [loc]() { return emitError(loc); }; + auto encodingInfoOr = + computeTMemLdStEncodingInfo(regTy, memTy, maxnreg, diag); + assert(succeeded(encodingInfoOr) && + "TMEM layout verification should catch invalid layouts"); + return lowerTMemLdStFromInfo(loc, rewriter, *encodingInfoOr, pred, llvmElemTy, + vals, tmemBase); } struct TensorMemoryLoadOpConversion @@ -710,15 +423,13 @@ struct TensorMemoryLoadOpConversion auto b = TritonLLVMOpBuilder(loc, rewriter); auto maxnreg = getContextualMaxNReg(op); - auto resultValsOr = - lowerTMemLdStFromTypes(loc, ctx, rewriter, regTy, memTy, tmemBase, - maxnreg, b.i1_val(true), llvmElemTy, {}); - if (failed(resultValsOr)) - return failure(); + auto resultVals = + lowerTMemLdStFromTypes(loc, rewriter, regTy, memTy, tmemBase, maxnreg, + b.i1_val(true), llvmElemTy, {}); Type structTy = getTypeConverter()->convertType(op.getType()); - Value resultStruct = packLLElements(loc, getTypeConverter(), *resultValsOr, - rewriter, structTy); + Value resultStruct = + packLLElements(loc, getTypeConverter(), resultVals, rewriter, structTy); // Wait insertion could be moved to the TTGIR level if needed. NVVM::Tcgen05WaitOp::create(rewriter, loc, NVVM::Tcgen05WaitKind::LOAD); rewriter.replaceOp(op, {resultStruct}); @@ -747,11 +458,8 @@ struct TensorMemoryStoreOpConversion SmallVector srcValues = unpackLLElements(loc, adaptor.getSrc(), rewriter); auto maxnreg = getContextualMaxNReg(op); - auto lowered = - lowerTMemLdStFromTypes(loc, ctx, rewriter, regTy, memTy, tmemBase, - maxnreg, pred, llvmElemTy, srcValues); - if (failed(lowered)) - return failure(); + lowerTMemLdStFromTypes(loc, rewriter, regTy, memTy, tmemBase, maxnreg, pred, + llvmElemTy, srcValues); NVVM::Tcgen05WaitOp::create(rewriter, loc, NVVM::Tcgen05WaitKind::STORE); // Emit a barrier to ensure all threads have finished writing to tensor @@ -795,11 +503,8 @@ struct TensorMemoryAllocOpConversion SmallVector srcValues = unpackLLElements(loc, adaptor.getSrc(), rewriter); Value ptr = b.inttoptr(base.getType(), allocAddress); - auto lowered = - lowerTMemLdStFromTypes(loc, ctx, rewriter, regTy, memTy, ptr, maxnreg, - b.i1_val(true), llvmElemTy, srcValues); - if (failed(lowered)) - return failure(); + lowerTMemLdStFromTypes(loc, rewriter, regTy, memTy, ptr, maxnreg, + b.i1_val(true), llvmElemTy, srcValues); NVVM::Tcgen05WaitOp::create(rewriter, loc, NVVM::Tcgen05WaitKind::STORE); // Emit a barrier to ensure all threads have finished writing to tensor // memory before any use of the tensor memory. From bd4df82f737205fa7fecbd9cea495b8e005ccd95 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Wed, 29 Oct 2025 15:33:04 +0000 Subject: [PATCH 3/9] [Gluon] Unwrap constexpr on TensorMemoryLayout attributes (#8585) --- python/test/gluon/test_frontend.py | 10 ++++++++++ .../gluon/language/nvidia/blackwell/__init__.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 904d418048..3751841dc9 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -3157,3 +3157,13 @@ def nv_tma_descriptor_store_kernel(input_ptr): } } """) + + +@filecheck_test +def tmem_constexpr(): + tmem_shape: ttgl.constexpr = (64, 64) + bitwidth: ttgl.constexpr = 32 + tmem_layout: ttgl.constexpr = TensorMemoryLayout(tmem_shape, col_stride=32 // bitwidth) + + # CHECK-NOT: constexpr + anchor_noinline(tmem_layout) diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index 3523042295..b29394f465 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -46,6 +46,9 @@ class TensorMemoryLayout: cta_split_num: Optional[Tuple[int, int]] = None def __post_init__(self): + super().__setattr__("block", _unwrap_if_constexpr(self.block)) + super().__setattr__("col_stride", _unwrap_if_constexpr(self.col_stride)) + super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) assert len(self.block) == 2 assert self.cta_split_num is None or len(self.cta_split_num) == 2 assert self.col_stride >= 1 and (self.col_stride & @@ -77,6 +80,7 @@ class TensorMemoryScalesLayout: cta_split_num: Optional[Tuple[int, int]] = None def __post_init__(self): + super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) assert self.cta_split_num is None or len(self.cta_split_num) == 2 def _to_ir(self, builder): From 33f077b379b46f88ddd7fe91a7b22e2dc4bc742b Mon Sep 17 00:00:00 2001 From: Mieszko Dziadowiec Date: Wed, 29 Oct 2025 16:37:33 +0100 Subject: [PATCH 4/9] [Interpreter][histogram] Fix silent data corruption (#8550) There's silent data corruption when calling `tl.histogram` with interpreter. ```python # test.py import torch import ctypes import triton import triton.language as tl @triton.jit def histogram_kernel(x_ptr, z_ptr): offset = tl.arange(0, 1) x = tl.load(x_ptr + offset) z = tl.histogram(x, 1) buf = (ctypes.c_int32 * 2).from_address(int(z_ptr)) print(f'before store: {list(buf)}') tl.store(z_ptr + offset, z) # tl.store treats z values as int64 while they're int32 print(f'after store: {list(buf)}') device = 'cpu' torch.manual_seed(17) x = torch.ones(1, device=device, dtype=torch.int32) z = torch.ones(2, dtype=torch.int32, device=device) histogram_kernel[(1, )](x, z) # Output: # TRITON_INTERPRET=1 TRITON_TEST_SUITE=interpreter python test.py # before store: [1, 1] # after store: [1, 0] <- second element shouldn't be cleared ``` Based on `np.histogram` docs: https://numpy.org/doc/2.3/reference/generated/numpy.histogram.html Returned dtype is taken account when optional weights param is passed, int64 othwerwise. That leads to `tl.store` thinking it's saving int64 values while there's int32 in my example tensor passed, so it's writing 8 bytes at once instead of 4 bytes, leading to writing 4 bytes exceeding it's data range causing silent data corruption. ```python import numpy as np data = np.array([1], dtype=np.int32) bins = 1 print(f'Data dtype before: {data.dtype}') histogram = np.histogram(data, bins=bins, range=(0, bins))[0] print(f'Data dtype after: {histogram.dtype}') # Data dtype before: int32 # Data dtype after: int64 ``` Applying "dummy_weights" fixes returned data type as expected fixing data corruption. ------------------------------ # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because np.histogram specific behavior with interpreter mode. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --- python/test/unit/language/test_core.py | 17 +++++++++++++++++ python/triton/runtime/interpreter.py | 10 +++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6cd41c47ce..4ef02a883d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2748,6 +2748,23 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): assert (z_torch == z).all() +@pytest.mark.interpreter +def test_histogram_silent_data_corruption(device): + + @triton.jit + def histogram_kernel(x_ptr, z_ptr): + offset = tl.arange(0, 1) + x = tl.load(x_ptr + offset) + z = tl.histogram(x, 1) + tl.store(z_ptr + offset, z) + + x = torch.ones(1, device=device, dtype=torch.int32) + z = torch.ones(2, device=device, dtype=torch.int32) + + histogram_kernel[(1, )](x, z) + assert z[1] == 1, f"Second element shouldn't be affected, expected_buffer=[1, 1], actual_buffer={z}" + + # ------------------------ # test histogram with mask # ------------------------ diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 9f77b6232d..2af2e3a718 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -603,9 +603,17 @@ def create_make_range(self, ret_ty, start, stop): def create_histogram(self, data, bins, mask): if mask is None: mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1) + + # By default np.histogram returns int64 dtype values + # Docs specify that returned dtype is taken based on optional weights.dtype + # This is fix for interpreter cases where for example int32 tensor is being passed + # But unexpectedly int64 values are being returned causing + # tl.store to write 8 bytes instead of 4 bytes which lead to silent data corruption + dummy_weights = np.ones_like(data.data, dtype=data.data.dtype) + # force all masked elements to zero data = np.where(mask.data, data.data, np.zeros_like(data.data)) - histogram = np.histogram(data, bins=bins, range=(0, bins))[0] + histogram = np.histogram(data, bins=bins, range=(0, bins), weights=dummy_weights)[0] # remove overcounted elements histogram[0] -= np.logical_not(mask.data).sum() return TensorHandle(histogram, tl.int32) From a295e601341a451042dc09cf7b0139aedb1942b1 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 29 Oct 2025 15:40:12 +0000 Subject: [PATCH 5/9] [AMD] Add `amdgpu.async_wait` to explicitly represent number of async transactions (#8575) `ttg.async_wait` counts the number of outstanding `ttg.commit_groups`. However, when lowering to LLVM on AMD we require the number of outstanding async intrinsics/final assembly instructions. The conversion is already done by `UpdateAsyncWaitCnt` which modifies the `num` of `ttg.async_wait` in place. This PR introduces a new op `amdgpu.async_wait` to make the change in semantics explicit in the IR. `UpdateAsyncWaitCount` is moved to `TTGIR->LLVM` primarily to also include in for `Gluon` kernels and we should always call it since it will only have an effect if there are `ttg.async_wait` ops present in the kernel. To avoid membar changes this also adds a `ttgpu.LocalBarrier` after each `amdgpu.async_wait`. Membar will respect the newly added barrier and behave the same as for `ttg.async_wait`. --- .../amd/async-ops-alias-scopes.mlir | 8 ++-- test/Conversion/amd/async_ops_to_llvm.mlir | 10 ++-- .../amd/amd-update-async-wait-count.mlir | 46 +++++++++---------- third_party/amd/backend/compiler.py | 3 +- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 17 +++++++ .../include/TritonAMDGPUTransforms/Passes.td | 2 +- .../lib/TritonAMDGPUToLLVM/AsyncUtility.cpp | 19 +++++++- .../amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h | 6 +++ .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 9 ++-- .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 2 + .../UpdateAsyncWaitCount.cpp | 19 ++++++-- 11 files changed, 98 insertions(+), 43 deletions(-) diff --git a/test/Conversion/amd/async-ops-alias-scopes.mlir b/test/Conversion/amd/async-ops-alias-scopes.mlir index 209e976ece..cc1f056344 100644 --- a/test/Conversion/amd/async-ops-alias-scopes.mlir +++ b/test/Conversion/amd/async-ops-alias-scopes.mlir @@ -65,7 +65,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ tt.func public @local_loads_with_token_from_async_wait(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !ttg.memdesc<64x1xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x16xf16, #shared, #smem, mutable>) { - %3 = ttg.async_wait {num = 1 : i32} + %3 = amdgpu.async_wait {num_inst = 1 : i32} // Check alias information is added for different lowering paths @@ -111,7 +111,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ %0 = ttg.async_copy_global_to_local %ptr, %arg1 : tensor<64x1x!tt.ptr, #blocked> -> <64x1xf32, #shared, #smem, mutable> %1 = ttg.async_commit_group tokens %0 - %3 = ttg.async_wait %1 {num = 1 : i32} + %3 = amdgpu.async_wait %1 {num_inst = 1 : i32} // Check alias information is not used at all for different lowering paths // COMMON-NOT: [[$ASYNC_COPY_SCOPE]] @@ -146,14 +146,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 - %1 = ttg.async_wait {num = 1 : i32} + %1 = amdgpu.async_wait {num_inst = 1 : i32} // COMMON: llvm.load %2 = ttg.local_load %arg1 token %1 : !ttg.memdesc<64x1xf16, #shared, #smem, mutable> -> tensor<64x1xf16, #blocked> %loop_result:2 = scf.for %arg14 = %c0_i32 to %loopIterCount step %c1_i32 iter_args(%arg10 = %1, %arg11 = %2) -> (!ttg.async.token, tensor<64x1xf16, #blocked>) : i32 { // COMMON: llvm.load {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], noalias_scopes = [[[$ASYNC_COPY_SCOPE]]] %3 = ttg.local_load %arg1 token %arg10 : !ttg.memdesc<64x1xf16, #shared, #smem, mutable> -> tensor<64x1xf16, #blocked> - %4 = ttg.async_wait {num = 1 : i32} + %4 = amdgpu.async_wait {num_inst = 1 : i32} scf.yield %4, %3: !ttg.async.token, tensor<64x1xf16, #blocked> } diff --git a/test/Conversion/amd/async_ops_to_llvm.mlir b/test/Conversion/amd/async_ops_to_llvm.mlir index a4a806a9b3..ef5830194e 100644 --- a/test/Conversion/amd/async_ops_to_llvm.mlir +++ b/test/Conversion/amd/async_ops_to_llvm.mlir @@ -106,24 +106,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: rocdl.s.waitcnt -49168 // CHECK: rocdl.s.waitcnt -7937 // CHECK: rocdl.s.barrier - ttg.async_wait {num = 0 : i32} + amdgpu.async_wait {num_inst = 0 : i32} // CHECK: rocdl.s.waitcnt -49167 // CHECK: rocdl.s.waitcnt -7937 // CHECK: rocdl.s.barrier - ttg.async_wait {num = 1 : i32} + amdgpu.async_wait {num_inst = 1 : i32} // CHECK: rocdl.s.waitcnt -2 // CHECK: rocdl.s.waitcnt -7937 // CHECK: rocdl.s.barrier - ttg.async_wait {num = 62 : i32} + amdgpu.async_wait {num_inst = 62 : i32} // CHECK: rocdl.s.waitcnt -1 // CHECK: rocdl.s.waitcnt -7937 // CHECK: rocdl.s.barrier - ttg.async_wait {num = 63 : i32} + amdgpu.async_wait {num_inst = 63 : i32} // Check that we clamp values > 63 // CHECK: rocdl.s.waitcnt -1 // CHECK: rocdl.s.waitcnt -7937 // CHECK: rocdl.s.barrier - ttg.async_wait {num = 64 : i32} + amdgpu.async_wait {num_inst = 64 : i32} tt.return } } diff --git a/test/TritonGPU/amd/amd-update-async-wait-count.mlir b/test/TritonGPU/amd/amd-update-async-wait-count.mlir index afaf10d99f..c032ca03eb 100644 --- a/test/TritonGPU/amd/amd-update-async-wait-count.mlir +++ b/test/TritonGPU/amd/amd-update-async-wait-count.mlir @@ -18,10 +18,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %3 = ttg.async_commit_group tokens %2 // Do not wait on the second async_copy => waitcnt 2 - // CHECK: ttg.async_wait {{.*}} {num = 2 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 2 %9 = ttg.async_wait %1 {num = 0 : i32} // No async_copies in between => waitcnt 0 - // CHECK: ttg.async_wait {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 0 %10 = ttg.async_wait %3 {num = 0 : i32} tt.return } @@ -47,10 +47,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %3 = ttg.async_commit_group tokens %2 // Do not wait on the second async_copy => waitcnt 2 - // CHECK: ttg.async_wait {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 0 %9 = ttg.async_wait %3 {num = 0 : i32} // No async_copies in between => waitcnt 0 - // CHECK: ttg.async_wait {{.*}} {num = 2 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 2 %10 = ttg.async_wait %1 {num = 0 : i32} tt.return } @@ -77,9 +77,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %4 = tt.load %arg3 : tensor<128x16x!tt.ptr, #blocked> - // CHECK: ttg.async_wait {{.*}} {num = 2 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 2 %9 = ttg.async_wait %1 {num = 0 : i32} - // CHECK: ttg.async_wait {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 0 %10 = ttg.async_wait %3 {num = 0 : i32} tt.return } @@ -106,7 +106,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %2 = ttg.async_copy_global_to_local %arg4, %arg2 : tensor<16x256x!tt.ptr, #blocked1> -> <16x256xf16, #shared1, #smem, mutable> %3 = ttg.async_commit_group tokens %2 %8:2 = scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg15 = %1, %arg16 = %3) -> (!ttg.async.token, !ttg.async.token) : i32 { - // CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0 %10 = ttg.async_wait %arg15, %arg16 {num = 2 : i32} %11 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> %12 = ttg.async_commit_group tokens %11 @@ -114,7 +114,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %14 = ttg.async_commit_group tokens %13 scf.yield %12, %14: !ttg.async.token, !ttg.async.token } - // CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0 %9 = ttg.async_wait %8#0, %8#1 {num = 0 : i32} tt.return } @@ -145,7 +145,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %6 = ttg.async_copy_global_to_local %arg4, %arg2 : tensor<16x256x!tt.ptr, #blocked1> -> <16x256xf16, #shared1, #smem, mutable> %7 = ttg.async_commit_group tokens %6 %8:4 = scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg15 = %1, %arg16 = %5, %arg17 = %3, %arg18 = %7) -> (!ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 { - // CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 3 + // CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 3 %10 = ttg.async_wait %arg15, %arg17 {num = 2 : i32} %11 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> %12 = ttg.async_commit_group tokens %11 @@ -153,7 +153,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %14 = ttg.async_commit_group tokens %13 scf.yield %arg16, %12, %arg18, %14 : !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token } - // CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0 %9 = ttg.async_wait %8#0, %8#1, %8#2, %8#3 {num = 0 : i32} tt.return } @@ -185,12 +185,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %8:4 = scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg15 = %1, %arg16 = %5, %arg17 = %3, %arg18 = %7) -> (!ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 { %103 = scf.if %cond -> (!ttg.async.token) { // We wait on both tokens so we interleave with one iteration => 3 - // CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 3 + // CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 3 %token1 = ttg.async_wait %arg15, %arg17 {num = 2 : i32} scf.yield %token1 : !ttg.async.token } else { // We only wait on the token of the first load so we can interleave one more load => 3 + 2 - // CHECK: ttg.async_wait {{.*}} {num = 5 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 5 %token2 = ttg.async_wait %arg15 {num = 1 : i32} scf.yield %token2 : !ttg.async.token } @@ -200,7 +200,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %14 = ttg.async_commit_group tokens %13 scf.yield %arg16, %12, %arg18, %14 : !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token } - // CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0 %9 = ttg.async_wait %8#0, %8#1, %8#2, %8#3 {num = 0 : i32} tt.return } @@ -235,7 +235,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %cond_load = ttg.async_copy_global_to_local %arg4, %arg2 : tensor<16x256x!tt.ptr, #blocked1> -> <16x256xf16, #shared1, #smem, mutable> %cond_load_commit = ttg.async_commit_group tokens %cond_load // We wait on both tokens (3) and additionally we should count the load inside our block (+2) => 5 - // CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 5 + // CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 5 %token1 = ttg.async_wait %arg15, %arg17 {num = 2 : i32} scf.yield %token1 : !ttg.async.token } else { @@ -247,7 +247,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %14 = ttg.async_commit_group tokens %13 scf.yield %arg16, %12, %arg18, %14 : !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token } - // CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0 %9 = ttg.async_wait %8#0, %8#1, %8#2, %8#3 {num = 0 : i32} tt.return } @@ -279,7 +279,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %7 = ttg.async_commit_group tokens %6 %8:4 = scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg15 = %1, %arg16 = %5, %arg17 = %3, %arg18 = %7) -> (!ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 { // The then block contains 3 instructions and the else 1 so we expect the count to be 3 (1 + 2) because there are also 2 instructions outside the scf.if in the loop body - // CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 3 + // CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 3 %token1 = ttg.async_wait %arg15, %arg17 {num = 2 : i32} %103 = scf.if %cond -> (!ttg.async.token) { @@ -296,7 +296,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %14 = ttg.async_commit_group tokens %13 scf.yield %arg16, %103, %arg18, %14 : !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token } - // CHECK: ttg.async_wait {{.*}}, {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}}, {{.*}} {num_inst = 0 %9 = ttg.async_wait %8#0, %8#1, %8#2, %8#3 {num = 0 : i32} tt.return } @@ -323,14 +323,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %7 = ttg.async_commit_group tokens %6 // Dynamic iteration count so we should not count its body %30 = scf.for %arg21 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg30 = %6) -> (!ttg.async.token) : i32 { - // CHECK: ttg.async_wait {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 0 %31 = ttg.async_wait %arg30 {num = 1 : i32} // Emits 1 direct to lds instruction %32 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> %33 = ttg.async_commit_group tokens %32 scf.yield %33 : !ttg.async.token } - // CHECK: ttg.async_wait {{.*}} {num = 1 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 1 %10 = ttg.async_wait %1 {num = 1 : i32} tt.return } @@ -357,14 +357,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %7 = ttg.async_commit_group tokens %6 // Loop with 4 iterations => 4 instructions %30 = scf.for %arg21 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg30 = %6) -> (!ttg.async.token) : i32 { - // CHECK: ttg.async_wait {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 0 %31 = ttg.async_wait %arg30 {num = 1 : i32} // Emits 1 direct to lds instruction %32 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> %33 = ttg.async_commit_group tokens %32 scf.yield %33 : !ttg.async.token } - // CHECK: ttg.async_wait {{.*}} {num = 5 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 5 %10 = ttg.async_wait %1 {num = 1 : i32} tt.return } @@ -397,10 +397,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // Check that we do not take other TDM loads into account (they use a different HW counter) - // CHECK: ttg.async_wait {{.*}} {num = 2 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 2 %cw1 = ttg.async_wait %21 {num = 0 : i32} - // CHECK: ttg.async_wait {{.*}} {num = 0 + // CHECK: amdgpu.async_wait {{.*}} {num_inst = 0 %cw2 = ttg.async_wait %51 {num = 0 : i32} tt.return } diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index cdfc13aa95..2cf1105ee4 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -256,8 +256,6 @@ def make_ttgir(mod, metadata, options): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) - if use_async_copy: - amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch) pm.run(mod, 'make_ttgir') return mod @@ -283,6 +281,7 @@ def make_llir(src, metadata, options): # TritonGPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch) # custom_lds_size is an experimental parameter that defines amount of LDS available # for one thread block. Measured in bytes. # diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 7f5d080335..9c904fb803 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -775,4 +775,21 @@ def AsyncTDMWait : TT_AMDGPU_Op<"async_tdm_wait"> { let assemblyFormat = "$asyncToken attr-dict"; } +//===----------------------------------------------------------------------===// +// AsyncWait +//===----------------------------------------------------------------------===// + +def AsyncWaitOp : TT_AMDGPU_Op<"async_wait"> { + let summary = "Wait until there are less than or equal to the given number of outstanding async intrinsics"; + let description = [{ + Similar to ttg.async_wait but instead of waiting on oustanding ttg.async_commit_groups + this op waits on the number of outstanding async instructions/intrinsics as required for the + lowering to LLVM on the AMD backend. + }]; + + let arguments = (ins Variadic:$asyncToken, I32Attr:$num_inst); + let results = (outs TTG_AsyncToken:$retToken); + let assemblyFormat = "($asyncToken^)? attr-dict"; +} + #endif diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index bd9fc77e31..42723b8c4b 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -257,7 +257,7 @@ def TritonAMDGPUUpdateAsyncWaitCount: Pass<"tritonamdgpu-update-async-wait-count compute the number of interleaving global memory instructions to emit the correct waitcnt during lowering. }]; - let dependentDialects = []; + let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ Option<"archGenerationName", "arch-generation-name", diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp index fe2142fb88..55222173be 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp @@ -2,6 +2,7 @@ #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "TargetInfo.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" @@ -13,7 +14,7 @@ constexpr const char *syncedViaAsyncWaitAttrName = // if all defining operations are an AsyncWait bool comesFromAsyncWait(Value token) { if (auto defOp = token.getDefiningOp()) { - return isa(defOp); + return isa(defOp); } auto blockArg = dyn_cast(token); @@ -50,6 +51,22 @@ bool comesFromAsyncWait(Value token) { } } // namespace +void addLocalBarrierAfterAmdGpuAsyncWait(ModuleOp mod) { + auto *ctx = mod->getContext(); + + SmallVector waits; + mod->walk([&waits](amdgpu::AsyncWaitOp waitOp) { waits.push_back(waitOp); }); + + IRRewriter builder(mod.getContext()); + for (auto waitOp : waits) { + if (isa(waitOp->getNextNode())) + continue; + + builder.setInsertionPointAfter(waitOp); + builder.create(waitOp->getLoc()); + } +} + void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod) { auto *ctx = mod->getContext(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h index e03d35bd91..c826d017de 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h @@ -9,6 +9,12 @@ namespace mlir::triton::AMD { class TargetInfo; +// Walks the module and adds a LocalBarrier after any amdgpu.async_wait if there +// is not already a barrier following it. This mimicks what Member does for +// common async wait operations and avoids AMD specific modifications to Membar. +// This yields to the same behaviour compared to when membar adds the barrier. +void addLocalBarrierAfterAmdGpuAsyncWait(ModuleOp mod); + // Annotates LocalLoadOps with ttg.amdgpu.syncedByAsyncWait=true if they are // synced by an AsyncWait. void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index c3047355bd..b953188fe9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1887,14 +1887,15 @@ struct AtomicRMWOpConversion } }; -struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern { +struct AsyncWaitOpConversion + : public ConvertOpToLLVMPattern { AsyncWaitOpConversion(LLVMTypeConverter &converter, const AMD::TargetInfo &targetInfo, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} LogicalResult - matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor, + matchAndRewrite(amdgpu::AsyncWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -1912,7 +1913,7 @@ struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern { // interested in those. // Clamp vmcnt to 6bits; a lower vmcnt will produce a conservative wait - unsigned vmCnt = std::min(63u, op.getNum()); + unsigned vmCnt = std::min(63u, op.getNumInst()); // Extract low and high bits and combine while setting all other bits to 1 unsigned lowBits = vmCnt & 0xF; @@ -1925,7 +1926,7 @@ struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern { } case ISAFamily::GFX1250: { // Clamp asyncCnt to 6bits(hw imit); lower means conservative - unsigned asyncCnt = std::min(63u, op.getNum()); + unsigned asyncCnt = std::min(63u, op.getNumInst()); LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.s.wait.asynccnt", {}, {b.i16_val(asyncCnt)}); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index a514f51749..b033177e6a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -121,6 +121,8 @@ struct ConvertTritonAMDGPUToLLVM if (targetInfo.requiresAliasInfoForAsyncOps()) AMD::annotateLocalLoadsSyncedViaAsyncWait(mod); + + AMD::addLocalBarrierAfterAmdGpuAsyncWait(mod); ModuleMembarAnalysis membarPass(&allocation, mlir::triton::AMD::membarFilter); membarPass.run(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp index ff8459a9f4..fa39cb4dbd 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp @@ -109,10 +109,23 @@ void updateWaitCount(WaitType waitOp, waitCnt = std::min(waitCnt, tokenWaitCnt); } - if (waitCnt == std::numeric_limits::max() || waitOp.getNum() == waitCnt) - return; + if (waitCnt == std::numeric_limits::max()) { + // TODO(alex): set to conservative waitcnt=0 after gluon refactoring + waitCnt = waitOp.getNum(); + } - rewriter.modifyOpInPlace(waitOp, [&]() { waitOp.setNum(waitCnt); }); + if (std::is_same_v) { + // Replace ttg.async_wait which counts outstanding commits groups with + // amdgpu.async_wait which counts the number of oustanding + // intrinsics + auto tokens = waitOp.getAsyncToken(); + rewriter.setInsertionPointAfter(waitOp); + rewriter.replaceOpWithNewOp(waitOp, tokens, waitCnt); + } else { + // For TDM each TTGIR op will create exactly one intrinsics so we do not use + // a separate op + rewriter.modifyOpInPlace(waitOp, [&]() { waitOp.setNum(waitCnt); }); + } } } // anonymous namespace From 430f8b290d9084c17d10d0609b7c8f1996f5254d Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Wed, 29 Oct 2025 19:05:41 +0000 Subject: [PATCH 6/9] [BACKEND] Fix unnecessary cvt caused by wgmma wait op (#8579) Fixes #8578 We're using the wrong output constraint which leads llvm to extend the fp16 value to 32-bits. Fixing the constraint removes the conversion. Note that we still end up with a no-op sequence like: ```ptx mov.b32 {%rs1, %rs2}, %r1 mov.b32 %r2, {%rs1, %rs2} ``` However, `ptxas` is able to optimize these out. --- test/Conversion/nvgpu_to_llvm.mlir | 13 ++++++++++ .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 26 ++++++++++++++++--- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/test/Conversion/nvgpu_to_llvm.mlir b/test/Conversion/nvgpu_to_llvm.mlir index c9b4804c04..e0b76704a6 100644 --- a/test/Conversion/nvgpu_to_llvm.mlir +++ b/test/Conversion/nvgpu_to_llvm.mlir @@ -58,6 +58,19 @@ llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) { // ----- +!struct = !llvm.struct<(f32, f32, i32, i32, f16, f16)> + +// CHECK-LABEL: @wgmma_wait +llvm.func @wgmma_wait(%in: !struct) { + // CHECK: // wait for regs: $0,$1,$2,$3,$4,$5 + // CHECK: wgmma.wait_group.sync.aligned 0; + // CHECK: "=f,=f,=r,=r,=h,=h,0,1,2,3,4,5" + %out = nvgpu.wgmma_wait_group %in {pendings = 0 : i32} : !struct + llvm.return +} + +// ----- + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @tensor_memory_base_lowering // CHECK: %[[TID:.+]] = nvvm.read.ptx.sreg.tid.x : i32 diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index d86881e240..879589e6dc 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -303,9 +303,29 @@ class WGMMAWaitGroupOpPattern : public OpRewritePattern { Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { auto outputStructType = cast(op.getType()); uint32_t numOutputRegs = outputStructType.getBody().size(); - std::string output = - outputStructType.getBody().front().isF32() ? "=f" : "=r"; - return Constraints(numOutputRegs, output); + Constraints constraints; + constraints.reserve(numOutputRegs); + mlir::DataLayout dl(op->getParentOfType()); + for (auto ty : outputStructType.getBody()) { + auto bitwidth = dl.getTypeSizeInBits(ty); + std::string c; + switch (bitwidth) { + case 64: + c = "=l"; + break; + case 32: + c = ty.isF32() ? "=f" : "=r"; + break; + case 16: + c = "=h"; + break; + default: + llvm::report_fatal_error("Unexpected bitwidth in WGMMAWaitGroupOp: " + + Twine(bitwidth)); + } + constraints.push_back(c); + } + return constraints; } OperandsAndConstraints From 3f5eb5075e4f4875a0f12d686c958a0f15e901cc Mon Sep 17 00:00:00 2001 From: xiaohuguo2023 <149615094+xiaohuguo2023@users.noreply.github.com> Date: Wed, 29 Oct 2025 22:26:43 +0000 Subject: [PATCH 7/9] [AMD] reimplement fast_tanhf() to avoid overflow (#8551) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### The Problem with the Original Formula The original formula is: ``` tanh(x) = (e^(2x) - 1) / (e^(2x) + 1) ``` - Issue with large positive x: - When x = 20: e^(40) ≈ 2.4 × 10^17 → manageable - When x = 50: e^(100) ≈ 2.7 × 10^43 → overflow to infinity - Result: (∞ - 1)/(∞ + 1) = NaN x - For negative x: The formula actually works fine because e^(2x) → 0, giving (-1)/(1) = -1 ### The Numerically Stable Solution - For Positive x: Reformulation ``` tanh(x) = (e^(2x) - 1) / (e^(2x) + 1) = (e^(2x) + 1 - 2) / (e^(2x) + 1) = 1 - 2/(e^(2x) + 1) ``` - For Negative x: Using Symmetry ``` tanh(-x) = (e^(-2x) - 1) / (e^(-2x) + 1) = (2/(e^(-2x) + 1) - 1) = -1 × (1 - 2/(e^(2|x|) + 1)) ``` ### Unified formulation: ``` tanh(x) = sign(x) × (1 - 2/(e^(2|x|) + 1)) ``` --- .../TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp | 56 +++++++++++++------ 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp index 21cbbe64e6..75f04163d6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -106,29 +106,49 @@ class CallOpConversion : public OpRewritePattern { assert(operands[0].getType().getIntOrFloatBitWidth() == 32); LLVM::FastmathFlagsAttr defaultFlags{}; - // Calculate 2*x - auto twoX = LLVM::FMulOp::create( - rewriter, loc, rewriter.getF32Type(), operands[0], + // Numerically stable tanh implementation: + // For positive x: tanh(x) = 1 - 2/(e^(2x) + 1) + // For negative x: tanh(x) = -tanh(-x) = -(1 - 2/(e^(-2x) + 1)) + // = 2/(e^(-2x) + 1) - 1 + // This avoids overflow when e^(2x) becomes infinity for large x + + // Get absolute value of x + auto absX = LLVM::FAbsOp::create(rewriter, loc, rewriter.getF32Type(), + operands[0]); + + // Calculate 2*|x| + auto twoAbsX = LLVM::FMulOp::create( + rewriter, loc, rewriter.getF32Type(), absX, LLVM::createConstantF32(loc, rewriter, 2.0), defaultFlags); - // Calculate fast_expf(2*x) using the utility function - auto exp2X = createFastExpf(rewriter, loc, twoX->getResult(0), - rewriter.getF32Type(), ftz); + // Calculate e^(2*|x|) + auto exp2AbsX = createFastExpf(rewriter, loc, twoAbsX->getResult(0), + rewriter.getF32Type(), ftz); - // Calculate exp2X - 1 - auto exp2XMinus1 = LLVM::FSubOp::create( - rewriter, loc, rewriter.getF32Type(), exp2X->getResult(0), + // Calculate e^(2*|x|) + 1 + auto exp2AbsXPlus1 = LLVM::FAddOp::create( + rewriter, loc, rewriter.getF32Type(), exp2AbsX->getResult(0), LLVM::createConstantF32(loc, rewriter, 1.0), defaultFlags); - // Calculate exp2X + 1 - auto exp2XPlus1 = LLVM::FAddOp::create( - rewriter, loc, rewriter.getF32Type(), exp2X->getResult(0), - LLVM::createConstantF32(loc, rewriter, 1.0), defaultFlags); - - // Calculate tanh(X) = (exp2X - 1) / (exp2X + 1) - replacementOp = LLVM::FDivOp::create( - rewriter, loc, returnType, exp2XMinus1->getResult(0), - exp2XPlus1->getResult(0), defaultFlags); + // Calculate 2 / (e^(2*|x|) + 1) + auto two = LLVM::createConstantF32(loc, rewriter, 2.0); + auto ratio = + LLVM::FDivOp::create(rewriter, loc, rewriter.getF32Type(), two, + exp2AbsXPlus1->getResult(0), defaultFlags); + + // Calculate 1 - 2/(e^(2*|x|) + 1) + auto one = LLVM::createConstantF32(loc, rewriter, 1.0); + auto posResult = + LLVM::FSubOp::create(rewriter, loc, rewriter.getF32Type(), one, + ratio->getResult(0), defaultFlags); + + // Apply the sign of the original input using copysign + // tanh(x) = sign(x) * (1 - 2/(e^(2*|x|) + 1)) + const char *intrinsic = "llvm.copysign.f32"; + auto args = + llvm::SmallVector{posResult->getResult(0), operands[0]}; + replacementOp = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, + returnType, args); } if (replacementOp) { From c186592a17299439900d712e85556e8578345821 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 29 Oct 2025 16:18:17 -0700 Subject: [PATCH 8/9] [FRONTEND] Add scales dimension checks for dot_scaled (#8564) --- include/triton/Dialect/Triton/IR/TritonOps.td | 1 + lib/Dialect/Triton/IR/Ops.cpp | 38 +++++++++++++++++++ .../test/unit/language/test_compile_errors.py | 20 ++++++++++ python/triton/language/semantic.py | 19 ++++++++++ test/Conversion/tritongpu_to_llvm_sm120.mlir | 8 ++-- test/Triton/invalid.mlir | 15 ++++++++ test/TritonGPU/accelerate-matmul.mlir | 8 ++-- 7 files changed, 101 insertions(+), 8 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index a745dd12e8..dbb4c24a22 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -732,6 +732,7 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, `lhs` `=` $a_elem_type `rhs` `=` $b_elem_type attr-dict `:` type($a) (`,` type($a_scale)^)? `*` type($b) (`,` type($b_scale)^)? `->` type($d) }]; + let hasVerifier = 1; } // diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 67b1fc683e..7be62c7340 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -356,6 +356,44 @@ bool DotScaledOp::verifyOutputDims() { return true; } +LogicalResult DotScaledOp::verify() { + auto aShape = this->getA().getType().getShape(); + int64_t rank = aShape.size(); + + auto k = aShape[rank - 1]; + if (this->getAElemType() == ScaleDotElemType::E2M1) { + if (this->getLhsKPack()) + k *= 2; + } + auto cShape = this->getC().getType().getShape(); + int64_t mDim = cShape[cShape.size() - 2]; + int64_t nDim = cShape[cShape.size() - 1]; + + if (getAScale()) { + auto aScaleShape = getAScale().getType().getShape(); + if (aScaleShape[rank - 2] != mDim) + return this->emitError( + "scales M dimension must match the operand M dimension"); + int scale_factor = + isa(getAScale().getType().getElementType()) ? 16 : 32; + if (aScaleShape[rank - 1] != k / scale_factor) + return this->emitError("scales K dimension must match the operand K " + "divided by the scale factor"); + } + if (getBScale()) { + auto bScaleShape = getBScale().getType().getShape(); + if (bScaleShape[rank - 2] != nDim) + return this->emitError( + "scales N dimension must match the operand N dimension"); + int scale_factor = + isa(getBScale().getType().getElementType()) ? 16 : 32; + if (bScaleShape[rank - 1] != k / scale_factor) + return this->emitError("scales K dimension must match the operand K " + "divided by the scale factor"); + } + return success(); +} + //-- MakeRangeOp -- OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { // make_range(start, start + 1) -> constant(start) diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index bf2caee374..2b26cd675a 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -489,3 +489,23 @@ def kernel(N: tl.constexpr): with pytest.raises(CompilationError, match="N marked as constexpr and listed in do_not_specialize"): kernel[(1, )](5) + + +def test_dot_scaled_shape_verification(fresh_triton_cache): + + @triton.jit + def kernel(): + M: tl.constexpr = 32 + K: tl.constexpr = 64 + N: tl.constexpr = 32 + a = tl.full((M, K), 0, tl.uint8) + b = tl.full((K, N), 0, tl.uint8) + lhs_scale_wrong = tl.full((M, 4), 0, tl.uint8) + rhs_scale = tl.full((N, 2), 0, tl.uint8) + acc = tl.full((M, N), 0.0, tl.float32) + tl.dot_scaled(a, lhs_scale_wrong, "e5m2", b, rhs_scale, "e5m2", acc, False, True, True, tl.float32) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + assert str(e.value.__cause__) == "lhs_scale must be a tensor of shape [32, 2]. Got ['32', '4']" diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 4b75d8e7cd..5f58f7bafb 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1590,6 +1590,20 @@ def _bitcast_to_fp_type(self, val: TensorTy, float_format: str): assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" return self.bitcast(val, triton_ty) + def verify_scaled_shape(self, M, N, K, lhs_scale, rhs_scale): + if lhs_scale is not None: + scale_factor = 16 if lhs_scale.dtype.is_fp8e4nv() else 32 + lhs_scale_shape = lhs_scale.type.shape + assert lhs_scale_shape == [ + M, K // scale_factor + ], f"lhs_scale must be a tensor of shape [{M}, {K // scale_factor}]. Got {lhs_scale_shape}" + if rhs_scale is not None: + scale_factor = 16 if rhs_scale.dtype.is_fp8e4nv() else 32 + rhs_scale_shape = rhs_scale.type.shape + assert rhs_scale_shape == [ + N, K // scale_factor + ], f"rhs_scale must be a tensor of shape [{N}, {K // scale_factor}]. Got {rhs_scale_shape}" + def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy, rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool, lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy: @@ -1621,8 +1635,11 @@ def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: T assert PACKED_B_DIM == PACKED_A_DIM, f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}" B = lhs.type.shape[0] if lhs_rank == 3 else None + K = K_LHS if not lhs_k_pack: M = M * PACKED_A + else: + K = K * PACKED_A if not rhs_k_pack: N = N * PACKED_B ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) @@ -1634,6 +1651,8 @@ def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: T assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle + self.verify_scaled_shape(M, N, K, None if lhs_scale_is_none else lhs_scale, + None if rhs_scale_is_none else rhs_scale) return self.tensor( self.builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle, rhs_format_enum, fast_math, lhs_k_pack, rhs_k_pack, acc_handle), ret_ty) diff --git a/test/Conversion/tritongpu_to_llvm_sm120.mlir b/test/Conversion/tritongpu_to_llvm_sm120.mlir index e07c9af9d0..19ec3ccf40 100644 --- a/test/Conversion/tritongpu_to_llvm_sm120.mlir +++ b/test/Conversion/tritongpu_to_llvm_sm120.mlir @@ -9,17 +9,17 @@ module attributes {"ttg.target" = "cuda:120", "ttg.num-ctas" = 1 : i32, "ttg.num // CHECK: mma.sync.aligned.m16n8k32.row.col.kind::mxf8f6f4.block_scale.scale_vec::1X tt.func public @sm120_mmav2_dot_scaled( %a: tensor<128x32xf8E5M2, #blocked_k>, - %sa: tensor<128x2xi8, #blocked>, + %sa: tensor<128x1xi8, #blocked>, %b: tensor<32x128xf8E5M2, #blocked>, - %sb: tensor<128x2xi8, #blocked>, + %sb: tensor<128x1xi8, #blocked>, %out: !tt.ptr ){ %c = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %a_d = ttg.convert_layout %a : tensor<128x32xf8E5M2, #blocked_k> -> tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> %b_d = ttg.convert_layout %b : tensor<32x128xf8E5M2, #blocked> -> tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> %d = tt.dot_scaled %a_d scale %sa, %b_d scale %sb, %c lhs = e5m2 rhs = e5m2 {fastMath = false} - : tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<128x2xi8, #blocked> - * tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, tensor<128x2xi8, #blocked> + : tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<128x1xi8, #blocked> + * tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, tensor<128x1xi8, #blocked> -> tensor<128x128xf32, #blocked> %out_splat = tt.splat %out : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> %out_ptrs = tt.broadcast %out_splat : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index f303a386d9..767dd8ecc2 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -541,6 +541,21 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} // ----- +module { + tt.func @dot_scaled_invalid_dims( + %a: tensor<128x128xf8E4M3FN>, + %b: tensor<128x128xf8E4M3FN>, + %a_scale: tensor<128x128xi8>, + %b_scale: tensor<128x4xi8>) -> tensor<128x128xf32> { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + // expected-error @below {{scales K dimension must match the operand K divided by the scale factor}} + %result = tt.dot_scaled %a scale %a_scale, %b scale %b_scale, %cst lhs = e4m3 rhs = e4m3 {fastMath = true} : tensor<128x128xf8E4M3FN>, tensor<128x128xi8> * tensor<128x128xf8E4M3FN>, tensor<128x4xi8>-> tensor<128x128xf32> + tt.return %result : tensor<128x128xf32> + } +} + +// ----- + tt.func @unsplat_invalid(%arg0: tensor<128xf32>) { // expected-error @below {{source tensor must have exactly one element}} %0 = tt.unsplat %arg0 : tensor<128xf32> diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index e5b4e9f9f6..77b89cc0fd 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -680,9 +680,9 @@ module attributes {"ttg.target" = "cuda:120", "ttg.num-ctas" = 1 : i32, "ttg.num // CHECK-LABEL: @sm120_dot_scaled_basic tt.func public @sm120_dot_scaled_basic( %a: tensor<128x32xi8, #blocked_k>, - %scale_a: tensor<128x2xi8, #blocked>, + %scale_a: tensor<128x1xi8, #blocked>, %b: tensor<32x128xi8, #blocked>, - %scale_b: tensor<128x2xi8, #blocked> + %scale_b: tensor<128x1xi8, #blocked> ) -> tensor<128x128xf32, #blocked> { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> // CHECK-DAG: tt.dot_scaled @@ -690,8 +690,8 @@ module attributes {"ttg.target" = "cuda:120", "ttg.num-ctas" = 1 : i32, "ttg.num // CHECK-DAG: #linear1 // CHECK-NOT: ttng.tc_gen5_mma_scaled %d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} - : tensor<128x32xi8, #blocked_k>, tensor<128x2xi8, #blocked> - * tensor<32x128xi8, #blocked>, tensor<128x2xi8, #blocked> + : tensor<128x32xi8, #blocked_k>, tensor<128x1xi8, #blocked> + * tensor<32x128xi8, #blocked>, tensor<128x1xi8, #blocked> -> tensor<128x128xf32, #blocked> tt.return %d : tensor<128x128xf32, #blocked> } From 478816633db30ada6c676e038aac4440f68e70c9 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Sat, 15 Nov 2025 02:21:29 +0000 Subject: [PATCH 9/9] [TEST] Update triton_kernels skiplist after `2b29c3d` Signed-off-by: Whitney Tsang --- scripts/skiplist/default/triton_kernels.txt | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/scripts/skiplist/default/triton_kernels.txt b/scripts/skiplist/default/triton_kernels.txt index c39b51f509..69153c1a90 100644 --- a/scripts/skiplist/default/triton_kernels.txt +++ b/scripts/skiplist/default/triton_kernels.txt @@ -12,18 +12,23 @@ tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-704-800 tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] +tests/test_matmul.py::test_op[False-False-False-True-True-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] +tests/test_matmul.py::test_op[False-False-False-True-True-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] +tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-700-700-ragged-float16-float16-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] +tests/test_matmul.py::test_op[False-False-True-True-True-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] +tests/test_matmul.py::test_op[False-False-True-True-True-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-700-700-ragged-float16-float16-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] @@ -37,6 +42,8 @@ tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-704-800- tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] +tests/test_matmul.py::test_op[False-True-False-True-True-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] +tests/test_matmul.py::test_op[False-True-False-True-True-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-700-700-ragged-float16-float16-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] @@ -44,9 +51,12 @@ tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800- tests/test_matmul.py::test_op[False-True-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] +tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-700-700-ragged-float16-float16-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False] tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True] tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True] +tests/test_matmul.py::test_op[False-True-True-True-True-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True] +tests/test_matmul.py::test_op[False-True-True-True-True-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]