From 8de5bb5456d1f375ee80454b25ea5e08b666df2a Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 3 Dec 2025 12:29:30 -0800 Subject: [PATCH 01/40] init einsum Signed-off-by: Phuong Nguyen --- tests/jax/test_custom_call_compute.py | 53 +++ tests/jax/test_einsum.py | 219 +++++++++ transformer_engine/jax/cpp_extensions/amax.py | 36 ++ transformer_engine/jax/cpp_extensions/base.py | 89 +++- transformer_engine/jax/cpp_extensions/gemm.py | 53 +-- .../jax/cpp_extensions/quantization.py | 40 +- transformer_engine/jax/einsum.py | 424 ++++++++++++++++++ transformer_engine/jax/quantize/tensor.py | 119 +++-- 8 files changed, 930 insertions(+), 103 deletions(-) create mode 100644 tests/jax/test_einsum.py create mode 100644 transformer_engine/jax/einsum.py diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index c8bd9d47c3..897d9f683e 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1290,6 +1290,59 @@ def test_quantize_dact_dbias_mxfp8_scaling( ) +class TestQuantizeWithVmap: + """Test vmap support for quantization primitives.""" + + @pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16]) + @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("q_layout", [QuantizeLayout.ROWWISE]) + def test_vmap_quantize(self, in_dtype, scaling_mode, q_layout): + """Test that vmap works with tex.quantize using the general batcher.""" + # Determine q_dtype based on scaling mode + if scaling_mode.is_nvfp4_scaling: + q_dtype = jnp.float4_e2m1fn + else: + q_dtype = jnp.float8_e4m3fn + + # Create batched input (E, M, K) - E experts + E, M, K = 4, 64, 128 + key = jax.random.PRNGKey(0) + batched_input = jax.random.uniform(key, (E, M, K), in_dtype) + + # Create per-expert quantizers + quantizers = [ + QuantizerFactory.create( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + ) + for _ in range(E) + ] + + # Stack quantizers for vmap + stacked_quantizers = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *quantizers) + + # Vmap over expert dimension + def quantize_single(x, quantizer): + return tex.quantize(x, quantizer=quantizer, flatten_axis=-1) + + vmapped_quantize = jax.vmap(quantize_single, in_axes=(0, 0)) + result = vmapped_quantize(batched_input, stacked_quantizers) + + # Verify shapes + assert result.data.shape == (E, M, K) + assert result.scale_inv.shape[0] == E # Per-expert scales + + # Compare with calling quantize for each expert individually + individual_results = [] + for i in range(E): + res_i = tex.quantize(batched_input[i], quantizer=quantizers[i], flatten_axis=-1) + individual_results.append(res_i.data) + + expected = jnp.stack(individual_results, axis=0) + assert_allclose(result.data, expected, dtype=quantizers[0].q_dtype) + + valid_fp8_gemm_operand_types = [ (jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e5m2, jnp.float8_e4m3fn), diff --git a/tests/jax/test_einsum.py b/tests/jax/test_einsum.py new file mode 100644 index 0000000000..39dffa6787 --- /dev/null +++ b/tests/jax/test_einsum.py @@ -0,0 +1,219 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Tests for TE einsum operation with FP8 quantization.""" + +import jax +import jax.numpy as jnp +import pytest +from jax import value_and_grad + +from utils import assert_allclose, pytest_parametrize_wrapper +from transformer_engine.jax.einsum import einsum +from transformer_engine.jax.quantize import ( + QuantizerFactory, + QuantizeMeta, + QuantizeMetaSet, +) +from transformer_engine.jax.quantize import helper + + +# Test parameters +DTYPES = [jnp.bfloat16] +# (B, S, M, E, C, H) +# B: Batch size +# S: Sequence length (number of tokens) +# M: Model dimension (hidden size) +# E: Number of experts +# C: Capacity (max tokens per expert) +# H: Hidden dimension (MLP intermediate size) +MOE_CASES = [ + (2, 32, 128, 4, 32, 64), +] + +# Get supported recipes +supported_recipes = helper.get_supported_quantization_recipes() +supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] + + +@pytest.fixture(autouse=True, scope="module") +def init(): + """WAR for CUDA uninitialize error""" + # Calling customcalls before jax may cause CUDA uninitialize error + _ = jnp.zeros(0) + yield + + +class TestMoEMLPWithRecipes: + """Test MoE MLP operations with different FP8 recipes and gradients.""" + + def _get_quantizer_sets(self, recipe, num_experts): + return QuantizerFactory.create_set( + n_quantizer_sets=num_experts, + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ), + ) + + def _einsum(self, equation, *operands, quantizer_sets=None, quantizer_dim=None, fallback=False): + out = einsum( + equation, + *operands, + quantizer_sets=quantizer_sets, + quantizer_dim=quantizer_dim, + fallback=fallback, + ) + return jnp.mean(out) + + def _ref_einsum(self, equation, *operands): + out = jnp.einsum(equation, *operands) + return jnp.mean(out) + + @pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES) + @pytest_parametrize_wrapper("recipe", supported_recipes) + def test_mlp_up_grad(self, B, S, M, E, C, H, recipe): + """Test MLP up: EBCM,EMH->EBCH with gradients and different recipes.""" + # Create per-expert quantizers + quantizer_sets = self._get_quantizer_sets(recipe, E) + dispatched = jax.random.normal( + jax.random.PRNGKey(0), (E, B, C, M), dtype=jnp.bfloat16 + ) / jnp.sqrt(M) + weights = jax.random.normal(jax.random.PRNGKey(1), (E, M, H), dtype=jnp.bfloat16) + + # Compute with TE einsum with quantization + loss_te, grads_te = value_and_grad(self._einsum, argnums=(1, 2))( + "EBCM,EMH->EBCH", dispatched, weights, quantizer_sets=quantizer_sets, quantizer_dim="E" + ) + + # Compute reference (BF16) + loss_ref, grads_ref = value_and_grad(self._ref_einsum, argnums=(1, 2))( + "EBCM,EMH->EBCH", dispatched, weights + ) + + # Verify shapes and no NaNs + assert grads_te[0].shape == dispatched.shape + assert grads_te[1].shape == weights.shape + assert not jnp.isnan(loss_te) + assert jnp.all(jnp.isfinite(grads_te[0])) + assert jnp.all(jnp.isfinite(grads_te[1])) + + # Compare with reference (with FP8 tolerance) + assert_allclose(loss_te, loss_ref, dtype=quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[0], grads_ref[0], dtype=quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[1], grads_ref[1], dtype=quantizer_sets[0].dgrad.q_dtype) + + @pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES) + @pytest_parametrize_wrapper("recipe", supported_recipes) + def test_mlp_down_grad(self, B, S, M, E, C, H, recipe): + """Test MLP down: EBCH,EHM->EBCM with gradients and different recipes.""" + # Create per-expert quantizers + quantizer_sets = self._get_quantizer_sets(recipe, E) + + hidden = jax.random.normal( + jax.random.PRNGKey(0), (E, B, C, H), dtype=jnp.bfloat16 + ) / jnp.sqrt(H) + weights = jax.random.normal(jax.random.PRNGKey(1), (E, H, M), dtype=jnp.bfloat16) + + # Compute with TE einsum with quantization + loss_te, grads_te = value_and_grad(self._einsum, argnums=(1, 2))( + "EBCH,EHM->EBCM", hidden, weights, quantizer_sets=quantizer_sets, quantizer_dim="E" + ) + + # Compute reference (BF16) + loss_ref, grads_ref = value_and_grad(self._ref_einsum, argnums=(1, 2))( + "EBCH,EHM->EBCM", hidden, weights + ) + + # Verify shapes and no NaNs + assert grads_te[0].shape == hidden.shape + assert grads_te[1].shape == weights.shape + assert not jnp.isnan(loss_te) + assert jnp.all(jnp.isfinite(grads_te[0])) + assert jnp.all(jnp.isfinite(grads_te[1])) + + # Compare with reference (with FP8 tolerance) + assert_allclose(loss_te, loss_ref, dtype=quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[0], grads_ref[0], dtype=quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[1], grads_ref[1], dtype=quantizer_sets[0].dgrad.q_dtype) + + @pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES) + @pytest_parametrize_wrapper("recipe", supported_recipes) + def test_full_moe_grad(self, B, S, M, E, C, H, recipe): + """Test full MoE pipeline (all 4 einsums) with gradients and different recipes.""" + # Create per-expert quantizers for each einsum + mlp_up_quantizer_sets = self._get_quantizer_sets(recipe, E) + mlp_down_quantizer_sets = self._get_quantizer_sets(recipe, E) + + tokens = jax.random.normal(jax.random.PRNGKey(0), (B, S, M), dtype=jnp.bfloat16) / jnp.sqrt(M) + routing = jax.random.normal(jax.random.PRNGKey(1), (B, S, E, C), dtype=jnp.bfloat16) + routing = jax.nn.softmax(routing, axis=-1) # Normalize routing weights + up_weights = jax.random.normal( + jax.random.PRNGKey(2), (E, M, H), dtype=jnp.bfloat16 + ) / jnp.sqrt(H) + down_weights = jax.random.normal( + jax.random.PRNGKey(3), (E, H, M), dtype=jnp.bfloat16 + ) / jnp.sqrt(M) + + # TE implementation with quantization + def full_moe_te(tokens, routing, up_w, down_w): + """Complete MoE pipeline with TE einsum.""" + dispatched = einsum("BSM,BSEC->EBCM", tokens, routing, fallback=True) + hidden = einsum( + "EBCM,EMH->EBCH", + dispatched, + up_w, + quantizer_sets=mlp_up_quantizer_sets, + quantizer_dim="E", + ) + expert_out = einsum( + "EBCH,EHM->EBCM", + hidden, + down_w, + quantizer_sets=mlp_down_quantizer_sets, + quantizer_dim="E", + ) + output = einsum("EBCM,BSEC->BSM", expert_out, routing, fallback=True) + return jnp.sum(output) + + # Reference implementation with jnp.einsum + def full_moe_ref(tokens, routing, up_w, down_w): + """Complete MoE pipeline with jnp.einsum.""" + dispatched = jnp.einsum("BSM,BSEC->EBCM", tokens, routing) + hidden = jnp.einsum("EBCM,EMH->EBCH", dispatched, up_w) + expert_out = jnp.einsum("EBCH,EHM->EBCM", hidden, down_w) + output = jnp.einsum("EBCM,BSEC->BSM", expert_out, routing) + return jnp.sum(output) + + loss_te, grads_te = value_and_grad(full_moe_te, argnums=(0, 1, 2, 3))( + tokens, routing, up_weights, down_weights + ) + + loss_ref, grads_ref = value_and_grad(full_moe_ref, argnums=(0, 1, 2, 3))( + tokens, routing, up_weights, down_weights + ) + + # Verify all gradient shapes + assert grads_te[0].shape == tokens.shape, f"tokens grad shape mismatch" + assert grads_te[1].shape == routing.shape, f"routing grad shape mismatch" + assert grads_te[2].shape == up_weights.shape, f"up_weights grad shape mismatch" + assert grads_te[3].shape == down_weights.shape, f"down_weights grad shape mismatch" + + # Verify no NaNs or Infs + assert not jnp.isnan(loss_te), "Loss is NaN" + assert jnp.isfinite(loss_te), "Loss is Inf" + assert jnp.all(jnp.isfinite(grads_te[0])), "tokens grad has NaN/Inf" + assert jnp.all(jnp.isfinite(grads_te[1])), "routing grad has NaN/Inf" + assert jnp.all(jnp.isfinite(grads_te[2])), "up_weights grad has NaN/Inf" + assert jnp.all(jnp.isfinite(grads_te[3])), "down_weights grad has NaN/Inf" + + # Compare with reference (with FP8 tolerance) + assert_allclose(loss_te, loss_ref, dtype=mlp_up_quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[0], grads_ref[0], dtype=mlp_up_quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[1], grads_ref[1], dtype=mlp_up_quantizer_sets[0].dgrad.q_dtype) + assert_allclose(grads_te[2], grads_ref[2], dtype=mlp_down_quantizer_sets[0].x.q_dtype) + assert_allclose(grads_te[3], grads_ref[3], dtype=mlp_down_quantizer_sets[0].dgrad.q_dtype) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/transformer_engine/jax/cpp_extensions/amax.py b/transformer_engine/jax/cpp_extensions/amax.py index 2f3bc402ec..19e229c1ee 100644 --- a/transformer_engine/jax/cpp_extensions/amax.py +++ b/transformer_engine/jax/cpp_extensions/amax.py @@ -160,6 +160,18 @@ def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types output_spec = (f"{prefix}_amax",) return SdyShardingRule((input_spec,), (output_spec,)) + @staticmethod + def batcher(batched_args, batch_dims, *, amax_scope, transpose_batch_sequence): + """Batcher for amax calculation - returns single amax value.""" + return AmaxCalculationPrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "amax_scope": amax_scope, + "transpose_batch_sequence": transpose_batch_sequence, + }, + ) + register_primitive(AmaxCalculationPrimitive, outer_only=True) @@ -370,6 +382,30 @@ def shardy_sharding_rule( output_post_rht_amax_spec = (f"{prefix}_post_rht_amax",) return SdyShardingRule((input_spec,), (output_amax_spec, output_post_rht_amax_spec)) + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """Batcher for RHT amax calculation - returns 2 amax values.""" + return RHTAmaxCalculationPrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "amax_scope": amax_scope, + "transpose_batch_sequence": transpose_batch_sequence, + "rht_matrix_random_sign_mask_t": rht_matrix_random_sign_mask_t, + "produce_regular_amax": produce_regular_amax, + "flatten_axis": flatten_axis, + }, + ) + register_primitive(RHTAmaxCalculationPrimitive) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 556b587191..9f88265e93 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,13 +7,14 @@ import warnings from abc import ABCMeta, abstractmethod from functools import partial +from typing import Any, Sequence, Union, Tuple from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching from jax._src import dispatch -from jax import ffi +from jax import ffi, numpy as jnp import transformer_engine_jax @@ -168,6 +169,92 @@ def shardy_sharding_rule(*args): del args return "... -> ..." + @classmethod + def batcher_impl( + cls, + batched_args: Sequence[Any], + batch_dims: Sequence[Union[int, None]], + static_kwargs: dict, + ) -> Tuple[Tuple[Any, ...], Tuple[Union[int, None], ...]]: + """Batcher implementation for JAX primitives. + + Implements the standard batching pattern: loop over batch dimension, + call primitive for each slice, and stack results. + + Args: + batched_args: Tuple of input tensors (some may be batched) + batch_dims: Tuple indicating batch dimension for each arg (None if not batched) + static_kwargs: Dictionary of static arguments to pass to primitive.bind() + + Returns: + Tuple of (output_tensors, output_batch_dims) + + Example: + @staticmethod + def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): + return MyPrimitive.batcher_impl( + batched_args, batch_dims, + static_kwargs={'arg1': arg1, 'arg2': arg2, 'arg3': arg3}, + ) + """ + from jax import lax + + # Find batch dimension and validate all batched args have the same batch_dim + batch_dim = None + batch_size = None + for arg, bdim in zip(batched_args, batch_dims): + if bdim is not None: + if batch_dim is None: + batch_dim = bdim + batch_size = arg.shape[bdim] + elif bdim != batch_dim: + raise ValueError( + "All batched arguments must have the same batch dimension. " + f"Got batch_dims={batch_dims}" + ) + assert batch_dim is not None and batch_size is not None, "Invalid batching config!" + + # Loop over batch dimension and collect results + all_results = [] + + for i in range(batch_size): + # Extract slice for each argument + sliced_args = [] + for arg, bdim in zip(batched_args, batch_dims): + if bdim is not None: + slice_i = lax.index_in_dim(arg, i, bdim, keepdims=False) + sliced_args.append(slice_i) + else: # For empty args + sliced_args.append(arg) + + # Call primitive with unbatched slices + result_i = cls.outer_primitive.bind(*sliced_args, **static_kwargs) + + # Normalize to tuple + if not isinstance(result_i, (tuple, list)): + result_i = (result_i,) + elif isinstance(result_i, list): + result_i = tuple(result_i) + + all_results.append(result_i) + + # Transpose: from list of tuples to tuple of lists + # all_results = [(out0_0, out1_0, ...), (out0_1, out1_1, ...), ...] + # transposed = ([out0_0, out0_1, ...], [out1_0, out1_1, ...], ...) + transposed = tuple(zip(*all_results)) + + # Stack each output along the batch dimension + stacked_results = tuple( + jnp.stack(list(out_list), axis=batch_dim) for out_list in transposed + ) + + # Single output: return unwrapped result + if len(stacked_results) == 1: + return stacked_results[0], batch_dim + + # Multiple outputs: return tuple of results + return stacked_results, [batch_dim for _ in stacked_results] + # Registry to store all registered primitive classes _primitive_registry = {} diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 76a8b225ba..55a1700838 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -808,40 +808,33 @@ def batcher( sequence_dim, is_outer, ): - del transpose_batch_sequence, sequence_dim, is_outer assert GemmPrimitive.outer_primitive is not None lhs_bdims, _, rhs_bdims, *_ = batch_dims - # Batched GEMM is not supported - assert ( - lhs_bdims is None and rhs_bdims is None - ), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})" - out_bdims = (None,) - - # Bias gradient is never batched - bias_bdims = (None,) - - # Pre-GeLU output, if exists, is batched like GEMM output - pre_gelu_bdims = (None,) - if fuse_gelu and not grad: - pre_gelu_bdims = out_bdims + # Validate batch dimensions + if lhs_bdims is not None or rhs_bdims is not None: + assert lhs_bdims == rhs_bdims, ( + "Batched GEMM requires matching batch dimensions, " + f"got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" + ) - return ( - GemmPrimitive.outer_primitive.bind( - *batched_args, - out_dtype=out_dtype, - contracting_dims=contracting_dims, - scaling_mode=scaling_mode, - fuse_bias=fuse_bias, - fuse_gelu=fuse_gelu, - grad=grad, - use_split_accumulator=use_split_accumulator, - collective_op=collective_op, - transpose_batch_sequence=transpose_batch_sequence, - sequence_dim=sequence_dim, - is_outer=is_outer, - ), - (out_bdims, bias_bdims, pre_gelu_bdims), + # Use general batcher from BasePrimitive + return GemmPrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "out_dtype": out_dtype, + "contracting_dims": contracting_dims, + "scaling_mode": scaling_mode, + "fuse_bias": fuse_bias, + "fuse_gelu": fuse_gelu, + "grad": grad, + "use_split_accumulator": use_split_accumulator, + "collective_op": collective_op, + "transpose_batch_sequence": transpose_batch_sequence, + "sequence_dim": sequence_dim, + "is_outer": is_outer, + }, ) @staticmethod diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index b3f24e9337..53c6937fb4 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -361,34 +361,24 @@ def batcher( stochastic_rounding, use_rht, ): - """ - to describe batch rules for vmap - """ - del is_outer + """Batch rule for quantization primitive using general batcher.""" check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None - x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args - x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims - out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim - return ( - BaseDBiasQuantizePrimitive.outer_primitive.bind( - x, - scale, - amax, - sr_rng_state, - post_rht_amax, - rht_matrix, - out_dtype=out_dtype, - scaling_mode=scaling_mode, - q_layout=q_layout, - flatten_axis=flatten_axis, - scale_dtype=scale_dtype, - is_dbias=is_dbias, - stochastic_rounding=stochastic_rounding, - use_rht=use_rht, - ), - out_bdims, + return BaseDBiasQuantizePrimitive.batcher_impl( + batched_args, + batch_dims, + static_kwargs={ + "out_dtype": out_dtype, + "scaling_mode": scaling_mode, + "q_layout": q_layout, + "flatten_axis": flatten_axis, + "scale_dtype": scale_dtype, + "is_dbias": is_dbias, + "is_outer": is_outer, + "stochastic_rounding": stochastic_rounding, + "use_rht": use_rht, + }, ) @staticmethod diff --git a/transformer_engine/jax/einsum.py b/transformer_engine/jax/einsum.py new file mode 100644 index 0000000000..20084c77ea --- /dev/null +++ b/transformer_engine/jax/einsum.py @@ -0,0 +1,424 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Einsum operation with FP8 quantization support for Transformer Engine in JAX. + +This module provides an einsum implementation that decomposes einsum operations into +a sequence of GEMMs, each with its own quantizer for FP8 support. It follows the +pattern of jax.numpy.einsum but uses TE's optimized GEMM operations. + +This module provides an einsum implementation optimized for Mixture-of-Experts (MoE) +models with per-expert quantization support. It leverages JAX's vmap and TE's dense +layer to efficiently handle tensor contractions with a single batch dimension. + +Key Features: + - **Per-expert quantization**: Each expert can have independent scaling and quantization parameters + - **Automatic differentiation**: Full gradient support via dense layer's VJP + - **Single batch dimension**: Optimized for MoE patterns (expert dimension) + - **Explicit API**: Requires quantizer_dim when using quantization + +Limitations: + - **NN layout only**: LHS last dim must contract, RHS last dim must not contract + - **Single batch dimension**: Only one batch dimension supported + - **2-operand only**: Only supports binary operations + - **Explicit quantizer_dim**: Required when quantizer_sets is provided + + For operations that don't meet these requirements (e.g., routing operations + like "BSM,BSEC->EBCM"), use jnp.einsum instead, or set fallback=True to + automatically fall back to jnp.einsum when the operation is not supported. + +Example - MoE Forward Pass with Per-Expert FP8: + ```python + from transformer_engine.jax.einsum import einsum + from transformer_engine.jax.quantize import QuantizerFactory, QuantizeMeta, QuantizeMetaSet + + # Create per-expert quantizers (E experts) + quantizer_sets = [ + QuantizerFactory.create_set( + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ) + ) for _ in range(num_experts) + ] + + # MoE pipeline with per-expert quantization, + # 1. Dispatch: BSM,BSEC -> EBCM (no quantization - routing operation) + dispatched = jnp.einsum("BSM,BSEC->EBCM", tokens, routing) + # Or with fallback: + # dispatched = einsum("BSM,BSEC->EBCM", tokens, routing, fallback=True) + + # 2. MLP Up: EBCM,EMH -> EBCH (per-expert quantization) + hidden = einsum("EBCM,EMH->EBCH", dispatched, expert_up_weights, + quantizer_sets=expert_quantizers, quantizer_dim='E') + + # 3. MLP Down: EBCH,EHM -> EBCM (per-expert quantization) + expert_out = einsum("EBCH,EHM->EBCM", hidden, expert_down_weights, + quantizer_sets=expert_quantizers, quantizer_dim='E') + + # 4. Combine: EBCM,BSEC -> BSM (no quantization - routing operation) + output = jnp.einsum("EBCM,BSEC->BSM", expert_out, routing) + # Or with fallback: + # output = einsum("EBCM,BSEC->BSM", expert_out, routing, fallback=True) + ``` + +Implementation Details: + The einsum function works by: + 1. Parsing the einsum equation to identify the single batch dimension and contracting dimensions + 2. Validating that quantizer_sets length matches the quantizer dimension size + 3. Creating a vmapped version of TE's dense layer over the batch dimension + 4. Vmapping over quantizer_sets to provide per-batch (e.g., per-expert) quantization + 5. Leveraging dense's existing VJP for automatic differentiation + + This design reuses TE's well-tested dense layer infrastructure while enabling + per-expert quantization for MoE models with minimal code complexity. +""" + +from typing import Tuple, Optional, List +import jax +import jax.numpy as jnp + +from .dense import dense +from .quantize import ( + QuantizerSet, + noop_quantizer_set, +) + + +def _parse_einsum_input(equation: str, *operands) -> Tuple[str, List[str], str]: + """Parse einsum equation into input specs and output spec. + + Args: + equation: Einsum equation string (e.g., "ij,jk->ik" or "BNSM,BNSEC->EBNCM") + operands: Input tensors + + Returns: + Tuple of (equation, input_specs, output_spec) + + Raises: + ValueError: If number of operands doesn't match equation + """ + # Remove spaces + equation = equation.replace(" ", "") + + if "->" in equation: + inputs_str, output_str = equation.split("->") + input_specs = inputs_str.split(",") + else: + # Implicit output mode + inputs_str = equation + input_specs = inputs_str.split(",") + # Compute implicit output + all_indices = set() + for spec in input_specs: + all_indices.update(spec) + output_str = "".join(sorted(all_indices)) + + # Validate each operand's ndim matches its spec + for i, (operand, spec) in enumerate(zip(operands, input_specs)): + expected_ndim = len(spec) + actual_ndim = operand.ndim + if actual_ndim != expected_ndim: + raise ValueError( + f"Operand {i} has {actual_ndim} dimensions but equation '{equation}' " + f"expects {expected_ndim} dimensions (spec: '{spec}'). " + f"Operand shape: {operand.shape}" + ) + + return equation, input_specs, output_str + + +def _find_contracting_and_batch_dims(lhs_spec: str, rhs_spec: str, output_spec: str): + """Find contracting and batch dimensions for a GEMM operation. + + Args: + lhs_spec: Index specification for LHS (e.g., "BNSM") + rhs_spec: Index specification for RHS (e.g., "BNSEC") + output_spec: Index specification for output (e.g., "EBNCM") + + Returns: + Tuple of (lhs_contracting, rhs_contracting, lhs_batch, rhs_batch) + """ + # Contracting dimensions: indices in both lhs and rhs but not in output + lhs_set = set(lhs_spec) + rhs_set = set(rhs_spec) + output_set = set(output_spec) + + contracting_indices = (lhs_set & rhs_set) - output_set + + # Batch dimensions: indices in lhs, rhs, and output + batch_indices = lhs_set & rhs_set & output_set + + # Find positions + lhs_contracting = tuple(i for i, c in enumerate(lhs_spec) if c in contracting_indices) + rhs_contracting = tuple(i for i, c in enumerate(rhs_spec) if c in contracting_indices) + lhs_batch = tuple(i for i, c in enumerate(lhs_spec) if c in batch_indices) + rhs_batch = tuple(i for i, c in enumerate(rhs_spec) if c in batch_indices) + + return lhs_contracting, rhs_contracting, lhs_batch, rhs_batch + + +def _einsum_to_gemm_info(equation: str, *operands): + """Extract GEMM information from einsum equation. + + Args: + equation: Einsum equation + operands: Input tensors + + Returns: + Dict with keys: lhs_idx, rhs_idx, contracting_dims, batch_dims, output_spec + """ + equation, input_specs, output_spec = _parse_einsum_input(equation, *operands) + + if len(input_specs) != 2: + raise NotImplementedError(f"Einsum with {len(input_specs)} operands not yet supported") + + lhs_spec, rhs_spec = input_specs + + lhs_contracting, rhs_contracting, lhs_batch, rhs_batch = _find_contracting_and_batch_dims( + lhs_spec, rhs_spec, output_spec + ) + + return { + "lhs_idx": 0, + "rhs_idx": 1, + "lhs_spec": lhs_spec, + "rhs_spec": rhs_spec, + "output_spec": output_spec, + "contracting_dims": (lhs_contracting, rhs_contracting), + "batch_dims": (lhs_batch, rhs_batch), + } + + +def einsum( + equation: str, + *operands: jnp.ndarray, + quantizer_sets: Optional[List[QuantizerSet]] = None, + quantizer_dim: Optional[str] = None, + operand_axes: Optional[List[Tuple[str, ...]]] = None, + output_axes: Optional[Tuple[str, ...]] = None, + fallback: bool = False, +) -> jnp.ndarray: + """Perform einsum operation with optional FP8 quantization using vmap + dense. + + This function implements einsum by: + 1. Identifying batch dimensions + 2. Using vmap to vectorize over batch dimensions + 3. Calling the existing dense() function which has VJP already implemented + + Each batched GEMM can have its own quantizer_set, enabling per-expert + quantization in MoE models. + + Args: + equation: Einsum equation string (e.g., "ij,jk->ik", "BSM,BSEC->EBCM") + *operands: Input tensors + quantizer_sets: List or tuple of QuantizerSets. Length must match the size of + the dimension specified by quantizer_dim. If None, creates noop quantizers. + quantizer_dim: Index label indicating which dimension the quantizers correspond to. + For MoE, this is typically 'E' (expert dimension). If None and + quantizer_sets is provided, assumes first batch dimension at position 0. + operand_axes: List of logical axes tuples for sharding each operand + output_axes: Logical axes for sharding the output + fallback: Whether to fallback to jnp.einsum if the einsum operation is not supported. + When fallback=True, unsupported operations (e.g., non-NN layouts, routing + operations) will use jnp.einsum. Note: quantization will NOT be applied + when falling back. + + Returns: + Result of the einsum operation + + Examples: + # Simple matrix multiplication with FP8 + result = einsum("ij,jk->ik", A, B, quantizer_sets=my_quantizer_set) + + # MoE with per-expert quantizers (E experts) + expert_quantizers = [quantizer_e0, quantizer_e1, ..., quantizer_eN] + result = einsum("EBNCM,EMH->EBNCH", tokens, weights, + quantizer_sets=expert_quantizers) + + # With fallback for routing operations + result = einsum("BSM,BSEC->EBCM", tokens, routing, fallback=True) + # Falls back to jnp.einsum (no quantization) + """ + if operand_axes is None: + operand_axes = [None] * len(operands) + + if len(operands) != 2: + if fallback: + import warnings + + warnings.warn( + f"TE einsum only supports 2-operand einsum, got {len(operands)} operands. " + "Falling back to jnp.einsum (no quantization will be applied).", + stacklevel=2, + ) + return jnp.einsum(equation, *operands) + raise NotImplementedError("Only 2-operand einsum currently supported") + + # Parse einsum to get GEMM info + gemm_info = _einsum_to_gemm_info(equation, *operands) + contracting_dims = gemm_info["contracting_dims"] + batch_dims = gemm_info["batch_dims"] + lhs_spec = gemm_info["lhs_spec"] + rhs_spec = gemm_info["rhs_spec"] + + lhs, rhs = operands + + # Validate quantizer_dim is provided when quantizer_sets is given + if quantizer_sets is not None and quantizer_dim is None: + raise ValueError( + "quantizer_dim must be specified when quantizer_sets is provided. " + "This explicitly indicates which dimension the quantizers correspond to." + ) + + # Find quantizer dimension + quantizer_dim_lhs = None + quantizer_dim_rhs = None + + if quantizer_dim is not None: + # Find position of quantizer_dim in lhs and rhs specs + if quantizer_dim in lhs_spec: + quantizer_dim_lhs = lhs_spec.index(quantizer_dim) + if quantizer_dim in rhs_spec: + quantizer_dim_rhs = rhs_spec.index(quantizer_dim) + + if quantizer_dim_lhs is None and quantizer_dim_rhs is None: + raise ValueError(f"quantizer_dim '{quantizer_dim}' not found in equation '{equation}'") + + # Check if we have batch dimensions + has_batch_dims = bool(batch_dims[0] or batch_dims[1]) + + # Determine expected quantizer_sets length based on quantizer_dim + if quantizer_dim is not None: + if quantizer_dim_lhs is not None: + expected_length = lhs.shape[quantizer_dim_lhs] + else: + expected_length = rhs.shape[quantizer_dim_rhs] + else: + # No quantizer_dim: determine from batch dimension + if has_batch_dims: + expected_length = lhs.shape[batch_dims[0][0]] + else: + expected_length = 1 + + # Validate and initialize quantizer_sets + if quantizer_sets is None: + quantizer_sets = [noop_quantizer_set] * expected_length + elif not isinstance(quantizer_sets, (list, tuple)): + raise TypeError(f"quantizer_sets must be a list or tuple, got {type(quantizer_sets)}") + elif len(quantizer_sets) != expected_length: + raise ValueError( + f"quantizer_sets length ({len(quantizer_sets)}) must match " + f"{'dimension ' + repr(quantizer_dim) if quantizer_dim else 'batch dimension'} " + f"size ({expected_length})" + ) + + # Validate that this is NN layout (required by dense) + # For NN: lhs last dim must contract, rhs last dim must NOT contract + lhs_ndim = len(gemm_info["lhs_spec"]) + rhs_ndim = len(gemm_info["rhs_spec"]) + lhs_last_contracts = lhs_ndim - 1 in contracting_dims[0] + rhs_last_contracts = rhs_ndim - 1 in contracting_dims[1] + + if not lhs_last_contracts or rhs_last_contracts: + if fallback: + import warnings + + if quantizer_sets is not None and quantizer_sets != [noop_quantizer_set] * len( + quantizer_sets + ): + warnings.warn( + f"TE einsum only supports NN layout. Equation '{equation}' is not NN layout. " + "Falling back to jnp.einsum. WARNING: Quantization will NOT be applied!", + stacklevel=2, + ) + return jnp.einsum(equation, *operands) + raise ValueError( + "TE einsum only supports NN layout (non-transposed matrix multiplication). Equation" + f" '{equation}' is not NN layout:\n - LHS '{gemm_info['lhs_spec']}': last dimension" + f" must contract (got contracting_dims={contracting_dims[0]})\n - RHS" + f" '{gemm_info['rhs_spec']}': last dimension must NOT contract (got" + f" contracting_dims={contracting_dims[1]})\nFor non-NN layouts (e.g., routing" + " operations), use jnp.einsum instead." + ) + + # Create vmapped dense function for batch dimensions + has_batch_dims = bool(batch_dims[0] or batch_dims[1]) + + if has_batch_dims: + # Validate single batch dimension (MoE use case) + if len(batch_dims[0]) != 1 or len(batch_dims[1]) != 1: + if fallback: + import warnings + + if quantizer_sets is not None and quantizer_sets != [noop_quantizer_set] * len( + quantizer_sets + ): + warnings.warn( + "TE einsum only supports single batch dimension. Got" + f" {len(batch_dims[0])} batch dims in lhs and {len(batch_dims[1])} in rhs." + " Falling back to jnp.einsum. WARNING: Quantization will NOT be applied!", + stacklevel=2, + ) + return jnp.einsum(equation, *operands) + raise NotImplementedError( + "Only single batch dimension is currently supported. " + f"Got {len(batch_dims[0])} batch dims in lhs and {len(batch_dims[1])} in rhs. " + f"Equation: '{equation}'" + ) + + lhs_batch_dim = batch_dims[0][0] + rhs_batch_dim = batch_dims[1][0] + + # Adjust contracting dims for the unbatched shapes seen by Python code + # (primitives will see batched shapes, but Python validation sees unbatched) + adj_lhs_contracting = tuple( + dim - (1 if dim > lhs_batch_dim else 0) for dim in contracting_dims[0] + ) + adj_rhs_contracting = tuple( + dim - (1 if dim > rhs_batch_dim else 0) for dim in contracting_dims[1] + ) + adj_contracting_dims = (adj_lhs_contracting, adj_rhs_contracting) + + # Stack quantizers into a pytree structure that vmap can handle + # QuantizerSet is already a pytree, so we can stack them + # For BF16 without quantizer_dim, this will be a stack of noop_quantizer_sets + stacked_quantizers = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *quantizer_sets) + + # Vmap over quantizers (or repeated noop quantizers for BF16) + def dense_with_quantizer(lhs_single, rhs_single, quantizer_set): + """Dense with explicit quantizer argument for vmapping.""" + return dense( + lhs_single, + rhs_single, + None, + contracting_dims=adj_contracting_dims, # Adjusted for unbatched shapes + transpose_batch_sequence=False, + input_axes=operand_axes[0], + kernel_axes=operand_axes[1], + output_axes=output_axes, + quantizer_set=quantizer_set, + ) + + vmapped_func = jax.vmap( + dense_with_quantizer, + in_axes=(lhs_batch_dim, rhs_batch_dim, 0), # vmap over stacked quantizers + out_axes=0, + ) + output = vmapped_func(lhs, rhs, stacked_quantizers) + else: + # No batch dimensions - direct dense call + # quantizer_set length already validated to be 1 + output = dense( + lhs, + rhs, + None, + contracting_dims=contracting_dims, + transpose_batch_sequence=False, + input_axes=operand_axes[0], + kernel_axes=operand_axes[1], + output_axes=output_axes, + quantizer_set=quantizer_sets[0], + ) + + return output diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 90f139c3da..120bd05c13 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -209,49 +209,63 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): flatten_axis: int has_rht_applied: bool - def __post_init__(self): - """Validates and adjusts the scale_inv shape after initialization. - - Ensures the scale_inv shape matches the expected shape based on the scaling mode - and quantization direction. Pads the scale_inv if necessary. - """ - assert self.flatten_axis > 0 - assert ( - 0 < self.flatten_axis < len(self.data.shape) - ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" - - if self.scaling_mode == ScalingMode.NO_SCALING: - self.scale_inv = jnp.empty((0,), dtype=jnp.float32) - else: - unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, - data_layout=self.data_layout, - is_colwise=self.is_colwise, - is_padded=False, - # expect the flatten_axis wrt the N layout - flatten_axis=( - self.flatten_axis - if self.data_layout == "N" - else self.data.ndim - self.flatten_axis - ), - ) - unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape( - self.data.shape, - data_layout=self.data_layout, - is_colwise=self.is_colwise, - is_padded=False, - # expect the flatten_axis wrt the N layout - flatten_axis=( - self.flatten_axis - if self.data_layout == "N" - else self.data.ndim - self.flatten_axis - ), - broadcast_2d_scale_shape_to_1d=True, - ) - assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), ( - f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or" - f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}." - ) + # def __post_init__(self): + # """Validates and adjusts the scale_inv shape after initialization. + # + # Ensures the scale_inv shape matches the expected shape based on the scaling mode + # and quantization direction. Pads the scale_inv if necessary. + # """ + # assert self.flatten_axis > 0 + # assert ( + # 0 < self.flatten_axis < len(self.data.shape) + # ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" + # + # if self.scaling_mode == ScalingMode.NO_SCALING: + # self.scale_inv = jnp.empty((0,), dtype=jnp.float32) + # else: + # unpadded_scale_shape = self.scaling_mode.get_scale_shape( + # self.data.shape, + # data_layout=self.data_layout, + # is_colwise=self.is_colwise, + # is_padded=False, + # # expect the flatten_axis wrt the N layout + # flatten_axis=( + # self.flatten_axis + # if self.data_layout == "N" + # else self.data.ndim - self.flatten_axis + # ), + # ) + # unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape( + # self.data.shape, + # data_layout=self.data_layout, + # is_colwise=self.is_colwise, + # is_padded=False, + # # expect the flatten_axis wrt the N layout + # flatten_axis=( + # self.flatten_axis + # if self.data_layout == "N" + # else self.data.ndim - self.flatten_axis + # ), + # broadcast_2d_scale_shape_to_1d=True, + # ) + # # Check shape, allowing for batch dimensions from vmap + # # If vmapped, shape will be (batch_size, *expected_shape) + # actual_shape = self.scale_inv.shape + # if actual_shape not in (unpadded_scale_shape, unpadded_scale_shape_broadcast): + # # Check if it's a batched version (extra leading dimensions) + # if len(actual_shape) > len(unpadded_scale_shape): + # # Batched: check that trailing dimensions match + # trailing_shape = actual_shape[-(len(unpadded_scale_shape)):] + # if trailing_shape not in (unpadded_scale_shape, unpadded_scale_shape_broadcast): + # raise AssertionError( + # f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or " + # f"{unpadded_scale_shape_broadcast} (possibly with batch dims) but got {self.scale_inv.shape}." + # ) + # else: + # raise AssertionError( + # f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or " + # f"{unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}." + # ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. @@ -431,10 +445,21 @@ def __post_init__(self): flatten_axis=self.flatten_axis, ) - assert self.scale_inv.shape == expected_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv, got {self.scale_inv.shape}" - ) + # Check shape, allowing for batch dimensions from vmap + actual_shape = self.scale_inv.shape + if actual_shape != expected_scale_shape: + # Check if it's a batched version + if len(actual_shape) > len(expected_scale_shape): + trailing_shape = actual_shape[-(len(expected_scale_shape)) :] + assert trailing_shape == expected_scale_shape, ( + f"Unexpected scale_inv shape! Expected {expected_scale_shape} for padded " + f"scale_inv (possibly with batch dims), got {self.scale_inv.shape}" + ) + else: + raise AssertionError( + f"Unexpected scale_inv shape! Expected {expected_scale_shape} for padded " + f"scale_inv, got {self.scale_inv.shape}" + ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. From 1f02cf41c7b521b82d99058e8f0fb6f2bd5b048e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Dec 2025 21:08:42 +0000 Subject: [PATCH 02/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_einsum.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_einsum.py b/tests/jax/test_einsum.py index 39dffa6787..7580a14638 100644 --- a/tests/jax/test_einsum.py +++ b/tests/jax/test_einsum.py @@ -145,7 +145,9 @@ def test_full_moe_grad(self, B, S, M, E, C, H, recipe): mlp_up_quantizer_sets = self._get_quantizer_sets(recipe, E) mlp_down_quantizer_sets = self._get_quantizer_sets(recipe, E) - tokens = jax.random.normal(jax.random.PRNGKey(0), (B, S, M), dtype=jnp.bfloat16) / jnp.sqrt(M) + tokens = jax.random.normal(jax.random.PRNGKey(0), (B, S, M), dtype=jnp.bfloat16) / jnp.sqrt( + M + ) routing = jax.random.normal(jax.random.PRNGKey(1), (B, S, E, C), dtype=jnp.bfloat16) routing = jax.nn.softmax(routing, axis=-1) # Normalize routing weights up_weights = jax.random.normal( From bf3ebc2ccf98a016ff61f859df7fa2686f36114d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 10 Dec 2025 15:29:37 +0100 Subject: [PATCH 03/40] code drop Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_grouped_gemm.cu | 511 ++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 484 +++++++++++++++++ .../common/include/transformer_engine/gemm.h | 36 ++ 4 files changed, 1032 insertions(+) create mode 100644 tests/cpp/operator/test_grouped_gemm.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index b2f14b1892..1392ffdadc 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -30,6 +30,7 @@ add_executable(test_operator test_causal_softmax.cu test_swizzle.cu test_swap_first_dims.cu + test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu new file mode 100644 index 0000000000..0e9c6c6a4d --- /dev/null +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -0,0 +1,511 @@ +/*********************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * + * See LICENSE for license information. + **********************************************************************/ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum class InputCase { + kFP8Delayed, + kFP8Current, + kBF16, +}; + +enum class ShapeCase { + kAllSame, + kSameFirst, + kSameLast, + kAllDifferent, +}; + +// Helper owning GPU buffers that back NVTEGroupedTensor. +// NVTEGroupedTensor does not own memory; data/offsets/scales +// must be allocated and freed by the test. +struct GroupedBuffers { + NVTEGroupedTensor handle{nullptr}; + void* data{nullptr}; + void* scale_inv{nullptr}; + int64_t* first_dims_dev{nullptr}; + int64_t* last_dims_dev{nullptr}; + int64_t* offsets_dev{nullptr}; + void* columnwise_data{nullptr}; + NVTEShape logical_shape{}; + std::vector offsets_host; + std::vector tensor_bytes; + size_t num_tensors{0}; + size_t elem_size{0}; + DType dtype{DType::kFloat32}; + NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING}; + + GroupedBuffers() = default; + GroupedBuffers(const GroupedBuffers&) = delete; + GroupedBuffers& operator=(const GroupedBuffers&) = delete; + GroupedBuffers(GroupedBuffers&& other) noexcept { + *this = std::move(other); + } + GroupedBuffers& operator=(GroupedBuffers&& other) noexcept { + if (this == &other) return *this; + handle = other.handle; + data = other.data; + scale_inv = other.scale_inv; + first_dims_dev = other.first_dims_dev; + last_dims_dev = other.last_dims_dev; + offsets_dev = other.offsets_dev; + logical_shape = other.logical_shape; + offsets_host = std::move(other.offsets_host); + tensor_bytes = std::move(other.tensor_bytes); + num_tensors = other.num_tensors; + elem_size = other.elem_size; + dtype = other.dtype; + scaling_mode = other.scaling_mode; + + other.handle = nullptr; + other.data = nullptr; + other.scale_inv = nullptr; + other.first_dims_dev = nullptr; + other.last_dims_dev = nullptr; + other.offsets_dev = nullptr; + other.num_tensors = 0; + return *this; + } + + ~GroupedBuffers() { + if (data) { + cudaFree(data); + data = nullptr; + } + if (scale_inv) { + cudaFree(scale_inv); + scale_inv = nullptr; + } + if (columnwise_data) { + cudaFree(columnwise_data); + columnwise_data = nullptr; + } + if (first_dims_dev) { + cudaFree(first_dims_dev); + first_dims_dev = nullptr; + } + if (last_dims_dev) { + cudaFree(last_dims_dev); + last_dims_dev = nullptr; + } + if (offsets_dev) { + cudaFree(offsets_dev); + offsets_dev = nullptr; + } + if (handle) { + nvte_destroy_grouped_tensor(handle); + handle = nullptr; + } + } +}; + +size_t grouped_setup_workspace_size(const size_t num_tensors) { + const size_t ptr_bytes = num_tensors * sizeof(void*); + const size_t int_bytes = num_tensors * sizeof(int); + size_t size = 4 * ptr_bytes + 3 * int_bytes + 2 * ptr_bytes; + const size_t alignment = 256; + size = ((size + alignment - 1) / alignment) * alignment; + return size; +} + +GroupedBuffers build_grouped_tensor(const std::vector& tensors, + const NVTEScalingMode scaling_mode) { + NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); + const NVTEShape shape = tensors[0]->rowwise_shape(); + const DType dtype = tensors[0]->dtype(); + const size_t num_tensors = tensors.size(); + const size_t elem_size = typeToSize(dtype); + GroupedBuffers grouped; + grouped.elem_size = elem_size; + grouped.num_tensors = num_tensors; + grouped.dtype = dtype; + grouped.scaling_mode = scaling_mode; + grouped.tensor_bytes.resize(num_tensors); + grouped.offsets_host.resize(num_tensors, 0); + + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + const auto s = tensors[i]->rowwise_shape(); + NVTE_CHECK(s.ndim == 2, "Grouped GEMM test expects 2D tensors."); + first_dims[i] = static_cast(s.data[0]); + last_dims[i] = static_cast(s.data[1]); + grouped.tensor_bytes[i] = bytes(s, dtype); + } + + const bool same_first = std::all_of(first_dims.begin(), first_dims.end(), + [&](int64_t v) { return v == first_dims[0]; }); + const bool same_last = std::all_of(last_dims.begin(), last_dims.end(), + [&](int64_t v) { return v == last_dims[0]; }); + + std::vector offsets(num_tensors, 0); + auto random_padding = [&]() -> int64_t { + static std::mt19937 gen(12345); + std::uniform_int_distribution dist(0, 3); + return dist(gen); + }; + + auto numel = [&](size_t idx) -> int64_t { + return first_dims[idx] * last_dims[idx]; + }; + + const bool need_offsets = !same_first || !same_last; + if (need_offsets) { + offsets[0] = 0; + for (size_t i = 1; i < num_tensors; ++i) { + offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding(); + } + } else { + for (size_t i = 0; i < num_tensors; ++i) { + offsets[i] = static_cast(i) * numel(0); + } + } + grouped.offsets_host = offsets; + + int64_t logical_first = 0; + int64_t logical_last = 0; + if (same_first && same_last) { + logical_first = first_dims[0] * static_cast(num_tensors); + logical_last = last_dims[0]; + } else if (same_first && !same_last) { + logical_first = first_dims[0]; + logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0}); + } else if (!same_first && same_last) { + logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0}); + logical_last = last_dims[0]; + } else { + logical_first = 1; + logical_last = 0; + for (size_t i = 0; i < num_tensors; ++i) { + logical_last += first_dims[i] * last_dims[i]; + } + } + size_t logical_data[2] = {static_cast(logical_first), + static_cast(logical_last)}; + grouped.logical_shape = nvte_make_shape(logical_data, 2); + grouped.handle = nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape); + + const int64_t last_idx = static_cast(num_tensors - 1); + const int64_t total_elems = need_offsets + ? (offsets[last_idx] + numel(last_idx)) + : (logical_first * logical_last); + const size_t total_bytes = static_cast(total_elems) * elem_size; + + NVTE_CHECK_CUDA(cudaMalloc(&grouped.data, total_bytes)); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data) + offset_bytes, + tensors[i]->rowwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + + NVTEBasicTensor data_tensor{grouped.data, static_cast(dtype), grouped.logical_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseData, &data_tensor); + + const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); + if (include_columnwise) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.columnwise_data, total_bytes)); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data) + offset_bytes, + tensors[i]->columnwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + NVTEBasicTensor col_tensor{grouped.columnwise_data, + static_cast(dtype), + grouped.logical_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseData, &col_tensor); + } + + if (!same_first) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.first_dims_dev, num_tensors * sizeof(int64_t))); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev, first_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor fd_tensor{grouped.first_dims_dev, kNVTEInt64, fd_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedFirstDims, &fd_tensor); + } + + if (!same_last) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.last_dims_dev, num_tensors * sizeof(int64_t))); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev, last_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor ld_tensor{grouped.last_dims_dev, kNVTEInt64, ld_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedLastDims, &ld_tensor); + } + + if (!same_first || !same_last) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.offsets_dev, num_tensors * sizeof(int64_t))); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev, offsets.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor off_tensor{grouped.offsets_dev, kNVTEInt64, off_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedTensorOffsets, &off_tensor); + } + + if (isFp8Type(dtype)) { + std::vector scale_inv_cpu(num_tensors, 1.f); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; + } + NVTE_CHECK_CUDA(cudaMalloc(&grouped.scale_inv, sizeof(float) * num_tensors)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv, scale_inv_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor scale_tensor{grouped.scale_inv, kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseScaleInv, &scale_tensor); + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); + } + + return grouped; +} + +Tensor make_fp8_operand(const std::string& name, const std::vector& shape) { + Tensor input_fp32(name + "_fp32", shape, DType::kFloat32); + fillUniform(&input_fp32); + + Tensor fp8(name, shape, TypeInfo::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING); + + nvte_compute_amax(input_fp32.data(), fp8.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(fp8.data(), config, 0); + nvte_quantize(input_fp32.data(), fp8.data(), 0); + return fp8; +} + +Tensor make_bf16_operand(const std::string& name, const std::vector& shape) { + Tensor t(name, shape, DType::kBFloat16); + fillUniform(&t); + return t; +} + +struct TestParams { + InputCase input_case; + bool transa; + bool transb; + ShapeCase shape_case; +}; + +std::vector> make_shapes(ShapeCase scase) { + switch (scase) { + case ShapeCase::kAllSame: + return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; + case ShapeCase::kSameFirst: // M wspólne, N/K zróżnicowane + return {{64, 64, 32}, {64, 96, 32}, {64, 80, 48}}; + case ShapeCase::kSameLast: // N wspólne, M/K zróżnicowane + return {{48, 80, 32}, {96, 80, 48}, {72, 80, 40}}; + case ShapeCase::kAllDifferent: + default: + return {{48, 80, 32}, {96, 64, 48}, {40, 72, 24}}; + } +} + +void run_grouped_gemm_case(const TestParams& params) { + if (params.input_case != InputCase::kBF16 && + getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP() << "FP8 grouped GEMM requires Hopper or newer."; + } + + const std::vector> shapes = make_shapes(params.shape_case); + + const size_t num_gemms = shapes.size(); + std::vector A_tensors; + std::vector B_tensors; + std::vector D_multi; + + A_tensors.reserve(num_gemms); + B_tensors.reserve(num_gemms); + D_multi.reserve(num_gemms); + + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + const std::vector a_shape = params.transa ? std::vector{K, M} + : std::vector{M, K}; + const std::vector b_shape = params.transb ? std::vector{N, K} + : std::vector{K, N}; + switch (params.input_case) { + case InputCase::kFP8Current: { + A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kBF16: { + A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + break; + } + } + D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + + std::vector A_ptrs(num_gemms); + std::vector B_ptrs(num_gemms); + std::vector D_ptrs(num_gemms); + std::vector bias_ptrs(num_gemms, nullptr); + std::vector gelu_ptrs(num_gemms, nullptr); + std::vector workspaces(num_gemms); + std::vector workspace_ptrs(num_gemms, nullptr); + + const size_t cublas_ws_bytes = 32ull * 1024 * 1024; + + for (size_t i = 0; i < num_gemms; ++i) { + A_ptrs[i] = A_tensors[i].data(); + B_ptrs[i] = B_tensors[i].data(); + D_ptrs[i] = D_multi[i].data(); + workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); + workspace_ptrs[i] = workspaces[i].data(); + } + + nvte_multi_tensor_gemm(A_ptrs.data(), + B_ptrs.data(), + D_ptrs.data(), + bias_ptrs.data(), + gelu_ptrs.data(), + static_cast(num_gemms), + params.transa, + params.transb, + false, + workspace_ptrs.data(), + false, + false, + 0, + 0); + + GroupedBuffers grouped_A = build_grouped_tensor(A_tensors, A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_tensors, B_tensors[0].scaling_mode()); + + std::vector C_tensors; + std::vector D_group_tensors; + C_tensors.reserve(num_gemms); + D_group_tensors.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + (void)K; + C_tensors.emplace_back(Tensor("C" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype()))); + } + + std::vector C_views, D_views; + for (size_t i = 0; i < num_gemms; ++i) { + C_views.push_back(&C_tensors[i]); + D_views.push_back(&D_group_tensors[i]); + } + + GroupedBuffers grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); + + Tensor alpha_tensor("alpha", std::vector{1}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{1}, DType::kFloat32); + const float alpha_val = 1.f; + const float beta_val = 0.f; + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), &alpha_val, sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), &beta_val, sizeof(float), cudaMemcpyHostToDevice)); + + const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); + Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); + + nvte_grouped_gemm(params.transa, + params.transb, + alpha_tensor.data(), + grouped_A.handle, + grouped_B.handle, + beta_tensor.data(), + grouped_C.handle, + grouped_D.handle, + setup_ws.data(), + cublas_ws.data(), + nullptr, + 0, + nullptr, + nullptr, + nullptr); + + for (size_t i = 0; i < num_gemms; ++i) { + Tensor grouped_split("grouped_D" + std::to_string(i), + std::vector{static_cast(std::get<0>(shapes[i])), + static_cast(std::get<1>(shapes[i]))}, + D_multi[i].dtype()); + const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), + static_cast(grouped_D.data) + offset_bytes, + grouped_D.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + grouped_split.to_cpu(); + D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(D_multi[i].dtype()); + compareResults("grouped_vs_multi", + grouped_split, + D_multi[i].rowwise_cpu_dptr(), + true, + atol, + rtol); + } +} + +class GroupedGemmTest : public ::testing::TestWithParam {}; + +TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { + run_grouped_gemm_case(GetParam()); +} + +std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { + constexpr const char* kInputNames[] = {"FP8Delayed", "FP8Current", "BF16"}; + constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; + const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + + "tb" + (info.param.transb ? "T" : "N"); + return std::string(kInputNames[static_cast(info.param.input_case)]) + "_" + + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout; +} + +const std::vector kTestParams = { + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame}, + {InputCase::kBF16, true, false, ShapeCase::kSameFirst}, + {InputCase::kBF16, false, true, ShapeCase::kSameLast}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame}, + {InputCase::kBF16, true, true, ShapeCase::kAllDifferent}, +}; + +INSTANTIATE_TEST_SUITE_P(OperatorTest, + GroupedGemmTest, + ::testing::ValuesIn(kTestParams), + MakeGroupedGemmTestName); + +} // namespace + + diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 97e8ec9a3e..53be59cc00 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1104,3 +1104,487 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cublas_path(); } } + + +// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) +struct TensorShapeInfo { + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr + + // Create from GroupedTensor + static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { + return { + t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, + t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, + t->get_common_first_dim(), + t->get_common_last_dim()}; + } + + // Create for C tensor (uses D's dimensions, only has offsets) + static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D) { + return { + nullptr, + nullptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, + D->get_common_first_dim(), + D->get_common_last_dim()}; + } +}; + +// Helper functions to compute average dimensions from logical_shape for heuristics +// These are hints for cuBLASLt algorithm selection, don't need to be exact +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor* t) { + // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) + // In both cases, dividing by num_tensors gives the average + return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); +} + +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor* t) { + if (t->all_same_last_dim()) { + // logical_shape[1] is the common N + return static_cast(t->logical_shape.data[1]); + } else { + // logical_shape[1] is sum_of_N, divide by num_tensors + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); + } +} + +// Workspace layout for grouped GEMM +struct GroupedGemmSetupWorkspace { + void **A_ptrs; + void **B_ptrs; + void **C_ptrs; + void **D_ptrs; + int *M; + int *N; + int *K; + float **alpha_ptrs; + float **beta_ptrs; + + // Initialize from workspace buffer + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, size_t alignment) { + GroupedGemmSetupWorkspace ws; + size_t offset = 0; + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.M = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; + ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; + ws.K = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + + offset = ((offset + alignment - 1) / alignment) * alignment; + + return ws; + } + + // Calculate required size for setup workspace (pointer arrays + M/N/K + alpha/beta ptrs) + static size_t required_setup_size(size_t num_tensors, size_t alignment) { + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + size_t size = 4 * ptr_size + 3 * int_size + 2 * ptr_size; // M, N, K only (no LDA/LDB/LDC/LDD) + size = ((size + alignment - 1) / alignment) * alignment; + return size; + } +}; + +// ----------------------------------------------------------------------------- +// Helper routines to keep nvte_grouped_gemm readable +// ----------------------------------------------------------------------------- +inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor* inputA, + const transformer_engine::GroupedTensor* inputB, + const transformer_engine::GroupedTensor* inputC, + const transformer_engine::GroupedTensor* outputD) { + const size_t num_tensors = inputA->num_tensors; + NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); + NVTE_CHECK(inputB->num_tensors == num_tensors, + "Grouped GEMM: A and B must have the same num_tensors"); + NVTE_CHECK(inputC->num_tensors == num_tensors, + "Grouped GEMM: A and C must have the same num_tensors"); + NVTE_CHECK(outputD->num_tensors == num_tensors, + "Grouped GEMM: A and D must have the same num_tensors"); + + auto is_fp8_or_16bit = [](DType dtype) { + return dtype == DType::kFloat8E4M3 || dtype == DType::kFloat8E5M2 || + dtype == DType::kBFloat16 || dtype == DType::kFloat16; + }; + auto is_output_dtype = [](DType dtype) { + return dtype == DType::kBFloat16 || dtype == DType::kFloat16 || dtype == DType::kFloat32; + }; + NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), + "Grouped GEMM inputs must be FP8, BF16, or FP16."); + NVTE_CHECK(is_output_dtype(inputC->dtype()) && is_output_dtype(outputD->dtype()), + "Grouped GEMM outputs must be BF16, FP16, or FP32."); + NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), + "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); + NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), + "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); +} + +// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. +// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and +// fallback to column-wise data when row-wise is absent. +struct GroupedOperandSelection { + const char* base = nullptr; + transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; + bool trans = false; + bool use_columnwise = false; +}; + +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor* t, + bool trans, bool is_A) { + using namespace transformer_engine; + const bool has_row = t->has_data(); + const bool has_col = t->has_columnwise_data(); + NVTE_CHECK(has_row || has_col, "Grouped GEMM operand is missing both row-wise and column-wise data"); + + // Not yet supported in grouped GEMM: block scaling, MXFP8, NVFP4 specialized layouts. + const auto sm = t->scaling_mode; + NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && + !is_mxfp_scaling(sm) && !is_nvfp_scaling(sm), + "Grouped GEMM does not yet support NVFP4/MXFP8/block scaling operand selection"); + + const DType row_dtype = t->data.dtype; + const DType col_dtype = t->columnwise_data.dtype; + GroupedOperandSelection sel; + sel.trans = trans; + + const DType rep_dtype = has_row ? row_dtype : col_dtype; + const bool is_fp8 = is_fp8_dtype(rep_dtype); + const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + + // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. + if (is_fp8 && !non_tn_fp8_ok) { + if (is_A) { + if (!sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); + sel.base = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = true; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } else { // B + if (sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); + sel.base = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = false; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } + } + + // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). + if (!has_row && has_col) { + sel.base = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = !sel.trans; + sel.use_columnwise = true; + return sel; + } + + // Default: use row-wise data (or column-wise if row-wise absent, covered above). + sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); + sel.dtype = has_row ? row_dtype : col_dtype; + sel.use_columnwise = !has_row && has_col; + return sel; +} + +inline void* validate_and_get_workspace_ptr(transformer_engine::Tensor* ws, size_t required_size, + const char* workspace_name) { + NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); + const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); + NVTE_CHECK(provided_size >= required_size, + "Grouped GEMM: Insufficient ", workspace_name, ". Required: ", required_size, + " bytes, Available: ", provided_size, " bytes."); + return ws->data.dptr; +} + +inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t& descA, + cublasLtMatrixLayoutOpaque_t& descB, + cublasLtMatrixLayoutOpaque_t& descC, + cublasLtMatrixLayoutOpaque_t& descD, + const GroupedGemmWorkspace& ws, bool transa, bool transb, + bool a_columnwise, bool b_columnwise, + size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, + cudaDataType_t D_type) { + // For column-major layout: leading dimension is the number of rows in storage. + // If columnwise data was chosen, storage is already transposed. + const int* rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); + const int* cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + const int* lda = rowa; + const int* rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); + const int* colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + const int* ldb = rowb; + + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void*)rowa, (void*)cola, (void*)lda)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void*)rowb, (void*)colb, (void*)ldb)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); +} + +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t& matmulDesc, cublasOperation_t op_A, + cublasOperation_t op_B) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + NVTE_CHECK_CUBLAS( + cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(op_A))); + NVTE_CHECK_CUBLAS( + cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(op_B))); + + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode))); + + int64_t alphabeta_batch_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); +} + +inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, + cublasLtMatmulDescOpaque_t& matmulDesc, + cublasLtMatrixLayoutOpaque_t& descA, + cublasLtMatrixLayoutOpaque_t& descB, + cublasLtMatrixLayoutOpaque_t& descC, + cublasLtMatrixLayoutOpaque_t& descD, int64_t avg_m, + int64_t avg_n, int64_t avg_k) { + cublasLtMatmulPreferenceOpaque_t preference; + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &kGroupedGemmCublasWorkspaceSize, + sizeof(size_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); + + cublasLtMatmulHeuristicResult_t heuristicResult; + int returnedResults = 0; + auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, + &preference, 1, &heuristicResult, &returnedResults); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK_CUBLAS(status); + NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); + return heuristicResult.algo; +} + +// Single kernel that sets up all GEMM parameters. +// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix M/N/K, +// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. +// We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. +__global__ void setup_grouped_gemm_kernel( + // Output arrays + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, + int *M, int *N, int *K, + float **alpha_ptrs, float **beta_ptrs, + // Base pointers + const char *a_base, const char *b_base, const char *c_base, char *d_base, + // Dimension info (per tensor) + TensorShapeInfo A_meta, TensorShapeInfo B_meta, + TensorShapeInfo C_meta, TensorShapeInfo D_meta, + // Element sizes + size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, + // Alpha/beta pointers (same for all groups) + float *alpha_ptr, float *beta_ptr, + // Transpose flags + bool transa, bool transb, + // Number of tensors + size_t num_tensors) { + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_tensors) return; + + // Get dimensions for this tensor (from array or uniform value) + int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; + int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; + int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; + int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; + + // Compute offsets (from array or compute from uniform dims) + int64_t a_offset = A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); + int64_t b_offset = B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); + int64_t c_offset = C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); + int64_t d_offset = D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + + // Compute data pointers + A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; + B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; + C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; + D_ptrs[idx] = d_base + d_offset * d_elem_size; + + // Compute M, N, K dimensions + M[idx] = static_cast(transa ? a_last : a_first); + K[idx] = static_cast(transa ? a_first : a_last); + N[idx] = static_cast(transb ? b_first : b_last); + + // Fill alpha/beta pointers (same for all groups) + alpha_ptrs[idx] = alpha_ptr; + beta_ptrs[idx] = beta_ptr; +} + +// Launch the setup kernel to populate workspace arrays +inline void launch_grouped_gemm_setup( + const GroupedGemmWorkspace &ws, + const transformer_engine::GroupedTensor *A, + const transformer_engine::GroupedTensor *B, + const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, + const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, + const char *a_base, const char *b_base, + size_t a_elem_size, size_t b_elem_size, + bool transa, bool transb, + size_t num_tensors, cudaStream_t stream) { + + TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A); + TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B); + TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); + TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); + + const char *c_base = static_cast(C->data.dptr); + char *d_base = static_cast(D->data.dptr); + + const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); + const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); + + const int threads_per_block = 256; + const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; + + setup_grouped_gemm_kernel<<>>( + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, + ws.M, ws.N, ws.K, + ws.alpha_ptrs, ws.beta_ptrs, + a_base, b_base, c_base, d_base, + A_meta, B_meta, C_meta, D_meta, + a_elem_size, b_elem_size, c_elem_size, d_elem_size, + static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), + transa, transb, num_tensors); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Constants for grouped GEMM workspace +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + +inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { + return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); +} + +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, + const NVTEGroupedTensor A, const NVTEGroupedTensor B, + const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, + const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k) { + NVTE_API_CALL(nvte_grouped_gemm); + using namespace transformer_engine; + + // Convert to internal types + const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); + const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); + const GroupedTensor *inputC = convertNVTEGroupedTensorCheck(C); + GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); + const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); + const Tensor *beta_tensor = convertNVTETensorCheck(beta); + Tensor *wspace_setup = convertNVTETensor(workspace_setup); + Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + + // Validate inputs and num_tensors + validate_grouped_gemm_inputs(inputA, inputB, inputC, outputD); + const size_t num_tensors = inputA->num_tensors; + + // Select operand storage (row-wise vs column-wise) and adjust transpose flags to + // mirror the non-grouped GEMM logic for FP8 layout constraints. + bool transa_flag = static_cast(transa); + bool transb_flag = static_cast(transb); + const auto A_sel = select_grouped_operand(inputA, transa_flag, /*is_A=*/true); + const auto B_sel = select_grouped_operand(inputB, transb_flag, /*is_A=*/false); + transa_flag = A_sel.trans; + transb_flag = B_sel.trans; + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + + // Workspaces: setup (pointer arrays) and cuBLAS + const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); + const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; + + void* setup_workspace_ptr = + validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, "Grouped GEMM setup workspace"); + void* cublas_workspace_ptr = + validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, "Grouped GEMM cuBLAS workspace"); + + NVTE_CHECK(cublas_workspace_ptr != nullptr, "Grouped GEMM: cuBLAS workspace pointer is null"); + + auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( + static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); + launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, + alpha_tensor, beta_tensor, + A_sel.base, B_sel.base, a_elem_size, b_elem_size, + transa_flag, transb_flag, + num_tensors, stream); + + // Get cuBLAS handle + using cublasHandleManager = detail::HandleManager; + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); + + // Get data types + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(outputD->dtype()); + + // Setup cuBLAS operations + cublasOperation_t op_A = transa_flag ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = transb_flag ? CUBLAS_OP_T : CUBLAS_OP_N; + + // Create grouped matrix layouts + cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, + transa_flag, transb_flag, A_sel.use_columnwise, B_sel.use_columnwise, + num_tensors, A_type, B_type, D_type); + + // Create matmul descriptor + cublasLtMatmulDescOpaque_t matmulDesc; + init_matmul_desc(matmulDesc, op_A, op_B); + + // Compute average dimensions for heuristics + // K dimension: if transa, K is A's first dim; if not, K is A's last dim + int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); + int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); + int64_t avg_k_val = + avg_k ? *avg_k : (transa_flag ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); + + // Heuristic selection + cublasLtMatmulAlgo_t algo = + select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, descD, avg_m_val, avg_n_val, + avg_k_val); + + // Execute the grouped GEMM + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, + &descC, setup_workspace.D_ptrs, &descD, + &algo, cublas_workspace_ptr, + kGroupedGemmCublasWorkspaceSize, stream)); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 950014cc9b..51241aef6b 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -228,6 +228,42 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C + * + * Performs batched GEMM on a collection of matrices with potentially different shapes. + * All tensors in the group must have compatible dimensions for matrix multiplication. + * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous + * memory layout and shape metadata. + * + * \param[in] transa Whether to transpose A matrices. + * \param[in] transb Whether to transpose B matrices. + * \param[in] alpha Scale multiplier for A @ B (NVTETensor with num_tensors elements, + * or single element for uniform alpha). + * \param[in] A Input grouped tensor A. + * \param[in] B Input grouped tensor B. + * \param[in] beta Scale multiplier for C (NVTETensor with num_tensors elements, + * or single element for uniform beta). + * \param[in] C Input grouped tensor C (can be NULL for beta=0). + * \param[out] D Output grouped tensor D. + * \param[in] workspace Workspace tensor for intermediate computations. + * \param[in] config Matrix multiplication configuration. + * \param[in] stream CUDA stream for the operation. + * + * Requirements: + * - A, B, C (if provided), D must have the same num_tensors + * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] + * - Shape compatibility: if transa=false, transb=false: + * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) + */ +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, + const NVTEGroupedTensor A, const NVTEGroupedTensor B, + const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, + const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus From 76293d4dc9ebb8a7e1c7ba2ae47f866d56998d33 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 14:32:15 +0000 Subject: [PATCH 04/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/cpp/operator/test_grouped_gemm.cu | 2 - .../common/gemm/cublaslt_gemm.cu | 279 +++++++++--------- .../common/include/transformer_engine/gemm.h | 11 +- 3 files changed, 141 insertions(+), 151 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 0e9c6c6a4d..d346e06887 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -507,5 +507,3 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, MakeGroupedGemmTestName); } // namespace - - diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 53be59cc00..2c8c2093c6 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1105,46 +1105,42 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor } } - // Helper struct to pass per-tensor shape/offset info (pointer or uniform value) struct TensorShapeInfo { - const int64_t *first_dims; // nullptr if uniform - const int64_t *last_dims; // nullptr if uniform - const int64_t *offsets; // nullptr if need to compute - int64_t uniform_first; // used if first_dims == nullptr - int64_t uniform_last; // used if last_dims == nullptr + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr // Create from GroupedTensor static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { - return { - t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, - t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, - t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, - t->get_common_first_dim(), - t->get_common_last_dim()}; + return {t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, + t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) + : nullptr, + t->get_common_first_dim(), t->get_common_last_dim()}; } // Create for C tensor (uses D's dimensions, only has offsets) static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D) { - return { - nullptr, - nullptr, - C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, - D->get_common_first_dim(), - D->get_common_last_dim()}; + return {nullptr, nullptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) + : nullptr, + D->get_common_first_dim(), D->get_common_last_dim()}; } }; // Helper functions to compute average dimensions from logical_shape for heuristics // These are hints for cuBLASLt algorithm selection, don't need to be exact -inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor* t) { +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) // In both cases, dividing by num_tensors gives the average return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); } -inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor* t) { +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { if (t->all_same_last_dim()) { // logical_shape[1] is the common N return static_cast(t->logical_shape.data[1]); @@ -1167,21 +1163,31 @@ struct GroupedGemmSetupWorkspace { float **beta_ptrs; // Initialize from workspace buffer - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, size_t alignment) { + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, + size_t alignment) { GroupedGemmSetupWorkspace ws; size_t offset = 0; const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.M = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.K = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.M = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.N = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.K = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; offset = ((offset + alignment - 1) / alignment) * alignment; @@ -1201,10 +1207,10 @@ struct GroupedGemmSetupWorkspace { // ----------------------------------------------------------------------------- // Helper routines to keep nvte_grouped_gemm readable // ----------------------------------------------------------------------------- -inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor* inputA, - const transformer_engine::GroupedTensor* inputB, - const transformer_engine::GroupedTensor* inputC, - const transformer_engine::GroupedTensor* outputD) { +inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, + const transformer_engine::GroupedTensor *inputB, + const transformer_engine::GroupedTensor *inputC, + const transformer_engine::GroupedTensor *outputD) { const size_t num_tensors = inputA->num_tensors; NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); NVTE_CHECK(inputB->num_tensors == num_tensors, @@ -1235,23 +1241,24 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and // fallback to column-wise data when row-wise is absent. struct GroupedOperandSelection { - const char* base = nullptr; + const char *base = nullptr; transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; bool trans = false; bool use_columnwise = false; }; -inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor* t, +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, bool trans, bool is_A) { using namespace transformer_engine; const bool has_row = t->has_data(); const bool has_col = t->has_columnwise_data(); - NVTE_CHECK(has_row || has_col, "Grouped GEMM operand is missing both row-wise and column-wise data"); + NVTE_CHECK(has_row || has_col, + "Grouped GEMM operand is missing both row-wise and column-wise data"); // Not yet supported in grouped GEMM: block scaling, MXFP8, NVFP4 specialized layouts. const auto sm = t->scaling_mode; - NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && - !is_mxfp_scaling(sm) && !is_nvfp_scaling(sm), + NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && !is_mxfp_scaling(sm) && + !is_nvfp_scaling(sm), "Grouped GEMM does not yet support NVFP4/MXFP8/block scaling operand selection"); const DType row_dtype = t->data.dtype; @@ -1268,7 +1275,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: if (is_A) { if (!sel.trans) { NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); + sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = true; // using pre-transposed storage sel.use_columnwise = true; @@ -1277,7 +1284,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: } else { // B if (sel.trans) { NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); + sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = false; // using pre-transposed storage sel.use_columnwise = true; @@ -1288,7 +1295,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). if (!has_row && has_col) { - sel.base = static_cast(t->columnwise_data.dptr); + sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = !sel.trans; sel.use_columnwise = true; @@ -1296,81 +1303,81 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: } // Default: use row-wise data (or column-wise if row-wise absent, covered above). - sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); + sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); sel.dtype = has_row ? row_dtype : col_dtype; - sel.use_columnwise = !has_row && has_col; + sel.use_columnwise = !has_row && has_col; return sel; } -inline void* validate_and_get_workspace_ptr(transformer_engine::Tensor* ws, size_t required_size, - const char* workspace_name) { +inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, + const char *workspace_name) { NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); - NVTE_CHECK(provided_size >= required_size, - "Grouped GEMM: Insufficient ", workspace_name, ". Required: ", required_size, - " bytes, Available: ", provided_size, " bytes."); + NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, + ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); return ws->data.dptr; } -inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t& descA, - cublasLtMatrixLayoutOpaque_t& descB, - cublasLtMatrixLayoutOpaque_t& descC, - cublasLtMatrixLayoutOpaque_t& descD, - const GroupedGemmWorkspace& ws, bool transa, bool transb, - bool a_columnwise, bool b_columnwise, +inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, const GroupedGemmWorkspace &ws, + bool transa, bool transb, bool a_columnwise, bool b_columnwise, size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, cudaDataType_t D_type) { // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. - const int* rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); - const int* cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); - const int* lda = rowa; - const int* rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); - const int* colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); - const int* ldb = rowb; - - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void*)rowa, (void*)cola, (void*)lda)); - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void*)rowb, (void*)colb, (void*)ldb)); - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); + const int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); + const int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + const int *lda = rowa; + const int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); + const int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + const int *ldb = rowb; + + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void *)rowa, + (void *)cola, (void *)lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void *)rowb, + (void *)colb, (void *)ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void *)ws.M, + (void *)ws.N, (void *)ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void *)ws.M, + (void *)ws.N, (void *)ws.M)); } -inline void init_matmul_desc(cublasLtMatmulDescOpaque_t& matmulDesc, cublasOperation_t op_A, +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, cublasOperation_t op_B) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); - NVTE_CHECK_CUBLAS( - cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(op_A))); - NVTE_CHECK_CUBLAS( - cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(op_B))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, + sizeof(op_A))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, + sizeof(op_B))); cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); int64_t alphabeta_batch_stride = 1; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, &alphabeta_batch_stride, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, &alphabeta_batch_stride, sizeof(int64_t))); } inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, - cublasLtMatmulDescOpaque_t& matmulDesc, - cublasLtMatrixLayoutOpaque_t& descA, - cublasLtMatrixLayoutOpaque_t& descB, - cublasLtMatrixLayoutOpaque_t& descC, - cublasLtMatrixLayoutOpaque_t& descD, int64_t avg_m, - int64_t avg_n, int64_t avg_k) { + cublasLtMatmulDescOpaque_t &matmulDesc, + cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + int64_t avg_m, int64_t avg_n, int64_t avg_k) { cublasLtMatmulPreferenceOpaque_t preference; NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &kGroupedGemmCublasWorkspaceSize, - sizeof(size_t))); + NVTE_CHECK_CUBLAS( + cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -1382,7 +1389,8 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, int returnedResults = 0; auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, &preference, 1, &heuristicResult, &returnedResults); - NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, + "Unable to find suitable cuBLAS grouped GEMM algorithm"); NVTE_CHECK_CUBLAS(status); NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); return heuristicResult.algo; @@ -1394,14 +1402,12 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, // We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. __global__ void setup_grouped_gemm_kernel( // Output arrays - void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, - int *M, int *N, int *K, + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, float **alpha_ptrs, float **beta_ptrs, // Base pointers const char *a_base, const char *b_base, const char *c_base, char *d_base, // Dimension info (per tensor) - TensorShapeInfo A_meta, TensorShapeInfo B_meta, - TensorShapeInfo C_meta, TensorShapeInfo D_meta, + TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, // Element sizes size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, // Alpha/beta pointers (same for all groups) @@ -1410,7 +1416,6 @@ __global__ void setup_grouped_gemm_kernel( bool transa, bool transb, // Number of tensors size_t num_tensors) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_tensors) return; @@ -1421,10 +1426,14 @@ __global__ void setup_grouped_gemm_kernel( int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; // Compute offsets (from array or compute from uniform dims) - int64_t a_offset = A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); - int64_t b_offset = B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); - int64_t c_offset = C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); - int64_t d_offset = D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + int64_t a_offset = + A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); + int64_t b_offset = + B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); + int64_t c_offset = + C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); + int64_t d_offset = + D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); // Compute data pointers A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; @@ -1444,18 +1453,12 @@ __global__ void setup_grouped_gemm_kernel( // Launch the setup kernel to populate workspace arrays inline void launch_grouped_gemm_setup( - const GroupedGemmWorkspace &ws, - const transformer_engine::GroupedTensor *A, - const transformer_engine::GroupedTensor *B, - const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D, - const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, - const char *a_base, const char *b_base, - size_t a_elem_size, size_t b_elem_size, - bool transa, bool transb, - size_t num_tensors, cudaStream_t stream) { - + const GroupedGemmWorkspace &ws, const transformer_engine::GroupedTensor *A, + const transformer_engine::GroupedTensor *B, const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, const char *a_base, const char *b_base, + size_t a_elem_size, size_t b_elem_size, bool transa, bool transb, size_t num_tensors, + cudaStream_t stream) { TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A); TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B); TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); @@ -1471,15 +1474,10 @@ inline void launch_grouped_gemm_setup( const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; setup_grouped_gemm_kernel<<>>( - ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, - ws.M, ws.N, ws.K, - ws.alpha_ptrs, ws.beta_ptrs, - a_base, b_base, c_base, d_base, - A_meta, B_meta, C_meta, D_meta, - a_elem_size, b_elem_size, c_elem_size, d_elem_size, - static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), - transa, transb, num_tensors); + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, + a_base, b_base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, + c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), transa, transb, num_tensors); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -1492,12 +1490,11 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); } -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, - const NVTEGroupedTensor A, const NVTEGroupedTensor B, - const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, - NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, - const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k) { +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, + const int64_t *avg_n, const int64_t *avg_k) { NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; @@ -1530,20 +1527,18 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; - void* setup_workspace_ptr = - validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, "Grouped GEMM setup workspace"); - void* cublas_workspace_ptr = - validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, "Grouped GEMM cuBLAS workspace"); + void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, + "Grouped GEMM setup workspace"); + void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, + "Grouped GEMM cuBLAS workspace"); NVTE_CHECK(cublas_workspace_ptr != nullptr, "Grouped GEMM: cuBLAS workspace pointer is null"); auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( - static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); - launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, - alpha_tensor, beta_tensor, - A_sel.base, B_sel.base, a_elem_size, b_elem_size, - transa_flag, transb_flag, - num_tensors, stream); + static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); + launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, alpha_tensor, + beta_tensor, A_sel.base, B_sel.base, a_elem_size, b_elem_size, + transa_flag, transb_flag, num_tensors, stream); // Get cuBLAS handle using cublasHandleManager = detail::HandleManager; @@ -1560,9 +1555,9 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, // Create grouped matrix layouts cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; - init_matrix_layouts(descA, descB, descC, descD, setup_workspace, - transa_flag, transb_flag, A_sel.use_columnwise, B_sel.use_columnwise, - num_tensors, A_type, B_type, D_type); + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, transa_flag, transb_flag, + A_sel.use_columnwise, B_sel.use_columnwise, num_tensors, A_type, B_type, + D_type); // Create matmul descriptor cublasLtMatmulDescOpaque_t matmulDesc; @@ -1576,15 +1571,13 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, avg_k ? *avg_k : (transa_flag ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); // Heuristic selection - cublasLtMatmulAlgo_t algo = - select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, descD, avg_m_val, avg_n_val, - avg_k_val); + cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, + descD, avg_m_val, avg_n_val, avg_k_val); // Execute the grouped GEMM NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, - setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, - setup_workspace.beta_ptrs, setup_workspace.C_ptrs, - &descC, setup_workspace.D_ptrs, &descD, - &algo, cublas_workspace_ptr, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, + setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, kGroupedGemmCublasWorkspaceSize, stream)); } diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 51241aef6b..948058295e 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -257,12 +257,11 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * - Shape compatibility: if transa=false, transb=false: * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) */ -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, - const NVTEGroupedTensor A, const NVTEGroupedTensor B, - const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, - NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, - const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k); +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, + const int64_t *avg_n, const int64_t *avg_k); #ifdef __cplusplus } // extern "C" From 296d77362099c52fa8e19a299f4a4134dc184096 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 10 Dec 2025 18:25:39 +0100 Subject: [PATCH 05/40] Add FP8 scale support and fix alignment for grouped GEMM - Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 55 +++++--- .../common/gemm/cublaslt_gemm.cu | 119 +++++++++++++----- .../common/include/transformer_engine/gemm.h | 2 + 3 files changed, 131 insertions(+), 45 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index d346e06887..bff175f405 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -1,8 +1,8 @@ -/*********************************************************************** - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. - **********************************************************************/ + ************************************************************************/ #include #include @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include "../test_common.h" @@ -136,7 +138,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, const NVTEShape shape = tensors[0]->rowwise_shape(); const DType dtype = tensors[0]->dtype(); const size_t num_tensors = tensors.size(); - const size_t elem_size = typeToSize(dtype); + const size_t elem_size = typeToNumBits(dtype) / 8; GroupedBuffers grouped; grouped.elem_size = elem_size; grouped.num_tensors = num_tensors; @@ -162,9 +164,13 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, std::vector offsets(num_tensors, 0); auto random_padding = [&]() -> int64_t { + // Random padding ensuring 16-byte alignment regardless of element size + // cuBLAS requires aligned pointers for vectorized loads static std::mt19937 gen(12345); std::uniform_int_distribution dist(0, 3); - return dist(gen); + // Calculate elements needed for 16-byte alignment + const size_t align_elements = (16 * 8) / typeToNumBits(dtype); // 16 bytes / element_size + return dist(gen) * static_cast(align_elements); }; auto numel = [&](size_t idx) -> int64_t { @@ -301,7 +307,12 @@ Tensor make_fp8_operand(const std::string& name, const std::vector& shap Tensor make_bf16_operand(const std::string& name, const std::vector& shape) { Tensor t(name, shape, DType::kBFloat16); - fillUniform(&t); + // Fill with ones for easier debugging + //fillUniform(&t); + const size_t numel = shape[0] * shape[1]; + std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f)); + NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(), + numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice)); return t; } @@ -312,17 +323,21 @@ struct TestParams { ShapeCase shape_case; }; +// Returns a vector of (M, N, K) tuples for each GEMM in the group. +// M - number of rows in output D +// N - number of columns in output D +// K - reduction dimension shared between A and B std::vector> make_shapes(ShapeCase scase) { switch (scase) { case ShapeCase::kAllSame: return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; - case ShapeCase::kSameFirst: // M wspólne, N/K zróżnicowane - return {{64, 64, 32}, {64, 96, 32}, {64, 80, 48}}; - case ShapeCase::kSameLast: // N wspólne, M/K zróżnicowane - return {{48, 80, 32}, {96, 80, 48}, {72, 80, 40}}; + case ShapeCase::kSameFirst: + return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; + case ShapeCase::kSameLast: + return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; case ShapeCase::kAllDifferent: default: - return {{48, 80, 32}, {96, 64, 48}, {40, 72, 24}}; + return {{64, 96, 32}, {64, 96, 48}, {64, 96, 64}}; } } @@ -345,10 +360,10 @@ void run_grouped_gemm_case(const TestParams& params) { for (size_t i = 0; i < num_gemms; ++i) { const auto [M, N, K] = shapes[i]; - const std::vector a_shape = params.transa ? std::vector{K, M} - : std::vector{M, K}; - const std::vector b_shape = params.transb ? std::vector{N, K} - : std::vector{K, N}; + const std::vector a_shape = params.transa ? std::vector{M, K} + : std::vector{K, M}; + const std::vector b_shape = params.transb ? std::vector{K, N} + : std::vector{N, K}; switch (params.input_case) { case InputCase::kFP8Current: { A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); @@ -373,6 +388,10 @@ void run_grouped_gemm_case(const TestParams& params) { std::vector gelu_ptrs(num_gemms, nullptr); std::vector workspaces(num_gemms); std::vector workspace_ptrs(num_gemms, nullptr); + std::vector A_views; + std::vector B_views; + A_views.reserve(num_gemms); + B_views.reserve(num_gemms); const size_t cublas_ws_bytes = 32ull * 1024 * 1024; @@ -382,6 +401,8 @@ void run_grouped_gemm_case(const TestParams& params) { D_ptrs[i] = D_multi[i].data(); workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); workspace_ptrs[i] = workspaces[i].data(); + A_views.push_back(&A_tensors[i]); + B_views.push_back(&B_tensors[i]); } nvte_multi_tensor_gemm(A_ptrs.data(), @@ -399,8 +420,8 @@ void run_grouped_gemm_case(const TestParams& params) { 0, 0); - GroupedBuffers grouped_A = build_grouped_tensor(A_tensors, A_tensors[0].scaling_mode()); - GroupedBuffers grouped_B = build_grouped_tensor(B_tensors, B_tensors[0].scaling_mode()); + GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); std::vector C_tensors; std::vector D_group_tensors; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 2c8c2093c6..bb29d58de4 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1115,20 +1115,50 @@ struct TensorShapeInfo { // Create from GroupedTensor static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { - return {t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, - t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + // When per-tensor dims are not provided, we must be in the uniform-shape case. + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + return {first_ptr, + last_ptr, t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, - t->get_common_first_dim(), t->get_common_last_dim()}; + uniform_first, + uniform_last}; } // Create for C tensor (uses D's dimensions, only has offsets) static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D) { - return {nullptr, nullptr, + const bool has_first = D->first_dims.has_data(); + const bool has_last = D->last_dims.has_data(); + NVTE_CHECK(has_first || D->all_same_first_dim(), + "GroupedTensor D is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || D->all_same_last_dim(), + "GroupedTensor D is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(D->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); + + return {first_ptr, + last_ptr, C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, - D->get_common_first_dim(), D->get_common_last_dim()}; + uniform_first, + uniform_last}; } }; @@ -1144,10 +1174,9 @@ inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) if (t->all_same_last_dim()) { // logical_shape[1] is the common N return static_cast(t->logical_shape.data[1]); - } else { - // logical_shape[1] is sum_of_N, divide by num_tensors - return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); } + // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); } // Workspace layout for grouped GEMM @@ -1163,6 +1192,7 @@ struct GroupedGemmSetupWorkspace { float **beta_ptrs; // Initialize from workspace buffer + // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, size_t alignment) { GroupedGemmSetupWorkspace ws; @@ -1170,6 +1200,7 @@ struct GroupedGemmSetupWorkspace { const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); + // Pointer arrays first (all 8-byte aligned) ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); @@ -1178,27 +1209,30 @@ struct GroupedGemmSetupWorkspace { offset += ptr_size; ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + + // Int arrays last (4-byte aligned, always satisfied after pointer arrays) ws.M = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.K = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; offset = ((offset + alignment - 1) / alignment) * alignment; return ws; } - // Calculate required size for setup workspace (pointer arrays + M/N/K + alpha/beta ptrs) + // Calculate required size for setup workspace (pointer arrays + M/N/K) static size_t required_setup_size(size_t num_tensors, size_t alignment) { const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - size_t size = 4 * ptr_size + 3 * int_size + 2 * ptr_size; // M, N, K only (no LDA/LDB/LDC/LDD) + // Layout: 6 ptr arrays, then 3 int arrays (no padding needed) + size_t size = 6 * ptr_size + 3 * int_size; size = ((size + alignment - 1) / alignment) * alignment; return size; } @@ -1220,12 +1254,16 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor NVTE_CHECK(outputD->num_tensors == num_tensors, "Grouped GEMM: A and D must have the same num_tensors"); - auto is_fp8_or_16bit = [](DType dtype) { - return dtype == DType::kFloat8E4M3 || dtype == DType::kFloat8E5M2 || - dtype == DType::kBFloat16 || dtype == DType::kFloat16; + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2 || + dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16; }; - auto is_output_dtype = [](DType dtype) { - return dtype == DType::kBFloat16 || dtype == DType::kFloat16 || dtype == DType::kFloat32; + auto is_output_dtype = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16 || + dtype == transformer_engine::DType::kFloat32; }; NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), "Grouped GEMM inputs must be FP8, BF16, or FP16."); @@ -1321,7 +1359,8 @@ inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cublasLtMatrixLayoutOpaque_t &descB, cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, const GroupedGemmWorkspace &ws, + cublasLtMatrixLayoutOpaque_t &descD, + const GroupedGemmSetupWorkspace &ws, bool transa, bool transb, bool a_columnwise, bool b_columnwise, size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, cudaDataType_t D_type) { @@ -1366,6 +1405,10 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera &alphabeta_batch_stride, sizeof(int64_t))); } +// Constants for grouped GEMM workspace (declared early for use in heuristics) +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, cublasLtMatmulDescOpaque_t &matmulDesc, cublasLtMatrixLayoutOpaque_t &descA, @@ -1442,9 +1485,11 @@ __global__ void setup_grouped_gemm_kernel( D_ptrs[idx] = d_base + d_offset * d_elem_size; // Compute M, N, K dimensions - M[idx] = static_cast(transa ? a_last : a_first); - K[idx] = static_cast(transa ? a_first : a_last); - N[idx] = static_cast(transb ? b_first : b_last); + // Test stores A as {K,M} when !transa, {M,K} when transa + // Test stores B as {N,K} when !transb, {K,N} when transb + M[idx] = static_cast(transa ? a_first : a_last); + K[idx] = static_cast(transa ? a_last : a_first); + N[idx] = static_cast(transb ? b_last : b_first); // Fill alpha/beta pointers (same for all groups) alpha_ptrs[idx] = alpha_ptr; @@ -1453,7 +1498,7 @@ __global__ void setup_grouped_gemm_kernel( // Launch the setup kernel to populate workspace arrays inline void launch_grouped_gemm_setup( - const GroupedGemmWorkspace &ws, const transformer_engine::GroupedTensor *A, + const GroupedGemmSetupWorkspace &ws, const transformer_engine::GroupedTensor *A, const transformer_engine::GroupedTensor *B, const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, const transformer_engine::Tensor *beta_tensor, const char *a_base, const char *b_base, @@ -1482,10 +1527,6 @@ inline void launch_grouped_gemm_setup( NVTE_CHECK_CUDA(cudaGetLastError()); } -// Constants for grouped GEMM workspace -static constexpr size_t kGroupedGemmAlignment = 256; -static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB - inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); } @@ -1563,6 +1604,28 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT cublasLtMatmulDescOpaque_t matmulDesc; init_matmul_desc(matmulDesc, op_A, op_B); + // Set FP8 scale pointers if needed + const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); + const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); + if (is_fp8_a || is_fp8_b) { + // For FP8 grouped GEMM, we need to pass scale_inv pointers + // The scale_inv arrays contain one float per tensor in the group + if (is_fp8_a) { + void *a_scale_inv = A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr + : inputA->scale_inv.dptr; + NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + } + if (is_fp8_b) { + void *b_scale_inv = B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr + : inputB->scale_inv.dptr; + NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + } + } + // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 948058295e..246fb5fefd 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,6 +11,8 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ +#include + #include "transformer_engine.h" #ifdef __cplusplus From 785df3440a443b72340dfdf33db7391280e3a968 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 17:26:49 +0000 Subject: [PATCH 06/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_gemm.cu | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index bb29d58de4..55f52a1c4d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1123,18 +1123,17 @@ struct TensorShapeInfo { NVTE_CHECK(has_last || t->all_same_last_dim(), "GroupedTensor is missing last_dims for varying shapes"); - const int64_t *first_ptr = has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *first_ptr = + has_first ? static_cast(t->first_dims.dptr) : nullptr; const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); - return {first_ptr, - last_ptr, + return {first_ptr, last_ptr, t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, - uniform_first, - uniform_last}; + uniform_first, uniform_last}; } // Create for C tensor (uses D's dimensions, only has offsets) @@ -1153,12 +1152,10 @@ struct TensorShapeInfo { const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); - return {first_ptr, - last_ptr, + return {first_ptr, last_ptr, C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, - uniform_first, - uniform_last}; + uniform_first, uniform_last}; } }; @@ -1360,9 +1357,9 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cublasLtMatrixLayoutOpaque_t &descB, cublasLtMatrixLayoutOpaque_t &descC, cublasLtMatrixLayoutOpaque_t &descD, - const GroupedGemmSetupWorkspace &ws, - bool transa, bool transb, bool a_columnwise, bool b_columnwise, - size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, + const GroupedGemmSetupWorkspace &ws, bool transa, bool transb, + bool a_columnwise, bool b_columnwise, size_t num_tensors, + cudaDataType_t A_type, cudaDataType_t B_type, cudaDataType_t D_type) { // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. @@ -1611,15 +1608,15 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // For FP8 grouped GEMM, we need to pass scale_inv pointers // The scale_inv arrays contain one float per tensor in the group if (is_fp8_a) { - void *a_scale_inv = A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr - : inputA->scale_inv.dptr; + void *a_scale_inv = + A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr : inputA->scale_inv.dptr; NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); } if (is_fp8_b) { - void *b_scale_inv = B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr - : inputB->scale_inv.dptr; + void *b_scale_inv = + B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr : inputB->scale_inv.dptr; NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); From 1329b3746abfe3f9d845e90da7945bede6e3893c Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 10 Dec 2025 22:34:16 +0100 Subject: [PATCH 07/40] fix Signed-off-by: Pawel Gadzinski --- .../common/gemm/cublaslt_gemm.cu | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 55f52a1c4d..3662247b51 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1217,9 +1217,6 @@ struct GroupedGemmSetupWorkspace { ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.K = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - - offset = ((offset + alignment - 1) / alignment) * alignment; return ws; } @@ -1363,21 +1360,21 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cudaDataType_t D_type) { // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. - const int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); - const int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); - const int *lda = rowa; - const int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); - const int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); - const int *ldb = rowb; - - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void *)rowa, - (void *)cola, (void *)lda)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void *)rowb, - (void *)colb, (void *)ldb)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void *)ws.M, - (void *)ws.N, (void *)ws.M)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void *)ws.M, - (void *)ws.N, (void *)ws.M)); + int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); + int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + int *lda = rowa; + int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); + int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + int *ldb = rowb; + + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, + rowa, cola, lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, + rowb, colb, ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, + ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, + ws.M, ws.N, ws.M)); } inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, From 47c58be8ce0ee14fc26a90a2f8b3ad8035283b4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 21:35:06 +0000 Subject: [PATCH 08/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_gemm.cu | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 3662247b51..91405bd42f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1367,14 +1367,10 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); int *ldb = rowb; - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, - rowa, cola, lda)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, - rowb, colb, ldb)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, - ws.M, ws.N, ws.M)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, - ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); } inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, From a155a8a3dd17663c82882f64b30a5a118ba3695b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 11 Dec 2025 11:55:44 +0100 Subject: [PATCH 09/40] Grouped GEMM: code cleanup and NULL C support - Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 210 ++++++++---------- .../common/gemm/cublaslt_gemm.cu | 163 +++++++------- .../common/include/transformer_engine/gemm.h | 34 +-- 3 files changed, 203 insertions(+), 204 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index bff175f405..5e5144fa4c 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -28,7 +29,6 @@ using namespace test; namespace { enum class InputCase { - kFP8Delayed, kFP8Current, kBF16, }; @@ -40,17 +40,37 @@ enum class ShapeCase { kAllDifferent, }; +// Custom deleters for RAII +struct CudaDeleter { + void operator()(void* p) const { if (p) cudaFree(p); } +}; +struct GroupedTensorDeleter { + void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); } +}; + +template +using CudaPtr = std::unique_ptr; +using GroupedTensorHandle = std::unique_ptr, GroupedTensorDeleter>; + +// Helper to allocate CUDA memory into a CudaPtr +template +CudaPtr cuda_alloc(size_t bytes) { + void* ptr = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes)); + return CudaPtr(static_cast(ptr)); +} + // Helper owning GPU buffers that back NVTEGroupedTensor. // NVTEGroupedTensor does not own memory; data/offsets/scales // must be allocated and freed by the test. struct GroupedBuffers { - NVTEGroupedTensor handle{nullptr}; - void* data{nullptr}; - void* scale_inv{nullptr}; - int64_t* first_dims_dev{nullptr}; - int64_t* last_dims_dev{nullptr}; - int64_t* offsets_dev{nullptr}; - void* columnwise_data{nullptr}; + GroupedTensorHandle handle; + CudaPtr<> data; + CudaPtr<> scale_inv; + CudaPtr first_dims_dev; + CudaPtr last_dims_dev; + CudaPtr offsets_dev; + CudaPtr<> columnwise_data; NVTEShape logical_shape{}; std::vector offsets_host; std::vector tensor_bytes; @@ -62,65 +82,13 @@ struct GroupedBuffers { GroupedBuffers() = default; GroupedBuffers(const GroupedBuffers&) = delete; GroupedBuffers& operator=(const GroupedBuffers&) = delete; - GroupedBuffers(GroupedBuffers&& other) noexcept { - *this = std::move(other); - } - GroupedBuffers& operator=(GroupedBuffers&& other) noexcept { - if (this == &other) return *this; - handle = other.handle; - data = other.data; - scale_inv = other.scale_inv; - first_dims_dev = other.first_dims_dev; - last_dims_dev = other.last_dims_dev; - offsets_dev = other.offsets_dev; - logical_shape = other.logical_shape; - offsets_host = std::move(other.offsets_host); - tensor_bytes = std::move(other.tensor_bytes); - num_tensors = other.num_tensors; - elem_size = other.elem_size; - dtype = other.dtype; - scaling_mode = other.scaling_mode; - - other.handle = nullptr; - other.data = nullptr; - other.scale_inv = nullptr; - other.first_dims_dev = nullptr; - other.last_dims_dev = nullptr; - other.offsets_dev = nullptr; - other.num_tensors = 0; - return *this; - } + GroupedBuffers(GroupedBuffers&&) = default; + GroupedBuffers& operator=(GroupedBuffers&&) = default; + ~GroupedBuffers() = default; - ~GroupedBuffers() { - if (data) { - cudaFree(data); - data = nullptr; - } - if (scale_inv) { - cudaFree(scale_inv); - scale_inv = nullptr; - } - if (columnwise_data) { - cudaFree(columnwise_data); - columnwise_data = nullptr; - } - if (first_dims_dev) { - cudaFree(first_dims_dev); - first_dims_dev = nullptr; - } - if (last_dims_dev) { - cudaFree(last_dims_dev); - last_dims_dev = nullptr; - } - if (offsets_dev) { - cudaFree(offsets_dev); - offsets_dev = nullptr; - } - if (handle) { - nvte_destroy_grouped_tensor(handle); - handle = nullptr; - } - } + // Convenience accessors for raw pointers + NVTEGroupedTensor get_handle() const { return handle.get(); } + void* get_data() const { return data.get(); } }; size_t grouped_setup_workspace_size(const size_t num_tensors) { @@ -211,7 +179,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, size_t logical_data[2] = {static_cast(logical_first), static_cast(logical_last)}; grouped.logical_shape = nvte_make_shape(logical_data, 2); - grouped.handle = nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape); + grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape)); const int64_t last_idx = static_cast(num_tensors - 1); const int64_t total_elems = need_offsets @@ -219,59 +187,60 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, : (logical_first * logical_last); const size_t total_bytes = static_cast(total_elems) * elem_size; - NVTE_CHECK_CUDA(cudaMalloc(&grouped.data, total_bytes)); + grouped.data = cuda_alloc(total_bytes); for (size_t i = 0; i < num_tensors; ++i) { const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data) + offset_bytes, + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, tensors[i]->rowwise_dptr(), grouped.tensor_bytes[i], cudaMemcpyDeviceToDevice)); } - NVTEBasicTensor data_tensor{grouped.data, static_cast(dtype), grouped.logical_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseData, &data_tensor); + NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), grouped.logical_shape}; + NVTEGroupedTensor h = grouped.handle.get(); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor); const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); if (include_columnwise) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.columnwise_data, total_bytes)); + grouped.columnwise_data = cuda_alloc(total_bytes); for (size_t i = 0; i < num_tensors; ++i) { const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data) + offset_bytes, + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data.get()) + offset_bytes, tensors[i]->columnwise_dptr(), grouped.tensor_bytes[i], cudaMemcpyDeviceToDevice)); } - NVTEBasicTensor col_tensor{grouped.columnwise_data, + NVTEBasicTensor col_tensor{grouped.columnwise_data.get(), static_cast(dtype), grouped.logical_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseData, &col_tensor); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor); } if (!same_first) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.first_dims_dev, num_tensors * sizeof(int64_t))); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev, first_dims.data(), + grouped.first_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor fd_tensor{grouped.first_dims_dev, kNVTEInt64, fd_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedFirstDims, &fd_tensor); + NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor); } if (!same_last) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.last_dims_dev, num_tensors * sizeof(int64_t))); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev, last_dims.data(), + grouped.last_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor ld_tensor{grouped.last_dims_dev, kNVTEInt64, ld_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedLastDims, &ld_tensor); + NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor); } if (!same_first || !same_last) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.offsets_dev, num_tensors * sizeof(int64_t))); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev, offsets.data(), + grouped.offsets_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor off_tensor{grouped.offsets_dev, kNVTEInt64, off_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedTensorOffsets, &off_tensor); + NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor); } if (isFp8Type(dtype)) { @@ -280,13 +249,13 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, tensors[i]->to_cpu(); scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; } - NVTE_CHECK_CUDA(cudaMalloc(&grouped.scale_inv, sizeof(float) * num_tensors)); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv, scale_inv_cpu.data(), + grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor scale_tensor{grouped.scale_inv, kNVTEFloat32, scale_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseScaleInv, &scale_tensor); - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); + NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); } return grouped; @@ -321,6 +290,7 @@ struct TestParams { bool transa; bool transb; ShapeCase shape_case; + bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) }; // Returns a vector of (M, N, K) tuples for each GEMM in the group. @@ -332,12 +302,14 @@ std::vector> make_shapes(ShapeCase scase) { case ShapeCase::kAllSame: return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; case ShapeCase::kSameFirst: - return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; + // Same M (first dim), varying N and K + return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}}; case ShapeCase::kSameLast: - return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; + // Same N (last dim), varying M and K + return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}}; case ShapeCase::kAllDifferent: default: - return {{64, 96, 32}, {64, 96, 48}, {64, 96, 64}}; + return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}}; } } @@ -430,9 +402,11 @@ void run_grouped_gemm_case(const TestParams& params) { for (size_t i = 0; i < num_gemms; ++i) { const auto [M, N, K] = shapes[i]; (void)K; - C_tensors.emplace_back(Tensor("C" + std::to_string(i), - std::vector{static_cast(M), static_cast(N)}, - DType::kBFloat16)); + if (!params.use_null_c) { + C_tensors.emplace_back(Tensor("C" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + } D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), std::vector{static_cast(M), static_cast(N)}, DType::kBFloat16)); @@ -441,11 +415,16 @@ void run_grouped_gemm_case(const TestParams& params) { std::vector C_views, D_views; for (size_t i = 0; i < num_gemms; ++i) { - C_views.push_back(&C_tensors[i]); + if (!params.use_null_c) { + C_views.push_back(&C_tensors[i]); + } D_views.push_back(&D_group_tensors[i]); } - GroupedBuffers grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + std::optional grouped_C; + if (!params.use_null_c) { + grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + } GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); Tensor alpha_tensor("alpha", std::vector{1}, DType::kFloat32); @@ -462,11 +441,11 @@ void run_grouped_gemm_case(const TestParams& params) { nvte_grouped_gemm(params.transa, params.transb, alpha_tensor.data(), - grouped_A.handle, - grouped_B.handle, + grouped_A.get_handle(), + grouped_B.get_handle(), beta_tensor.data(), - grouped_C.handle, - grouped_D.handle, + params.use_null_c ? nullptr : grouped_C->get_handle(), + grouped_D.get_handle(), setup_ws.data(), cublas_ws.data(), nullptr, @@ -482,7 +461,7 @@ void run_grouped_gemm_case(const TestParams& params) { D_multi[i].dtype()); const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), - static_cast(grouped_D.data) + offset_bytes, + static_cast(grouped_D.get_data()) + offset_bytes, grouped_D.tensor_bytes[i], cudaMemcpyDeviceToDevice)); grouped_split.to_cpu(); @@ -504,22 +483,25 @@ TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { } std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { - constexpr const char* kInputNames[] = {"FP8Delayed", "FP8Current", "BF16"}; + constexpr const char* kInputNames[] = {"FP8Current", "BF16"}; constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + "tb" + (info.param.transb ? "T" : "N"); + const std::string null_c = info.param.use_null_c ? "_NullC" : ""; return std::string(kInputNames[static_cast(info.param.input_case)]) + "_" + - kShapeNames[static_cast(info.param.shape_case)] + "_" + layout; + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c; } const std::vector kTestParams = { - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent}, - {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent}, - {InputCase::kFP8Current, false, false, ShapeCase::kAllSame}, - {InputCase::kBF16, true, false, ShapeCase::kSameFirst}, - {InputCase::kBF16, false, true, ShapeCase::kSameLast}, - {InputCase::kBF16, false, false, ShapeCase::kAllSame}, - {InputCase::kBF16, true, true, ShapeCase::kAllDifferent}, + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, + {InputCase::kBF16, false, true, ShapeCase::kSameLast, false}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false}, + // Test NULL C (valid when beta=0) + {InputCase::kBF16, false, false, ShapeCase::kAllSame, true}, }; INSTANTIATE_TEST_SUITE_P(OperatorTest, diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 91405bd42f..9d9a5097d4 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1190,8 +1190,7 @@ struct GroupedGemmSetupWorkspace { // Initialize from workspace buffer // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, - size_t alignment) { + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { GroupedGemmSetupWorkspace ws; size_t offset = 0; const size_t ptr_size = num_tensors * sizeof(void *); @@ -1243,8 +1242,11 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); NVTE_CHECK(inputB->num_tensors == num_tensors, "Grouped GEMM: A and B must have the same num_tensors"); - NVTE_CHECK(inputC->num_tensors == num_tensors, - "Grouped GEMM: A and C must have the same num_tensors"); + // C can be NULL (will use D as C when beta=0) + if (inputC != nullptr) { + NVTE_CHECK(inputC->num_tensors == num_tensors, + "Grouped GEMM: A and C must have the same num_tensors"); + } NVTE_CHECK(outputD->num_tensors == num_tensors, "Grouped GEMM: A and D must have the same num_tensors"); @@ -1261,8 +1263,13 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor }; NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), "Grouped GEMM inputs must be FP8, BF16, or FP16."); - NVTE_CHECK(is_output_dtype(inputC->dtype()) && is_output_dtype(outputD->dtype()), - "Grouped GEMM outputs must be BF16, FP16, or FP32."); + // Only check C dtype if C is provided + if (inputC != nullptr) { + NVTE_CHECK(is_output_dtype(inputC->dtype()), + "Grouped GEMM: C must be BF16, FP16, or FP32."); + } + NVTE_CHECK(is_output_dtype(outputD->dtype()), + "Grouped GEMM: D must be BF16, FP16, or FP32."); NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), @@ -1273,6 +1280,7 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and // fallback to column-wise data when row-wise is absent. struct GroupedOperandSelection { + const transformer_engine::GroupedTensor *tensor = nullptr; const char *base = nullptr; transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; bool trans = false; @@ -1296,6 +1304,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: const DType row_dtype = t->data.dtype; const DType col_dtype = t->columnwise_data.dtype; GroupedOperandSelection sel; + sel.tensor = t; sel.trans = trans; const DType rep_dtype = has_row ? row_dtype : col_dtype; @@ -1327,6 +1336,9 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). if (!has_row && has_col) { + // On Hopper FP8, this would break TN requirement - should have been handled above + NVTE_CHECK(!is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = !sel.trans; @@ -1334,10 +1346,10 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: return sel; } - // Default: use row-wise data (or column-wise if row-wise absent, covered above). - sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); - sel.dtype = has_row ? row_dtype : col_dtype; - sel.use_columnwise = !has_row && has_col; + // Default: use row-wise data (column-wise case already handled above) + sel.base = static_cast(t->data.dptr); + sel.dtype = row_dtype; + sel.use_columnwise = false; return sel; } @@ -1354,17 +1366,22 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cublasLtMatrixLayoutOpaque_t &descB, cublasLtMatrixLayoutOpaque_t &descC, cublasLtMatrixLayoutOpaque_t &descD, - const GroupedGemmSetupWorkspace &ws, bool transa, bool transb, - bool a_columnwise, bool b_columnwise, size_t num_tensors, - cudaDataType_t A_type, cudaDataType_t B_type, - cudaDataType_t D_type) { + const GroupedGemmSetupWorkspace &ws, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, + const transformer_engine::GroupedTensor *D, + size_t num_tensors) { + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); + // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. - int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); - int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); + int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); int *lda = rowa; - int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); - int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); + int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); int *ldb = rowb; NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); @@ -1395,6 +1412,31 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera &alphabeta_batch_stride, sizeof(int64_t))); } +inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel) { + const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); + const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); + if (!is_fp8_a && !is_fp8_b) return; + + if (is_fp8_a) { + void *a_scale_inv = A_sel.use_columnwise + ? A_sel.tensor->columnwise_scale_inv.dptr + : A_sel.tensor->scale_inv.dptr; + NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + } + if (is_fp8_b) { + void *b_scale_inv = B_sel.use_columnwise + ? B_sel.tensor->columnwise_scale_inv.dptr + : B_sel.tensor->scale_inv.dptr; + NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + } +} + // Constants for grouped GEMM workspace (declared early for use in heuristics) static constexpr size_t kGroupedGemmAlignment = 256; static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB @@ -1488,20 +1530,20 @@ __global__ void setup_grouped_gemm_kernel( // Launch the setup kernel to populate workspace arrays inline void launch_grouped_gemm_setup( - const GroupedGemmSetupWorkspace &ws, const transformer_engine::GroupedTensor *A, - const transformer_engine::GroupedTensor *B, const transformer_engine::GroupedTensor *C, + const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, const char *a_base, const char *b_base, - size_t a_elem_size, size_t b_elem_size, bool transa, bool transb, size_t num_tensors, - cudaStream_t stream) { - TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A); - TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B); + const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { + TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); + TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); const char *c_base = static_cast(C->data.dptr); char *d_base = static_cast(D->data.dptr); + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); @@ -1510,9 +1552,9 @@ inline void launch_grouped_gemm_setup( setup_grouped_gemm_kernel<<>>( ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, - a_base, b_base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, - c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), transa, transb, num_tensors); + A_sel.base, B_sel.base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, + b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), A_sel.trans, B_sel.trans, num_tensors); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -1532,7 +1574,7 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // Convert to internal types const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); - const GroupedTensor *inputC = convertNVTEGroupedTensorCheck(C); + const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); const Tensor *beta_tensor = convertNVTETensorCheck(beta); @@ -1540,19 +1582,16 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); // Validate inputs and num_tensors - validate_grouped_gemm_inputs(inputA, inputB, inputC, outputD); + validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD); + + // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) + const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; const size_t num_tensors = inputA->num_tensors; // Select operand storage (row-wise vs column-wise) and adjust transpose flags to // mirror the non-grouped GEMM logic for FP8 layout constraints. - bool transa_flag = static_cast(transa); - bool transb_flag = static_cast(transb); - const auto A_sel = select_grouped_operand(inputA, transa_flag, /*is_A=*/true); - const auto B_sel = select_grouped_operand(inputB, transb_flag, /*is_A=*/false); - transa_flag = A_sel.trans; - transb_flag = B_sel.trans; - const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); - const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); + const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); // Workspaces: setup (pointer arrays) and cuBLAS const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); @@ -1563,65 +1602,35 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, "Grouped GEMM cuBLAS workspace"); - NVTE_CHECK(cublas_workspace_ptr != nullptr, "Grouped GEMM: cuBLAS workspace pointer is null"); - auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( - static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); - launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, alpha_tensor, - beta_tensor, A_sel.base, B_sel.base, a_elem_size, b_elem_size, - transa_flag, transb_flag, num_tensors, stream); + static_cast(setup_workspace_ptr), num_tensors); + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, + alpha_tensor, beta_tensor, num_tensors, stream); // Get cuBLAS handle using cublasHandleManager = detail::HandleManager; cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); - // Get data types - const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); - const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); - const cudaDataType_t D_type = get_cuda_dtype(outputD->dtype()); - // Setup cuBLAS operations - cublasOperation_t op_A = transa_flag ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t op_B = transb_flag ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; // Create grouped matrix layouts cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; - init_matrix_layouts(descA, descB, descC, descD, setup_workspace, transa_flag, transb_flag, - A_sel.use_columnwise, B_sel.use_columnwise, num_tensors, A_type, B_type, - D_type); + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, + num_tensors); // Create matmul descriptor cublasLtMatmulDescOpaque_t matmulDesc; init_matmul_desc(matmulDesc, op_A, op_B); - - // Set FP8 scale pointers if needed - const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); - const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); - if (is_fp8_a || is_fp8_b) { - // For FP8 grouped GEMM, we need to pass scale_inv pointers - // The scale_inv arrays contain one float per tensor in the group - if (is_fp8_a) { - void *a_scale_inv = - A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr : inputA->scale_inv.dptr; - NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); - } - if (is_fp8_b) { - void *b_scale_inv = - B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr : inputB->scale_inv.dptr; - NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); - } - } + set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); int64_t avg_k_val = - avg_k ? *avg_k : (transa_flag ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); + avg_k ? *avg_k : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) : compute_avg_last_dim(A_sel.tensor)); // Heuristic selection cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 246fb5fefd..02cf01853d 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -239,19 +239,27 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous * memory layout and shape metadata. * - * \param[in] transa Whether to transpose A matrices. - * \param[in] transb Whether to transpose B matrices. - * \param[in] alpha Scale multiplier for A @ B (NVTETensor with num_tensors elements, - * or single element for uniform alpha). - * \param[in] A Input grouped tensor A. - * \param[in] B Input grouped tensor B. - * \param[in] beta Scale multiplier for C (NVTETensor with num_tensors elements, - * or single element for uniform beta). - * \param[in] C Input grouped tensor C (can be NULL for beta=0). - * \param[out] D Output grouped tensor D. - * \param[in] workspace Workspace tensor for intermediate computations. - * \param[in] config Matrix multiplication configuration. - * \param[in] stream CUDA stream for the operation. + * \param[in] transa Whether to transpose A matrices. + * \param[in] transb Whether to transpose B matrices. + * \param[in] alpha Scale multiplier for A @ B (single element NVTETensor). + * \param[in] A Input grouped tensor A. + * \param[in] B Input grouped tensor B. + * \param[in] beta Scale multiplier for C (single element NVTETensor). + * \param[in] C Input grouped tensor C (can be NULL for beta=0). + * \param[out] D Output grouped tensor D. + * \param[in] workspace_setup Workspace tensor for pointer array setup. + * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. + * \param[in] config Matrix multiplication configuration. + * \param[in] stream CUDA stream for the operation. + * \param[in] avg_m Optional hint for average M dimension across all matrices in the + * group. Used by cuBLASLt for algorithm selection heuristics. + * If NULL, computed automatically from D's logical shape. + * \param[in] avg_n Optional hint for average N dimension across all matrices in the + * group. Used by cuBLASLt for algorithm selection heuristics. + * If NULL, computed automatically from D's logical shape. + * \param[in] avg_k Optional hint for average K (reduction) dimension across all + * matrices in the group. Used by cuBLASLt for algorithm selection + * heuristics. If NULL, computed automatically from A's logical shape. * * Requirements: * - A, B, C (if provided), D must have the same num_tensors From 3b2fcdf3137cec31b83dc6dc0f64e2e367aa6f9b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 10:57:26 +0000 Subject: [PATCH 10/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_gemm.cu | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9d9a5097d4..7f2635943b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1265,11 +1265,9 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor "Grouped GEMM inputs must be FP8, BF16, or FP16."); // Only check C dtype if C is provided if (inputC != nullptr) { - NVTE_CHECK(is_output_dtype(inputC->dtype()), - "Grouped GEMM: C must be BF16, FP16, or FP32."); + NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); } - NVTE_CHECK(is_output_dtype(outputD->dtype()), - "Grouped GEMM: D must be BF16, FP16, or FP32."); + NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), @@ -1337,8 +1335,9 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). if (!has_row && has_col) { // On Hopper FP8, this would break TN requirement - should have been handled above - NVTE_CHECK(!is_fp8 || non_tn_fp8_ok, - "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); + NVTE_CHECK( + !is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = !sel.trans; @@ -1369,8 +1368,7 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, const GroupedOperandSelection &B_sel, - const transformer_engine::GroupedTensor *D, - size_t num_tensors) { + const transformer_engine::GroupedTensor *D, size_t num_tensors) { const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); @@ -1420,17 +1418,15 @@ inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, if (!is_fp8_a && !is_fp8_b) return; if (is_fp8_a) { - void *a_scale_inv = A_sel.use_columnwise - ? A_sel.tensor->columnwise_scale_inv.dptr - : A_sel.tensor->scale_inv.dptr; + void *a_scale_inv = A_sel.use_columnwise ? A_sel.tensor->columnwise_scale_inv.dptr + : A_sel.tensor->scale_inv.dptr; NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); } if (is_fp8_b) { - void *b_scale_inv = B_sel.use_columnwise - ? B_sel.tensor->columnwise_scale_inv.dptr - : B_sel.tensor->scale_inv.dptr; + void *b_scale_inv = B_sel.use_columnwise ? B_sel.tensor->columnwise_scale_inv.dptr + : B_sel.tensor->scale_inv.dptr; NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); @@ -1604,8 +1600,8 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( static_cast(setup_workspace_ptr), num_tensors); - launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, - alpha_tensor, beta_tensor, num_tensors, stream); + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, + beta_tensor, num_tensors, stream); // Get cuBLAS handle using cublasHandleManager = detail::HandleManager; @@ -1629,8 +1625,9 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // K dimension: if transa, K is A's first dim; if not, K is A's last dim int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); - int64_t avg_k_val = - avg_k ? *avg_k : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) : compute_avg_last_dim(A_sel.tensor)); + int64_t avg_k_val = avg_k ? *avg_k + : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) + : compute_avg_last_dim(A_sel.tensor)); // Heuristic selection cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, From 5b0582bbf0fd05773242df67836ec263014d52dd Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 11 Dec 2025 12:15:12 +0100 Subject: [PATCH 11/40] Grouped GEMM: per-matrix alpha/beta support - Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 15 +++++++----- .../common/gemm/cublaslt_gemm.cu | 24 ++++++++++++++----- .../common/include/transformer_engine/gemm.h | 4 ++-- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 5e5144fa4c..82b5bd3803 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -427,12 +427,15 @@ void run_grouped_gemm_case(const TestParams& params) { } GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); - Tensor alpha_tensor("alpha", std::vector{1}, DType::kFloat32); - Tensor beta_tensor("beta", std::vector{1}, DType::kFloat32); - const float alpha_val = 1.f; - const float beta_val = 0.f; - NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), &alpha_val, sizeof(float), cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), &beta_val, sizeof(float), cudaMemcpyHostToDevice)); + // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) + Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); + std::vector alpha_vals(num_gemms, 1.f); + std::vector beta_vals(num_gemms, 0.f); + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7f2635943b..caa394d549 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1237,7 +1237,9 @@ struct GroupedGemmSetupWorkspace { inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, const transformer_engine::GroupedTensor *inputB, const transformer_engine::GroupedTensor *inputC, - const transformer_engine::GroupedTensor *outputD) { + const transformer_engine::GroupedTensor *outputD, + const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor) { const size_t num_tensors = inputA->num_tensors; NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); NVTE_CHECK(inputB->num_tensors == num_tensors, @@ -1250,6 +1252,16 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor NVTE_CHECK(outputD->num_tensors == num_tensors, "Grouped GEMM: A and D must have the same num_tensors"); + // Validate alpha/beta have per-matrix values + const size_t alpha_numel = alpha_tensor->data.shape.numel(); + const size_t beta_numel = beta_tensor->data.shape.numel(); + NVTE_CHECK(alpha_numel == num_tensors, + "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", + alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, + "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", + beta_numel); + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2 || @@ -1481,7 +1493,7 @@ __global__ void setup_grouped_gemm_kernel( TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, // Element sizes size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, - // Alpha/beta pointers (same for all groups) + // Alpha/beta pointers (per-matrix arrays) float *alpha_ptr, float *beta_ptr, // Transpose flags bool transa, bool transb, @@ -1519,9 +1531,9 @@ __global__ void setup_grouped_gemm_kernel( K[idx] = static_cast(transa ? a_last : a_first); N[idx] = static_cast(transb ? b_last : b_first); - // Fill alpha/beta pointers (same for all groups) - alpha_ptrs[idx] = alpha_ptr; - beta_ptrs[idx] = beta_ptr; + // Fill alpha/beta pointers (per-matrix) + alpha_ptrs[idx] = alpha_ptr + idx; + beta_ptrs[idx] = beta_ptr + idx; } // Launch the setup kernel to populate workspace arrays @@ -1578,7 +1590,7 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); // Validate inputs and num_tensors - validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD); + validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 02cf01853d..9dfa009115 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -241,10 +241,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * * \param[in] transa Whether to transpose A matrices. * \param[in] transb Whether to transpose B matrices. - * \param[in] alpha Scale multiplier for A @ B (single element NVTETensor). + * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). * \param[in] A Input grouped tensor A. * \param[in] B Input grouped tensor B. - * \param[in] beta Scale multiplier for C (single element NVTETensor). + * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). * \param[in] C Input grouped tensor C (can be NULL for beta=0). * \param[out] D Output grouped tensor D. * \param[in] workspace_setup Workspace tensor for pointer array setup. From 101766bcb15e9cd6a9df01eaa6e5b5b9d9989f40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 11:17:48 +0000 Subject: [PATCH 12/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_gemm.cu | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index caa394d549..1d63cf65cf 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1255,12 +1255,10 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Validate alpha/beta have per-matrix values const size_t alpha_numel = alpha_tensor->data.shape.numel(); const size_t beta_numel = beta_tensor->data.shape.numel(); - NVTE_CHECK(alpha_numel == num_tensors, - "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", - alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, - "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", - beta_numel); + NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, + ") elements, got ", alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, + ") elements, got ", beta_numel); auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || From 1167f7539fb91a7d8cb7de2ea252e89415967073 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 11 Dec 2025 12:25:28 +0100 Subject: [PATCH 13/40] Fix alpha/beta numel - use SimpleTensor::numel() Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski --- transformer_engine/common/gemm/cublaslt_gemm.cu | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1d63cf65cf..b8aa2a8ba3 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1253,12 +1253,14 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor "Grouped GEMM: A and D must have the same num_tensors"); // Validate alpha/beta have per-matrix values - const size_t alpha_numel = alpha_tensor->data.shape.numel(); - const size_t beta_numel = beta_tensor->data.shape.numel(); - NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, - ") elements, got ", alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, - ") elements, got ", beta_numel); + const size_t alpha_numel = alpha_tensor->data.numel(); + const size_t beta_numel = beta_tensor->data.numel(); + NVTE_CHECK(alpha_numel == num_tensors, + "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", + alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, + "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", + beta_numel); auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || From 00eb18662846645875c9da5edaeb37b216c8833c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 16 Dec 2025 17:41:24 -0800 Subject: [PATCH 14/40] Einsum WIP 1 --- build_tools/build_ext.py | 6 ++ transformer_engine/jax/cpp_extensions/base.py | 12 +-- transformer_engine/jax/cpp_extensions/gemm.py | 10 +-- transformer_engine/jax/dense.py | 87 +++++++++++++------ transformer_engine/jax/sharding.py | 2 + 5 files changed, 79 insertions(+), 38 deletions(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 349858ac49..c269a29874 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -61,6 +61,12 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: f"-DCMAKE_BUILD_TYPE={build_type}", f"-DCMAKE_INSTALL_PREFIX={install_dir}", ] + if bool(int(os.getenv("NVTE_USE_CCACHE", "0"))): + ccache_bin = os.getenv("NVTE_CCACHE_BIN", "ccache") + configure_command += [ + f"-DCMAKE_CXX_COMPILER_LAUNCHER={ccache_bin}", + f"-DCMAKE_CUDA_COMPILER_LAUNCHER={ccache_bin}", + ] configure_command += self.cmake_flags import pybind11 diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 22a4b7dda4..70734ad4c4 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -207,12 +207,12 @@ def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): if batch_dim is None: batch_dim = bdim batch_size = arg.shape[bdim] - elif bdim != batch_dim: - raise ValueError( - "All batched arguments must have the same batch dimension. " - f"Got batch_dims={batch_dims}" - ) - assert batch_dim is not None and batch_size is not None, "Invalid batching config!" + # elif bdim != batch_dim: + # raise ValueError( + # "All batched arguments must have the same batch dimension. " + # f"Got batch_dims={batch_dims}" + # ) + # assert batch_dim is not None and batch_size is not None, "Invalid batching config!" # Loop over batch dimension and collect results all_results = [] diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 55a1700838..7d44643046 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -812,11 +812,11 @@ def batcher( lhs_bdims, _, rhs_bdims, *_ = batch_dims # Validate batch dimensions - if lhs_bdims is not None or rhs_bdims is not None: - assert lhs_bdims == rhs_bdims, ( - "Batched GEMM requires matching batch dimensions, " - f"got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" - ) + # if lhs_bdims is not None or rhs_bdims is not None: + # assert lhs_bdims == rhs_bdims, ( + # "Batched GEMM requires matching batch dimensions, " + # f"got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" + # ) # Use general batcher from BasePrimitive return GemmPrimitive.batcher_impl( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index c499b0651e..f941e598ae 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -69,6 +69,7 @@ def dense( output_axes: Tuple[str, ...] = None, collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, quantizer_set: QuantizerSet = noop_quantizer_set, + batch_dims : Tuple[Sequence[int], Sequence[int]] = ((), ()), ): """Perform dense layer transformation with optional quantization. @@ -109,11 +110,12 @@ def dense( output_axes, collective_op_set, quantizer_set, + batch_dims, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 10)) def _dense( x, kernel, @@ -125,6 +127,7 @@ def _dense( output_axes, collective_op_set, quantizer_set, # need to be a diff_arg for DelayedScaling state management + batch_dims, ): """Internal implementation of dense layer transformation with custom VJP. @@ -157,6 +160,7 @@ def _dense( output_axes, collective_op_set, quantizer_set, + batch_dims, ) return output @@ -172,6 +176,7 @@ def _dense_fwd_rule( output_axes, collective_op_set, quantizer_set, + batch_dims, ): """Forward pass rule for dense layer transformation. @@ -185,9 +190,9 @@ def _dense_fwd_rule( # Check supported input layout x_is_transposed = x.ndim - 1 not in x_contracting_dims k_is_transposed = kernel.ndim - 1 in k_contracting_dims - assert ( - not x_is_transposed and not k_is_transposed - ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." + # assert ( + # not x_is_transposed and not k_is_transposed + # ), f"Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel. {x_contracting_dims=},{x.ndim=},{k_contracting_dims=},{kernel.ndim=}" flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) @@ -237,6 +242,47 @@ def _dense_fwd_rule( ) return output, ctx +def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, + swap_ans=False): + # from: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py#L198 + import itertools + import numpy as np + def _remaining(original, *removed_lists): + removed = set(itertools.chain(*removed_lists)) + return [i for i in original if i not in removed] + + def _ranges_like(*xs): + start = 0 + for x in xs: + x_len = len(x) + yield range(start, start + x_len) + start += x_len + + (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers + x_ndim = x.ndim + x_kept = _remaining(range(x_ndim), x_contract, x_batch) + y_kept = _remaining(range(y.ndim), y_contract, y_batch) + if swap_ans: + ans_batch, ans_y, _ = _ranges_like(x_batch, y_kept, x_kept) + else: + ans_batch, _, ans_y = _ranges_like(x_batch, x_kept, y_kept) + dims = ((ans_y, y_kept), (ans_batch, y_batch)) + x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) + out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) + x_bar = jax.lax.transpose( + # TODO(jberchtold): I'm ignoring the batch_dims here, do I need to explicitly use vmap or something? + tex.gemm(g, y, contracting_dims=dims[0]), + tuple(out_axes) + ) + return x_bar + +def dot_general_transpose_rhs(g, x, y, *, dimension_numbers): + (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers + swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) + y_bar = dot_general_transpose_lhs( + g, y, x, dimension_numbers=swapped_dimension_numbers, + swap_ans=True) + return y_bar def _dense_bwd_rule( contracting_dims, @@ -245,6 +291,7 @@ def _dense_bwd_rule( kernel_axes, output_axes, collective_op_set, + batch_dims, ctx, grad, ): @@ -277,35 +324,21 @@ def _dense_bwd_rule( transpose_batch_sequence=transpose_batch_sequence, ) - # GEMM NT - # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_contracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) - ) - # k_non_contracting_dims - k_contracting_dim = tuple( - dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims - ) + fwd_cdims = (fwd_x_contracting_dims, fwd_k_contracting_dims) + dims = (fwd_cdims, batch_dims) - dgrad = tex.gemm( + dgrad = dot_general_transpose_lhs( casted_grad.get_tensor(usage=TensorUsage.LHS), + casted_x_lhs, casted_kernel_rhs, - contracting_dims=(g_contracting_dim, k_contracting_dim), - transpose_batch_sequence=transpose_batch_sequence, - collective_op=collective_op_set.backward, + dimension_numbers=dims, ) - # GEMM TN - # x_non_contracting_dims - g_contracting_dim = x_contracting_dim = tuple( - range(0, len(x_shape) - len(fwd_x_contracting_dims)) - ) - - wgrad = tex.gemm( + wgrad = dot_general_transpose_rhs( + casted_grad.get_tensor(usage=TensorUsage.LHS), # TODO(jberchtold): should be RHS to use fused kernel for 2x layout? but would need to update dims accordingly casted_x_lhs, - casted_grad.get_tensor(usage=TensorUsage.RHS), - contracting_dims=(x_contracting_dim, g_contracting_dim), - transpose_batch_sequence=transpose_batch_sequence, + casted_kernel_rhs, + dimension_numbers=dims, ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 6cb0dd257c..01405ba87a 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -261,6 +261,8 @@ def get_mesh_axis_size(axis, mesh=None): if axis is None: return 1 + print(mesh) + assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}" return mesh.shape[axis] From 38defb8ec354055f0a14017d5a525e1cc911d57c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 18 Dec 2025 08:45:19 -0800 Subject: [PATCH 15/40] Test --- transformer_engine/jax/cpp_extensions/base.py | 2 +- transformer_engine/jax/cpp_extensions/quantization.py | 2 +- transformer_engine/jax/dense.py | 9 ++------- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 70734ad4c4..defdce7b68 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -212,7 +212,7 @@ def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): # "All batched arguments must have the same batch dimension. " # f"Got batch_dims={batch_dims}" # ) - # assert batch_dim is not None and batch_size is not None, "Invalid batching config!" + assert batch_dim is not None and batch_size is not None, "Invalid batching config!" # Loop over batch dimension and collect results all_results = [] diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 53c6937fb4..c5d76cf28c 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -362,7 +362,7 @@ def batcher( use_rht, ): """Batch rule for quantization primitive using general batcher.""" - check_valid_batch_dims(batch_dims) + # check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None return BaseDBiasQuantizePrimitive.batcher_impl( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index f941e598ae..62b0e054aa 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -69,7 +69,6 @@ def dense( output_axes: Tuple[str, ...] = None, collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, quantizer_set: QuantizerSet = noop_quantizer_set, - batch_dims : Tuple[Sequence[int], Sequence[int]] = ((), ()), ): """Perform dense layer transformation with optional quantization. @@ -110,12 +109,11 @@ def dense( output_axes, collective_op_set, quantizer_set, - batch_dims, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 10)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8)) def _dense( x, kernel, @@ -127,7 +125,6 @@ def _dense( output_axes, collective_op_set, quantizer_set, # need to be a diff_arg for DelayedScaling state management - batch_dims, ): """Internal implementation of dense layer transformation with custom VJP. @@ -160,7 +157,6 @@ def _dense( output_axes, collective_op_set, quantizer_set, - batch_dims, ) return output @@ -176,7 +172,6 @@ def _dense_fwd_rule( output_axes, collective_op_set, quantizer_set, - batch_dims, ): """Forward pass rule for dense layer transformation. @@ -291,7 +286,6 @@ def _dense_bwd_rule( kernel_axes, output_axes, collective_op_set, - batch_dims, ctx, grad, ): @@ -325,6 +319,7 @@ def _dense_bwd_rule( ) fwd_cdims = (fwd_x_contracting_dims, fwd_k_contracting_dims) + batch_dims = ((), ()) # vmap is done outside dense VJP if needed dims = (fwd_cdims, batch_dims) dgrad = dot_general_transpose_lhs( From e4a80a3522b8d1b29199d807a4770ebc815ca487 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 09:57:33 +0100 Subject: [PATCH 16/40] Refactor: move grouped GEMM to separate file and cleanup API Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 12 +- .../common/gemm/cublaslt_gemm.cu | 549 +--------------- .../common/gemm/cublaslt_grouped_gemm.cu | 599 ++++++++++++++++++ .../common/gemm/cublaslt_grouped_gemm.cuh | 18 + .../common/include/transformer_engine/gemm.h | 12 +- 5 files changed, 635 insertions(+), 555 deletions(-) create mode 100644 transformer_engine/common/gemm/cublaslt_grouped_gemm.cu create mode 100644 transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 82b5bd3803..0ea76946bc 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -314,9 +315,12 @@ std::vector> make_shapes(ShapeCase scase) { } void run_grouped_gemm_case(const TestParams& params) { - if (params.input_case != InputCase::kBF16 && - getDeviceComputeCapability() < hopperComputeCapability) { - GTEST_SKIP() << "FP8 grouped GEMM requires Hopper or newer."; +#if CUBLAS_VERSION < 130200 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " + << CUBLAS_VERSION << "."; +#else + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Hopper (SM90) or newer."; } const std::vector> shapes = make_shapes(params.shape_case); @@ -451,7 +455,6 @@ void run_grouped_gemm_case(const TestParams& params) { grouped_D.get_handle(), setup_ws.data(), cublas_ws.data(), - nullptr, 0, nullptr, nullptr, @@ -477,6 +480,7 @@ void run_grouped_gemm_case(const TestParams& params) { atol, rtol); } +#endif // CUBLAS_VERSION >= 130200 } class GroupedGemmTest : public ::testing::TestWithParam {}; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index b8aa2a8ba3..86f517af7d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -23,6 +23,7 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "./config.h" +#include "./cublaslt_grouped_gemm.cuh" #include "./cutlass_grouped_gemm.cuh" namespace { @@ -1104,551 +1105,3 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cublas_path(); } } - -// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) -struct TensorShapeInfo { - const int64_t *first_dims; // nullptr if uniform - const int64_t *last_dims; // nullptr if uniform - const int64_t *offsets; // nullptr if need to compute - int64_t uniform_first; // used if first_dims == nullptr - int64_t uniform_last; // used if last_dims == nullptr - - // Create from GroupedTensor - static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { - const bool has_first = t->first_dims.has_data(); - const bool has_last = t->last_dims.has_data(); - // When per-tensor dims are not provided, we must be in the uniform-shape case. - NVTE_CHECK(has_first || t->all_same_first_dim(), - "GroupedTensor is missing first_dims for varying shapes"); - NVTE_CHECK(has_last || t->all_same_last_dim(), - "GroupedTensor is missing last_dims for varying shapes"); - - const int64_t *first_ptr = - has_first ? static_cast(t->first_dims.dptr) : nullptr; - const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; - - const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); - const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); - - return {first_ptr, last_ptr, - t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) - : nullptr, - uniform_first, uniform_last}; - } - - // Create for C tensor (uses D's dimensions, only has offsets) - static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D) { - const bool has_first = D->first_dims.has_data(); - const bool has_last = D->last_dims.has_data(); - NVTE_CHECK(has_first || D->all_same_first_dim(), - "GroupedTensor D is missing first_dims for varying shapes"); - NVTE_CHECK(has_last || D->all_same_last_dim(), - "GroupedTensor D is missing last_dims for varying shapes"); - - const int64_t *first_ptr = - has_first ? static_cast(D->first_dims.dptr) : nullptr; - const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; - const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); - const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); - - return {first_ptr, last_ptr, - C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) - : nullptr, - uniform_first, uniform_last}; - } -}; - -// Helper functions to compute average dimensions from logical_shape for heuristics -// These are hints for cuBLASLt algorithm selection, don't need to be exact -inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { - // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) - // In both cases, dividing by num_tensors gives the average - return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); -} - -inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { - if (t->all_same_last_dim()) { - // logical_shape[1] is the common N - return static_cast(t->logical_shape.data[1]); - } - // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. - return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); -} - -// Workspace layout for grouped GEMM -struct GroupedGemmSetupWorkspace { - void **A_ptrs; - void **B_ptrs; - void **C_ptrs; - void **D_ptrs; - int *M; - int *N; - int *K; - float **alpha_ptrs; - float **beta_ptrs; - - // Initialize from workspace buffer - // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { - GroupedGemmSetupWorkspace ws; - size_t offset = 0; - const size_t ptr_size = num_tensors * sizeof(void *); - const size_t int_size = num_tensors * sizeof(int); - - // Pointer arrays first (all 8-byte aligned) - ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - - // Int arrays last (4-byte aligned, always satisfied after pointer arrays) - ws.M = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.N = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.K = reinterpret_cast(setup_ws_ptr + offset); - - return ws; - } - - // Calculate required size for setup workspace (pointer arrays + M/N/K) - static size_t required_setup_size(size_t num_tensors, size_t alignment) { - const size_t ptr_size = num_tensors * sizeof(void *); - const size_t int_size = num_tensors * sizeof(int); - // Layout: 6 ptr arrays, then 3 int arrays (no padding needed) - size_t size = 6 * ptr_size + 3 * int_size; - size = ((size + alignment - 1) / alignment) * alignment; - return size; - } -}; - -// ----------------------------------------------------------------------------- -// Helper routines to keep nvte_grouped_gemm readable -// ----------------------------------------------------------------------------- -inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, - const transformer_engine::GroupedTensor *inputB, - const transformer_engine::GroupedTensor *inputC, - const transformer_engine::GroupedTensor *outputD, - const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor) { - const size_t num_tensors = inputA->num_tensors; - NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); - NVTE_CHECK(inputB->num_tensors == num_tensors, - "Grouped GEMM: A and B must have the same num_tensors"); - // C can be NULL (will use D as C when beta=0) - if (inputC != nullptr) { - NVTE_CHECK(inputC->num_tensors == num_tensors, - "Grouped GEMM: A and C must have the same num_tensors"); - } - NVTE_CHECK(outputD->num_tensors == num_tensors, - "Grouped GEMM: A and D must have the same num_tensors"); - - // Validate alpha/beta have per-matrix values - const size_t alpha_numel = alpha_tensor->data.numel(); - const size_t beta_numel = beta_tensor->data.numel(); - NVTE_CHECK(alpha_numel == num_tensors, - "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", - alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, - "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", - beta_numel); - - auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { - return dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2 || - dtype == transformer_engine::DType::kBFloat16 || - dtype == transformer_engine::DType::kFloat16; - }; - auto is_output_dtype = [](transformer_engine::DType dtype) { - return dtype == transformer_engine::DType::kBFloat16 || - dtype == transformer_engine::DType::kFloat16 || - dtype == transformer_engine::DType::kFloat32; - }; - NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), - "Grouped GEMM inputs must be FP8, BF16, or FP16."); - // Only check C dtype if C is provided - if (inputC != nullptr) { - NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); - } - NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); - NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), - "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); - NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), - "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); -} - -// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. -// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and -// fallback to column-wise data when row-wise is absent. -struct GroupedOperandSelection { - const transformer_engine::GroupedTensor *tensor = nullptr; - const char *base = nullptr; - transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; - bool trans = false; - bool use_columnwise = false; -}; - -inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, - bool trans, bool is_A) { - using namespace transformer_engine; - const bool has_row = t->has_data(); - const bool has_col = t->has_columnwise_data(); - NVTE_CHECK(has_row || has_col, - "Grouped GEMM operand is missing both row-wise and column-wise data"); - - // Not yet supported in grouped GEMM: block scaling, MXFP8, NVFP4 specialized layouts. - const auto sm = t->scaling_mode; - NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && !is_mxfp_scaling(sm) && - !is_nvfp_scaling(sm), - "Grouped GEMM does not yet support NVFP4/MXFP8/block scaling operand selection"); - - const DType row_dtype = t->data.dtype; - const DType col_dtype = t->columnwise_data.dtype; - GroupedOperandSelection sel; - sel.tensor = t; - sel.trans = trans; - - const DType rep_dtype = has_row ? row_dtype : col_dtype; - const bool is_fp8 = is_fp8_dtype(rep_dtype); - const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); - - // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. - if (is_fp8 && !non_tn_fp8_ok) { - if (is_A) { - if (!sel.trans) { - NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; - sel.trans = true; // using pre-transposed storage - sel.use_columnwise = true; - return sel; - } - } else { // B - if (sel.trans) { - NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; - sel.trans = false; // using pre-transposed storage - sel.use_columnwise = true; - return sel; - } - } - } - - // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). - if (!has_row && has_col) { - // On Hopper FP8, this would break TN requirement - should have been handled above - NVTE_CHECK( - !is_fp8 || non_tn_fp8_ok, - "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); - sel.base = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; - sel.trans = !sel.trans; - sel.use_columnwise = true; - return sel; - } - - // Default: use row-wise data (column-wise case already handled above) - sel.base = static_cast(t->data.dptr); - sel.dtype = row_dtype; - sel.use_columnwise = false; - return sel; -} - -inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, - const char *workspace_name) { - NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); - const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); - NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, - ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); - return ws->data.dptr; -} - -inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, - cublasLtMatrixLayoutOpaque_t &descB, - cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, - const GroupedGemmSetupWorkspace &ws, - const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel, - const transformer_engine::GroupedTensor *D, size_t num_tensors) { - const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); - const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); - const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); - - // For column-major layout: leading dimension is the number of rows in storage. - // If columnwise data was chosen, storage is already transposed. - int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); - int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); - int *lda = rowa; - int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); - int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); - int *ldb = rowb; - - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); -} - -inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, - cublasOperation_t op_B) { - NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); - - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, - sizeof(op_A))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, - sizeof(op_B))); - - cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, sizeof(pointer_mode))); - - int64_t alphabeta_batch_stride = 1; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, - CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, - &alphabeta_batch_stride, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, - CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, - &alphabeta_batch_stride, sizeof(int64_t))); -} - -inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, - const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel) { - const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); - const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); - if (!is_fp8_a && !is_fp8_b) return; - - if (is_fp8_a) { - void *a_scale_inv = A_sel.use_columnwise ? A_sel.tensor->columnwise_scale_inv.dptr - : A_sel.tensor->scale_inv.dptr; - NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); - } - if (is_fp8_b) { - void *b_scale_inv = B_sel.use_columnwise ? B_sel.tensor->columnwise_scale_inv.dptr - : B_sel.tensor->scale_inv.dptr; - NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); - } -} - -// Constants for grouped GEMM workspace (declared early for use in heuristics) -static constexpr size_t kGroupedGemmAlignment = 256; -static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB - -inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, - cublasLtMatmulDescOpaque_t &matmulDesc, - cublasLtMatrixLayoutOpaque_t &descA, - cublasLtMatrixLayoutOpaque_t &descB, - cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, - int64_t avg_m, int64_t avg_n, int64_t avg_k) { - cublasLtMatmulPreferenceOpaque_t preference; - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); - NVTE_CHECK_CUBLAS( - cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); - - cublasLtMatmulHeuristicResult_t heuristicResult; - int returnedResults = 0; - auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, - &preference, 1, &heuristicResult, &returnedResults); - NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, - "Unable to find suitable cuBLAS grouped GEMM algorithm"); - NVTE_CHECK_CUBLAS(status); - NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); - return heuristicResult.algo; -} - -// Single kernel that sets up all GEMM parameters. -// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix M/N/K, -// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. -// We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. -__global__ void setup_grouped_gemm_kernel( - // Output arrays - void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, - float **alpha_ptrs, float **beta_ptrs, - // Base pointers - const char *a_base, const char *b_base, const char *c_base, char *d_base, - // Dimension info (per tensor) - TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, - // Element sizes - size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, - // Alpha/beta pointers (per-matrix arrays) - float *alpha_ptr, float *beta_ptr, - // Transpose flags - bool transa, bool transb, - // Number of tensors - size_t num_tensors) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num_tensors) return; - - // Get dimensions for this tensor (from array or uniform value) - int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; - int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; - int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; - int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; - - // Compute offsets (from array or compute from uniform dims) - int64_t a_offset = - A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); - int64_t b_offset = - B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); - int64_t c_offset = - C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); - int64_t d_offset = - D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); - - // Compute data pointers - A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; - B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; - C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; - D_ptrs[idx] = d_base + d_offset * d_elem_size; - - // Compute M, N, K dimensions - // Test stores A as {K,M} when !transa, {M,K} when transa - // Test stores B as {N,K} when !transb, {K,N} when transb - M[idx] = static_cast(transa ? a_first : a_last); - K[idx] = static_cast(transa ? a_last : a_first); - N[idx] = static_cast(transb ? b_last : b_first); - - // Fill alpha/beta pointers (per-matrix) - alpha_ptrs[idx] = alpha_ptr + idx; - beta_ptrs[idx] = beta_ptr + idx; -} - -// Launch the setup kernel to populate workspace arrays -inline void launch_grouped_gemm_setup( - const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { - TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); - TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); - TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); - TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); - - const char *c_base = static_cast(C->data.dptr); - char *d_base = static_cast(D->data.dptr); - - const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); - const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); - const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); - const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); - - const int threads_per_block = 256; - const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; - - setup_grouped_gemm_kernel<<>>( - ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, - A_sel.base, B_sel.base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, - b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), A_sel.trans, B_sel.trans, num_tensors); - - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { - return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); -} - -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, - const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, - NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, - const int64_t *avg_n, const int64_t *avg_k) { - NVTE_API_CALL(nvte_grouped_gemm); - using namespace transformer_engine; - - // Convert to internal types - const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); - const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); - const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL - GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); - const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); - const Tensor *beta_tensor = convertNVTETensorCheck(beta); - Tensor *wspace_setup = convertNVTETensor(workspace_setup); - Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); - - // Validate inputs and num_tensors - validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); - - // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) - const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; - const size_t num_tensors = inputA->num_tensors; - - // Select operand storage (row-wise vs column-wise) and adjust transpose flags to - // mirror the non-grouped GEMM logic for FP8 layout constraints. - const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); - const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); - - // Workspaces: setup (pointer arrays) and cuBLAS - const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); - const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; - - void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, - "Grouped GEMM setup workspace"); - void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, - "Grouped GEMM cuBLAS workspace"); - - auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( - static_cast(setup_workspace_ptr), num_tensors); - launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, - beta_tensor, num_tensors, stream); - - // Get cuBLAS handle - using cublasHandleManager = detail::HandleManager; - cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); - - // Setup cuBLAS operations - cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; - - // Create grouped matrix layouts - cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; - init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, - num_tensors); - - // Create matmul descriptor - cublasLtMatmulDescOpaque_t matmulDesc; - init_matmul_desc(matmulDesc, op_A, op_B); - set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); - - // Compute average dimensions for heuristics - // K dimension: if transa, K is A's first dim; if not, K is A's last dim - int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); - int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); - int64_t avg_k_val = avg_k ? *avg_k - : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) - : compute_avg_last_dim(A_sel.tensor)); - - // Heuristic selection - cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, - descD, avg_m_val, avg_n_val, avg_k_val); - - // Execute the grouped GEMM - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, - setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, - setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, - setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, - kGroupedGemmCublasWorkspaceSize, stream)); -} diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu new file mode 100644 index 0000000000..4125bd82bf --- /dev/null +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -0,0 +1,599 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/handle_manager.h" +#include "../util/logging.h" +#include "./cublaslt_grouped_gemm.cuh" + +namespace { + +inline void CreateCublasHandle(cublasLtHandle_t *handle) { + NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); +} + +} // namespace + +#if CUBLAS_VERSION >= 130100 + +namespace { + +// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) +struct TensorShapeInfo { + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr + + // Create from GroupedTensor + static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + // When per-tensor dims are not provided, we must be in the uniform-shape case. + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + return {first_ptr, last_ptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) + : nullptr, + uniform_first, uniform_last}; + } + + // Create for C tensor (uses D's dimensions, only has offsets) + static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D) { + const bool has_first = D->first_dims.has_data(); + const bool has_last = D->last_dims.has_data(); + NVTE_CHECK(has_first || D->all_same_first_dim(), + "GroupedTensor D is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || D->all_same_last_dim(), + "GroupedTensor D is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(D->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); + + return {first_ptr, last_ptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) + : nullptr, + uniform_first, uniform_last}; + } +}; + +// Helper functions to compute average dimensions from logical_shape for heuristics +// These are hints for cuBLASLt algorithm selection, don't need to be exact +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { + // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) + // In both cases, dividing by num_tensors gives the average + return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); +} + +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { + if (t->all_same_last_dim()) { + // logical_shape[1] is the common N + return static_cast(t->logical_shape.data[1]); + } + // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); +} + +// Workspace layout for grouped GEMM +struct GroupedGemmSetupWorkspace { + void **A_ptrs; + void **B_ptrs; + void **C_ptrs; + void **D_ptrs; + int *M; + int *N; + int *K; + float **alpha_ptrs; + float **beta_ptrs; + + // Initialize from workspace buffer + // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { + GroupedGemmSetupWorkspace ws; + size_t offset = 0; + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + + // Pointer arrays first (all 8-byte aligned) + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + + // Int arrays last (4-byte aligned, always satisfied after pointer arrays) + ws.M = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.N = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.K = reinterpret_cast(setup_ws_ptr + offset); + + return ws; + } + + // Calculate required size for setup workspace (pointer arrays + M/N/K) + static size_t required_setup_size(size_t num_tensors, size_t alignment) { + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + // Layout: 6 ptr arrays, then 3 int arrays (no padding needed) + size_t size = 6 * ptr_size + 3 * int_size; + size = ((size + alignment - 1) / alignment) * alignment; + return size; + } +}; + +// ----------------------------------------------------------------------------- +// Helper routines to keep nvte_grouped_gemm readable +// ----------------------------------------------------------------------------- +inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, + const transformer_engine::GroupedTensor *inputB, + const transformer_engine::GroupedTensor *inputC, + const transformer_engine::GroupedTensor *outputD, + const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor) { + const size_t num_tensors = inputA->num_tensors; + NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); + NVTE_CHECK(inputB->num_tensors == num_tensors, + "Grouped GEMM: A and B must have the same num_tensors"); + // C can be NULL (will use D as C when beta=0) + if (inputC != nullptr) { + NVTE_CHECK(inputC->num_tensors == num_tensors, + "Grouped GEMM: A and C must have the same num_tensors"); + } + NVTE_CHECK(outputD->num_tensors == num_tensors, + "Grouped GEMM: A and D must have the same num_tensors"); + + // Validate alpha/beta have per-matrix values + const size_t alpha_numel = alpha_tensor->data.numel(); + const size_t beta_numel = beta_tensor->data.numel(); + NVTE_CHECK(alpha_numel == num_tensors, + "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", + alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, + "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", + beta_numel); + + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2 || + dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16; + }; + auto is_output_dtype = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16 || + dtype == transformer_engine::DType::kFloat32; + }; + NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), + "Grouped GEMM inputs must be FP8, BF16, or FP16."); + // Only check C dtype if C is provided + if (inputC != nullptr) { + NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); + } + NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); + NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), + "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); + NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), + "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); +} + +// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. +// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and +// fallback to column-wise data when row-wise is absent. +struct GroupedOperandSelection { + const transformer_engine::GroupedTensor *tensor = nullptr; + const char *dptr = nullptr; + transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; + bool trans = false; + bool use_columnwise = false; +}; + +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, + bool trans, bool is_A) { + using namespace transformer_engine; + const bool has_row = t->has_data(); + const bool has_col = t->has_columnwise_data(); + NVTE_CHECK(has_row || has_col, + "Grouped GEMM operand is missing both row-wise and column-wise data"); + + // Currently only unquantized data and tensor-scaled FP8 are supported. + const auto sm = t->scaling_mode; + NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING, + "Grouped GEMM is only supported with unquantized data and tensor-scaled FP8 data"); + + const DType row_dtype = t->data.dtype; + const DType col_dtype = t->columnwise_data.dtype; + GroupedOperandSelection sel; + sel.tensor = t; + sel.trans = trans; + + const DType rep_dtype = has_row ? row_dtype : col_dtype; + const bool is_fp8 = is_fp8_dtype(rep_dtype); + const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + + // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. + if (is_fp8 && !non_tn_fp8_ok) { + if (is_A) { + if (!sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = true; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } else { // B + if (sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = false; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } + } + + // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). + if (!has_row && has_col) { + // On Hopper FP8, this would break TN requirement - should have been handled above + NVTE_CHECK( + !is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = !sel.trans; + sel.use_columnwise = true; + return sel; + } + + // Default: use row-wise data (column-wise case already handled above) + sel.dptr = static_cast(t->data.dptr); + sel.dtype = row_dtype; + sel.use_columnwise = false; + return sel; +} + +inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, + const char *workspace_name) { + NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); + const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); + NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, + ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); + return ws->data.dptr; +} + +inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + const GroupedGemmSetupWorkspace &ws, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, + const transformer_engine::GroupedTensor *D, size_t num_tensors) { + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); + + // For column-major layout: leading dimension is the number of rows in storage. + // If columnwise data was chosen, storage is already transposed. + int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); + int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); + int *lda = rowa; + int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); + int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); + int *ldb = rowb; + + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); +} + +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, + cublasOperation_t op_B) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, + sizeof(op_A))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, + sizeof(op_B))); + + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode))); + + int64_t alphabeta_batch_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); +} + +inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel) { + const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); + const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); + if (!is_fp8_a && !is_fp8_b) return; + + if (is_fp8_a) { + void *a_scale_inv = A_sel.use_columnwise ? A_sel.tensor->columnwise_scale_inv.dptr + : A_sel.tensor->scale_inv.dptr; + NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + } + if (is_fp8_b) { + void *b_scale_inv = B_sel.use_columnwise ? B_sel.tensor->columnwise_scale_inv.dptr + : B_sel.tensor->scale_inv.dptr; + NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + } +} + +// Constants for grouped GEMM workspace (declared early for use in heuristics) +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + +inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, + cublasLtMatmulDescOpaque_t &matmulDesc, + cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + int64_t avg_m, int64_t avg_n, int64_t avg_k) { + cublasLtMatmulPreferenceOpaque_t preference; + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); + NVTE_CHECK_CUBLAS( + cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); + + cublasLtMatmulHeuristicResult_t heuristicResult; + int returnedResults = 0; + auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, + &preference, 1, &heuristicResult, &returnedResults); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, + "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK_CUBLAS(status); + NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); + return heuristicResult.algo; +} + +// Single kernel that sets up all GEMM parameters. +// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix M/N/K, +// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. +// We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. +__global__ void setup_grouped_gemm_kernel( + // Output arrays + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, + float **alpha_ptrs, float **beta_ptrs, + // Base pointers + const char *a_base, const char *b_base, const char *c_base, char *d_base, + // Dimension info (per tensor) + TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, + // Element sizes + size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, + // Alpha/beta pointers (per-matrix arrays) + float *alpha_ptr, float *beta_ptr, + // Transpose flags + bool transa, bool transb, + // Number of tensors + size_t num_tensors) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_tensors) return; + + // Get dimensions for this tensor (from array or uniform value) + int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; + int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; + int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; + int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; + + // Compute offsets (from array or compute from uniform dims) + int64_t a_offset = + A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); + int64_t b_offset = + B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); + int64_t c_offset = + C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); + int64_t d_offset = + D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + + // Compute data pointers + A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; + B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; + C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; + D_ptrs[idx] = d_base + d_offset * d_elem_size; + + // Compute M, N, K dimensions + // Test stores A as {K,M} when !transa, {M,K} when transa + // Test stores B as {N,K} when !transb, {K,N} when transb + M[idx] = static_cast(transa ? a_first : a_last); + K[idx] = static_cast(transa ? a_last : a_first); + N[idx] = static_cast(transb ? b_last : b_first); + + // Fill alpha/beta pointers (per-matrix) + alpha_ptrs[idx] = alpha_ptr + idx; + beta_ptrs[idx] = beta_ptr + idx; +} + +// Launch the setup kernel to populate workspace arrays +inline void launch_grouped_gemm_setup( + const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { + TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); + TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); + TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); + TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); + + const char *c_base = static_cast(C->data.dptr); + char *d_base = static_cast(D->data.dptr); + + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); + const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); + + const int threads_per_block = 256; + const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; + + setup_grouped_gemm_kernel<<>>( + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, + A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, + b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), A_sel.trans, B_sel.trans, num_tensors); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { + return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); +} + +} // namespace + +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, + const int64_t *avg_k) { + NVTE_API_CALL(nvte_grouped_gemm); + using namespace transformer_engine; + + // Grouped GEMM requires Hopper (SM90) or newer + const int current_device = cuda::current_device(); + NVTE_CHECK(cuda::sm_arch(current_device) >= 90, + "nvte_grouped_gemm requires Hopper (SM90) or newer architecture."); + + // Convert to internal types + const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); + const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); + const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL + GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); + const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); + const Tensor *beta_tensor = convertNVTETensorCheck(beta); + Tensor *wspace_setup = convertNVTETensor(workspace_setup); + Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + + // Validate inputs and num_tensors + validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); + + // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) + const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; + const size_t num_tensors = inputA->num_tensors; + + // Select operand storage (row-wise vs column-wise) and adjust transpose flags to + // mirror the non-grouped GEMM logic for FP8 layout constraints. + const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); + const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); + + // Workspaces: setup (pointer arrays) and cuBLAS + const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); + const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; + + void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, + "Grouped GEMM setup workspace"); + void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, + "Grouped GEMM cuBLAS workspace"); + + auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( + static_cast(setup_workspace_ptr), num_tensors); + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, + beta_tensor, num_tensors, stream); + + // Get cuBLAS handle + using cublasHandleManager = detail::HandleManager; + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); + + // Setup cuBLAS operations + cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + + // Create grouped matrix layouts + cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, + num_tensors); + + // Create matmul descriptor + cublasLtMatmulDescOpaque_t matmulDesc; + init_matmul_desc(matmulDesc, op_A, op_B); + set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); + + // Compute average dimensions for heuristics + // K dimension: if transa, K is A's first dim; if not, K is A's last dim + int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); + int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); + int64_t avg_k_val = avg_k ? *avg_k + : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) + : compute_avg_last_dim(A_sel.tensor)); + + // Heuristic selection + cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, + descD, avg_m_val, avg_n_val, avg_k_val); + + // Execute the grouped GEMM + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, + setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, + kGroupedGemmCublasWorkspaceSize, stream)); +} + +#else // CUBLAS_VERSION < 130100 + +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, + const int64_t *avg_k) { + NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.2 or newer."); +} + +#endif // CUBLAS_VERSION >= 130100 + diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh new file mode 100644 index 0000000000..6514ba2f97 --- /dev/null +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh @@ -0,0 +1,18 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ +#define TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ + +#include +#include +#include + +// nvte_grouped_gemm is declared in transformer_engine/gemm.h +// This header is for internal use only. + +#endif // TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ + diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 9dfa009115..b2e42bd66f 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,7 +11,7 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ -#include +#include #include "transformer_engine.h" @@ -233,6 +233,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C + * + * \note Requires cuBLAS 13.2+ (CUDA 13.2+) and Hopper (SM90) or newer GPU architecture. + * Will error at runtime if compiled with an older cuBLAS version or run on + * a pre-Hopper GPU. * * Performs batched GEMM on a collection of matrices with potentially different shapes. * All tensors in the group must have compatible dimensions for matrix multiplication. @@ -262,6 +266,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * heuristics. If NULL, computed automatically from A's logical shape. * * Requirements: + * - cuBLAS 13.2+ (CUDA 13.2+) + * - Hopper (SM90) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] * - Shape compatibility: if transa=false, transb=false: @@ -270,8 +276,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, - const int64_t *avg_n, const int64_t *avg_k); + cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, + const int64_t *avg_k); #ifdef __cplusplus } // extern "C" From 047a9f93bd5252241883077e0a904b2c7f1c6e57 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 12:29:12 +0100 Subject: [PATCH 17/40] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 5 +++-- transformer_engine/common/CMakeLists.txt | 1 + transformer_engine/common/include/transformer_engine/gemm.h | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 0ea76946bc..3336dbc6d5 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -137,8 +137,9 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, // cuBLAS requires aligned pointers for vectorized loads static std::mt19937 gen(12345); std::uniform_int_distribution dist(0, 3); - // Calculate elements needed for 16-byte alignment - const size_t align_elements = (16 * 8) / typeToNumBits(dtype); // 16 bytes / element_size + // Calculate elements needed for 16-byte alignment in bytes, rounded up + const size_t align_elements = + std::max(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size return dist(gen) * static_cast(align_elements); }; diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 264f7f9a78..e25bf02439 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -144,6 +144,7 @@ list(APPEND transformer_engine_cuda_sources fused_attn/fused_attn_fp8.cu fused_attn/utils.cu gemm/cublaslt_gemm.cu + gemm/cublaslt_grouped_gemm.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index b2e42bd66f..f1e2776158 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -234,7 +234,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C * - * \note Requires cuBLAS 13.2+ (CUDA 13.2+) and Hopper (SM90) or newer GPU architecture. + * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Hopper (SM90) or newer GPU architecture. * Will error at runtime if compiled with an older cuBLAS version or run on * a pre-Hopper GPU. * @@ -253,7 +253,6 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * \param[out] D Output grouped tensor D. * \param[in] workspace_setup Workspace tensor for pointer array setup. * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. - * \param[in] config Matrix multiplication configuration. * \param[in] stream CUDA stream for the operation. * \param[in] avg_m Optional hint for average M dimension across all matrices in the * group. Used by cuBLASLt for algorithm selection heuristics. From c490e06ab71f9919d69bfc2c67eb6b7cf6bc20ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 11:32:34 +0000 Subject: [PATCH 18/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_grouped_gemm.cu | 11 ++++------- .../common/gemm/cublaslt_grouped_gemm.cuh | 1 - .../common/include/transformer_engine/gemm.h | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 4125bd82bf..3647a4c39e 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -180,12 +180,10 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Validate alpha/beta have per-matrix values const size_t alpha_numel = alpha_tensor->data.numel(); const size_t beta_numel = beta_tensor->data.numel(); - NVTE_CHECK(alpha_numel == num_tensors, - "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", - alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, - "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", - beta_numel); + NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, + ") elements, got ", alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, + ") elements, got ", beta_numel); auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || @@ -596,4 +594,3 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT } #endif // CUBLAS_VERSION >= 130100 - diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh index 6514ba2f97..a032e594d5 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh @@ -15,4 +15,3 @@ // This header is for internal use only. #endif // TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ - diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index f1e2776158..0c8d601d50 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,7 +11,7 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ -#include +#include #include "transformer_engine.h" From e39784572a83cb560fca20f2e7f77f7f7795a834 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 19 Dec 2025 08:35:50 -0800 Subject: [PATCH 19/40] batching working correctly for quant and gemm but slow Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/base.py | 30 ++++-- transformer_engine/jax/cpp_extensions/gemm.py | 94 ++++++++++++++----- .../jax/cpp_extensions/quantization.py | 10 +- transformer_engine/jax/sharding.py | 2 - 4 files changed, 102 insertions(+), 34 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index defdce7b68..335af2eb47 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -175,6 +175,7 @@ def batcher_impl( batched_args: Sequence[Any], batch_dims: Sequence[Union[int, None]], static_kwargs: dict, + output_bdims: Union[Sequence[Union[int, None]], None] = None, ) -> Tuple[Tuple[Any, ...], Tuple[Union[int, None], ...]]: """Batcher implementation for JAX primitives. @@ -207,13 +208,21 @@ def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): if batch_dim is None: batch_dim = bdim batch_size = arg.shape[bdim] - # elif bdim != batch_dim: - # raise ValueError( - # "All batched arguments must have the same batch dimension. " - # f"Got batch_dims={batch_dims}" - # ) + elif output_bdims is None and bdim != batch_dim: + raise ValueError( + "All batched arguments must have the same batch dimension. " + f"Got batch_dims={batch_dims}" + ) + elif arg.shape[bdim] != batch_size: + raise ValueError( + "All batched arguments must have the same batch size. " + f"Got sizes {[arg.shape[bdim] for arg, bdim in zip(batched_args, batch_dims) if bdim is not None]}. " + f"Got batched_args={[arg.shape for arg, bdim in zip(batched_args, batch_dims) if bdim is not None]}." + ) assert batch_dim is not None and batch_size is not None, "Invalid batching config!" + print(f"[{cls.__name__}] Batching with size {batch_size}") + # Loop over batch dimension and collect results all_results = [] @@ -244,9 +253,14 @@ def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): transposed = tuple(zip(*all_results)) # Stack each output along the batch dimension - stacked_results = tuple( - jnp.stack(list(out_list), axis=batch_dim) for out_list in transposed - ) + if output_bdims is not None: + stacked_results = tuple( + jnp.stack(list(out_list), axis=out_bdim) for out_list, out_bdim in zip(transposed, output_bdims) + ) + else: + stacked_results = tuple( + jnp.stack(list(out_list), axis=batch_dim) for out_list in transposed + ) # Single output: return unwrapped result if len(stacked_results) == 1: diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 7d44643046..28100c9715 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -583,27 +583,27 @@ def lowering( ) lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) - lhs_contracting_size = ( - reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) - if lhs_transposed - else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) - ) - assert_cublas_requirements( - scaling_mode, - lhs_contracting_size, - "LHS", - ) - rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) - rhs_contracting_size = ( - reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) - if rhs_transposed - else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) - ) - assert_cublas_requirements( - scaling_mode, - rhs_contracting_size, - "RHS", - ) + # lhs_contracting_size = ( + # reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) + # if lhs_transposed + # else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) + # ) + # assert_cublas_requirements( + # scaling_mode, + # lhs_contracting_size, + # f"LHS {lhs_aval.shape} with contracting dims {lhs_cdims}", + # ) + # rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) + # rhs_contracting_size = ( + # reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) + # if rhs_transposed + # else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) + # ) + # assert_cublas_requirements( + # scaling_mode, + # rhs_contracting_size, + # f"RHS {rhs_aval.shape} with contracting dims {rhs_cdims}", + # ) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) kwargs = { @@ -818,10 +818,60 @@ def batcher( # f"got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" # ) + f = partial(GemmPrimitive.outer_impl, + **{ + "out_dtype": out_dtype, + "contracting_dims": contracting_dims, + "scaling_mode": scaling_mode, + "fuse_bias": fuse_bias, + "fuse_gelu": fuse_gelu, + "grad": grad, + "use_split_accumulator": use_split_accumulator, + "collective_op": collective_op, + "transpose_batch_sequence": transpose_batch_sequence, + "sequence_dim": sequence_dim, + "is_outer": is_outer, + }) + + lhs_cdims, rhs_cdims = contracting_dims + # Calculate output batch dimension based on input batch dims and contracting dims + # Both lhs and rhs have batch dimensions that may be at different indices + if lhs_bdims is not None and rhs_bdims is not None: + # Count non-contracting dimensions in LHS before the batch dimension + lhs_non_contracting_before_batch = sum( + 1 for i in range(lhs_bdims) + if i not in lhs_cdims + ) + # The output batch dimension will be at the position corresponding to + # the LHS batch dimension's position among non-contracting dimensions + output_bdim = lhs_non_contracting_before_batch + elif lhs_bdims is not None: + # LHS has a batch dimension - this will be the output batch dimension + output_bdim = 0 + elif rhs_bdims is not None: + # RHS has a batch dimension - need to account for LHS non-contracting dims + lhs_non_contracting = len([i for i in range(len(batched_args[0].shape)) + if i not in lhs_cdims and i != lhs_bdims]) + output_bdim = lhs_non_contracting + else: + # No batch dimensions in either operand + output_bdim = None + # Use general batcher from BasePrimitive return GemmPrimitive.batcher_impl( batched_args, - batch_dims, + batch_dims=( + lhs_bdims, # lhs + 0, # lhs_scale_inv + rhs_bdims, # rhs + 0, # rhs_scale_inv + *(None for _ in batched_args[4:]), # bias, gelu_input, alpha, beta + ), + output_bdims=( + output_bdim, # output + 0, # bias_grad + 0, # pre_gelu_out + ), static_kwargs={ "out_dtype": out_dtype, "contracting_dims": contracting_dims, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index c5d76cf28c..a95afe8b8e 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -20,7 +20,6 @@ from .base import BasePrimitive, register_primitive from .misc import ( get_padded_spec, - check_valid_batch_dims, te_dtype_to_jax_dtype, jax_dtype_to_te_dtype, multidim_transpose, @@ -362,12 +361,19 @@ def batcher( use_rht, ): """Batch rule for quantization primitive using general batcher.""" - # check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None return BaseDBiasQuantizePrimitive.batcher_impl( batched_args, batch_dims, + output_bdims=( + batch_dims[0], # out + batch_dims[0], # colwise_out (probably need to transpose according if scaling mode does it) + 0, # scale_inv + 0, # colwise_scale_inv + 0, # updated_amax + 0, # dbias + ), static_kwargs={ "out_dtype": out_dtype, "scaling_mode": scaling_mode, diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 01405ba87a..6cb0dd257c 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -261,8 +261,6 @@ def get_mesh_axis_size(axis, mesh=None): if axis is None: return 1 - print(mesh) - assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}" return mesh.shape[axis] From 59145cc2a7d4e4cb92addbd39c374541cbed5eb9 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 22 Dec 2025 10:21:19 +0100 Subject: [PATCH 20/40] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 7 ++++--- .../common/gemm/cublaslt_grouped_gemm.cu | 10 +++++----- .../common/include/transformer_engine/gemm.h | 6 +++--- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 3336dbc6d5..bdcfa68a4f 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -95,7 +95,8 @@ struct GroupedBuffers { size_t grouped_setup_workspace_size(const size_t num_tensors) { const size_t ptr_bytes = num_tensors * sizeof(void*); const size_t int_bytes = num_tensors * sizeof(int); - size_t size = 4 * ptr_bytes + 3 * int_bytes + 2 * ptr_bytes; + // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 3 int arrays (M, N, K) + size_t size = 6 * ptr_bytes + 3 * int_bytes; const size_t alignment = 256; size = ((size + alignment - 1) / alignment) * alignment; return size; @@ -320,8 +321,8 @@ void run_grouped_gemm_case(const TestParams& params) { GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " << CUBLAS_VERSION << "."; #else - if (getDeviceComputeCapability() < hopperComputeCapability) { - GTEST_SKIP() << "Grouped GEMM requires Hopper (SM90) or newer."; + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; } const std::vector> shapes = make_shapes(params.shape_case); diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 3647a4c39e..40180fe760 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -503,10 +503,10 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; - // Grouped GEMM requires Hopper (SM90) or newer + // Grouped GEMM requires Blackwell (SM100) or newer const int current_device = cuda::current_device(); - NVTE_CHECK(cuda::sm_arch(current_device) >= 90, - "nvte_grouped_gemm requires Hopper (SM90) or newer architecture."); + NVTE_CHECK(cuda::sm_arch(current_device) >= 100, + "nvte_grouped_gemm requires Blackwell (SM100) or newer architecture."); // Convert to internal types const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); @@ -589,8 +589,8 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, const int64_t *avg_k) { - NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", - CUBLAS_VERSION, ". Please upgrade to CUDA 13.2 or newer."); + NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.1+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); } #endif // CUBLAS_VERSION >= 130100 diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 0c8d601d50..168141224c 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -234,9 +234,9 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C * - * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Hopper (SM90) or newer GPU architecture. + * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. * Will error at runtime if compiled with an older cuBLAS version or run on - * a pre-Hopper GPU. + * a pre-Blackwell GPU. * * Performs batched GEMM on a collection of matrices with potentially different shapes. * All tensors in the group must have compatible dimensions for matrix multiplication. @@ -266,7 +266,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * * Requirements: * - cuBLAS 13.2+ (CUDA 13.2+) - * - Hopper (SM90) or newer GPU architecture + * - Blackwell (SM100) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] * - Shape compatibility: if transa=false, transb=false: From 77b422ac8d6e33bb5d56651a2e956629c17a5db8 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 22 Dec 2025 10:47:19 +0100 Subject: [PATCH 21/40] Require Blackwell (SM100) and cuBLAS 13.1+ for grouped GEMM Signed-off-by: Pawel Gadzinski --- 3rdparty/cudnn-frontend | 2 +- tests/cpp/operator/test_grouped_gemm.cu | 4 ++-- transformer_engine/common/include/transformer_engine/gemm.h | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4d..be6c079be8 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index bdcfa68a4f..2514f11ab3 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -317,8 +317,8 @@ std::vector> make_shapes(ShapeCase scase) { } void run_grouped_gemm_case(const TestParams& params) { -#if CUBLAS_VERSION < 130200 - GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " +#if CUBLAS_VERSION < 130100 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.1+, but compile-time cuBLAS version is " << CUBLAS_VERSION << "."; #else if (getDeviceComputeCapability() < blackwellComputeCapability) { diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 168141224c..f4c60ca3fe 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -265,7 +265,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * heuristics. If NULL, computed automatically from A's logical shape. * * Requirements: - * - cuBLAS 13.2+ (CUDA 13.2+) + * - cuBLAS 13.1+ (CUDA 13.1+) * - Blackwell (SM100) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] From 9c8158ee86a30699710c0dc1cb17c5d9b9aa4ced Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 22 Dec 2025 11:28:47 +0100 Subject: [PATCH 22/40] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 2514f11ab3..ada6980858 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -482,7 +482,7 @@ void run_grouped_gemm_case(const TestParams& params) { atol, rtol); } -#endif // CUBLAS_VERSION >= 130200 +#endif // CUBLAS_VERSION >= 130100 } class GroupedGemmTest : public ::testing::TestWithParam {}; From b1e0893be9eb00495765f65c636b23eae698afc1 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 22 Dec 2025 11:22:11 -0800 Subject: [PATCH 23/40] fix --- transformer_engine/common/gemm/cublaslt_gemm.cu | 8 ++++---- transformer_engine/jax/dense.py | 13 ++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 118bf19335..92d89b425f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -154,8 +154,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.lda % 16 == 0, - "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + // NVTE_CHECK(ret.lda % 16 == 0, + // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -245,8 +245,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.ldb % 16 == 0, - "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + // NVTE_CHECK(ret.ldb % 16 == 0, + // "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { if (is_B_transposed) { diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 62b0e054aa..9db60d3bd8 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -244,28 +244,27 @@ def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, import numpy as np def _remaining(original, *removed_lists): removed = set(itertools.chain(*removed_lists)) - return [i for i in original if i not in removed] + return tuple(i for i in original if i not in removed) def _ranges_like(*xs): start = 0 for x in xs: x_len = len(x) - yield range(start, start + x_len) + yield tuple(range(start, start + x_len)) start += x_len (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.ndim - x_kept = _remaining(range(x_ndim), x_contract, x_batch) - y_kept = _remaining(range(y.ndim), y_contract, y_batch) + x_kept = _remaining(tuple(range(x_ndim)), x_contract, x_batch) + y_kept = _remaining(tuple(range(y.ndim)), y_contract, y_batch) if swap_ans: ans_batch, ans_y, _ = _ranges_like(x_batch, y_kept, x_kept) else: ans_batch, _, ans_y = _ranges_like(x_batch, x_kept, y_kept) dims = ((ans_y, y_kept), (ans_batch, y_batch)) - x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) - out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) + x_contract_sorted_by_y = tuple(np.take(x_contract, np.argsort(y_contract))) + out_axes = np.argsort(tuple(x_batch) + x_kept + x_contract_sorted_by_y) x_bar = jax.lax.transpose( - # TODO(jberchtold): I'm ignoring the batch_dims here, do I need to explicitly use vmap or something? tex.gemm(g, y, contracting_dims=dims[0]), tuple(out_axes) ) From fb2067bacb9c21b71ff6cd329cae542415400887 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 23 Dec 2025 10:03:29 -0800 Subject: [PATCH 24/40] move einsum logic into TE --- transformer_engine/jax/flax/__init__.py | 3 +- transformer_engine/jax/flax/module.py | 62 +++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index d1a9cb47f8..59a0958b7b 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -4,7 +4,7 @@ """Transformer Engine bindings for JAX""" from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP -from .module import wrap_function_in_te_state_module, make_dot_general_cls +from .module import wrap_function_in_te_state_module, make_dot_general_cls, make_einsum_cls from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -16,6 +16,7 @@ "LayerNormMLP", "wrap_function_in_te_state_module", "make_dot_general_cls", + "make_einsum_cls", "extend_logical_axis_rules", "DotProductAttention", "MultiHeadAttention", diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index dcfb812896..ca84d46d6b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1438,3 +1438,65 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): ) return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") + +def make_einsum_cls(quantization_recipe): + import functools + import jax + def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs): + quantizer_set = generate_quantizer_set() + def dot_general(x, kernel, dims, *args, **kwargs): + # print(f"TE dot_general called with dims: {dims}, args: {args}, kwargs: {kwargs}") + contracting_dims, batch_dims = dims + ((x_bdim,), (k_bdim,)) = batch_dims + batch_dims = (x_bdim, k_bdim) + + if x_bdim != 0 or k_bdim != 0: + print(f"{x_bdim=}, {k_bdim=}") + return jax.lax.dot_general(x, kernel, dims, *args, **kwargs) + + if x.dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]: + # HACK: because x input is bool for dispatch mask + x = x.astype(kernel.dtype) + + # Adjust for unbatched + contracting_dims = tuple( + tuple(dim - (1 if dim > bdim else 0) for dim in cdims) + for bdim, cdims in zip(batch_dims, contracting_dims)) + + f = functools.partial( + dense, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set) + return jax.vmap(f, in_axes=(x_bdim, k_bdim))( + x, + kernel, + ) + + group_sizes = None + + # assuming x batch dim is axis 0, squash dims so we have (B*M, K) + # import math + # num_groups = x.shape[0] + # group_size = math.prod(x.shape[1:-1]) + # x_orig_ndim = x.ndim + # # FIXME: breaks partitioning + # x = x.reshape(x.shape[0] * group_size, x.shape[-1]) + # contracting_dims = ( + # tuple([c - (x_orig_ndim - x.ndim) for c in contracting_dims[0]]), + # *contracting_dims[1:], + # ) + + # group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) + + # print(f'{group_sizes=}, {contracting_dims=}, {x.shape=}, {kernel.shape=}, {contracting_dims=}') + + # return transformer_engine.jax.dense.grouped_dense( + # x, + # kernel, + # group_sizes=group_sizes, + # contracting_dims=contracting_dims, + # # quantizer_set=quantizer_set + # ) + return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) + + return wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")() From 30716a622c2d1f381de0e09800ef9936b030c420 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 23 Dec 2025 10:42:36 -0800 Subject: [PATCH 25/40] einsum unit tests --- tests/jax/test_custom_call_compute.py | 41 +++++++++++++++++++++++++++ transformer_engine/jax/flax/module.py | 7 ++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 897d9f683e..7a81683bc7 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1974,3 +1974,44 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) + +class TestEinsum: + + def _te_einsum(self, eqn, a, b, quantization_recipe): + from transformer_engine.jax.flax import make_einsum_cls + + te_einsum = make_einsum_cls(quantization_recipe=quantization_recipe) + var_collect = te_einsum.init(jax.random.PRNGKey(0), eqn, a, b) + return te_einsum.apply(var_collect, eqn, a, b) + + def _ref_einsum(self, eqn, a, b): + return jnp.einsum(eqn, a, b) + + @pytest_parametrize_wrapper('eqn,a_shape,b_shape', [ + # ('ij,jk->ik', (64, 32), (32, 128)), + # ('bij,bjk->bik', (8, 64, 32), (8, 32, 128)), + # ('abc,cde->abde', (4, 8, 16), (16, 32, 64)), + ('BSM,BSEC->EBCM', (2, 4096, 4096), (2, 4096, 8, 1024)), + ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)) , + ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)), + ('EBCH,EHM->EBCM', (8, 2, 1024, 14336), (8, 14336, 4096)), + ('EBCM,BSEC->BSM', (8, 2, 1024, 4096), (2, 4096, 8, 1024)), + ]) + @pytest_parametrize_wrapper('dtype', [jnp.bfloat16]) + @pytest_parametrize_wrapper('quantization_recipe', supported_recipes) + def test_einsum(self, eqn, a_shape, b_shape, dtype, quantization_recipe): + from transformer_engine.common.recipe import Float8CurrentScaling + import functools + + if not isinstance(quantization_recipe, Float8CurrentScaling): + pytest.skip("Einsum currently only supports Float8CurrentScaling recipe.") + return + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 2) + a = jax.random.uniform(subkeys[0], a_shape, dtype=dtype) + b = jax.random.uniform(subkeys[1], b_shape, dtype=dtype) + + te_out = jax.jit(functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe))(a, b) + ref_out = jax.jit(functools.partial(self._ref_einsum, eqn))(a, b) + + assert_allclose(te_out, ref_out, dtype=dtype) \ No newline at end of file diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index ca84d46d6b..0399ccfabf 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1443,7 +1443,8 @@ def make_einsum_cls(quantization_recipe): import functools import jax def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs): - quantizer_set = generate_quantizer_set() + # with open("/tmp/te_einsum_log.txt", "a") as f: + # f.write(f"{(s, x.shape, kernel.shape)}\n") def dot_general(x, kernel, dims, *args, **kwargs): # print(f"TE dot_general called with dims: {dims}, args: {args}, kwargs: {kwargs}") contracting_dims, batch_dims = dims @@ -1453,6 +1454,10 @@ def dot_general(x, kernel, dims, *args, **kwargs): if x_bdim != 0 or k_bdim != 0: print(f"{x_bdim=}, {k_bdim=}") return jax.lax.dot_general(x, kernel, dims, *args, **kwargs) + + quantizer_set = generate_quantizer_set() + print(f'{quantizer_set=}') + # import pdb; pdb.set_trace() if x.dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]: # HACK: because x input is bool for dispatch mask From 349c3155fdd34b1fc1ca009252ac64105fc6c24e Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 23 Dec 2025 10:47:19 -0800 Subject: [PATCH 26/40] fwd bwd einsum test --- tests/jax/test_custom_call_compute.py | 56 ++++++++++++++++++++------- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 7a81683bc7..082a99cd8b 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1975,6 +1975,18 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) +@pytest_parametrize_wrapper('eqn,a_shape,b_shape', [ + # ('ij,jk->ik', (64, 32), (32, 128)), + # ('bij,bjk->bik', (8, 64, 32), (8, 32, 128)), + # ('abc,cde->abde', (4, 8, 16), (16, 32, 64)), + ('BSM,BSEC->EBCM', (2, 4096, 4096), (2, 4096, 8, 1024)), + ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)) , + ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)), + ('EBCH,EHM->EBCM', (8, 2, 1024, 14336), (8, 14336, 4096)), + ('EBCM,BSEC->BSM', (8, 2, 1024, 4096), (2, 4096, 8, 1024)), +]) +@pytest_parametrize_wrapper('dtype', [jnp.bfloat16]) +@pytest_parametrize_wrapper('quantization_recipe', supported_recipes) class TestEinsum: def _te_einsum(self, eqn, a, b, quantization_recipe): @@ -1987,19 +1999,7 @@ def _te_einsum(self, eqn, a, b, quantization_recipe): def _ref_einsum(self, eqn, a, b): return jnp.einsum(eqn, a, b) - @pytest_parametrize_wrapper('eqn,a_shape,b_shape', [ - # ('ij,jk->ik', (64, 32), (32, 128)), - # ('bij,bjk->bik', (8, 64, 32), (8, 32, 128)), - # ('abc,cde->abde', (4, 8, 16), (16, 32, 64)), - ('BSM,BSEC->EBCM', (2, 4096, 4096), (2, 4096, 8, 1024)), - ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)) , - ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)), - ('EBCH,EHM->EBCM', (8, 2, 1024, 14336), (8, 14336, 4096)), - ('EBCM,BSEC->BSM', (8, 2, 1024, 4096), (2, 4096, 8, 1024)), - ]) - @pytest_parametrize_wrapper('dtype', [jnp.bfloat16]) - @pytest_parametrize_wrapper('quantization_recipe', supported_recipes) - def test_einsum(self, eqn, a_shape, b_shape, dtype, quantization_recipe): + def test_einsum_fwd(self, eqn, a_shape, b_shape, dtype, quantization_recipe): from transformer_engine.common.recipe import Float8CurrentScaling import functools @@ -2014,4 +2014,32 @@ def test_einsum(self, eqn, a_shape, b_shape, dtype, quantization_recipe): te_out = jax.jit(functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe))(a, b) ref_out = jax.jit(functools.partial(self._ref_einsum, eqn))(a, b) - assert_allclose(te_out, ref_out, dtype=dtype) \ No newline at end of file + assert_allclose(te_out, ref_out, dtype=dtype) + + def test_einsum_fwd_and_bwd(self, eqn, a_shape, b_shape, dtype, quantization_recipe): + from transformer_engine.common.recipe import Float8CurrentScaling + import functools + + if not isinstance(quantization_recipe, Float8CurrentScaling): + pytest.skip("Einsum currently only supports Float8CurrentScaling recipe.") + return + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 2) + a = jax.random.uniform(subkeys[0], a_shape, dtype=dtype) + b = jax.random.uniform(subkeys[1], b_shape, dtype=dtype) + + def wrap_in_mean(f): + @functools.wraps(f) + def wrapped(*args): + return jnp.mean(f(*args)) + return wrapped + + te_fwd, te_grads = jax.jit(jax.value_and_grad(wrap_in_mean(functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe))))(a, b) + ref_fwd, ref_grads = jax.jit(jax.value_and_grad(wrap_in_mean(functools.partial(self._ref_einsum, eqn))))(a, b) + + assert_allclose(te_fwd, ref_fwd, dtype=dtype) + + assert len(te_grads) == len(ref_grads), f"Number of gradients differ: {len(te_grads)=} vs {len(ref_grads)=}" + + for te_grad, ref_grad in zip(te_grads, ref_grads): + assert_allclose(te_grad, ref_grad, dtype=dtype) \ No newline at end of file From 57ab3b09c9baf1587aaca4ecb5632b91021e1c14 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 23 Dec 2025 11:12:59 -0800 Subject: [PATCH 27/40] unit tests passed with grouped gemm in bf16 --- transformer_engine/jax/flax/module.py | 78 +++++++++++++++------------ 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 0399ccfabf..733eaf513b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -17,7 +17,7 @@ from jax.ad_checkpoint import checkpoint_name -from ..dense import dense +from ..dense import dense, grouped_dense from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm @@ -1455,9 +1455,9 @@ def dot_general(x, kernel, dims, *args, **kwargs): print(f"{x_bdim=}, {k_bdim=}") return jax.lax.dot_general(x, kernel, dims, *args, **kwargs) + target_out_shape = jax.lax.dot_general(x, kernel, dims).shape + # TODO: add num groups to make grouped quantizer set quantizer_set = generate_quantizer_set() - print(f'{quantizer_set=}') - # import pdb; pdb.set_trace() if x.dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]: # HACK: because x input is bool for dispatch mask @@ -1468,40 +1468,50 @@ def dot_general(x, kernel, dims, *args, **kwargs): tuple(dim - (1 if dim > bdim else 0) for dim in cdims) for bdim, cdims in zip(batch_dims, contracting_dims)) - f = functools.partial( - dense, - contracting_dims=contracting_dims, - quantizer_set=quantizer_set) - return jax.vmap(f, in_axes=(x_bdim, k_bdim))( + group_sizes = None + print(f'{x.shape=}, {kernel.shape=}, {dims=}') + + def reorder_lhs_for_grouped_gemm(tensor, cdims): + # (B*M, K) + assert len(cdims) == 1, f"Only support single contracting dim for now, got {cdims}" + cdim = cdims[0] + 1 # account for batch dim at front + out = jnp.transpose(tensor, tuple(range(cdim)) + tuple(range(cdim + 1, tensor.ndim)) + (cdim,)) + return out.reshape((-1, out.shape[-1])) + + + def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): + # (B, K, N) + assert len(bdims) == 1 and len(cdims) == 1, f"Only support single batch and contracting dim for now, got {bdims}, {cdims}" + bdim = bdims[0] + assert bdim == 0, f"Only support batch dim 0 for now, got {bdim}" + cdim = cdims[0] + 1 # account for batch dim at front + out = jnp.transpose(tensor, (bdim, cdim) + tuple(i for i in range(tensor.ndim) if i != bdim and i != cdim)) + return out.reshape((*out.shape[:2], -1)) + + x = reorder_lhs_for_grouped_gemm(x, contracting_dims[0]) + kernel = reorder_rhs_for_grouped_gemm(kernel, (batch_dims[1],), contracting_dims[1]) + + num_groups = kernel.shape[0] + group_size = x.shape[0] // num_groups + + group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) + + print(f'{group_sizes=}, {contracting_dims=}, {x.shape=}, {kernel.shape=}, {contracting_dims=}') + + contracting_dims = ( + # (B*M, K) + (1,), + # (B, K, N) + (1,), + ) + out = grouped_dense( x, kernel, + group_sizes=group_sizes, + contracting_dims=contracting_dims, + # quantizer_set=quantizer_set ) - - group_sizes = None - - # assuming x batch dim is axis 0, squash dims so we have (B*M, K) - # import math - # num_groups = x.shape[0] - # group_size = math.prod(x.shape[1:-1]) - # x_orig_ndim = x.ndim - # # FIXME: breaks partitioning - # x = x.reshape(x.shape[0] * group_size, x.shape[-1]) - # contracting_dims = ( - # tuple([c - (x_orig_ndim - x.ndim) for c in contracting_dims[0]]), - # *contracting_dims[1:], - # ) - - # group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) - - # print(f'{group_sizes=}, {contracting_dims=}, {x.shape=}, {kernel.shape=}, {contracting_dims=}') - - # return transformer_engine.jax.dense.grouped_dense( - # x, - # kernel, - # group_sizes=group_sizes, - # contracting_dims=contracting_dims, - # # quantizer_set=quantizer_set - # ) + return out.reshape(target_out_shape) return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) return wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")() From ab98852671870d1ebabeaf22eb65609d536ca744 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 23 Dec 2025 11:26:56 -0800 Subject: [PATCH 28/40] grouped quantization working for single gpu --- transformer_engine/jax/flax/module.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 733eaf513b..cc6088e8d2 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -377,6 +377,7 @@ def generate_quantizer_set( variable_collection: str = None, quantization_checkpoint_name: Optional[str] = None, fp8_recipe=None, + n_groups: int = None, ): """ Generate a set of FP8 meta for a GEMM. @@ -409,6 +410,7 @@ def generate_quantizer_set( fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set, checkpoint_name=quantization_checkpoint_name, + n_groups=n_groups, ) return quantizer_set @@ -1379,12 +1381,13 @@ def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] class TEWrapper(te.flax.module.TransformerEngineBase): """Wrapper Flax module for TransformerEngine quantization support.""" - def generate_quantizer_set(self, postfix: str = ""): + def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, fp8_recipe=quantization_recipe, + n_groups=n_groups, ) @nn.compact @@ -1456,8 +1459,6 @@ def dot_general(x, kernel, dims, *args, **kwargs): return jax.lax.dot_general(x, kernel, dims, *args, **kwargs) target_out_shape = jax.lax.dot_general(x, kernel, dims).shape - # TODO: add num groups to make grouped quantizer set - quantizer_set = generate_quantizer_set() if x.dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]: # HACK: because x input is bool for dispatch mask @@ -1496,6 +1497,8 @@ def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) + quantizer_set = generate_quantizer_set(n_groups=num_groups) + print(f'{group_sizes=}, {contracting_dims=}, {x.shape=}, {kernel.shape=}, {contracting_dims=}') contracting_dims = ( @@ -1509,7 +1512,7 @@ def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): kernel, group_sizes=group_sizes, contracting_dims=contracting_dims, - # quantizer_set=quantizer_set + quantizer_set=quantizer_set ) return out.reshape(target_out_shape) return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) From f1fc31c5d043f9b4224d0a6e95e0e55335788383 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 5 Jan 2026 08:59:34 -0800 Subject: [PATCH 29/40] wip --- .../jax/csrc/extensions/gemm.cpp | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 6566ff1689..79418c138e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -768,10 +768,24 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); } - nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, - grad, workspace_list.data(), accumulate, use_split_accumulator, - num_math_sm, stream); + // nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), + // pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, + // grad, workspace_list.data(), accumulate, use_split_accumulator, + // num_math_sm, stream); + int64_t avg_m = 0, avg_n = 0, avg_k = 0; + nvte_grouped_gemm( + rhs_is_trans, lhs_is_trans, + alpha, + rhs_list, lhs_list, + beta, + C, + out_list, + workspace_setup, + workspace_cublas, + stream, + &avg_m, + &avg_n, + &avg_k); return ffi_with_cuda_error_check(); } From c8cf7633aa29fbb93a05d7b70475ff1366fc43f0 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 7 Jan 2026 11:22:48 -0800 Subject: [PATCH 30/40] with many hacks grouped gemm with new api works for a particular hardcoded shape --- transformer_engine/jax/cpp_extensions/gemm.py | 15 ++- .../jax/csrc/extensions/gemm.cpp | 105 ++++++++++++++---- 2 files changed, 96 insertions(+), 24 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 28100c9715..38d21f26ec 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1463,7 +1463,7 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + impl_static_args = (9, 10, 11, 12, 13, 14, 15, 16, 17, 18) inner_primitive = None outer_primitive = None @@ -1476,6 +1476,8 @@ def abstract( bias_aval, group_sizes_aval, group_offset_aval, + alpha, + beta, *, M, N, @@ -1535,6 +1537,8 @@ def abstract( # We also pad scale_inv swizzle buffers size for 256 bytes alignment. workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + + workspace_size += 1024*1024 # HACK: properly make a workspace_setup buffer in addition to the workspace_cublas buffer workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) out_shape = (M, N) @@ -1587,6 +1591,8 @@ def impl( bias, group_sizes, group_offset, + alpha, + beta, M, N, K, @@ -1607,6 +1613,8 @@ def impl( bias, group_sizes, group_offset, + alpha, + beta, M=M, N=N, K=K, @@ -2115,6 +2123,9 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias + num_gemms = group_sizes.shape[0] + alpha = jnp.ones((num_gemms,), jnp.float32) + beta = jnp.zeros((num_gemms,), jnp.float32) (out,) = GroupedGemmPrimitive.outer_primitive.bind( lhs_data, lhs_scale_inv, @@ -2123,6 +2134,8 @@ def grouped_gemm( bias, group_sizes, group_offset, + alpha, + beta, M=M, N=N, K=K_lhs, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 79418c138e..7c2d4c81e6 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -399,10 +399,62 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); +NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors) { + printf("make_grouped_tensor data shape: "); + for (auto dim : data.dimensions()) { + printf("%zu, ", dim); + } + printf("\n"); + NVTEShape logical_shape{}; + if (data.dimensions().size() == 1) { + // HACK + size_t cdim_size = 4096; + logical_shape.ndim = 2; + logical_shape.data[0] = data.dimensions()[0] / cdim_size; + logical_shape.data[1] = cdim_size; + } + else { + NVTE_CHECK(data.dimensions().size() == 2, "Expected 2D tensor for GEMM operand but received ndim=", data.dimensions().size()); + + logical_shape.ndim = 2; + logical_shape.data[0] = data.dimensions()[0]; + logical_shape.data[1] = data.dimensions()[1]; + } + + NVTEGroupedTensor grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, logical_shape); + + NVTEBasicTensor data_tensor{reinterpret_cast(data.untyped_data()), + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())), + logical_shape}; + nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseData, &data_tensor); + + if (scale_inv.has_value()) { + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", scale_inv->dimensions().size()); + } + NVTEBasicTensor scale_inv_tensor{reinterpret_cast(scale_inv->untyped_data()), + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())), + logical_scale_shape}; + nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseScaleInv, &scale_inv_tensor); + } + + return grouped_tensor; +} + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, - Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, + Buffer_Type group_sizes, Buffer_Type group_offset, + Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type workspace, + size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: @@ -577,7 +629,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type std::vector bias_list; std::vector pre_gelu_list; std::vector out_list; - std::vector workspace_list; size_t lhs_sinv_total_size = 0; size_t rhs_sinv_total_size = 0; @@ -724,15 +775,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type out_list.push_back(out_wrapper_list.back().data()); } - auto workspace_shape = std::vector{workspace_size}; - for (int i = 0; i < num_streams; i++) { - auto workspace_i = - TensorWrapper(static_cast(workspace_ptr), workspace_shape, DType::kByte); - workspace_wrapper_list.push_back(std::move(workspace_i)); - workspace_list.push_back(workspace_wrapper_list.back().data()); - workspace_ptr += workspace_size; - } - if (is_fp8_gemm) { if (is_tensor_scaling) { lhs_sinv_size *= tensor_scaling_sinv_aligment; @@ -772,20 +814,35 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, // grad, workspace_list.data(), accumulate, use_split_accumulator, // num_math_sm, stream); - int64_t avg_m = 0, avg_n = 0, avg_k = 0; + + constexpr size_t workspace_setup_size = 1024 * 1024; // HACK: dummy workspace for setup + TensorWrapper workspace_setup(workspace_ptr, + std::vector{workspace_setup_size}, DType::kByte); + TensorWrapper workspace_cublas(workspace_ptr + workspace_setup_size, + std::vector{workspace_size}, DType::kByte); + + TensorWrapper alpha_tensor(static_cast(alpha.untyped_data()), std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(alpha.element_type())); + TensorWrapper beta_tensor(static_cast(beta.untyped_data()), std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(beta.element_type())); + + NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms); + NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms); + NVTEGroupedTensor out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms); + nvte_grouped_gemm( rhs_is_trans, lhs_is_trans, - alpha, - rhs_list, lhs_list, - beta, - C, - out_list, - workspace_setup, - workspace_cublas, + alpha_tensor.data(), + rhs_tensor, lhs_tensor, + beta_tensor.data(), + nullptr, + out_tensor, + workspace_setup.data(), + workspace_cublas.data(), stream, - &avg_m, - &avg_n, - &avg_k); + nullptr, + nullptr, + nullptr); return ffi_with_cuda_error_check(); } @@ -800,6 +857,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Arg() // bias .Arg() // group_sizes .Arg() // group_offset + .Arg() // alpha + .Arg() // beta .Ret() // output .Ret() // workspace .Attr("M") From 21e7002991831ecd933388f4ad95a53d0d64d69b Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 7 Jan 2026 11:59:53 -0800 Subject: [PATCH 31/40] progress --- transformer_engine/jax/cpp_extensions/gemm.py | 7 +++ .../jax/csrc/extensions/gemm.cpp | 60 ++++++++++--------- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 38d21f26ec..25f3315ba7 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2123,6 +2123,13 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias + print(f"{lhs_data.shape=}, {rhs_data.shape=}, {group_sizes.shape=}") + print(f"{M=}, {N=}, {K_lhs=}, {K_rhs=}") + # import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() + # print(f"{lhs_is_trans=}, {rhs_is_trans=}") + # import pdb; pdb.set_trace() + num_gemms = group_sizes.shape[0] alpha = jnp.ones((num_gemms,), jnp.float32) beta = jnp.zeros((num_gemms,), jnp.float32) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 7c2d4c81e6..9543c66356 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -399,33 +399,34 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); -NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors) { - printf("make_grouped_tensor data shape: "); - for (auto dim : data.dimensions()) { - printf("%zu, ", dim); - } - printf("\n"); - NVTEShape logical_shape{}; - if (data.dimensions().size() == 1) { - // HACK - size_t cdim_size = 4096; - logical_shape.ndim = 2; - logical_shape.data[0] = data.dimensions()[0] / cdim_size; - logical_shape.data[1] = cdim_size; - } - else { - NVTE_CHECK(data.dimensions().size() == 2, "Expected 2D tensor for GEMM operand but received ndim=", data.dimensions().size()); - - logical_shape.ndim = 2; - logical_shape.data[0] = data.dimensions()[0]; - logical_shape.data[1] = data.dimensions()[1]; - } - - NVTEGroupedTensor grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, logical_shape); +NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors, NVTEShape const& dataShape) { + // printf("make_grouped_tensor data shape: "); + // for (auto dim : data.dimensions()) { + // printf("%zu, ", dim); + // } + // printf("\n"); + // NVTEShape logical_shape{}; + // if (data.dimensions().size() == 1) { + // // HACK + // size_t cdim_size = 4096; + // logical_shape.ndim = 2; + // logical_shape.data[0] = data.dimensions()[0] / cdim_size; + // logical_shape.data[1] = cdim_size; + // printf("NUM TENSORS: %zu\n", num_tensors); + // } + // else { + // NVTE_CHECK(data.dimensions().size() == 2, "Expected 2D tensor for GEMM operand but received ndim=", data.dimensions().size()); + + // logical_shape.ndim = 2; + // logical_shape.data[0] = data.dimensions()[0]; + // logical_shape.data[1] = data.dimensions()[1]; + // } + + NVTEGroupedTensor grouped_tensor = nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); NVTEBasicTensor data_tensor{reinterpret_cast(data.untyped_data()), static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())), - logical_shape}; + dataShape}; nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseData, &data_tensor); if (scale_inv.has_value()) { @@ -826,9 +827,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type TensorWrapper beta_tensor(static_cast(beta.untyped_data()), std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms); - NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms); - NVTEGroupedTensor out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms); + NVTEShape rhsShape{.data={num_gemms * k, n}, .ndim=2}; + NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); + NVTEShape lhsShape{.data={m, k}, .ndim=2}; + NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + NVTEShape outShape{.data={m, n}, .ndim=2}; + NVTEGroupedTensor out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); + + NVTE_CHECK(!rhs_is_trans && !lhs_is_trans, "TE grouped GEMM only supports non-transposed inputs but received rhs_is_trans=", rhs_is_trans, " lhs_is_trans=", lhs_is_trans); nvte_grouped_gemm( rhs_is_trans, lhs_is_trans, From 1ae08ddd7dfde42a9c2fea90128f19a74f9a191c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 7 Jan 2026 13:46:24 -0800 Subject: [PATCH 32/40] more tests pass --- test_einsum.py | 74 +++++++++++++++++++ .../jax/csrc/extensions/gemm.cpp | 16 +++- 2 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 test_einsum.py diff --git a/test_einsum.py b/test_einsum.py new file mode 100644 index 0000000000..5bb05403f2 --- /dev/null +++ b/test_einsum.py @@ -0,0 +1,74 @@ +from enum import Enum + +import jax +import jax.numpy as jnp +import numpy as np +import transformer_engine.jax as te +from transformer_engine.common.recipe import Recipe, Float8CurrentScaling, MXFP8BlockScaling, DelayedScaling, NVFP4BlockScaling +from flax import linen as nn + +def make_einsum_cls(quantization_recipe): + def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs): + def dot_general(x, kernel, dims, *args, **kwargs): + contracting_dims, batch_dims = dims + assert batch_dims == ((), ()), "Batch dims not supported in TE/JAX yet" + + quantizer_set = generate_quantizer_set("quantizer_set_for_einsum") + return te.dense.dense( + x, + kernel, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, + ) + return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) + + return te.flax.wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")() + +class EinsumType(Enum): + JAX = 'jax' + TE = 'te' + +def main(): + + class SimpleModel(nn.Module): + + einsum_type: EinsumType + quantization_recipe: Recipe = None + + def _einsum(self, *args, **kwargs): + if self.einsum_type == EinsumType.JAX: + return jnp.einsum(*args, **kwargs) + elif self.einsum_type == EinsumType.TE: + # It is important that we call make_einsum_cls(recipe) here each time einsum + # is called. If we were to call make_einsum_cls only once and re-use it, the state for some recipes such as DelayedScaling would become incorrectly shared instead of each call having its own state. + return make_einsum_cls(self.quantization_recipe)(*args, **kwargs) + else: + raise ValueError(f"Unsupported einsum type: {self.einsum_type}") + + @nn.compact + def __call__(self, x): + kernel = self.param('kernel', jax.nn.initializers.lecun_normal(), (32, 32), jnp.bfloat16) + return self._einsum("ij,jk->ik", x, kernel) + + + def test_model(einsum_type: EinsumType, quantization_recipe: Recipe = None): + model = SimpleModel(einsum_type=einsum_type, quantization_recipe=quantization_recipe) + x = jax.random.uniform(jax.random.PRNGKey(2), (32, 32), jnp.bfloat16) + var_collect = model.init(jax.random.PRNGKey(3), x) + # It is important to use var_collect here to ensure all state (e.g., quantizer states) is properly handled. If you use var_collect['params'] only, TE's state management will not work correctly for recipes that require state (e.g. DelayedScaling). + y = model.apply(var_collect, x) + return y + + # einsum_cls = None, so standard JAX computation + ref_out = test_model(einsum_type=EinsumType.JAX) + + # einsum using Transformer Engine's Float8CurrentScaling recipe + te_out = test_model(einsum_type=EinsumType.TE, quantization_recipe=Float8CurrentScaling()) + + # Compare outputs + atol = float(jnp.finfo(jnp.float8_e4m3fn).eps) + np.testing.assert_allclose(ref_out, te_out, atol=atol) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 9543c66356..61e241b197 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -827,14 +827,26 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type TensorWrapper beta_tensor(static_cast(beta.untyped_data()), std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - NVTEShape rhsShape{.data={num_gemms * k, n}, .ndim=2}; + NVTEShape rhsShape{.data={k, n}, .ndim=2}; + if (!is_grouped_dense_wgrad) { + rhsShape.data[0] *= num_gemms; + } + if (rhs_is_trans) { + std::swap(rhsShape.data[0], rhsShape.data[1]); + } NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); NVTEShape lhsShape{.data={m, k}, .ndim=2}; + if (lhs_is_trans) { + std::swap(lhsShape.data[0], lhsShape.data[1]); + } NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); NVTEShape outShape{.data={m, n}, .ndim=2}; + if (is_grouped_dense_wgrad) { + outShape.data[0] *= num_gemms; + } NVTEGroupedTensor out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - NVTE_CHECK(!rhs_is_trans && !lhs_is_trans, "TE grouped GEMM only supports non-transposed inputs but received rhs_is_trans=", rhs_is_trans, " lhs_is_trans=", lhs_is_trans); + // NVTE_CHECK(!rhs_is_trans && !lhs_is_trans, "TE grouped GEMM only supports non-transposed inputs but received rhs_is_trans=", rhs_is_trans, " lhs_is_trans=", lhs_is_trans); nvte_grouped_gemm( rhs_is_trans, lhs_is_trans, From fe39e39be1abfa46642fdde9e3ede365bc1dfb3c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 7 Jan 2026 14:25:47 -0800 Subject: [PATCH 33/40] einsum tests pass --- transformer_engine/jax/csrc/extensions/gemm.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 61e241b197..f49530ee1c 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -828,15 +828,16 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type convert_ffi_datatype_to_te_dtype(beta.element_type())); NVTEShape rhsShape{.data={k, n}, .ndim=2}; - if (!is_grouped_dense_wgrad) { - rhsShape.data[0] *= num_gemms; - } if (rhs_is_trans) { std::swap(rhsShape.data[0], rhsShape.data[1]); } + if (!is_grouped_dense_wgrad) { + // If is_grouped_dense_wgrad, then n already includes num_gemms (G) pre-multiplied in gemm.py, so we don't need to multiply it here. + rhsShape.data[0] *= num_gemms; + } NVTEGroupedTensor rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); NVTEShape lhsShape{.data={m, k}, .ndim=2}; - if (lhs_is_trans) { + if (lhs_is_trans && is_grouped_dense_wgrad) { std::swap(lhsShape.data[0], lhsShape.data[1]); } NVTEGroupedTensor lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); From 5e47d57b3e670d86ce37e5e2e44397158360adb4 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 8 Jan 2026 09:37:17 -0800 Subject: [PATCH 34/40] more progress, works in maxtext single-gpu and is closer to bf16 batched gemm speed --- transformer_engine/jax/cpp_extensions/gemm.py | 4 +- .../jax/csrc/extensions/gemm.cpp | 246 +----------------- .../jax/csrc/extensions/quantization.cpp | 26 +- transformer_engine/jax/flax/module.py | 4 +- 4 files changed, 27 insertions(+), 253 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 25f3315ba7..5c53dedb8a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2123,8 +2123,8 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias - print(f"{lhs_data.shape=}, {rhs_data.shape=}, {group_sizes.shape=}") - print(f"{M=}, {N=}, {K_lhs=}, {K_rhs=}") + # print(f"{lhs_data.shape=}, {rhs_data.shape=}, {group_sizes.shape=}") + # print(f"{M=}, {N=}, {K_lhs=}, {K_rhs=}") # import pdb; pdb.set_trace() # import pdb; pdb.set_trace() # print(f"{lhs_is_trans=}, {rhs_is_trans=}") diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f49530ee1c..0bfab2d7dc 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -534,22 +534,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - if (is_tensor_scaling) { - size_t dpitch = tensor_scaling_sinv_aligment; - size_t spitch = lhs_sinv_dtype_bytes; - size_t width = lhs_sinv_dtype_bytes; - size_t height = lhs_sinv_size; - cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream); - spitch = rhs_sinv_dtype_bytes; - width = rhs_sinv_dtype_bytes; - height = rhs_sinv_size; - cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream); - lhs_sinv_ptr = lhs_scatter_aligned_ptr; - rhs_sinv_ptr = rhs_scatter_aligned_ptr; - } - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); @@ -576,29 +560,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type " = ", expected_out_size, ", got ", actual_out_size); } - size_t dim_list_bytes = sizeof(int32_t) * num_gemms; - std::vector dim_list_host(num_gemms); - size_t host_num_gemms = 0; - if (use_async_d2h_group_sizes) { - host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); - NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, - " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); - } else { - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); - } - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; @@ -612,210 +573,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } - // These lists are to keep the TensorWrapper objects alive - std::vector lhs_wrapper_list; - std::vector rhs_wrapper_list; - std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling - std::vector rhs_swizzle_wrapper_list; - std::vector bias_wrapper_list; - std::vector pre_gelu_wrapper_list; - std::vector out_wrapper_list; - std::vector workspace_wrapper_list; - - // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM - std::vector lhs_list; - std::vector rhs_list; - std::vector lhs_swizzle_list; - std::vector rhs_swizzle_list; - std::vector bias_list; - std::vector pre_gelu_list; - std::vector out_list; - - size_t lhs_sinv_total_size = 0; - size_t rhs_sinv_total_size = 0; - - std::vector zero_out_dptr_list; - std::vector zero_out_size_list; - - for (size_t i = 0; i < num_gemms; i++) { - // Matrix data shapes - size_t m_i = dim_list_host[i]; - auto lhs_shape_i = std::vector{m_i, k}; - auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; - auto out_shape_i = std::vector{m_i, n}; - if (is_grouped_dense_wgrad) { - size_t k_i = dim_list_host[i]; - lhs_shape_i[0] = lhs_is_trans ? k_i : m; - lhs_shape_i[1] = lhs_is_trans ? m : k_i; - rhs_shape_i[0] = rhs_is_trans ? n : k_i; - rhs_shape_i[1] = rhs_is_trans ? k_i : n; - out_shape_i[0] = m; - out_shape_i[1] = n; - } - - size_t lhs_size = lhs_shape_i[0] * lhs_shape_i[1]; - size_t rhs_size = rhs_shape_i[0] * rhs_shape_i[1]; - size_t out_size = out_shape_i[0] * out_shape_i[1]; - bool is_empty_gemm = lhs_size == 0 || rhs_size == 0; - if (is_empty_gemm && out_size > 0) { - zero_out_dptr_list.push_back(out_ptr); - zero_out_size_list.push_back(out_size * out_dtype_bytes); - } - - // Set matrix data pointers - auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto out_i = TensorWrapper(static_cast(out_ptr), out_shape_i, out_dtype); - void *lhs_vptr = static_cast(lhs_ptr); - void *rhs_vptr = static_cast(rhs_ptr); - if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - else - rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - else - lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - - // Set scale_inv shapes and pointers - void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); - void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); - size_t lhs_sinv_size_i = 0; - size_t rhs_sinv_size_i = 0; - if (is_tensor_scaling) { - auto tensor_scaling_sinv_shape = std::vector{1}; - // If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers - if (!is_empty_gemm) { - lhs_sinv_size_i = tensor_scaling_sinv_aligment / lhs_sinv_dtype_bytes; - rhs_sinv_size_i = tensor_scaling_sinv_aligment / rhs_sinv_dtype_bytes; - } - if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); - else - rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); - if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); - else - lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); - } else if (is_mxfp8_scaling) { - auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); - void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); - - // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i - // point to swizzled scale_inv data (store on workspace, only used for GEMM). - // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers - auto lhs_sinv_shape_i = - get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); - auto rhs_sinv_shape_i = - get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); - lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; - rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; - if (lhs_use_colwise) { - lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - } else { - lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - } - if (rhs_use_colwise) { - rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - } else { - rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - } - - if (!is_empty_gemm) { - lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); - rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); - lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); - rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); - } - } else { - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, - "Unsupported scaling mode: ", static_cast(scaling_mode)); - } - - auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); - - // Update pointer for the next GEMM pair - lhs_ptr += lhs_size * lhs_dtype_bytes; - rhs_ptr += rhs_size * rhs_dtype_bytes; - out_ptr += out_size * out_dtype_bytes; - if (is_fp8_gemm) { - lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - lhs_sinv_total_size += lhs_sinv_size_i; - rhs_sinv_total_size += rhs_sinv_size_i; - if (is_mxfp8_scaling) { - swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - } - } - if (has_bias) bias_ptr += n * bias_dtype_bytes; - - // Move objects to the lists to keep them alive - if (is_empty_gemm) continue; - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - out_wrapper_list.push_back(std::move(out_i)); - bias_wrapper_list.push_back(std::move(bias_i)); - pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); - - lhs_list.push_back(lhs_wrapper_list.back().data()); - rhs_list.push_back(rhs_wrapper_list.back().data()); - bias_list.push_back(bias_wrapper_list.back().data()); - pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data()); - out_list.push_back(out_wrapper_list.back().data()); - } - - if (is_fp8_gemm) { - if (is_tensor_scaling) { - lhs_sinv_size *= tensor_scaling_sinv_aligment; - rhs_sinv_size *= tensor_scaling_sinv_aligment; - } - NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ", - lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size); - NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ", - rhs_sinv_total_size, " exceeds estimated upper bound ", rhs_sinv_size); - } - - size_t num_non_empty_gemms = lhs_list.size(); - - if (is_mxfp8_scaling) { - for (int i = 0; i < num_non_empty_gemms; i++) { - // The i-th GEMM will use the (i % num_streams)-th stream to compute, - // use the same stream to swizzle the scaling factors to make sure that - // the swizzling is done before the GEMM computation starts. - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); - nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); - } - } - - // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM - size_t num_zero_outs = zero_out_dptr_list.size(); - for (int i = 0; i < num_zero_outs; i++) { - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - void *dptr = zero_out_dptr_list[i]; - size_t count = zero_out_size_list[i]; - NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); - } - - // nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - // pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, - // grad, workspace_list.data(), accumulate, use_split_accumulator, - // num_math_sm, stream); - constexpr size_t workspace_setup_size = 1024 * 1024; // HACK: dummy workspace for setup TensorWrapper workspace_setup(workspace_ptr, std::vector{workspace_setup_size}, DType::kByte); @@ -888,7 +645,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad") - .Attr("use_async_d2h_group_sizes")); + .Attr("use_async_d2h_group_sizes"), + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 1f7db84383..ad3553313f 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -375,11 +375,24 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty size_t num_groups = group_sizes.dimensions()[0]; size_t dim_list_bytes = group_size_dtype_bytes * num_groups; std::vector dim_list_host(num_groups); - auto *group_size_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); + // HACK: assumes batched gemm with equal group sizes + for (size_t i = 0; i < num_groups; i++) { + if (input_dims[0] == num_groups) { + dim_list_host[i] = 1; + continue; + } + dim_list_host[i] = m / num_groups; + } + // auto *group_size_ptr = reinterpret_cast(group_sizes.untyped_data()); + // cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + // stream); + // // Note: This may break cudaGraph. + // cudaStreamSynchronize(stream); + // printf("GroupedQuantizeFFI: m=%zu, n=%zu, group sizes = ", m, n); + // for (size_t i = 0; i < num_groups; i++) { + // printf("%d ", dim_list_host[i]); + // } + // printf("\n"); size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, @@ -492,7 +505,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Ret() // amax .Attr("scaling_mode") .Attr("q_layout") - .Attr("flatten_axis")); + .Attr("flatten_axis"), + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index cc6088e8d2..3b4a5ef148 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1444,6 +1444,7 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): def make_einsum_cls(quantization_recipe): import functools + import math import jax def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs): # with open("/tmp/te_einsum_log.txt", "a") as f: @@ -1493,7 +1494,8 @@ def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): kernel = reorder_rhs_for_grouped_gemm(kernel, (batch_dims[1],), contracting_dims[1]) num_groups = kernel.shape[0] - group_size = x.shape[0] // num_groups + group_size = math.prod(x.shape[:-1]) // num_groups + print(f'{num_groups=}, {group_size=}, {x.shape=}, {kernel.shape=}') group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) From bc6cf66512bf4a4a35ce9e014768bb34f749744b Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 8 Jan 2026 10:44:12 -0800 Subject: [PATCH 35/40] attempt at passing thru stateful args for DS --- transformer_engine/jax/quantize/quantizer.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 4edc187795..6831758875 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -7,7 +7,7 @@ This module provides classes and utilities for quantizing tensors in JAX. """ from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass, field, InitVar from functools import partial from typing import Union, Optional, Tuple import warnings @@ -893,6 +893,7 @@ class GroupedQuantizer(Quantizer): data_layout: str = None n_groups: int = 1 quantizers: Tuple[Quantizer] = field(default_factory=lambda: (None,)) + extra_kwargs: InitVar[dict] = None def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -911,10 +912,12 @@ def tree_flatten(self): ) return (children, aux_data) - def __post_init__(self): + def __post_init__(self, extra_kwargs: dict): + print(f"QuantizerFactory creating quantizers for GroupedQuantizer: {self.n_groups=}, {self.scaling_mode=}, {self.q_dtype=}, {self.q_layout=}, {extra_kwargs=}, {self.quantizers=}") if self.quantizers[0] is None: quantizers = QuantizerFactory.create( - self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout + n_quantizers=self.n_groups, + scaling_mode=self.scaling_mode, q_dtype=self.q_dtype, q_layout=self.q_layout, **extra_kwargs ) self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers self.data_layout = self.quantizers[0].data_layout @@ -1106,8 +1109,14 @@ def create( warnings.warn( "Using more than one GroupedQuantizer for a grouped input is not recommended" ) - quantizer_type = GroupedQuantizer - kwargs["n_groups"] = n_groups + quantizer_type = lambda q_dtype, scaling_mode, q_layout, checkpoint_name, **kwargs: GroupedQuantizer( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + checkpoint_name=checkpoint_name, + n_groups=n_groups, + extra_kwargs=kwargs, + ) else: quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) From bcbe864825fa8f40103c72b8b750a807490de28f Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 8 Jan 2026 10:44:18 -0800 Subject: [PATCH 36/40] Revert "attempt at passing thru stateful args for DS" This reverts commit bc6cf66512bf4a4a35ce9e014768bb34f749744b. --- transformer_engine/jax/quantize/quantizer.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 6831758875..4edc187795 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -7,7 +7,7 @@ This module provides classes and utilities for quantizing tensors in JAX. """ from abc import ABC, abstractmethod -from dataclasses import dataclass, field, InitVar +from dataclasses import dataclass, field from functools import partial from typing import Union, Optional, Tuple import warnings @@ -893,7 +893,6 @@ class GroupedQuantizer(Quantizer): data_layout: str = None n_groups: int = 1 quantizers: Tuple[Quantizer] = field(default_factory=lambda: (None,)) - extra_kwargs: InitVar[dict] = None def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -912,12 +911,10 @@ def tree_flatten(self): ) return (children, aux_data) - def __post_init__(self, extra_kwargs: dict): - print(f"QuantizerFactory creating quantizers for GroupedQuantizer: {self.n_groups=}, {self.scaling_mode=}, {self.q_dtype=}, {self.q_layout=}, {extra_kwargs=}, {self.quantizers=}") + def __post_init__(self): if self.quantizers[0] is None: quantizers = QuantizerFactory.create( - n_quantizers=self.n_groups, - scaling_mode=self.scaling_mode, q_dtype=self.q_dtype, q_layout=self.q_layout, **extra_kwargs + self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout ) self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers self.data_layout = self.quantizers[0].data_layout @@ -1109,14 +1106,8 @@ def create( warnings.warn( "Using more than one GroupedQuantizer for a grouped input is not recommended" ) - quantizer_type = lambda q_dtype, scaling_mode, q_layout, checkpoint_name, **kwargs: GroupedQuantizer( - q_dtype=q_dtype, - scaling_mode=scaling_mode, - q_layout=q_layout, - checkpoint_name=checkpoint_name, - n_groups=n_groups, - extra_kwargs=kwargs, - ) + quantizer_type = GroupedQuantizer + kwargs["n_groups"] = n_groups else: quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) From b40353fbad69d3b90197f1ea8dd28dee9263d593 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 8 Jan 2026 14:06:45 -0800 Subject: [PATCH 37/40] batch gemm specialization for CS amax calc --- .../jax/cpp_extensions/quantization.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index a95afe8b8e..b8ea3bd4f4 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1209,21 +1209,26 @@ def grouped_quantize( assert n_groups == len( quantizer.quantizers ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}" - scale = jnp.empty((n_groups,), jnp.float32) + scale = jnp.ones((n_groups,), jnp.float32) if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: for i, quantizer_i in enumerate(quantizer.quantizers): scale = scale.at[i].set(quantizer_i.scale[0]) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - if amax is not None: - row_amax = amax - else: - row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) - segment_ids = jnp.repeat( - jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] - ) - grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) + # TODO fixme, measure perf with always scale/amax of 1 to just isolate quant and gemm + # HACK: assumes equal group sizes + assert group_axis == 0, f"Currently only group_axis = 0 is supported for current-tensor-scaling, but received {group_axis=}" + grouped_amax = jnp.max(jnp.abs(x.reshape((n_groups, x.shape[0]//n_groups, *x.shape[1:]))), axis=tuple(range(1, x.ndim+1))) + # import pdb; pdb.set_trace() + # if amax is not None: + # row_amax = amax + # else: + # row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) + # segment_ids = jnp.repeat( + # jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] + # ) + # grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0) scale = scale.at[i].set(tmp_scale[0]) From ee71c96552c4065bee9826992e1cadfd9556c012 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 15 Jan 2026 10:35:06 -0800 Subject: [PATCH 38/40] multi-GPU grouped quantize working now in shard_map (with hack to use single-stream for multi tensor quantize) --- transformer_engine/common/cast/cast.cu | 22 +---------- .../jax/cpp_extensions/quantization.py | 21 ++++------ .../jax/csrc/extensions/quantization.cpp | 39 +++++++------------ transformer_engine/jax/flax/__init__.py | 3 +- transformer_engine/jax/flax/module.py | 20 ++++++++++ transformer_engine/jax/sharding.py | 6 ++- 6 files changed, 49 insertions(+), 62 deletions(-) diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 73467d7275..dc77a35886 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -75,29 +75,9 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, constexpr bool IS_ACT = false; - const size_t num_streams = nvte_get_num_compute_streams(); - - int num_stream_used = std::min(num_streams, num_tensors); - // wait for current stream to finish - NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA( - cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); - } - for (int i = 0; i < num_tensors; i++) { dispatch::quantize_fwd_helper( - inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams)); - } - - // record events on compute streams - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA( - cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); - } - // wait for all compute streams to finish - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); + inputs[i], outputs[i], quant_configs, stream); } } diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index b8ea3bd4f4..4a2c001f5b 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1216,19 +1216,14 @@ def grouped_quantize( scale = scale.at[i].set(quantizer_i.scale[0]) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - # TODO fixme, measure perf with always scale/amax of 1 to just isolate quant and gemm - # HACK: assumes equal group sizes - assert group_axis == 0, f"Currently only group_axis = 0 is supported for current-tensor-scaling, but received {group_axis=}" - grouped_amax = jnp.max(jnp.abs(x.reshape((n_groups, x.shape[0]//n_groups, *x.shape[1:]))), axis=tuple(range(1, x.ndim+1))) - # import pdb; pdb.set_trace() - # if amax is not None: - # row_amax = amax - # else: - # row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) - # segment_ids = jnp.repeat( - # jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] - # ) - # grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) + if amax is not None: + row_amax = amax + else: + row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) + segment_ids = jnp.repeat( + jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] + ) + grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0) scale = scale.at[i].set(tmp_scale[0]) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index ad3553313f..2b7beb8d6b 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -375,29 +375,19 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty size_t num_groups = group_sizes.dimensions()[0]; size_t dim_list_bytes = group_size_dtype_bytes * num_groups; std::vector dim_list_host(num_groups); - // HACK: assumes batched gemm with equal group sizes - for (size_t i = 0; i < num_groups; i++) { - if (input_dims[0] == num_groups) { - dim_list_host[i] = 1; - continue; - } - dim_list_host[i] = m / num_groups; - } - // auto *group_size_ptr = reinterpret_cast(group_sizes.untyped_data()); - // cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - // stream); - // // Note: This may break cudaGraph. - // cudaStreamSynchronize(stream); - // printf("GroupedQuantizeFFI: m=%zu, n=%zu, group sizes = ", m, n); - // for (size_t i = 0; i < num_groups; i++) { - // printf("%d ", dim_list_host[i]); - // } - // printf("\n"); - - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, - "Unexpected group_sizes! Got %zu (M=%zu, input_dims[0] = %zu)", sum_group_sizes, m, - input_dims[0]); + auto *group_size_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + + // For MaxText case, I think is okay if this check fails as we are expecting to overallocate the buffers in the current use_ring_of_experts impl, which will result in the group sizes not filling the whole tensor. + // size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + // NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, + // "Unexpected group_sizes! Got ", sum_group_sizes, " (M=", m, ", input_dims[0] = ", input_dims[0], ")"); + + // TODO(jberchtold): This is a temporary fix to zero out the output buffers to prevent NaNs in output when this buffer is over-allocated and the groups do not fill the whole buffer. Though these NaNs should be ignored in the downstream GEMM, so more debugging is needed to see why they cause issues. + cudaMemsetAsync(outputs->untyped_data(), 0, outputs->size_bytes(), stream); if (is_delayed_scaling) { NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups, @@ -505,8 +495,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Ret() // amax .Attr("scaling_mode") .Attr("q_layout") - .Attr("flatten_axis"), - FFI_CudaGraph_Traits); + .Attr("flatten_axis")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 59a0958b7b..1a19685697 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -4,7 +4,7 @@ """Transformer Engine bindings for JAX""" from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP -from .module import wrap_function_in_te_state_module, make_dot_general_cls, make_einsum_cls +from .module import wrap_function_in_te_state_module, make_dot_general_cls, make_einsum_cls, make_ragged_dot_cls from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -17,6 +17,7 @@ "wrap_function_in_te_state_module", "make_dot_general_cls", "make_einsum_cls", + "make_ragged_dot_cls", "extend_logical_axis_rules", "DotProductAttention", "MultiHeadAttention", diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 3b4a5ef148..03d5581ae6 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1520,3 +1520,23 @@ def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) return wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")() + +def make_ragged_dot_cls(quantization_recipe): + import jax + def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): + num_groups = group_sizes.shape[0] + quantizer_set = generate_quantizer_set(n_groups=num_groups) + + target_out_shape = jax.lax.ragged_dot(x, kernel, group_sizes=group_sizes).shape + + out = grouped_dense( + x, + kernel, + group_sizes=group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set + ) + + return out.reshape(target_out_shape) + + return wrap_function_in_te_state_module(te_grouped_dot_general, quantization_recipe, "ragged_dot")() diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index b4b8c42027..4171d1c7b0 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -51,7 +51,8 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): return mesh.shape[resource], resource -def _validate_mesh_resource_configuration(mesh_resource): +# TODO(jberchtold): FIXME, this validation fails in FP8CS amax reduction because the GlobalMeshResource is set but there is no active mesh in the context (afaict shard_map does not share it's mesh as a context), so this is triggering a FalsePositive assert. However, I am not sure if we can safely ignore this when the mesh is empty or all axes are manual as some users may use shard_map with some axes manual and some auto. +# def _validate_mesh_resource_configuration(mesh_resource): """Validate that the mesh resource configuration is consistent and conflict-free.""" is_tp_enabled = ( mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1 @@ -375,7 +376,8 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) - _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + # TODO(jberchtold): FIXME, this validation fails in FP8CS amax reduction because the GlobalMeshResource is set but there is no active mesh in the context (afaict shard_map does not share it's mesh as a context), so this is triggering a FalsePositive assert. However, I am not sure if we can safely ignore this when the mesh is empty or all axes are manual as some users may use shard_map with some axes manual and some auto. + # _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) return _GLOBAL_MESH_RESOURCE From 9856862450547b2cbd688f30dc4fa8ecda111227 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 15 Jan 2026 11:11:55 -0800 Subject: [PATCH 39/40] reduce size of zero'ing memset to only uninitialized part of quantization buffer --- .../jax/csrc/extensions/quantization.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 2b7beb8d6b..3d98126290 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -382,13 +382,10 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty cudaStreamSynchronize(stream); // For MaxText case, I think is okay if this check fails as we are expecting to overallocate the buffers in the current use_ring_of_experts impl, which will result in the group sizes not filling the whole tensor. - // size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); // NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, // "Unexpected group_sizes! Got ", sum_group_sizes, " (M=", m, ", input_dims[0] = ", input_dims[0], ")"); - // TODO(jberchtold): This is a temporary fix to zero out the output buffers to prevent NaNs in output when this buffer is over-allocated and the groups do not fill the whole buffer. Though these NaNs should be ignored in the downstream GEMM, so more debugging is needed to see why they cause issues. - cudaMemsetAsync(outputs->untyped_data(), 0, outputs->size_bytes(), stream); - if (is_delayed_scaling) { NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups, ", got ", amaxs->dimensions()[0]); @@ -402,6 +399,13 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty size_t num_non_empty_groups = 0; size_t total_rowwise_sinv_size = 0; size_t total_colwise_sinv_size = 0; + + + // TODO(jberchtold): This is a temporary fix to zero out the output buffers to prevent NaNs in output when this buffer is over-allocated and the groups do not fill the whole buffer. Though these NaNs should be ignored in the downstream GEMM, so more debugging is needed to see why they cause issues. + size_t used_output_size = (sum_group_sizes*non_group_m) * n * output_dtype_bytes; + cudaMemsetAsync(outputs->untyped_data() + used_output_size, 0, outputs->size_bytes() - used_output_size, stream); + + for (size_t i = 0; i < num_groups; i++) { size_t m_i = dim_list_host[i] * non_group_m; // Skip for zero-size input + shiff the scale ptr From f58ba23bf667b375a71e752e8069c7d4f47c85ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jan 2026 19:52:42 +0000 Subject: [PATCH 40/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test_einsum.py | 51 +++--- tests/jax/test_custom_call_compute.py | 51 ++++-- transformer_engine/common/cast/cast.cu | 4 +- .../common/gemm/cublaslt_gemm.cu | 2 +- transformer_engine/jax/cpp_extensions/base.py | 6 +- transformer_engine/jax/cpp_extensions/gemm.py | 42 +++-- .../jax/cpp_extensions/quantization.py | 6 +- .../jax/csrc/extensions/gemm.cpp | 83 +++++---- .../jax/csrc/extensions/quantization.cpp | 7 +- transformer_engine/jax/dense.py | 82 ++++----- transformer_engine/jax/flax/__init__.py | 7 +- transformer_engine/jax/flax/module.py | 169 ++++++++++-------- transformer_engine/jax/sharding.py | 5 +- 13 files changed, 288 insertions(+), 227 deletions(-) diff --git a/test_einsum.py b/test_einsum.py index 5bb05403f2..1b1f502c51 100644 --- a/test_einsum.py +++ b/test_einsum.py @@ -4,29 +4,39 @@ import jax.numpy as jnp import numpy as np import transformer_engine.jax as te -from transformer_engine.common.recipe import Recipe, Float8CurrentScaling, MXFP8BlockScaling, DelayedScaling, NVFP4BlockScaling +from transformer_engine.common.recipe import ( + Recipe, + Float8CurrentScaling, + MXFP8BlockScaling, + DelayedScaling, + NVFP4BlockScaling, +) from flax import linen as nn + def make_einsum_cls(quantization_recipe): def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs): - def dot_general(x, kernel, dims, *args, **kwargs): - contracting_dims, batch_dims = dims - assert batch_dims == ((), ()), "Batch dims not supported in TE/JAX yet" - - quantizer_set = generate_quantizer_set("quantizer_set_for_einsum") - return te.dense.dense( - x, - kernel, - contracting_dims=contracting_dims, - quantizer_set=quantizer_set, - ) - return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) - + def dot_general(x, kernel, dims, *args, **kwargs): + contracting_dims, batch_dims = dims + assert batch_dims == ((), ()), "Batch dims not supported in TE/JAX yet" + + quantizer_set = generate_quantizer_set("quantizer_set_for_einsum") + return te.dense.dense( + x, + kernel, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, + ) + + return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) + return te.flax.wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")() + class EinsumType(Enum): - JAX = 'jax' - TE = 'te' + JAX = "jax" + TE = "te" + def main(): @@ -47,9 +57,10 @@ def _einsum(self, *args, **kwargs): @nn.compact def __call__(self, x): - kernel = self.param('kernel', jax.nn.initializers.lecun_normal(), (32, 32), jnp.bfloat16) + kernel = self.param( + "kernel", jax.nn.initializers.lecun_normal(), (32, 32), jnp.bfloat16 + ) return self._einsum("ij,jk->ik", x, kernel) - def test_model(einsum_type: EinsumType, quantization_recipe: Recipe = None): model = SimpleModel(einsum_type=einsum_type, quantization_recipe=quantization_recipe) @@ -68,7 +79,7 @@ def test_model(einsum_type: EinsumType, quantization_recipe: Recipe = None): # Compare outputs atol = float(jnp.finfo(jnp.float8_e4m3fn).eps) np.testing.assert_allclose(ref_out, te_out, atol=atol) - + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 082a99cd8b..674e3e76f7 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1975,18 +1975,22 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) -@pytest_parametrize_wrapper('eqn,a_shape,b_shape', [ - # ('ij,jk->ik', (64, 32), (32, 128)), - # ('bij,bjk->bik', (8, 64, 32), (8, 32, 128)), - # ('abc,cde->abde', (4, 8, 16), (16, 32, 64)), - ('BSM,BSEC->EBCM', (2, 4096, 4096), (2, 4096, 8, 1024)), - ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)) , - ('EBCM,EMH->EBCH', (8, 2, 1024, 4096), (8, 4096, 14336)), - ('EBCH,EHM->EBCM', (8, 2, 1024, 14336), (8, 14336, 4096)), - ('EBCM,BSEC->BSM', (8, 2, 1024, 4096), (2, 4096, 8, 1024)), -]) -@pytest_parametrize_wrapper('dtype', [jnp.bfloat16]) -@pytest_parametrize_wrapper('quantization_recipe', supported_recipes) + +@pytest_parametrize_wrapper( + "eqn,a_shape,b_shape", + [ + # ('ij,jk->ik', (64, 32), (32, 128)), + # ('bij,bjk->bik', (8, 64, 32), (8, 32, 128)), + # ('abc,cde->abde', (4, 8, 16), (16, 32, 64)), + ("BSM,BSEC->EBCM", (2, 4096, 4096), (2, 4096, 8, 1024)), + ("EBCM,EMH->EBCH", (8, 2, 1024, 4096), (8, 4096, 14336)), + ("EBCM,EMH->EBCH", (8, 2, 1024, 4096), (8, 4096, 14336)), + ("EBCH,EHM->EBCM", (8, 2, 1024, 14336), (8, 14336, 4096)), + ("EBCM,BSEC->BSM", (8, 2, 1024, 4096), (2, 4096, 8, 1024)), + ], +) +@pytest_parametrize_wrapper("dtype", [jnp.bfloat16]) +@pytest_parametrize_wrapper("quantization_recipe", supported_recipes) class TestEinsum: def _te_einsum(self, eqn, a, b, quantization_recipe): @@ -2011,7 +2015,9 @@ def test_einsum_fwd(self, eqn, a_shape, b_shape, dtype, quantization_recipe): a = jax.random.uniform(subkeys[0], a_shape, dtype=dtype) b = jax.random.uniform(subkeys[1], b_shape, dtype=dtype) - te_out = jax.jit(functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe))(a, b) + te_out = jax.jit( + functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe) + )(a, b) ref_out = jax.jit(functools.partial(self._ref_einsum, eqn))(a, b) assert_allclose(te_out, ref_out, dtype=dtype) @@ -2032,14 +2038,25 @@ def wrap_in_mean(f): @functools.wraps(f) def wrapped(*args): return jnp.mean(f(*args)) + return wrapped - te_fwd, te_grads = jax.jit(jax.value_and_grad(wrap_in_mean(functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe))))(a, b) - ref_fwd, ref_grads = jax.jit(jax.value_and_grad(wrap_in_mean(functools.partial(self._ref_einsum, eqn))))(a, b) + te_fwd, te_grads = jax.jit( + jax.value_and_grad( + wrap_in_mean( + functools.partial(self._te_einsum, eqn, quantization_recipe=quantization_recipe) + ) + ) + )(a, b) + ref_fwd, ref_grads = jax.jit( + jax.value_and_grad(wrap_in_mean(functools.partial(self._ref_einsum, eqn))) + )(a, b) assert_allclose(te_fwd, ref_fwd, dtype=dtype) - assert len(te_grads) == len(ref_grads), f"Number of gradients differ: {len(te_grads)=} vs {len(ref_grads)=}" + assert len(te_grads) == len( + ref_grads + ), f"Number of gradients differ: {len(te_grads)=} vs {len(ref_grads)=}" for te_grad, ref_grad in zip(te_grads, ref_grads): - assert_allclose(te_grad, ref_grad, dtype=dtype) \ No newline at end of file + assert_allclose(te_grad, ref_grad, dtype=dtype) diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index dc77a35886..e6f9b70549 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -76,8 +76,8 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, constexpr bool IS_ACT = false; for (int i = 0; i < num_tensors; i++) { - dispatch::quantize_fwd_helper( - inputs[i], outputs[i], quant_configs, stream); + dispatch::quantize_fwd_helper(inputs[i], outputs[i], quant_configs, + stream); } } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index cc4a29b304..6cdf3c95b9 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -156,7 +156,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage // NVTE_CHECK(ret.lda % 16 == 0, - // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 335af2eb47..8096f57396 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -216,7 +216,8 @@ def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): elif arg.shape[bdim] != batch_size: raise ValueError( "All batched arguments must have the same batch size. " - f"Got sizes {[arg.shape[bdim] for arg, bdim in zip(batched_args, batch_dims) if bdim is not None]}. " + "Got sizes" + f" {[arg.shape[bdim] for arg, bdim in zip(batched_args, batch_dims) if bdim is not None]}. " f"Got batched_args={[arg.shape for arg, bdim in zip(batched_args, batch_dims) if bdim is not None]}." ) assert batch_dim is not None and batch_size is not None, "Invalid batching config!" @@ -255,7 +256,8 @@ def batcher(batched_args, batch_dims, *, arg1, arg2, arg3): # Stack each output along the batch dimension if output_bdims is not None: stacked_results = tuple( - jnp.stack(list(out_list), axis=out_bdim) for out_list, out_bdim in zip(transposed, output_bdims) + jnp.stack(list(out_list), axis=out_bdim) + for out_list, out_bdim in zip(transposed, output_bdims) ) else: stacked_results = tuple( diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 5c53dedb8a..4227965707 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -818,8 +818,9 @@ def batcher( # f"got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" # ) - f = partial(GemmPrimitive.outer_impl, - **{ + f = partial( + GemmPrimitive.outer_impl, + **{ "out_dtype": out_dtype, "contracting_dims": contracting_dims, "scaling_mode": scaling_mode, @@ -831,16 +832,16 @@ def batcher( "transpose_batch_sequence": transpose_batch_sequence, "sequence_dim": sequence_dim, "is_outer": is_outer, - }) - + }, + ) + lhs_cdims, rhs_cdims = contracting_dims # Calculate output batch dimension based on input batch dims and contracting dims # Both lhs and rhs have batch dimensions that may be at different indices if lhs_bdims is not None and rhs_bdims is not None: # Count non-contracting dimensions in LHS before the batch dimension lhs_non_contracting_before_batch = sum( - 1 for i in range(lhs_bdims) - if i not in lhs_cdims + 1 for i in range(lhs_bdims) if i not in lhs_cdims ) # The output batch dimension will be at the position corresponding to # the LHS batch dimension's position among non-contracting dimensions @@ -850,8 +851,13 @@ def batcher( output_bdim = 0 elif rhs_bdims is not None: # RHS has a batch dimension - need to account for LHS non-contracting dims - lhs_non_contracting = len([i for i in range(len(batched_args[0].shape)) - if i not in lhs_cdims and i != lhs_bdims]) + lhs_non_contracting = len( + [ + i + for i in range(len(batched_args[0].shape)) + if i not in lhs_cdims and i != lhs_bdims + ] + ) output_bdim = lhs_non_contracting else: # No batch dimensions in either operand @@ -861,16 +867,16 @@ def batcher( return GemmPrimitive.batcher_impl( batched_args, batch_dims=( - lhs_bdims, # lhs - 0, # lhs_scale_inv - rhs_bdims, # rhs - 0, # rhs_scale_inv - *(None for _ in batched_args[4:]), # bias, gelu_input, alpha, beta + lhs_bdims, # lhs + 0, # lhs_scale_inv + rhs_bdims, # rhs + 0, # rhs_scale_inv + *(None for _ in batched_args[4:]), # bias, gelu_input, alpha, beta ), output_bdims=( - output_bdim, # output - 0, # bias_grad - 0, # pre_gelu_out + output_bdim, # output + 0, # bias_grad + 0, # pre_gelu_out ), static_kwargs={ "out_dtype": out_dtype, @@ -1538,7 +1544,9 @@ def abstract( workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding - workspace_size += 1024*1024 # HACK: properly make a workspace_setup buffer in addition to the workspace_cublas buffer + workspace_size += ( + 1024 * 1024 + ) # HACK: properly make a workspace_setup buffer in addition to the workspace_cublas buffer workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) out_shape = (M, N) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 4a2c001f5b..5e01a4cece 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -368,10 +368,12 @@ def batcher( batch_dims, output_bdims=( batch_dims[0], # out - batch_dims[0], # colwise_out (probably need to transpose according if scaling mode does it) + batch_dims[ + 0 + ], # colwise_out (probably need to transpose according if scaling mode does it) 0, # scale_inv 0, # colwise_scale_inv - 0, # updated_amax + 0, # updated_amax 0, # dbias ), static_kwargs={ diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 0bfab2d7dc..9496caf406 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -399,7 +399,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); -NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional scale_inv, JAXX_Scaling_Mode scaling_mode, size_t num_tensors, NVTEShape const& dataShape) { +NVTEGroupedTensor make_grouped_tensor(Buffer_Type const &data, std::optional scale_inv, + JAXX_Scaling_Mode scaling_mode, size_t num_tensors, + NVTEShape const &dataShape) { // printf("make_grouped_tensor data shape: "); // for (auto dim : data.dimensions()) { // printf("%zu, ", dim); @@ -422,11 +424,12 @@ NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional(data.untyped_data()), - static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())), - dataShape}; + NVTEBasicTensor data_tensor{ + reinterpret_cast(data.untyped_data()), + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())), dataShape}; nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseData, &data_tensor); if (scale_inv.has_value()) { @@ -439,11 +442,13 @@ NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optionaldimensions()[0]; logical_scale_shape.data[1] = scale_inv->dimensions()[1]; } else { - NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", scale_inv->dimensions().size()); + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", + scale_inv->dimensions().size()); } - NVTEBasicTensor scale_inv_tensor{reinterpret_cast(scale_inv->untyped_data()), - static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())), - logical_scale_shape}; + NVTEBasicTensor scale_inv_tensor{ + reinterpret_cast(scale_inv->untyped_data()), + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())), + logical_scale_shape}; nvte_set_grouped_tensor_param(&grouped_tensor, kNVTEGroupedRowwiseScaleInv, &scale_inv_tensor); } @@ -452,11 +457,10 @@ NVTEGroupedTensor make_grouped_tensor(Buffer_Type const& data, std::optional("is_grouped_dense_wgrad") .Attr("use_async_d2h_group_sizes"), - FFI_CudaGraph_Traits); + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 3d98126290..4460232852 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -400,12 +400,11 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty size_t total_rowwise_sinv_size = 0; size_t total_colwise_sinv_size = 0; - // TODO(jberchtold): This is a temporary fix to zero out the output buffers to prevent NaNs in output when this buffer is over-allocated and the groups do not fill the whole buffer. Though these NaNs should be ignored in the downstream GEMM, so more debugging is needed to see why they cause issues. - size_t used_output_size = (sum_group_sizes*non_group_m) * n * output_dtype_bytes; - cudaMemsetAsync(outputs->untyped_data() + used_output_size, 0, outputs->size_bytes() - used_output_size, stream); + size_t used_output_size = (sum_group_sizes * non_group_m) * n * output_dtype_bytes; + cudaMemsetAsync(outputs->untyped_data() + used_output_size, 0, + outputs->size_bytes() - used_output_size, stream); - for (size_t i = 0; i < num_groups; i++) { size_t m_i = dim_list_host[i] * non_group_m; // Skip for zero-size input + shiff the scale ptr diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 9db60d3bd8..aa31be1842 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -237,46 +237,46 @@ def _dense_fwd_rule( ) return output, ctx -def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, - swap_ans=False): - # from: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py#L198 - import itertools - import numpy as np - def _remaining(original, *removed_lists): - removed = set(itertools.chain(*removed_lists)) - return tuple(i for i in original if i not in removed) - - def _ranges_like(*xs): - start = 0 - for x in xs: - x_len = len(x) - yield tuple(range(start, start + x_len)) - start += x_len - - (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers - x_ndim = x.ndim - x_kept = _remaining(tuple(range(x_ndim)), x_contract, x_batch) - y_kept = _remaining(tuple(range(y.ndim)), y_contract, y_batch) - if swap_ans: - ans_batch, ans_y, _ = _ranges_like(x_batch, y_kept, x_kept) - else: - ans_batch, _, ans_y = _ranges_like(x_batch, x_kept, y_kept) - dims = ((ans_y, y_kept), (ans_batch, y_batch)) - x_contract_sorted_by_y = tuple(np.take(x_contract, np.argsort(y_contract))) - out_axes = np.argsort(tuple(x_batch) + x_kept + x_contract_sorted_by_y) - x_bar = jax.lax.transpose( - tex.gemm(g, y, contracting_dims=dims[0]), - tuple(out_axes) - ) - return x_bar + +def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, swap_ans=False): + # from: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py#L198 + import itertools + import numpy as np + + def _remaining(original, *removed_lists): + removed = set(itertools.chain(*removed_lists)) + return tuple(i for i in original if i not in removed) + + def _ranges_like(*xs): + start = 0 + for x in xs: + x_len = len(x) + yield tuple(range(start, start + x_len)) + start += x_len + + (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers + x_ndim = x.ndim + x_kept = _remaining(tuple(range(x_ndim)), x_contract, x_batch) + y_kept = _remaining(tuple(range(y.ndim)), y_contract, y_batch) + if swap_ans: + ans_batch, ans_y, _ = _ranges_like(x_batch, y_kept, x_kept) + else: + ans_batch, _, ans_y = _ranges_like(x_batch, x_kept, y_kept) + dims = ((ans_y, y_kept), (ans_batch, y_batch)) + x_contract_sorted_by_y = tuple(np.take(x_contract, np.argsort(y_contract))) + out_axes = np.argsort(tuple(x_batch) + x_kept + x_contract_sorted_by_y) + x_bar = jax.lax.transpose(tex.gemm(g, y, contracting_dims=dims[0]), tuple(out_axes)) + return x_bar + def dot_general_transpose_rhs(g, x, y, *, dimension_numbers): - (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers - swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) - y_bar = dot_general_transpose_lhs( - g, y, x, dimension_numbers=swapped_dimension_numbers, - swap_ans=True) - return y_bar + (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers + swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) + y_bar = dot_general_transpose_lhs( + g, y, x, dimension_numbers=swapped_dimension_numbers, swap_ans=True + ) + return y_bar + def _dense_bwd_rule( contracting_dims, @@ -318,7 +318,7 @@ def _dense_bwd_rule( ) fwd_cdims = (fwd_x_contracting_dims, fwd_k_contracting_dims) - batch_dims = ((), ()) # vmap is done outside dense VJP if needed + batch_dims = ((), ()) # vmap is done outside dense VJP if needed dims = (fwd_cdims, batch_dims) dgrad = dot_general_transpose_lhs( @@ -329,7 +329,9 @@ def _dense_bwd_rule( ) wgrad = dot_general_transpose_rhs( - casted_grad.get_tensor(usage=TensorUsage.LHS), # TODO(jberchtold): should be RHS to use fused kernel for 2x layout? but would need to update dims accordingly + casted_grad.get_tensor( + usage=TensorUsage.LHS + ), # TODO(jberchtold): should be RHS to use fused kernel for 2x layout? but would need to update dims accordingly casted_x_lhs, casted_kernel_rhs, dimension_numbers=dims, diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 1a19685697..5805d80734 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -4,7 +4,12 @@ """Transformer Engine bindings for JAX""" from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP -from .module import wrap_function_in_te_state_module, make_dot_general_cls, make_einsum_cls, make_ragged_dot_cls +from .module import ( + wrap_function_in_te_state_module, + make_dot_general_cls, + make_einsum_cls, + make_ragged_dot_cls, +) from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 03d5581ae6..b15f8e77f2 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1442,87 +1442,102 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") + def make_einsum_cls(quantization_recipe): import functools import math import jax + def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs): - # with open("/tmp/te_einsum_log.txt", "a") as f: - # f.write(f"{(s, x.shape, kernel.shape)}\n") - def dot_general(x, kernel, dims, *args, **kwargs): - # print(f"TE dot_general called with dims: {dims}, args: {args}, kwargs: {kwargs}") - contracting_dims, batch_dims = dims - ((x_bdim,), (k_bdim,)) = batch_dims - batch_dims = (x_bdim, k_bdim) - - if x_bdim != 0 or k_bdim != 0: - print(f"{x_bdim=}, {k_bdim=}") - return jax.lax.dot_general(x, kernel, dims, *args, **kwargs) - - target_out_shape = jax.lax.dot_general(x, kernel, dims).shape - - if x.dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]: - # HACK: because x input is bool for dispatch mask - x = x.astype(kernel.dtype) - - # Adjust for unbatched - contracting_dims = tuple( - tuple(dim - (1 if dim > bdim else 0) for dim in cdims) - for bdim, cdims in zip(batch_dims, contracting_dims)) - - group_sizes = None - print(f'{x.shape=}, {kernel.shape=}, {dims=}') - - def reorder_lhs_for_grouped_gemm(tensor, cdims): - # (B*M, K) - assert len(cdims) == 1, f"Only support single contracting dim for now, got {cdims}" - cdim = cdims[0] + 1 # account for batch dim at front - out = jnp.transpose(tensor, tuple(range(cdim)) + tuple(range(cdim + 1, tensor.ndim)) + (cdim,)) - return out.reshape((-1, out.shape[-1])) - - - def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): - # (B, K, N) - assert len(bdims) == 1 and len(cdims) == 1, f"Only support single batch and contracting dim for now, got {bdims}, {cdims}" - bdim = bdims[0] - assert bdim == 0, f"Only support batch dim 0 for now, got {bdim}" - cdim = cdims[0] + 1 # account for batch dim at front - out = jnp.transpose(tensor, (bdim, cdim) + tuple(i for i in range(tensor.ndim) if i != bdim and i != cdim)) - return out.reshape((*out.shape[:2], -1)) - - x = reorder_lhs_for_grouped_gemm(x, contracting_dims[0]) - kernel = reorder_rhs_for_grouped_gemm(kernel, (batch_dims[1],), contracting_dims[1]) - - num_groups = kernel.shape[0] - group_size = math.prod(x.shape[:-1]) // num_groups - print(f'{num_groups=}, {group_size=}, {x.shape=}, {kernel.shape=}') - - group_sizes = jnp.array([group_size]*num_groups, dtype=jnp.int32) + # with open("/tmp/te_einsum_log.txt", "a") as f: + # f.write(f"{(s, x.shape, kernel.shape)}\n") + def dot_general(x, kernel, dims, *args, **kwargs): + # print(f"TE dot_general called with dims: {dims}, args: {args}, kwargs: {kwargs}") + contracting_dims, batch_dims = dims + ((x_bdim,), (k_bdim,)) = batch_dims + batch_dims = (x_bdim, k_bdim) + + if x_bdim != 0 or k_bdim != 0: + print(f"{x_bdim=}, {k_bdim=}") + return jax.lax.dot_general(x, kernel, dims, *args, **kwargs) + + target_out_shape = jax.lax.dot_general(x, kernel, dims).shape + + if x.dtype not in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]: + # HACK: because x input is bool for dispatch mask + x = x.astype(kernel.dtype) + + # Adjust for unbatched + contracting_dims = tuple( + tuple(dim - (1 if dim > bdim else 0) for dim in cdims) + for bdim, cdims in zip(batch_dims, contracting_dims) + ) - quantizer_set = generate_quantizer_set(n_groups=num_groups) + group_sizes = None + print(f"{x.shape=}, {kernel.shape=}, {dims=}") - print(f'{group_sizes=}, {contracting_dims=}, {x.shape=}, {kernel.shape=}, {contracting_dims=}') + def reorder_lhs_for_grouped_gemm(tensor, cdims): + # (B*M, K) + assert len(cdims) == 1, f"Only support single contracting dim for now, got {cdims}" + cdim = cdims[0] + 1 # account for batch dim at front + out = jnp.transpose( + tensor, tuple(range(cdim)) + tuple(range(cdim + 1, tensor.ndim)) + (cdim,) + ) + return out.reshape((-1, out.shape[-1])) + + def reorder_rhs_for_grouped_gemm(tensor, bdims, cdims): + # (B, K, N) + assert ( + len(bdims) == 1 and len(cdims) == 1 + ), f"Only support single batch and contracting dim for now, got {bdims}, {cdims}" + bdim = bdims[0] + assert bdim == 0, f"Only support batch dim 0 for now, got {bdim}" + cdim = cdims[0] + 1 # account for batch dim at front + out = jnp.transpose( + tensor, + (bdim, cdim) + tuple(i for i in range(tensor.ndim) if i != bdim and i != cdim), + ) + return out.reshape((*out.shape[:2], -1)) + + x = reorder_lhs_for_grouped_gemm(x, contracting_dims[0]) + kernel = reorder_rhs_for_grouped_gemm(kernel, (batch_dims[1],), contracting_dims[1]) + + num_groups = kernel.shape[0] + group_size = math.prod(x.shape[:-1]) // num_groups + print(f"{num_groups=}, {group_size=}, {x.shape=}, {kernel.shape=}") + + group_sizes = jnp.array([group_size] * num_groups, dtype=jnp.int32) + + quantizer_set = generate_quantizer_set(n_groups=num_groups) + + print( + f"{group_sizes=}, {contracting_dims=}, {x.shape=}, {kernel.shape=}," + f" {contracting_dims=}" + ) + + contracting_dims = ( + # (B*M, K) + (1,), + # (B, K, N) + (1,), + ) + out = grouped_dense( + x, + kernel, + group_sizes=group_sizes, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, + ) + return out.reshape(target_out_shape) + + return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) - contracting_dims = ( - # (B*M, K) - (1,), - # (B, K, N) - (1,), - ) - out = grouped_dense( - x, - kernel, - group_sizes=group_sizes, - contracting_dims=contracting_dims, - quantizer_set=quantizer_set - ) - return out.reshape(target_out_shape) - return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs) - return wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")() + def make_ragged_dot_cls(quantization_recipe): import jax + def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): num_groups = group_sizes.shape[0] quantizer_set = generate_quantizer_set(n_groups=num_groups) @@ -1530,13 +1545,15 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa target_out_shape = jax.lax.ragged_dot(x, kernel, group_sizes=group_sizes).shape out = grouped_dense( - x, - kernel, - group_sizes=group_sizes, - contracting_dims=((1,), (1,)), - quantizer_set=quantizer_set + x, + kernel, + group_sizes=group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, ) return out.reshape(target_out_shape) - - return wrap_function_in_te_state_module(te_grouped_dot_general, quantization_recipe, "ragged_dot")() + + return wrap_function_in_te_state_module( + te_grouped_dot_general, quantization_recipe, "ragged_dot" + )() diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 4171d1c7b0..37edea4024 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -50,9 +50,8 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}." return mesh.shape[resource], resource - -# TODO(jberchtold): FIXME, this validation fails in FP8CS amax reduction because the GlobalMeshResource is set but there is no active mesh in the context (afaict shard_map does not share it's mesh as a context), so this is triggering a FalsePositive assert. However, I am not sure if we can safely ignore this when the mesh is empty or all axes are manual as some users may use shard_map with some axes manual and some auto. -# def _validate_mesh_resource_configuration(mesh_resource): + # TODO(jberchtold): FIXME, this validation fails in FP8CS amax reduction because the GlobalMeshResource is set but there is no active mesh in the context (afaict shard_map does not share it's mesh as a context), so this is triggering a FalsePositive assert. However, I am not sure if we can safely ignore this when the mesh is empty or all axes are manual as some users may use shard_map with some axes manual and some auto. + # def _validate_mesh_resource_configuration(mesh_resource): """Validate that the mesh resource configuration is consistent and conflict-free.""" is_tp_enabled = ( mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1