diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py new file mode 100644 index 0000000000..f35f23f6be --- /dev/null +++ b/tests/pytorch/test_grouped_tensor.py @@ -0,0 +1,429 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for GroupedTensor class""" + +from typing import List, Tuple +import pytest +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, + Float8BlockQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, +) +import transformer_engine_torch as tex + +# Check available recipes +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +_quantization_params = [ + pytest.param( + "fp8_delayed_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_blockwise", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + ), +] + + +def make_quantizers(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]): + """Create quantizers for given quantization scheme""" + quantizers = [] + for i in range(num_tensors): + if quantization == "fp8_delayed_scaling": + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device="cuda"), + amax=torch.zeros(1, dtype=torch.float32, device="cuda"), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + quantizer.set_usage(rowwise=True, columnwise=False) + elif quantization == "fp8_blockwise": + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=False, + force_pow_2_scales=True, + amax_epsilon=0.0, + block_scaling_dim=1, + ) + elif quantization == "mxfp8": + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + elif quantization == "nvfp4": + quantizer = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + else: + raise ValueError(f"Unknown quantization scheme: {quantization}") + + quantizer.internal = False + quantizers.append(quantizer) + + return quantizers + + +def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor: + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"): + return qtensor._data + if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"): + return qtensor._rowwise_data + raise ValueError(f"Unknown quantization scheme: {quantization}") + + +def _rowwise_offset_bytes(numel: int, quantization: str) -> int: + if quantization == "nvfp4": + return numel // 2 + return numel + + +class TestGroupedTensor: + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_basic_construction_all_same_shape(self) -> None: + """Test GroupedTensor construction with all tensors having same shape""" + num_tensors = 4 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + shape=shape, + quantizers=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.all_same_shape() + assert grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.logical_shape == (num_tensors * 256, 512) + assert grouped_tensor.get_common_first_dim() == 256 + assert grouped_tensor.get_common_last_dim() == 512 + assert grouped_tensor.has_data() + + def test_basic_construction_varying_first_dim(self) -> None: + """Test GroupedTensor construction with varying first dimension""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + shape=shape, + quantizers=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert not grouped_tensor.all_same_shape() + assert not grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.get_common_last_dim() == shape[0][1] + assert grouped_tensor.logical_shape == ( + sum(v for v, _ in shape), + shape[0][1], + ) # sum of first dims + + def test_basic_construction_varying_last_dim(self) -> None: + """Test GroupedTensor construction with varying last dimension""" + num_tensors = 3 + shape = [(512, 128), (512, 256), (512, 384)] + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + shape=shape, + quantizers=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert not grouped_tensor.all_same_shape() + assert grouped_tensor.all_same_first_dim() + assert not grouped_tensor.all_same_last_dim() + assert grouped_tensor.get_common_first_dim() == shape[0][0] + assert grouped_tensor.logical_shape == ( + shape[0][0], + sum(v for _, v in shape), + ) # sum of last dims + + def test_basic_construction_varying_both_dims(self) -> None: + """Test GroupedTensor construction with varying both dimensions""" + num_tensors = 3 + shape = [(128, 256), (256, 384), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + shape=shape, + quantizers=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert not grouped_tensor.all_same_shape() + assert not grouped_tensor.all_same_first_dim() + assert not grouped_tensor.all_same_last_dim() + assert grouped_tensor.varying_both_dims() + total_elements = sum(s[0] * s[1] for s in shape) + assert grouped_tensor.logical_shape == (1, total_elements) + + def test_split_into_quantized_tensors_no_quantization(self) -> None: + """Test split_into_quantized_tensors for unquantized tensors""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + shape=shape, + quantizers=None, + device="cuda", + dtype=torch.float32, + ) + + # Get the original data pointer + original_data_ptr = grouped_tensor.data.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor has correct shape and shares storage + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + assert isinstance(tensor, torch.Tensor) + assert not hasattr(tensor, "_data") # Not a quantized tensor + + # Verify data pointer is within the original grouped tensor storage + # The tensor should be a view of the original data + assert tensor.data_ptr() >= original_data_ptr + + # Calculate expected offset + expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: + """Test split_into_quantized_tensors for quantized tensors""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizers(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + shape=shape, + quantizers=quantizers, + device="cuda", + ) + + # Get the original data pointer + original_data_ptr = grouped_tensor.data.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor shares storage with the grouped tensor + for i, tensor in enumerate(tensors): + rowwise_data = _get_rowwise_data_tensor(tensor, quantization) + assert rowwise_data is not None + assert rowwise_data.data_ptr() >= original_data_ptr + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + def test_split_varying_shapes(self) -> None: + """Test split_into_quantized_tensors with varying shapes""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + shape=shape, + quantizers=None, + device="cuda", + dtype=torch.float32, + ) + + original_data_ptr = grouped_tensor.data.data_ptr() + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify shapes and storage + cumulative_offset = 0 + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + expected_offset = cumulative_offset * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + cumulative_offset += shape[i][0] * shape[i][1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_inplace(self, quantization: str) -> None: + """Test that quantize is done in-place for all recipes""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizers(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + shape=shape, + quantizers=quantizers, + device="cuda", + ) + + # Get original data pointers before quantization + original_data_ptr = grouped_tensor.data.data_ptr() + original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr() + original_scale_ptr = ( + grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None + ) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointers haven't changed (in-place operation) + assert grouped_tensor.data.data_ptr() == original_data_ptr + assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr + if original_scale_ptr is not None: + assert grouped_tensor.scale.data_ptr() == original_scale_ptr + + # Verify returned tensors point to the same storage + for i, qtensor in enumerate(quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_varying_shapes(self, quantization: str) -> None: + """Test quantize with varying shapes""" + num_tensors = 3 + shape = [(256, 512), (512, 512), (768, 512)] + quantizers = make_quantizers(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + shape=shape, + quantizers=quantizers, + device="cuda", + ) + + # Get original data pointers + original_data_ptr = grouped_tensor.data.data_ptr() + + # Create input tensors with varying shapes + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointer hasn't changed + assert grouped_tensor.data.data_ptr() == original_data_ptr + + # Verify each tensor points to correct location + cumulative_numel = 0 + for qtensor, tensor_shape in zip(quantized_tensors, shape): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + cumulative_numel += tensor_shape[0] * tensor_shape[1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_static_quantize_method(self, quantization: str) -> None: + """Test the static quantize method""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizers(quantization, num_tensors, shape) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Use static quantize method + grouped_tensor = GroupedTensor.create_and_quantize( + tensors=input_tensors, + quantizers=quantizers, + device="cuda", + ) + + # Verify the grouped tensor was created correctly + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.has_data() + + # Verify quantized_tensors were created and point to same storage + assert grouped_tensor.quantized_tensors is not None + assert len(grouped_tensor.quantized_tensors) == num_tensors + + original_data_ptr = grouped_tensor.data.data_ptr() + for i, qtensor in enumerate(grouped_tensor.quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + def test_clear(self) -> None: + """Test clear method""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + shape=shape, + quantizers=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.has_data() + assert grouped_tensor.num_tensors == num_tensors + + grouped_tensor.clear() + + assert not grouped_tensor.has_data() + assert grouped_tensor.num_tensors == 0 + assert grouped_tensor.data is None + assert grouped_tensor.logical_shape == (0, 0) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index e9d24c1a8e..f5bb47ab5c 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -from typing import Optional +from typing import Optional, List import torch import pytest @@ -137,6 +137,117 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() +def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"): + """ + Verify that tensors are stored in contiguous memory. + + Args: + tensors: List or iterable of tensors to check + num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4) + tensor_name: Name to use in error messages + """ + tensor_list = list(tensors) + if len(tensor_list) < 2: + return # Nothing to check + + for i in range(1, len(tensor_list)): + prev_tensor = tensor_list[i - 1] + curr_tensor = tensor_list[i] + + # Calculate expected offset based on previous tensor size + prev_numel = prev_tensor.numel() + expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size() + + # Verify current tensor's data pointer is correctly offset + expected_ptr = prev_tensor.data_ptr() + expected_offset + actual_ptr = curr_tensor.data_ptr() + + assert ( + actual_ptr == expected_ptr + ), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}" + + +def check_grouped_tensor_pointers( + weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None +): + """ + Verify that the pointers of the weights are in contiguous memory for GroupedTensor. + TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach. + """ + + num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1 + + # Check data. + if hasattr(weights[0], "_data") and weights[0]._data is not None: + data_tensors = [w._data for w in weights] + check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data") + + # Check transpose. + if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None: + transpose_tensors = [w._transpose for w in weights] + check_grouped_tensor_pointers_helper( + transpose_tensors, num_elems_in_byte=1, tensor_name="transpose" + ) + + # Check scale_inv. + if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None: + scale_inv_tensors = [w._scale_inv for w in weights] + check_grouped_tensor_pointers_helper( + scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv" + ) + + # Check rowwise scale_inv. + if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None: + scale_inv_tensors = [w._rowwise_scale_inv for w in weights] + check_grouped_tensor_pointers_helper( + scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv" + ) + + # Check columnwise scale_inv. + if ( + hasattr(weights[0], "_columnwise_scale_inv") + and weights[0]._columnwise_scale_inv is not None + ): + columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights] + check_grouped_tensor_pointers_helper( + columnwise_scale_inv_tensors, + num_elems_in_byte=1, + tensor_name="columnwise scale_inv", + ) + + # Check rowwise amax. + if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None: + rowwise_amax_tensors = [w._rowwise_amax for w in weights] + check_grouped_tensor_pointers_helper( + rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax" + ) + + # Check columnwise amax. + if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None: + columnwise_amax_tensors = [w._columnwise_amax for w in weights] + check_grouped_tensor_pointers_helper( + columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax" + ) + + # Check rowwise data. + if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None: + rowwise_data_tensors = [w._rowwise_data for w in weights] + check_grouped_tensor_pointers_helper( + rowwise_data_tensors, + num_elems_in_byte=num_elems_in_a_data_byte, + tensor_name="rowwise data", + ) + + # Check columnwise data. + if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None: + columnwise_data_tensors = [w._columnwise_data for w in weights] + check_grouped_tensor_pointers_helper( + columnwise_data_tensors, + num_elems_in_byte=num_elems_in_a_data_byte, + tensor_name="columnwise data", + ) + + def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( (config.max_seqlen_q, config.batch_size, config.hidden_size), @@ -495,9 +606,17 @@ def test_sanity_grouped_linear( use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_grouped_linear = GroupedLinear( - num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype + num_gemms, + config.hidden_size, + ffn_hidden_size, + bias=use_bias, + params_dtype=dtype, ).cuda() + # Verify that weights are stored in contiguous GroupedTensor storage. + weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)] + check_grouped_tensor_pointers(weights, fp8_recipe) + inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() @@ -956,7 +1075,13 @@ def test_replace_raw_data_for_float8tensor(): random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda") fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor) - attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"] + attrs_to_check = [ + "_quantizer", + "_fp8_dtype", + "_scale_inv", + "_transpose", + "_transpose_invalid", + ] attrs = {} for attr in attrs_to_check: attrs[attr] = getattr(fp8_tensor, attr) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a83cbe3e30..9c53770ebb 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -172,10 +172,12 @@ list(APPEND transformer_engine_cuda_arch_specific_sources cast/cast.cu gemm/cutlass_grouped_gemm.cu hadamard_transform/group_hadamard_transform.cu + hadamard_transform/graph_safe_group_hadamard_transform.cu hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu + hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu transpose/quantize_transpose_square_blockwise.cu diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu new file mode 100644 index 0000000000..bee69d891c --- /dev/null +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -0,0 +1,582 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "hadamard_transform_utils.cuh" + +namespace transformer_engine { +namespace { + +constexpr int kMaxTensorsPerKernel = 64; +constexpr int kThreadsPerWarp = 32; + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t* const __restrict__ offsets_ptr) { + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t current_row = current_offset / last_logical_dim; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + // upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors) + size_t low = 0; + size_t hi = num_tensors; // half-open [low, hi) + + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + + if (mid_offset <= current_offset) { + low = mid + 1; + } else { + hi = mid; + } + } + + // low = first index where offsets[low] > current_offset (or low == num_tensors) + // id = low - 1, but need to evaluate if current_offset < offsets[0] + return (low == 0) ? 0 : (low - 1); + } +} + +template +__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], + IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + uint32_t& local_amax_reg, + uint32_t& local_amax_t_reg) { + uint32_t a_frag[4]; // A matrix fragment + uint32_t c_frag[4]; // Result fragment + + int warp_id = threadIdx.x / kThreadsPerWarp; + int local_rank = (threadIdx.x % kThreadsPerWarp); + + int ld_row_idx = local_rank % kHadamardDimension; + int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + + uint32_t temp_amax_reg; + uint32_t temp_amax_t_reg; + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } + + if (kReturnTransposedAmax) { + // TODO(Frank): This is not efficient, since we could directly load the + // matrix in transposed layout. + if (!kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], + b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_t_reg) + : "r"(local_amax_t_reg), "r"(temp_amax_t_reg)); + } + + if (kReturnPreRhtAmax) { + if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[1])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[2]) + : "r"(a_frag[2]), "r"(a_frag[3])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[2])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_pre_rht_amax_reg) + : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); + } +} + +template +__device__ __host__ constexpr int NextPowerOf2() { + static_assert(kN > 0, "kN must be > 0"); + // Round up to the next power of 2 by counting leading zeros. + return 1 << (32 - __builtin_clz(kN - 1)); +} + +template +__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax, + const float transpose_amax, float* staging_for_pre_rht, + float* staging_for_identity, float* staging_for_transpose, + float* output_pre_rht_amax_ptr, + float* output_identity_amax_ptr, + float* output_transpose_amax_ptr, const int warpid) { + // intra-warp reduction + constexpr int kWarpSize = 32; + int local_rank = threadIdx.x % 32; + float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max(pre_rht_amax) : 0.0f; + float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max(identity_amax) : 0.0f; + float warp_transpose_amax = + kReturnTransposedAmax ? warp_reduce_max(transpose_amax) : 0.0f; + + // inter-warp reduction + if (threadIdx.x % 32 == 0) { + if (kReturnPreRhtAmax) { + staging_for_pre_rht[warpid] = warp_pre_rht_amax; + } + if (kReturnIdentityAmax) { + staging_for_identity[warpid] = warp_identity_amax; + } + if (kReturnTransposedAmax) { + staging_for_transpose[warpid] = warp_transpose_amax; + } + } + __syncthreads(); + constexpr int kNumWarpsPow2 = NextPowerOf2(); + if (warpid == 0) { + if (kReturnIdentityAmax) { + float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f; + identity_accum = warp_reduce_max(identity_accum); + if (local_rank == 0) { + atomicMaxFloat(output_identity_amax_ptr, identity_accum); + } + } + } + if (warpid == 1) { + if (kReturnTransposedAmax) { + float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f; + transpose_accum = warp_reduce_max(transpose_accum); + if (local_rank == 0) { + atomicMaxFloat(output_transpose_amax_ptr, transpose_accum); + } + } + } + if (warpid == 2) { + if (kReturnPreRhtAmax) { + float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f; + pre_rht_accum = warp_reduce_max(pre_rht_accum); + if (local_rank == 0) { + atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum); + } + } + } +} + +__global__ void GraphSafeMultiZeroAmaxKernel(const size_t num_tensors, float* amax_rowwise_ptr, + float* amax_colwise_ptr) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + for (; tid < num_tensors; tid += stride) { + amax_rowwise_ptr[tid] = 0; + amax_colwise_ptr[tid] = 0; + } +} + +__global__ void GraphSafeMultiAmaxMemcpyD2DKernelPreRHT(const size_t num_tensors, + float* amax_rowwise_ptr, + float* amax_colwise_ptr) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + for (; tid < num_tensors; tid += stride) { + float* output_pre_rht_amax_ptr = amax_rowwise_ptr + tid; + float* output_transpose_amax_ptr = amax_colwise_ptr + tid; + if (output_pre_rht_amax_ptr != nullptr) { + float pre_rht_amax = *output_pre_rht_amax_ptr; + if (output_transpose_amax_ptr != nullptr) { + *output_transpose_amax_ptr = pre_rht_amax; + } + } + } +} + +template +__global__ void GraphSafeGroupHadamardAmaxTmaKernel( + const __grid_constant__ CUtensorMap tensor_map_input, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t* const __restrict__ offsets_ptr, const int64_t* const __restrict__ first_dims_ptr, + float* const __restrict__ amax_rowwise_ptr, float* const __restrict__ amax_colwise_ptr) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + float* output_pre_rht_amax_ptr; + float* output_identity_amax_ptr = nullptr; + float* output_transpose_amax_ptr; + + // calculate the global offset to get tensor id + size_t global_offset = blockIdx.y * CHUNK_DIM_Y * last_logical_dim; + int tensor_id = get_current_tensor_id(shape_rep, num_tensors, global_offset, first_logical_dim, + last_logical_dim, offsets_ptr); + output_pre_rht_amax_ptr = static_cast(amax_rowwise_ptr) + tensor_id; + output_transpose_amax_ptr = static_cast(amax_colwise_ptr) + tensor_id; + + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0); + static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0); + + constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y; + constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X; + + constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp; + + const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X; + + extern __shared__ __align__(128) char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uint8_t* dshmem = reinterpret_cast((base_shmem_ptr + 127) & ~127ULL); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + IType* in_sh_0 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + IType* in_sh_1 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + + IType* in_shs[2] = {in_sh_0, in_sh_1}; + + constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + + const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0); + + // Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + uint64_t* mbar = reinterpret_cast(dshmem); + dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y); + + float* max_staging_identity = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_transpose = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_pre_rht = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + + initialize_barriers(mbar, + is_master_thread); + + copy_2d_to_shared(in_shs[0], reinterpret_cast(&tensor_map_input), + input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0], + is_master_thread); + + uint32_t had_frag_i[4]; + uint32_t had_frag_t[4]; + get_hadamard_matrix_fragment( + had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t); + + float local_pre_rht_amax = 0.0; + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_pre_rht_amax_reg = *reinterpret_cast(&local_pre_rht_amax); + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { + for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { + int stage = STAGES_X * stage_y + stage_x; + + const int next_stage = stage + 1; + const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1; + const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y; + + if (next_stage < STAGES_X * STAGES_Y) { + const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y; + const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X; + + copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong + reinterpret_cast(&tensor_map_input), input_global_offset_X, + input_global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + const size_t compute_stage_x_num = + BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)); + const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y); + + const size_t in_row_stride = BUFF_DIM_X; + + IType* in_sh_ptr = in_shs[stage % 2]; + +#pragma unroll + for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) { + const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y + + threadIdx.y * kHadamardDimension); + const int in_row_offset = row_idx_offset * in_row_stride; + +#pragma unroll + for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) { + ComputeKernel( + had_frag_i, had_frag_t, + in_sh_ptr + in_row_offset + + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), + local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + } + + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); + } + } + } + + const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp; + + if constexpr (kReturnPreRhtAmax) { + unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax); + } + if constexpr (kReturnIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + } + if constexpr (kReturnTransposedAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + } + + ReduceMax( + local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity, + max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr, + output_transpose_amax_ptr, warpid); + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +} // namespace + +// broadcast_pre_rht_amax: when it's true, hadamard transform will be disabled +// if at this time, the amax buffers for output expects both amax_rowwise and amax_colwise +// then call MultiAmaxMemcpyD2DKernelPreRHT to D2D copy the amax values +void group_hadamard_transform_amax_graph_safe(const GroupedTensor* input, GroupedTensor* output, + uint16_t random_sign_mask, + uint16_t random_sign_mask_t, + bool broadcast_pre_rht_amax, cudaStream_t stream) { + NVTE_API_CALL(group_hadamard_transform_amax_graph_safe); +#if CUDA_VERSION >= 12080 + + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Number of input and output tensors must be same."); + NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); + + checkCuDriverContext(stream); + + bool all_return_pre_rht_amax = output->has_data(); + // there is no rowwise RHT transform in current recipe + bool all_return_identity_amax = false; + bool all_return_transposed_amax = output->has_columnwise_data(); + + NVTE_CHECK(all_return_pre_rht_amax || all_return_identity_amax || all_return_transposed_amax, + "At least one of return_pre_rht_amax, return_identity_amax, or return_transposed_amax " + "must be true"); + + if (broadcast_pre_rht_amax) { + NVTE_CHECK(all_return_pre_rht_amax, + "broadcast_pre_rht_amax is only supported when we compute pre-RHT amax"); + // if all_return_identity_amax and all_return_transposed_amax both are false, there is no need to broadcast anything + broadcast_pre_rht_amax &= (all_return_identity_amax || all_return_transposed_amax); + } + + const size_t num_tensors = input->num_tensors; + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + // const size_t elts_total = first_logical_dim * last_logical_dim; + NVTE_CHECK(first_logical_dim % 128 == 0, + "First dimension of a grouped tensor should be divisible by 128."); + NVTE_CHECK(last_logical_dim % 128 == 0, + "Last dimension of a grouped tensor should be divisible by 128."); + + float* const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + float* const amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + + const int64_t* const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); + const int64_t* const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); + // const int64_t *const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + + // some sanity checks + if (all_return_pre_rht_amax) { + NVTE_CHECK(amax_rowwise_ptr != nullptr, "Amax rowwise pointer should not be nullptr."); + } + if (all_return_transposed_amax) { + NVTE_CHECK(amax_colwise_ptr != nullptr, "Amax columnwise pointer should not be nullptr."); + } + + // Multi zero out multiple amaxes if needed + dim3 block_setup_amax(kMaxTensorsPerKernel); + dim3 grid_setup_amax(1); + GraphSafeMultiZeroAmaxKernel<<>>( + num_tensors, amax_rowwise_ptr, amax_colwise_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); + + using IType = bf16; + constexpr int kHadamardDimension = 16; + + // four (1x4) 64x64 sub-tiles for ping-pong overlap + constexpr uint64_t kChunkBlockXSmall = 256; + constexpr uint64_t kChunkBlockYSmall = 64; + constexpr uint64_t kBuffDimX = 64; + constexpr uint64_t kBuffDimY = 64; + + alignas(64) CUtensorMap tensor_map_input{}; + + create_2D_tensor_map( + /*tensorMap=*/tensor_map_input, + /*tensor=*/input->data, + /*globalY=*/first_logical_dim, + /*globalX=*/last_logical_dim, + /*shmemY=*/kBuffDimY, + /*shmemX=*/kBuffDimX, + /*stride_elems=*/last_logical_dim, + /*offset_elems=*/0, + /*type_num_bits=*/sizeof(IType) * 8, + /*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B); + + constexpr uint64_t kThreadBlockX = 4; + constexpr uint64_t kThreadBlockY = 1; + constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY; + + dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY); + dim3 grid(DIVUP(last_logical_dim, kChunkBlockXSmall), + DIVUP(first_logical_dim, kChunkBlockYSmall)); + + ShapeRepresentation shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + if (output->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (output->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (output->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (output->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + + const bool is_const_last_dim = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || + shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + + NVTE_CHECK(is_const_last_dim, + "Currently we only support const last dimension for graph safe hadamard transform."); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + (all_return_transposed_amax && !broadcast_pre_rht_amax), kReturnTransposedAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + (all_return_identity_amax && !broadcast_pre_rht_amax), kReturnIdentityAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_return_pre_rht_amax, kReturnPreRhtAmax, + + // *2 for ping-pong + size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType); + size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) * + (kChunkBlockYSmall / kBuffDimY); + size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3; + // Add padding in case shmem ptr is not aligned to 128 bytes. + shmem_bytes = (shmem_bytes + 128); + + auto kernel = GraphSafeGroupHadamardAmaxTmaKernel< + IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY, + kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax, + kReturnIdentityAmax, kReturnTransposedAmax>; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_bytes); + + kernel<<>>( + tensor_map_input, random_sign_mask, random_sign_mask_t, shape_rep, num_tensors, + first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, + amax_rowwise_ptr, amax_colwise_ptr); + if (broadcast_pre_rht_amax) { + GraphSafeMultiAmaxMemcpyD2DKernelPreRHT<<>>(num_tensors, amax_rowwise_ptr, + amax_colwise_ptr); + }))); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", + CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace transformer_engine + +void nvte_group_hadamard_transform_amax_graph_safe(const NVTEGroupedTensor input, + NVTEGroupedTensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_hadamard_transform_amax_graph_safe); + using namespace transformer_engine; + + GroupedTensor* input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor* output_tensor = convertNVTEGroupedTensorCheck(output); + + if (input_tensor->num_tensors == 0) { + return; + } + + // Call the group tensor Hadamard transform amax implementation. + group_hadamard_transform_amax_graph_safe( + input_tensor, output_tensor, static_cast(random_sign_mask), + static_cast(random_sign_mask_t), false, stream); +} + +// Grouped-tensor amax without doing hadamard transform +void nvte_group_amax_graph_safe(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_amax_graph_safe); + using namespace transformer_engine; + + GroupedTensor* input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor* output_tensor = convertNVTEGroupedTensorCheck(output); + + if (input_tensor->num_tensors == 0) { + return; + } + + group_hadamard_transform_amax_graph_safe(input_tensor, output_tensor, 0, 0, true, stream); +} diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu new file mode 100644 index 0000000000..030dddfce4 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -0,0 +1,1513 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "customized_pipeline.cuh" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/float8.h" +#include "cutlass/float_subbyte.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/platform/platform.h" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/print_error.hpp" + +namespace transformer_engine { +namespace detail { +namespace { + +using namespace cute; + +// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor +using cute::Tensor; + +constexpr int kMaxTensorsPerKernel = 64; +constexpr int kNVFP4BlockSize = 16; + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t current_row = current_offset / last_logical_dim; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + // upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors) + size_t low = 0; + size_t hi = num_tensors; // half-open [low, hi) + + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + + if (mid_offset <= current_offset) { + low = mid + 1; + } else { + hi = mid; + } + } + + // low = first index where offsets[low] > current_offset (or low == num_tensors) + // id = low - 1, but need to evaluate if current_offset < offsets[0] + return (low == 0) ? 0 : (low - 1); + } +} + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverterBase( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + auto output_ptr = reinterpret_cast(&output); + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" + "}" + : "=h"(output_ptr[0]), "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), + "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return output; +} + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverter( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = + reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = + reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = + reinterpret_cast const *>(&rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template +struct SharedStorage { + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr EpilogueUnrollFactor = EpilogueUnrollFactor_; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = + cutlass::detail::CustomizedPipelineTmaUmmaAsync, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + using SchedPipeline = cutlass::PipelineCLCFetchAsync; + using SchedPipelineStorage = typename SchedPipeline::SharedStorage; + using SchedThrottlePipeline = cutlass::PipelineAsync; + using SchedThrottlePipelineStorage = typename SchedThrottlePipeline::SharedStorage; + + struct TensorStorage : cute::aligned_struct<128, _1> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + alignas(16) SchedPipelineStorage sched; + alignas(16) SchedThrottlePipelineStorage sched_throttle; + alignas(16) int32_t atomic_tile_id[SchedulerPipelineStageCount_]; + alignas(16) float global_a_amax[kMaxTensorsPerKernel]; + alignas(16) float global_d_amax[kMaxTensorsPerKernel]; + uint32_t atomic_tile_counter[SchedulerPipelineStageCount_]; + uint32_t tmem_base_ptr; +}; + +// Main RHT GEMM kernel entry -- highly templated for flexible architecture/config support +template +__launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_graph_safe( + MShape M, NShape packed_N, KShape K, ClusterShape cluster_shape, ClusterTileShape cluster_tile, + TA const *A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const *B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + TQA *QA, QAStride dQA, TSFA *SFA, TSFALayout sfa_layout, TQA *QA_COLWISE, TSFA *SFA_COLWISE, + float *amax_rowwise, float *amax_colwise, const int64_t *offsets, const int64_t *first_dims, + size_t num_tensors, ShapeRepresentation shape_rep, uint32_t *tile_scheduler_workspace, + TiledMMA mma, const size_t *rng_state) { + using namespace cute; + + // Abort immediately if compilation is not supported + constexpr bool is_blackwell_arch = ARCH_BLACKWELL_FAMILY; + if constexpr (!is_blackwell_arch) { + NVTE_DEVICE_ERROR( + "group_row_col_rht_gemm_device_graph_safe is only supported on Blackwell " + "with architecture-specific compilation. " + "Try recompiling with sm_100a or similar."); + return; + } + static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, + "group_row_col_rht_gemm_device_graph_safe must generate row-wise " + "and/or column-wise output."); +#if !defined(CUTLASS_ARCH_CLC_ENABLED) + CUTLASS_NOT_IMPLEMENTED(); + return; +#endif + + using X = Underscore; + // Accumulator data type for main computation + using ElementAccumulator = float; + static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); + static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; + static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; + static constexpr bool kEnableRowQuant = kEnableRowQuant_; + static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_; + static constexpr bool kUseFastMath = kUseFastMath_; + + // Constant for RHT tensor processing (tile size etc) + static int constexpr RhtTensorSize = 16; + + // Get the total number of tokens to process + // Note that here M is the hidden size, which is the last logical dimension of the input tensor x + // The kernel is designed in column major, so M is the hidden size + size_t sum_token_dims = offsets[num_tensors] / M; + + // Transaction bytes for TMA transfer on RHT tensor blocks + static int constexpr kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + + // Mainloop pipeline stage calculation, vectorization parameters for scaling factors + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + static int constexpr SFVecSize = 16; + // Swizzle output layout for scaling factor arrays + using SwizzledSFALayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + + // Mainloop pipeline types for TMA async execution and epilogue cluster scheduling + using MainloopPipeline = + cutlass::detail::CustomizedPipelineTmaUmmaAsync; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + using SchedPipeline = cutlass::PipelineCLCFetchAsync; + using SchedPipelineState = typename SchedPipeline::PipelineState; + using SchedThrottlePipeline = cutlass::PipelineAsync; + using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState; + + static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>"); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static int constexpr VectorSize = RhtTensorSize; + + // Compile-time safety: static shapes required for shared memory layouts + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + // CUTE_STATIC_ASSERT(is_static::value); + + auto cluster_size = size<0>(cluster_shape); + auto mainloop_tiler = Shape<_128, _16, _128>{}; + auto epilogue_tiler = Shape<_128, _128, _128>{}; + + static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile); + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler); + + struct TileScheduler { + uint32_t tiles_in_m = 0; + uint32_t tiles_in_n = 0; + uint32_t linear_idx = 0; + uint32_t next_linear_idx = 0; + uint32_t start_idx = 0; + uint32_t tile_m_idx = 0; + uint32_t tile_n_idx = 0; + int k_tile_max = 0; + uint32_t *atomic_tile_index_; + uint32_t *smem_tile_counter; + uint32_t atomic_offset; + cutlass::FastDivmodU64 divmod_tiles_in_m; + + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, + uint32_t *atomic_tile_index, uint32_t *smem_tile_counter) + : tiles_in_m(tiles_m), + tiles_in_n(tiles_n), + linear_idx(blockIdx.x), + next_linear_idx(blockIdx.x), + start_idx(blockIdx.x), + k_tile_max(kmax), + atomic_tile_index_(atomic_tile_index), + smem_tile_counter(smem_tile_counter), + atomic_offset(gridDim.x), + divmod_tiles_in_m(uint64_t(tiles_m)) { + update_tile_idx(); + } + CUTLASS_DEVICE void update_tile_idx() { + uint64_t q, r; + divmod_tiles_in_m(q, r, uint64_t(linear_idx)); + tile_m_idx = static_cast(r); + tile_n_idx = static_cast(q) * uint32_t(k_tile_max); + } + CUTLASS_DEVICE uint32_t tile_m() const { return tile_m_idx; } + CUTLASS_DEVICE uint32_t tile_n_base() const { return tile_n_idx; } + CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; } + + CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; } + + CUTLASS_DEVICE bool is_valid() const { + return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()), + cute::make_coord(tiles_in_m, tiles_in_n)); + } + + CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; } + + CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; } + + // Fetch a new tile_id using atomics. + CUTLASS_DEVICE uint32_t fetch_tile_id_counter(int pred) { + uint32_t tile_id_counter = 0; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p atom.global.add.u32 %0, [%1], 1; \n\t" + "}" + : "=r"(tile_id_counter) + : "l"(atomic_tile_index_), "r"(pred)); + + return tile_id_counter; + } + + CUTLASS_DEVICE auto fetch_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_consumer_state) { + sched_pipeline.consumer_wait(sched_pipeline_consumer_state); + next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()]; + cutlass::arch::fence_view_async_shared(); + sched_pipeline.consumer_release(sched_pipeline_consumer_state); + return; + } + + CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_producer_state) { + uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state); + // Wait for clcID buffer to become empty with a flipped phase + sched_pipeline.producer_acquire(sched_pipeline_producer_state); + auto is_leading_thread = cute::elect_one_sync(); + uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset; + uint32_t smem_addr = + cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]); + if (is_leading_thread) { + cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0); + } + + ++sched_pipeline_producer_state; + return sched_pipeline_producer_state; + } + + CUTLASS_DEVICE auto update_work_tile_info() { + linear_idx = next_linear_idx; + update_tile_idx(); + return; + } + }; + + // Allocate and alias shared memory to the kernel's shared storage type + extern __shared__ char shared_memory[]; + using SharedStorage = + SharedStorage; + SharedStorage &shared_storage = *reinterpret_cast(shared_memory); + + // Compute the number of tiles in M and N after tiling and assign scheduler + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); + uint32_t tiles_in_n = uint32_t(size(ceil_div(sum_token_dims, size<2>(epilogue_tiler)))); + + TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace, + shared_storage.atomic_tile_counter); + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Shapes for accumulated tiles in mainloop and epilogue + auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{}); + auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{}); + + // Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended + auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); + auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); + + // Number of threads assigned for various epilogue roles depending on quantization settings + static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; + static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; + static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0; + static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant ? 1 : 0; + static int constexpr NumSchedThreads = 32; + static int constexpr NumMainloopLoadThreads = 32; + static int constexpr NumEpilogueThreads = + NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + NumMmaThreadCount + NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // warp assignment + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_sched_warp = (warp_idx == 2); + bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); + bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + mainloop_pipeline_params.num_consumers = NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount; + + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, + cluster_shape, cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + using AccumulatorPipelineInitBarriers = cute::bool_constant; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = + cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_col_quant_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = + size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, + cluster_shape, AccumulatorPipelineInitBarriers{}, + cute::true_type{}); // Delay mask calculation + typename SchedPipeline::Params sched_pipeline_params; + if (is_sched_warp) { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer; + } else { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer; + } + sched_pipeline_params.producer_blockid = 0; + sched_pipeline_params.producer_arv_count = 1; + sched_pipeline_params.consumer_arv_count = + NumSchedThreads + + cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); + sched_pipeline_params.transaction_bytes = sizeof(uint32_t); + sched_pipeline_params.initializing_warp = 3; + SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape); + SchedPipelineState sched_pipeline_consumer_state; + SchedPipelineState sched_pipeline_producer_state = + cutlass::make_producer_start_state(); + + typename SchedThrottlePipeline::Params sched_throttle_pipeline_params; + if (is_dma_warp) { + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Producer; + } + if (is_sched_warp) { + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Consumer; + } + sched_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + sched_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + sched_throttle_pipeline_params.dst_blockid = 0; + sched_throttle_pipeline_params.initializing_warp = 4; + + SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle, + sched_throttle_pipeline_params); + SchedThrottlePipelineState sched_pipeline_throttle_consumer_state; + SchedThrottlePipelineState sched_pipeline_throttle_producer_state = + cutlass::make_producer_start_state(); + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + + // Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer + if (is_dma_warp) { + // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). + cutlass::arch::warpgroup_reg_dealloc<32>(); + // Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory. + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); + + // Partition tensors for tiling according to the mainloop and cluster tilers. + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gB_nk = + local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) + + // Shared memory tensors for pipeline + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + // Determine warp/tile positioning + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Partition global to local fragments for A and B + Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = + tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = + tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); + + auto [tBgB, tBsB] = + tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + if constexpr (kEnableRHTColQuant) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], + kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), + tBsB(_, 0)); + } + } + + do { + // is_first_wave indicates whether this scheduler wave is the first among a group. + bool is_first_wave = scheduler.is_first_wave(); + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_, scheduler.tile_m(), _); + int k_tile = 0; + + sched_throttle_pipeline.producer_acquire(sched_pipeline_throttle_producer_state); + sched_throttle_pipeline.producer_commit(sched_pipeline_throttle_producer_state); + ++sched_pipeline_throttle_producer_state; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { + int k_tile_idx_n = scheduler.tile_n_base() + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType *tma_barrier = + mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), + tAsA(_, write_stage)); + } + } + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + // scheduler.advance(); + } while (scheduler.is_valid()); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + // This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform. + cutlass::arch::warpgroup_reg_dealloc<32>(); + if constexpr (kEnableRHTColQuant) { + // Setup shared memory fragments for A and B tiles. + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + // Wait until the B (Hadamard) tensor copy is complete + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + + auto barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_, _, _, read_stage); + auto tCrB_nk = tCrB(_, _, 0, 0); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) { + int accumulator_k_block = + accumulator_pipe_producer_state.index() * EpilogueUnrollFactor; + int tCrA_k_block = k_block * EpilogueUnrollFactor; + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < EpilogueUnrollFactor; i++) { + auto accumulators = bulk_tmem_mma(_, _, _, accumulator_k_block + i); + gemm(mma, tCrA_mk(_, _, tCrA_k_block + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state); + barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } else if (is_sched_warp) { + // Scheduler warp manages tile assignment and pipeline progress for warps + cutlass::arch::warpgroup_reg_dealloc<32>(); + do { + sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state); + sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state); + ++sched_pipeline_throttle_consumer_state; + sched_pipeline_producer_state = + scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } else if (is_epilogue_col_quant_warp) { + // Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage, + // and writing result tensors/scales to global memory. + cutlass::arch::warpgroup_reg_alloc<192>(); + if constexpr (kEnableRHTColQuant) { + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + auto acc_epilogue_pipelined_shape = + append(acc_shape_epilogue, Int{}); + auto bulk_tmem_epilogue_layout = make_layout( + acc_epilogue_pipelined_shape, + make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler))); + auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); + + // Use 256-bit fragments for aligned bulk stores + static int constexpr FragmentSize = 256 / sizeof_bits_v; + + // Wait for TMEM allocation for this pipeline to finish + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup; + // g2s load all global_d_amax + CUTLASS_PRAGMA_NO_UNROLL + for (int g = local_thread_idx; g < num_tensors; g += NumEpilogueColQuantThreadCount) { + shared_storage.global_d_amax[g] = __ldg(reinterpret_cast(amax_colwise + g)); + } + + size_t rng_seed = 0; + size_t rng_offset = 0; + // Setup RNG for stochastic rounding + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + // TODO(zhongbo): double check the logic here + int group_idx = get_current_tensor_id(shape_rep, num_tensors, + (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, + packed_N, M, offsets); + + // Determine quantization scale factor layouts/output splits for this group + TSFDLayout sfd_layout; + int cur_N = static_cast(first_dims[group_idx]); + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // Build output tensors for columns and their quant scales + // TODO(zhongbo): double check the logic here + Tensor mD = make_tensor(cute::subbyte_iterator(reinterpret_cast( + reinterpret_cast(QA_COLWISE) + offsets[group_idx] / 2)), + make_shape(M, cur_N), DStride{}); // (M,packed_N) + Tensor gD_mn = + local_tile(mD, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + // for every tensor [x, y] row major, x y both a multiple of 128 + // both of its rowwise and colwise scaling factors will have exactly x * y / 16 elements in FP8 E4M3 + Tensor mSFD = make_tensor( + make_gmem_ptr(reinterpret_cast(reinterpret_cast(SFA_COLWISE) + + offsets[group_idx] / kNVFP4BlockSize)), + sfd_layout); + Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + + // Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); + auto tiled_r2g = + make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(local_thread_idx); + auto thr_r2g = tiled_r2g.get_slice(local_thread_idx); + + cutlass::arch::NamedBarrier::sync(NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + float c_global_amax_val = shared_storage.global_d_amax[group_idx]; + float global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + float global_decode_scale = 1.0f / global_encode_scale; + + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + + do { + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); + ++k_tile) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + // TODO(zhongbo): double check the logic here + int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, + global_tile_n_offset * M, packed_N, M, offsets); + + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + c_global_amax_val = shared_storage.global_d_amax[group_idx]; + // update amax + global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + // TODO(zhongbo): double check the logic here + cur_N = first_dims[group_idx]; + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = + tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = + make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // update tensor + mD = make_tensor(cute::subbyte_iterator(reinterpret_cast( + reinterpret_cast(QA_COLWISE) + offsets[group_idx] / 2)), + make_shape(M, cur_N), DStride{}); + gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + mSFD = make_tensor( + make_gmem_ptr(reinterpret_cast(reinterpret_cast(SFA_COLWISE) + + offsets[group_idx] / kNVFP4BlockSize)), + sfd_layout); + gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + } + int group_start_offset = offsets[group_idx] / M; + int local_tile_n_idx = + (global_tile_n_offset - group_start_offset) / size<1>(epilogue_tiler); + Tensor tDgD_mn = gD_mn_view(_, _, _, scheduler.tile_m(), local_tile_n_idx); + + Tensor tDgSFD_mn = gSFD_mn(_, _, scheduler.tile_m(), local_tile_n_idx); + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto Acc = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); + Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = + make_tensor(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrD = make_tensor(shape(tDgD)); + Tensor tTR_rAcc_frag = + recast>(coalesce(tTR_rAcc)); + Tensor tDrD_frag = recast>(coalesce(tDrD)); + + Tensor src = thr_r2g.retile_S(tDrD); + Tensor dst = thr_r2g.retile_D(tDgD); + + Tensor tDgSFD_view = make_tensor( + tDgSFD_mn.data(), make_layout(make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); + Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); + Tensor tDrSFD = make_tensor(shape(tDgSFD)); + + static int constexpr NumVecs = size(tDgD) / VectorSize; + Tensor tD_rRowSFD_frg = recast>(tDrSFD); + + // Compute amax and quantization scales for this tile + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // Copy from TMEM to registers + copy(tiled_t2r, tDtAcc, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + ++accumulator_pipe_consumer_state; + + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with + // unfused kernels + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); + } + + auto compute_frgs = reinterpret_cast *>( + tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales = + cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); + } + auto pvscales_cvted = + cutlass::NumericArrayConverter{}(pvscales); + + tD_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}( + tD_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}( + qpvscale_ups, global_decode_scale); + cutlass::Array acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = cutlass::divides>{}( + 1.0, qpvscale_scaled); + } + + // Prepare stochastic rounding random state if enabled + uint4 random_uint4 = uint4{0, 0, 0, 0}; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + // "Prefetch" a stochastic rounding state for the first tile + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + // Apply round/quantize to each fragment, with or without stochastic rounding + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], cutlass::platform::numeric_limits::max()); + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale)); + } + } + + // Write quantized FP4 tile and dequant scale to gmem + copy(tiled_r2g, src, dst); + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + } else if (is_epilogue_row_quant_warp) { + // Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage. + cutlass::arch::warpgroup_reg_alloc<136>(); + if constexpr (kEnableRowQuant) { + using S2RVectorType = uint128_t; + + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % 256; + size_t rng_seed = 0; + size_t rng_offset = 0; + // g2s load all global_a_amax for all groups/tensors + CUTLASS_PRAGMA_NO_UNROLL + for (int g = local_thread_idx; g < num_tensors; g += NumEpilogueRowQuantThreadCount) { + shared_storage.global_a_amax[g] = __ldg(reinterpret_cast(amax_rowwise + g)); + } + // RNG for stochastic rounding + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + // Input/output tensors/partitions for row quant warp + Tensor mQA = + make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, packed_N), dQA)); + Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); + + Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_N) + // Swizzled shared memory A tile, with layout + Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>( + coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) + + // Set up layouts for partitioning – tile-by-warp, with vector granularity + using S2RWarpLayout = Layout>; + using WarpGroupLayout = Layout>; + using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); + using S2RValLayout = Layout, _1>>; + using S2RAtomA = Copy_Atom; + using R2GAtomQA = Copy_Atom; + using R2GAtomSFA = Copy_Atom; + auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_QA = make_tiled_copy(R2GAtomQA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_SFA = make_tiled_copy(R2GAtomSFA{}, S2RThreadLayout{}, S2RValLayout{}); + + auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); + auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); + auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx); + Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) + + // Allocate temporary register tensors for copying quantization => output + Tensor tQArA = make_tensor_like( + make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) + Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn); + Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); + + Tensor tQAgSFA = thr_r2g_SFA.partition_S(gSFA_mn); + Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); + + // Will result in barrier_id=10 passed to bar.sync instr as cutlass adds 8 + // in order to go over the reserved named barrier count. + constexpr int row_quant_barrier_id = 2; + cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); + + int group_idx = get_current_tensor_id(shape_rep, num_tensors, + (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, + packed_N, M, offsets); + float a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + float global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float global_decode_scale = 1.0f / global_encode_scale; + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + auto sfa_converter = cutlass::NumericConverter{}; + do { + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, + global_tile_n_offset * M, packed_N, M, offsets); + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Update group quantization parameters/scaling + global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + } + + auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state); + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA); + cutlass::arch::fence_view_async_shared(); + mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); + ++mainloop_pipe_consumer_state; + ++k_tile; + + // static int constexpr NumVecs = size(tQArA) / VectorSize; + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; + auto compute_frgs = reinterpret_cast *>(tQArA.data()); + auto output_frgs = + reinterpret_cast *>(raw_pointer_cast(tQArQA.data())); + Tensor amax = + make_tensor(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{})); + Tensor pvscales = make_tensor_like(amax); + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 + + tiles_in_m * tiles_in_n * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size<1>(group_modes<1, rank(tQArA)>(tQArA)); v++) { + auto amax_view = group_modes<1, rank(amax)>(amax); + auto pvscales_view = group_modes<1, rank(pvscales)>(pvscales); + auto compute_frgs_up = + cutlass::NumericArrayConverter{}( + compute_frgs[v]); + amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up); + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales_view(_0{}, v) = cutlass::multiplies{}( + amax_view(_0{}, v), global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales_view(_0{}, v) = + cutlass::divides{}(amax_view(_0{}, v), fp4_max); + pvscales_view(_0{}, v) = cutlass::multiplies{}( + pvscales_view(_0{}, v), global_encode_scale); + } + filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v)); + auto qpvscale_ups = + cutlass::NumericConverter{}(filter(tQArSFA)(v)); + auto qpvscale_scaled = + cutlass::multiplies{}(qpvscale_ups, global_decode_scale); + ElementAccumulator acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); + } + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales, cutlass::platform::numeric_limits::max()); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = + cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale)); + } + } + copy(tiled_r2g_QA, tQArQA, tQAgQA_mn); + copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn)); + } + // scheduler.advance(); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + + } else { + cutlass::arch::warpgroup_reg_dealloc<32>(); + } +} // NOLINT(readability/fn_size) + +template +void group_row_col_rht_gemm_ntt_w_sfc_graph_safe( + int packed_sequence_length, int hidden_size, size_t num_tensors, ShapeRepresentation shape_rep, + TA const *A, TB const *B, TQA *QA, TSFA *SFA, TQA *QA_COLWISE, TSFA *SFA_COLWISE, + float *amax_rowwise, float *amax_colwise, const int64_t *offsets, const int64_t *first_dims, + const size_t *rng_state, uint32_t *tile_scheduler_workspace, uint32_t sm_count, + cudaStream_t stream, int k_tile_size = 1024) { + using namespace cute; + static int constexpr SFVecSize = 16; + static int constexpr RhtTensorSize = 16; + + static_assert(RhtTensorSize == 16, "RhtTensorSize must be 16"); + using LinearSFALayout = decltype(make_layout(make_shape(make_shape(Int{}, 0), 0), + make_stride(make_stride(_0{}, _1{}), 0))); + using LinearSFDLayout = decltype(make_layout(make_shape(0, make_shape(Int{}, 0)), + make_stride(0, make_stride(_0{}, _1{})))); + + using SwizzledSFALayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFALayout = decltype(tile_to_shape( + SwizzledSFALayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{})); + using SwizzledSFDLayout = decltype(tile_to_shape( + SwizzledSFDLayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{})); + + using SFALayout = cute::conditional_t; + using SFDLayout = cute::conditional_t; + SFALayout sfa_layout; + SFDLayout sfd_layout; + + if constexpr (kEnableSwizzleSFOutput) { + sfa_layout = tile_to_shape(SwizzledSFALayoutAtom{}, + make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{}); + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, + make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{}); + } else { + sfa_layout = make_layout( + make_shape(make_shape(Int{}, hidden_size / SFVecSize), packed_sequence_length), + make_stride(make_stride(_0{}, _1{}), hidden_size / SFVecSize)); + sfd_layout = make_layout( + make_shape(hidden_size, make_shape(Int{}, packed_sequence_length / SFVecSize)), + make_stride(packed_sequence_length / SFVecSize, make_stride(_0{}, _1{}))); + } + + // Define shapes (dynamic) + auto M = hidden_size; + auto N = packed_sequence_length; + Tensor tensorA = make_tensor(A, make_shape(hidden_size, packed_sequence_length), LayoutLeft{}); + Tensor tensorB = make_tensor(B, make_shape(RhtTensorSize, RhtTensorSize), LayoutLeft{}); + Tensor tensorQA = make_tensor(QA, make_shape(hidden_size, packed_sequence_length), LayoutLeft{}); + Tensor tensorSFA = make_tensor(SFA, sfa_layout); + + // Define strides (from tensors) + auto dA = stride(tensorA); // (dM,dK) + auto dB = stride(tensorB); // (dN,dK) + auto dD = LayoutRight{}; // (dM,dN) + auto dQA = stride(tensorQA); // (dM,dK) + using ClusterShape = Shape<_1, _1, _1>; + auto cluster_shape = ClusterShape{}; + auto cluster_tile_shape = Shape<_128, Int, Int>{}; + auto cluster_tile_mainloop = Shape<_128, Int, _128>{}; + + // Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles + static int constexpr EpilogueUnrollFactor = + size<2>(cluster_tile_mainloop) / size<2>(cluster_tile_shape); + // Construct the MMA + auto mma = make_tiled_mma( + SM100_MMA_F16BF16_SS(cluster_tile_shape), size<1>(cluster_tile_shape), + UMMA::Major::MN, UMMA::Major::MN>{}, + Layout>{}); + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cluster_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cluster_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = + partition_shape_B(mma, make_shape(size<1>(cluster_tile_shape), size<2>(cluster_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div( + shape<0>(cluster_tile_shape), + shape_div(shape<0>(cluster_tile_shape), size<0>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div( + shape<1>(cluster_tile_shape), + shape_div(shape<1>(cluster_tile_shape), size<1>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cluster_tile_shape)); + + using SmemLayoutAtomB = + decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + + auto mma_shape_A = partition_shape_A( + mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = + decltype(shape_div(shape<0>(cluster_tile_mainloop), + shape_div(shape<0>(cluster_tile_mainloop), + size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + static uint32_t constexpr TotalTmemRows = 128; + static uint32_t constexpr Sm100TmemCapacityColumns = 512; + static uint32_t constexpr TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns; + static uint32_t constexpr AccumulatorPipelineStageCount = + TotalTmem / (cute::size<0>(cluster_tile_shape) * cute::size<1>(cluster_tile_shape)); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int SchedulerPipelineStageCount = 4; + static int constexpr MainloopPipelineBytes = sizeof( + typename cutlass::detail::CustomizedPipelineTmaUmmaAsync<1, Shape<_1, _1, _1>, + Shape<_1, _1, _1>>::SharedStorage); + + static int constexpr SchedulerWorkspaceBytes = sizeof(int) * SchedulerPipelineStageCount; + static int constexpr SchedulerThrottlePipelineBytes = + sizeof(typename cutlass::PipelineAsync::SharedStorage); + static int constexpr SchedulerPipelineBytes = + sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + + static int constexpr TmemDeallocBytes = sizeof(cutlass::arch::ClusterBarrier); + static int constexpr BTensorBytes = cute::size(mma_shape_B) * sizeof(TB); + static int constexpr AccPipelineBytes = sizeof( + typename cutlass::PipelineUmmaAsync>::SharedStorage); + static int constexpr TmemBasePtrsBytes = sizeof(uint32_t); + static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes + static int constexpr kBytesPerStage = + cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes; + static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + + SchedulerPipelineBytes + TmemBasePtrsBytes + + TmemDeallocBytes + BTensorBytes + + AccPipelineBytes; // Reserve for barriers and other uses + static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP), + Step<_2, _1, _3>{}); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, + append(mma_shape_B, _1{})); // (MMA,MMA_N,MMA_K, _1) + auto sD = Layout<_1>{}; // XXX Dummy + + auto tma_load_a = + make_tma_copy_A_sm100(SM90_TMA_LOAD{}, tensorA, sA(_, _, _, 0), cluster_tile_mainloop, mma); + auto tma_load_b = + make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cluster_tile_shape, mma); + + // Assert checks on tile sizes -- no predication + assert(M % size<0>(cluster_tile_shape) == 0); + assert(N % size<1>(cluster_tile_shape) == 0); + + dim3 dimBlock(512); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(sm_count, 1, 1); + + int smem_size = sizeof( + SharedStorage); + + auto *kernel_ptr = &group_row_col_rht_gemm_device_graph_safe< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cluster_shape), + decltype(cluster_tile_shape), TA, decltype(dA), decltype(sA), decltype(tma_load_a), TB, + decltype(dB), decltype(sB), decltype(tma_load_b), TD, decltype(dD), decltype(sD), TSFD, + decltype(sfd_layout), TQA, decltype(dQA), TSFA, decltype(sfa_layout), decltype(mma), + AccumulatorPipelineStageCount, SchedulerPipelineStageCount, kEnableStochasticRounding, + kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kUseFastMath>; + + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Set workspace and set to zero + NVTE_CHECK_CUDA(cudaMemsetAsync(reinterpret_cast(tile_scheduler_workspace), 0, + sizeof(uint32_t), stream)); + + // Launch kernel + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; + cutlass::Status status = cutlass::launch_kernel_on_cluster( + params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, A, dA, + sA, tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, QA_COLWISE, SFA_COLWISE, + amax_rowwise, amax_colwise, offsets, first_dims, num_tensors, shape_rep, + tile_scheduler_workspace, mma, rng_state); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed."); +} + +} // namespace +} // namespace detail + +void group_hadamard_transform_cast_fusion_graph_safe(const GroupedTensor *input, + GroupedTensor *output, + const Tensor &hadamard_matrix_, + QuantizationConfig &quant_config, + Tensor &quant_workspace, cudaStream_t stream) { + NVTE_API_CALL(group_hadamard_transform_cast_fusion_graph_safe); + + using transformer_engine::detail::kMaxTensorsPerKernel; + using transformer_engine::detail::ShapeRepresentation; + + void *input_base_ptr = reinterpret_cast(input->data.dptr); + // TODO(zhongbo): add input sanity checks here + + bool all_has_row_quant = output->has_data(); + bool all_has_col_quant = output->has_columnwise_data(); + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (use_stochastic_rounding) { + NVTE_CHECK(quant_config.rng_state != nullptr, + "Enabled stochastic rounding without providing RNG state"); + const Tensor &rng_state_tensor = *convertNVTETensorCheck(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + uint32_t *tile_scheduler_workspace = nullptr; + NVTE_CHECK(quant_workspace.data.dptr != nullptr, "Quantization workspace must be provided."); + NVTE_CHECK(quant_workspace.data.buffer_size_bytes() >= sizeof(uint32_t), + "Quantization workspace must be at least 4 bytes."); + tile_scheduler_workspace = reinterpret_cast(quant_workspace.data.dptr); + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TD = cutlass::float_e2m1_t; + using TSFD = cutlass::float_ue4m3_t; + using TQA = TD; + using TSFA = TSFD; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t num_tensors = input->num_tensors; + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + // const size_t elts_total = first_logical_dim * last_logical_dim; + NVTE_CHECK(first_logical_dim % 128 == 0, + "First dimension of a grouped tensor should be divisible by 128."); + NVTE_CHECK(last_logical_dim % 128 == 0, + "Last dimension of a grouped tensor should be divisible by 128."); + NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel, + "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); + + ShapeRepresentation shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + if (output->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (output->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (output->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (output->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + + TQA *const rowwise_data_base_ptr = reinterpret_cast(output->data.dptr); + TSFA *const rowwise_scale_inv_base_ptr = reinterpret_cast(output->scale_inv.dptr); + TQA *const colwise_data_base_ptr = reinterpret_cast(output->columnwise_data.dptr); + TSFA *const colwise_scale_inv_base_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + float *const amax_rowwise_base_ptr = reinterpret_cast(output->amax.dptr); + float *const amax_colwise_base_ptr = reinterpret_cast(output->columnwise_amax.dptr); + + const int64_t *const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); + // const int64_t *const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + + const bool is_const_last_dim = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || + shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + NVTE_CHECK(is_const_last_dim, + "Currently we only support const last dimension for graph safe hadamard transform."); + + auto sm_count = transformer_engine::cuda::sm_count(); + + int k_tile_size = 1024; + + const bool use_swizzle_sf_output = false; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kEnableStochasticRounding, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_has_col_quant, kEnableRhtColQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_has_row_quant, kEnableRowQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_swizzle_sf_output, kEnableSwizzleSFOutput, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config.use_fast_math, kUseFastMath, + + if constexpr (kEnableRhtColQuant || kEnableRowQuant) { + detail::group_row_col_rht_gemm_ntt_w_sfc_graph_safe< + kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant, + kEnableSwizzleSFOutput, TA, TB, TQA, TSFA, TD, TSFD, kUseFastMath>( + /*packed_sequence_length=*/first_logical_dim, + /*hidden_size=*/last_logical_dim, + /*num_tensors=*/num_tensors, + /*shape_rep=*/shape_rep, + /*A=*/reinterpret_cast(input_base_ptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*QA=*/reinterpret_cast(rowwise_data_base_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), + /*QA_COLWISE=*/reinterpret_cast(colwise_data_base_ptr), + /*SFA_COLWISE=*/reinterpret_cast(colwise_scale_inv_base_ptr), + /*amax_rowwise=*/reinterpret_cast(amax_rowwise_base_ptr), + /*amax_colwise=*/reinterpret_cast(amax_colwise_base_ptr), + /*offsets=*/offsets_ptr, + /*first_dims=*/first_dims_ptr, + /*rng_state=*/rng_state, + /*tile_scheduler_workspace=*/tile_scheduler_workspace, + /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + } else { + NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", + kEnableRhtColQuant, ", kEnableRowQuant=", kEnableRowQuant, ")."); + } + + ););););); +} + +} // namespace transformer_engine + +void nvte_group_hadamard_transform_cast_fusion_graph_safe( + const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, NVTETensor quant_workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion_graph_safe); + using namespace transformer_engine; + + GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + + Tensor *quant_workspace_tensor = convertNVTETensorCheck(quant_workspace); + + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + if (input_tensor->num_tensors == 0) { + return; + } + + // Call the multi-tensor Hadamard transform amax implementation. + group_hadamard_transform_cast_fusion_graph_safe( + input_tensor, output_tensor, *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, + *quant_workspace_tensor, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 13103cc388..bee939f0cd 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -86,6 +86,24 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp int random_sign_mask, int random_sign_mask_t, cudaStream_t stream); +/*! \brief Grouped-tensor amax with Hadamard transform (graph safe, device-managed grouping). + * + * This function is experimental and the API is not stable. + * + * This API assumes that the split info (grouping of tensors) is on device and unknown to the host; + * therefore, this is a graph safe API and the grouped-tensor argument is passed as a single device structure. + * + * \param[in] input NVTEGroupedTensor representing grouped input tensors. + * \param[in,out] output NVTEGroupedTensor for output amax (row/col). Only the row-wise and + * column-wise amaxes are updated. + * \param[in] random_sign_mask 16-bit sign mask for RHT. + * \param[in] random_sign_mask_t 16-bit sign mask for transposed RHT. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_hadamard_transform_amax_graph_safe(const NVTEGroupedTensor input, + NVTEGroupedTensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream); + /*! * \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation. * @@ -124,6 +142,22 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso const NVTEQuantizationConfig quant_config, NVTETensor quant_workspace, cudaStream_t stream); +/*! + * \brief Perform the grouped-tensor Hadamard transform cast fusion operation in graph-safe mode. + * + * This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated. + * + * \param[in] input NVTEGroupedTensor representing grouped input tensors. + * \param[in,out] output NVTEGroupedTensor for output (row/column-wise quantized results). + * \param[in] hadamard_matrix Hadamard matrix to use for transformation. + * \param[in] quant_config Quantization configuration. + * \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_hadamard_transform_cast_fusion_graph_safe( + const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, NVTETensor quant_workspace, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 303801a88a..b5eadcf678 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -296,6 +296,17 @@ void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor ** void nvte_group_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections, size_t num_tensors, cudaStream_t stream); +/*! \brief Grouped-tensor amax without doing hadamard transform. + * + * This function is experimental and the API is not stable. + * + * \param[in] input NVTEGroupedTensor Input tensor. + * \param[in,out] output NVTEGroupedTensor Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_amax_graph_safe(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index fd0125c8d0..711d46e7e0 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -919,6 +919,210 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; +/*! \struct GroupedTensorWrapper + * \brief C++ wrapper for the NVTEGroupedTensor class. + */ +class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * TE grouped tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(&tensor_, param, &data); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + return nvte_get_grouped_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + /*! \brief Get an underlying NVTEGroupedTensor. + * + * \return NVTEGroupedTensor held by this GroupedTensorWrapper. + */ + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; +}; + /*! \enum Float8BlockScaleTensorFormat * \brief Data format for an FP8 block-scaled tensor */ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..18577b0eb4 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -88,33 +88,40 @@ class Recipe: Base recipe class. """ - def nvfp4(self): + @classmethod + def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" - return isinstance(self, NVFP4BlockScaling) + return issubclass(cls, NVFP4BlockScaling) - def mxfp8(self): + @classmethod + def mxfp8(cls): """Whether the given recipe is MXFP8 block scaling.""" - return isinstance(self, MXFP8BlockScaling) + return issubclass(cls, MXFP8BlockScaling) - def delayed(self): + @classmethod + def delayed(cls): """Whether the given recipe is delayed scaling.""" - return isinstance(self, DelayedScaling) + return issubclass(cls, DelayedScaling) - def float8_current_scaling(self): + @classmethod + def float8_current_scaling(cls): """Whether the given recipe is (per-tensor) current scaling.""" - return isinstance(self, Float8CurrentScaling) + return issubclass(cls, Float8CurrentScaling) - def float8_per_tensor_scaling(self): + @classmethod + def float8_per_tensor_scaling(cls): """Whether the given recipe is per-tensor scaling.""" - return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + return issubclass(cls, (DelayedScaling, Float8CurrentScaling)) - def float8_block_scaling(self): + @classmethod + def float8_block_scaling(cls): """Whether the given recipe is float8 blockwise scaling.""" - return isinstance(self, Float8BlockScaling) + return issubclass(cls, Float8BlockScaling) - def custom(self): + @classmethod + def custom(cls): """Whether the given recipe is custom.""" - return isinstance(self, CustomRecipe) + return issubclass(cls, CustomRecipe) @dataclass() diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 16171054cb..5f2851e5fa 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -156,6 +156,104 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return NVTE_MXFP8_1D_SCALING; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return NVTE_NVFP4_1D_SCALING; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + const int block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); + return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; + } + return NVTE_DELAYED_TENSOR_SCALING; +} + +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { + // Returns a GroupedTensorWrapper from a PyTorch GroupedTensor. + const auto num_tensors = tensor.attr("num_tensors").cast(); + const auto logical_shape = tensor.attr("logical_shape").cast>(); + + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + if (!tensor.attr("quantizers").is_none()) { + const auto quantizers = tensor.attr("quantizers").cast(); + if (!quantizers.empty() && !quantizers[0].is_none()) { + scaling_mode = ScalingModeFromQuantizer(quantizers[0]); + } + } + + auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); + + // Rowwise data + if (!tensor.attr("data").is_none()) { + const auto &data = tensor.attr("data").cast(); + ret.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()), + getTensorShape(data)); + } + + // Columnwise data + if (!tensor.attr("columnwise_data").is_none()) { + const auto &data = tensor.attr("columnwise_data").cast(); + ret.set_columnwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()), + getTensorShape(data)); + } + + // Scale + if (!tensor.attr("scale").is_none()) { + const auto &scale = tensor.attr("scale").cast(); + ret.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + } + + // Amax + if (!tensor.attr("amax").is_none()) { + const auto &amax = tensor.attr("amax").cast(); + ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + if (!tensor.attr("columnwise_amax").is_none()) { + const auto &amax = tensor.attr("columnwise_amax").cast(); + ret.set_columnwise_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + + // Scale inverse + if (!tensor.attr("scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("scale_inv").cast(); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDType(scale_inv.scalar_type()), + getTensorShape(scale_inv)); + } + if (!tensor.attr("columnwise_scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast(); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDType(scale_inv.scalar_type()), + getTensorShape(scale_inv)); + } + + // Shape metadata + if (!tensor.attr("first_dims").is_none()) { + const auto &first_dims = tensor.attr("first_dims").cast(); + ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()), + getTensorShape(first_dims)); + } + if (!tensor.attr("last_dims").is_none()) { + const auto &last_dims = tensor.attr("last_dims").cast(); + ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()), + getTensorShape(last_dims)); + } + if (!tensor.attr("tensor_offsets").is_none()) { + const auto &tensor_offsets = tensor.attr("tensor_offsets").cast(); + ret.set_tensor_offsets(tensor_offsets.data_ptr(), + GetTransformerEngineDType(tensor_offsets.scalar_type()), + getTensorShape(tensor_offsets)); + } + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1e6f0b00ab..9536105534 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -13,6 +13,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from .base import ( get_dummy_wgrad, TransformerEngineBaseModule, @@ -147,7 +148,10 @@ def forward( # tensors (like scales), but bulk allocation shares storage across all tensors, # so if scales can't be offloaded, nothing in the group can be offloaded. inputmats = tex.split_quantize( - inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading + inp_view, + m_splits, + input_quantizers, + disable_bulk_allocation=cpu_offloading, ) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -365,7 +369,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], for i in range(ctx.num_gemms): grad_biases[i] = grad_output_mats[i].sum(dim=0) grad_output = DebugQuantizer.multi_tensor_quantize( - grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype + grad_output_view, + ctx.grad_output_quantizers, + ctx.m_splits, + ctx.activation_dtype, ) else: # Only split grad output. Grad bias is fused with @@ -436,7 +443,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.input_quantizers[0] is not None: for input_quantizer in ctx.input_quantizers: if isinstance( - input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + input_quantizer, + (Float8Quantizer, Float8CurrentScalingQuantizer), ): input_quantizer.set_usage(rowwise=True, columnwise=True) else: @@ -446,7 +454,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( - inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype + inp_view, + ctx.input_quantizers, + ctx.m_splits, + ctx.activation_dtype, ) else: inputmats = torch.split( @@ -616,7 +627,7 @@ def __init__( ) -> None: super().__init__() - params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_gemms = num_gemms self.in_features = in_features self.out_features = out_features @@ -631,13 +642,20 @@ def __init__( assert ( not ub_overlap_rs and not ub_overlap_ag ), "GroupedLinear doesn't support Userbuffer overlap." + self.init_method = init_method self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute) - self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} + self._offsets = { + "input": 0, + "weight": 1, + "output": 2, + "grad_output": 0, + "grad_input": 1, + } self._num_fp8_tensors_per_gemm = { "fwd": 3, "bwd": 2, @@ -679,7 +697,7 @@ def __init__( self.out_features, self.in_features, device=device, - dtype=params_dtype, + dtype=self.params_dtype, ), ), init_fn=init_method, @@ -695,13 +713,13 @@ def __init__( torch.empty( self.out_features, device=device, - dtype=params_dtype, + dtype=self.params_dtype, ), ), init_fn=init_method_constant(0.0), ) else: - bias = torch.Tensor().to(dtype=params_dtype, device=device) + bias = torch.Tensor().to(dtype=self.params_dtype, device=device) setattr(self, f"bias{i}", bias) if self.primary_weights_in_fp8: @@ -709,6 +727,7 @@ def __init__( is_meta = torch.device(device).type == "meta" self.reset_parameters(defer_init=is_meta) + self.make_grouped_weights(defer_init=is_meta) if self.wgrad_store.delay_wgrad_compute(): for name, param in self.named_parameters(): @@ -729,8 +748,49 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: ) self._customize_quantizers_float8_current_scaling(fwd, recipe) + def make_grouped_weights(self, defer_init=False) -> None: + """ + Convert parameters into a GroupedTensor and re-register them as parameters. + """ + + if defer_init: + return + + weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + weight_quantizers = self._get_weight_quantizers() + + # Create the weight storage. + grouped_weights = GroupedTensor.make_grouped_tensor( + num_tensors=self.num_gemms, + shape=[(self.out_features, self.in_features)] * self.num_gemms, + quantizers=weight_quantizers, + dtype=self.params_dtype, + ) + + # Copy existing params into storage. + # TODO(ksivamani): Verify correctness of copy for all recipes. + with torch.no_grad(): + for i in range(self.num_gemms): + grouped_weights.quantized_tensors[i].copy_(weights[i]) + + # Re-register the grouped weights as parameters. + for i in range(self.num_gemms): + self.register_parameter( + f"weight{i}", + torch.nn.Parameter(grouped_weights.quantized_tensors[i]), + init_fn=self.init_method, + get_rng_state_tracker=self.get_rng_state_tracker, + fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"], + ) + + self.set_tensor_parallel_attributes(defer_init=defer_init) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) + self.set_tensor_parallel_attributes(defer_init=defer_init) + + def set_tensor_parallel_attributes(self, defer_init=False) -> None: + """Set attributes needed for TP""" if not defer_init: # Set parallelism attributes for linear weights diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 43cbdcf9e6..38cf8a3b86 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -11,7 +11,11 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + Recipe, +) from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from ..quantized_tensor import QuantizedTensor, Quantizer @@ -154,6 +158,10 @@ def calibrate(self, tensor: torch.Tensor) -> None: amin, amax = tensor.aminmax() self.amax.copy_(torch.max(-amin, amax)) + def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for Float8 1D blockwise quantization.""" + return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1]) + def create_tensor_from_data( self, data: torch.Tensor, @@ -407,6 +415,10 @@ def create_tensor_from_data( quantizer=self, ) + def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for Float8 1D blockwise quantization.""" + return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1]) + def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: """Function using primitives with ONNX defined translations.""" if tensor.dtype != torch.float32: @@ -768,7 +780,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): kwargs, ) return Float8Tensor.make_like( - tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape + tensor, + data=func_out, + data_transpose=func_transposed_out, + shape=func_out.shape, ) if func == torch.ops.aten.detach.default: diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 88081f51bf..4fa408bcb1 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -162,6 +162,49 @@ def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass + def get_scale_shape( + self, + shape: Iterable[int], + columnwise: bool, + ) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For MXFP8 1D blockwise quantization, blocksize is 32 + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + if columnwise: + # Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]] + # with padding to multiples of [4, 128] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + ) + # Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE] + # with padding to multiples of [128, 4] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + ) + + def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for MXFP8 1D blockwise quantization.""" + return rowwise_data_shape + def create_tensor_from_data( self, data: torch.Tensor, @@ -694,7 +737,7 @@ def fsdp_post_all_gather( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=fp8_dtype, dtype=param_dtype, - shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, + shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape), quantizer=self._quantizer, ) out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 8b707af3b2..a5e83c9602 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -340,7 +340,10 @@ def make_empty( ) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_inv = torch.empty( - columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + columnwise_scale_shape, + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, ) amax_columnwise = torch.zeros( 1, dtype=torch.float32, device=device, pin_memory=pin_memory diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py index d7a2719200..7c8a014c1d 100644 --- a/transformer_engine/pytorch/tensor/storage/__init__.py +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -7,3 +7,4 @@ from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 +from .grouped_tensor import GroupedTensor # noqa: F401 diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py new file mode 100644 index 0000000000..7a2a977796 --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -0,0 +1,918 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped tensor class for handling collections of tensors with different shapes""" +from __future__ import annotations +from typing import Optional, Tuple, List, Union +import math + +import torch + +from transformer_engine_torch import Float8BlockScaleTensorFormat + +from ...quantized_tensor import QuantizedTensorStorage, Quantizer + +from ..mxfp8_tensor import MXFP8Tensor +from ..nvfp4_tensor import NVFP4Tensor +from ..float8_tensor import Float8Tensor +from ..float8_blockwise_tensor import Float8BlockwiseQTensor +from .float8_tensor_storage import Float8TensorStorage +from .mxfp8_tensor_storage import MXFP8TensorStorage +from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .nvfp4_tensor_storage import NVFP4TensorStorage + + +class GroupedTensor: + """ + EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. + + Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode. + + Shape Representation: + - logical_shape: 2D shape representing the conceptual layout, i.e. the shape when member tensors + are flattened to 2D and stacked together (REQUIRED) + + When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N) + + When varying_first_dim(): [~sum_of_first_dims, N] where N is common + + When varying_last_dim(): [M, ~sum_of_last_dims] where M is common + + When varying_both_dims(): [1, total_elements] (fully flattened) + + - first_dims and last_dims are OPTIONAL (None if dimension is uniform) + + None first_dims: all tensors have the same first dimension + + None last_dims: all tensors have the same last dimension + + Both None: all tensors have identical shapes + + Both set: each tensor has unique shape (first_dims[i], last_dims[i]) + + Data Layout: + - ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.) + - logical_shape provides the conceptual 2D interpretation + - All data is stored on device in contiguous layout + + Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode. + """ + + def __init__( + self, + num_tensors: int, + shape: List[Tuple[int, int]], + quantizers: List[Optional[Quantizer]] = None, + dtype: Optional[torch.dtype] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + logical_shape: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initialize a GroupedTensor. + + Args: + num_tensors: Number of tensors in the group + shape: 2D shape of each tensor (len num_tensors) + quantizers: Quantizers for the grouped tensor + data: Row-wise data buffer (1D flattened) + columnwise_data: Column-wise data buffer (1D flattened) + scale_inv: Row-wise scale inverse buffer + columnwise_scale_inv: Column-wise scale inverse buffer + amax: Row-wise amax buffer + columnwise_amax: Column-wise amax buffer + scale: Scale buffer (for FP8-DS only) + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) + offsets: Vector of integer offsets for each tensor. + logical_shape: 2D tuple representing conceptual shape + """ + self.num_tensors = num_tensors + self.quantizers = quantizers + self.shape = shape + self.dtype = ( + dtype if dtype is not None else torch.float32 + ) # Default to float32 if not provided + + # Data buffers + self.data = data + self.columnwise_data = columnwise_data + self.scale_inv = scale_inv + self.columnwise_scale_inv = columnwise_scale_inv + self.amax = amax + self.columnwise_amax = columnwise_amax + self.scale = scale + + # For convenient indexing for python GroupedTensor API. + self.scale_inv_offsets = scale_inv_offsets + self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + + # Shape information (OPTIONAL - None if dimension is uniform across all tensors) + # first_dims[i] = first dimension of tensor i (None if all tensors have same first dim) + # last_dims[i] = last dimension of tensor i (None if all tensors have same last dim) + self.first_dims = ( + first_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + self.last_dims = ( + last_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + + # Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape()) + # tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1) + # Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size + # If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions) + self.tensor_offsets = ( + tensor_offsets # Device pointer to int64_t array of length num_tensors (or None) + ) + self.offsets = offsets # Vector of integer offsets for each tensor. + + # Logical shape: conceptual 2D shape of the grouped data (REQUIRED) + # Represents how the 1D flattened data should be interpreted as 2D + # Always 2D with positive dimensions + self.logical_shape = logical_shape if logical_shape is not None else (0, 0) + + # Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor. + # Used as a convenience. + self.quantized_tensors = None + + def has_data(self) -> bool: + """ + Check if the tensor has row-wise data. + + Returns: + True if data buffer is initialized, False otherwise + """ + return self.data is not None + + def has_columnwise_data(self) -> bool: + """ + Check if the tensor has column-wise data. + + Returns: + True if columnwise_data buffer is initialized, False otherwise + """ + return self.columnwise_data is not None + + def all_same_first_dim(self) -> bool: + """ + Check if all tensors in the group have the same first dimension. + + Returns: + True if first dimension is uniform across all tensors + """ + return self.first_dims is None + + def all_same_last_dim(self) -> bool: + """ + Check if all tensors in the group have the same last dimension. + + Returns: + True if last dimension is uniform across all tensors + """ + return self.last_dims is None + + def all_same_shape(self) -> bool: + """ + Check if all tensors in the group have identical shapes. + + Returns: + True if all tensors have the same shape + """ + return self.first_dims is None and self.last_dims is None + + def varying_both_dims(self) -> bool: + """ + Check if both dimensions vary across tensors. + + Returns: + True if both first and last dimensions vary + """ + return self.first_dims is not None and self.last_dims is not None + + def get_common_first_dim(self) -> int: + """ + Get the common first dimension when all tensors share it. + + Returns: + The common first dimension + + Raises: + RuntimeError: If first dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_first_dim(): + raise RuntimeError("First dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + if self.all_same_shape(): + # When both dims are uniform: logical_shape = [num_tensors * M, N] + return self.logical_shape[0] // self.num_tensors + # When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims] + return self.logical_shape[0] + + def get_common_last_dim(self) -> int: + """ + Get the common last dimension when all tensors share it. + + Returns: + The common last dimension + + Raises: + RuntimeError: If last dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_last_dim(): + raise RuntimeError("Last dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + # For both uniform and varying first dim cases: logical_shape[1] is the common last dim + return self.logical_shape[1] + + def get_dtype(self) -> torch.dtype: + """ + Get the high precision data type of the tensor. + + Returns: + The high precision dtype of the data buffer + """ + + return self.dtype + + def clear(self) -> None: + """ + Reset tensor data and clear all buffers. + """ + self.data = None + self.columnwise_data = None + self.scale_inv = None + self.columnwise_scale_inv = None + self.amax = None + self.columnwise_amax = None + self.scale = None + self.first_dims = None + self.last_dims = None + self.tensor_offsets = None + self.logical_shape = (0, 0) + self.num_tensors = 0 + self.quantizers = None + self.quantized_tensors = None + self.offsets = None + self.scale_inv_offsets = None + self.columnwise_scale_inv_offsets = None + + def __repr__(self) -> str: + """String representation of the GroupedTensor.""" + return ( + f"GroupedTensor(num_tensors={self.num_tensors}, " + f"shape={self.shape}, " + f"logical_shape={self.logical_shape}, " + f"dtype={self.get_dtype()})" + ) + + def __str__(self) -> str: + """User-friendly string representation.""" + shape_info = [] + if self.all_same_shape(): + shape_info.append("uniform shape") + else: + if not self.all_same_first_dim(): + shape_info.append("varying first dim") + if not self.all_same_last_dim(): + shape_info.append("varying last dim") + + return ( + f"GroupedTensor with {self.num_tensors} tensors " + f"({', '.join(shape_info) if shape_info else 'uniform'}), " + f"logical_shape={self.logical_shape}, " + f"dtype={self.get_dtype()}" + ) + + @staticmethod + def make_grouped_tensor( + num_tensors: int, + shape: List[Tuple[int, int]], + quantizers: List[Optional[Quantizer]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensor: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + shape: 2D shape of each tensor (len num_tensors) + quantizers: List of quantizers for each tensor (len num_tensors) + Used to figure out the recipe and what to allocate. + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + # Input validation + assert ( + len(shape) == num_tensors + ), f"Shape list length {len(shape)} must match num_tensors {num_tensors}" + assert all(len(s) == 2 for s in shape), "All shapes must be 2D tuples" + assert all(s[0] > 0 and s[1] > 0 for s in shape), "All dimensions must be positive" + + # Set device + if device is None: + device = torch.cuda.current_device() + + # Analyze shape patterns + first_dims_list = [s[0] for s in shape] + last_dims_list = [s[1] for s in shape] + + all_same_first = len(set(first_dims_list)) == 1 + all_same_last = len(set(last_dims_list)) == 1 + + # Create dimension arrays if needed + first_dims = ( + None + if all_same_first + else torch.tensor(first_dims_list, dtype=torch.int64, device=device) + ) + last_dims = ( + None + if all_same_last + else torch.tensor(last_dims_list, dtype=torch.int64, device=device) + ) + + # Calculate tensor offsets (cumulative element offsets) + tensor_offsets = None + offsets = None + if not (all_same_first and all_same_last): + # Need explicit offsets for non-uniform shapes + # Offsets are based on number of elements and not pointers. + # Kernels need to calculate precise pointers based on size of elements. + numels = [s[0] * s[1] for s in shape] + offsets = [0] + for i in range(num_tensors - 1): + offsets.append(offsets[-1] + numels[i]) + tensor_offsets = torch.tensor(offsets, dtype=torch.int64, device=device) + + # Calculate logical shape based on shape pattern + if all_same_first and all_same_last: + # All same shape: [num_tensors * M, N] + M, N = shape[0] + logical_shape = (num_tensors * M, N) + elif all_same_first and not all_same_last: + # Varying last dim only: [M, sum_of_last_dims] + M = first_dims_list[0] + sum_last = sum(last_dims_list) + logical_shape = (M, sum_last) + elif not all_same_first and all_same_last: + # Varying first dim only: [sum_of_first_dims, N] + sum_first = sum(first_dims_list) + N = last_dims_list[0] + logical_shape = (sum_first, N) + else: + # Varying both dims: [1, total_elements] + total_elements = sum(s[0] * s[1] for s in shape) + logical_shape = (1, total_elements) + + no_quantization = quantizers is None or len(quantizers) == 0 or quantizers[0] is None + + # TODO(ksivaman): (Do we need multiple quantizers?) + # Current implementation assumes all tensors have the different quantizers. + # instances but effectively the same quantizer. + rowwise_usage = quantizers[0].rowwise_usage if not no_quantization else True + columnwise_usage = quantizers[0].columnwise_usage if not no_quantization else False + + # Calculate total elements across all tensors + total_elements = sum(s[0] * s[1] for s in shape) + + data = None + columnwise_data = None + scale_inv = None + columnwise_scale_inv = None + amax = None + columnwise_amax = None + scale = None + scale_inv_offsets = None + columnwise_scale_inv_offsets = None + if no_quantization: + assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=dtype, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=dtype, device=device) + elif quantizers[0]._get_compatible_recipe().mxfp8(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse buffer for MXFP8 - complex shape based on block scaling + # For grouped tensors, we need to calculate scale_inv size for all tensors + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizers[i].get_scale_shape(s, False) + scale_elements = math.prod(scale_inv_shape) + total_scale_elements += scale_elements + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizers[i].get_scale_shape(s, False) + columnwise_scale_elements = math.prod(scale_inv_shape) + total_columnwise_scale_elements += columnwise_scale_elements + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + elif quantizers[0]._get_compatible_recipe().delayed(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + scale_inv_offsets = list(range(num_tensors)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + columnwise_scale_inv_offsets = list(range(num_tensors)) + + # Amax buffer for delayed scaling - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizers[0]._get_compatible_recipe().nvfp4(): + + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) + data = torch.empty((total_elements) // 2, dtype=torch.uint8, device=device) + # Scale inverse buffer for NVFP4 - complex shape based on block scaling + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizers[i].get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + # Amax buffer - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) + columnwise_data = torch.empty( + (total_elements) // 2, dtype=torch.uint8, device=device + ) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizers[i].get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + # Columnwise amax buffer - one per tensor + columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizers[0]._get_compatible_recipe().float8_block_scaling(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - size depends on block configuration + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizers[i].get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizers[i].get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.float32, device=device + ) + elif quantizers[0]._get_compatible_recipe().float8_current_scaling(): + # Current scaling - per-tensor scaling computed on the fly + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + scale_inv_offsets = list(range(num_tensors)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + columnwise_scale_inv_offsets = list(range(num_tensors)) + + # Scale and amax buffers for current scaling - one per tensor + scale = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + else: + raise ValueError(f"Unsupported quantizer for GroupedTensor: {quantizers[0]}") + + grouped_tensor = GroupedTensor( + num_tensors=num_tensors, + shape=shape, + dtype=dtype, + quantizers=quantizers, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=amax, + columnwise_amax=columnwise_amax, + scale=scale, + first_dims=first_dims, + last_dims=last_dims, + tensor_offsets=tensor_offsets, + offsets=offsets, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + logical_shape=logical_shape, + ) + + grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() + return grouped_tensor + + def split_into_quantized_tensors( + self, + ) -> List[Union[QuantizedTensorStorage, torch.Tensor]]: + """ + Split the GroupedTensor into a list of `num_tensors` + quantized tensors based on the quantizer. No additional memory allocation is performed, + so the tensors returned are the same as the ones used to create the GroupedTensor. + + If quantizer is None, returns normal torch tensors. + If quantizer.internal is True, returns QuantizedTensorStorage. + Otherwise, returns QuantizedTensor. + """ + + result = [] + + no_quantization = ( + self.quantizers is None or len(self.quantizers) == 0 or self.quantizers[0] is None + ) + + # Case 1: No quantization - return regular torch tensors + if no_quantization: + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.shape[i] + + # Get tensor data slice + if self.offsets is not None: + start_offset = self.offsets[i] + numel = tensor_shape[0] * tensor_shape[1] + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + else: + # All same shape case + numel = tensor_shape[0] * tensor_shape[1] + start_offset = i * numel + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + + return result + + # Case 2: Quantized tensors + recipe = self.quantizers[0]._get_compatible_recipe() + + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.shape[i] + numel = tensor_shape[0] * tensor_shape[1] + + # Get data offsets + if self.offsets is not None: + data_start = self.offsets[i] + data_end = data_start + numel + else: + # All same shape + data_start = i * numel + data_end = data_start + numel + + # Special shape handling for NVFP4. + nvfp4 = self.quantizers[i]._get_compatible_recipe().nvfp4() + if nvfp4: + data_start = data_start // 2 + data_end = data_end // 2 + + # Extract rowwise and columnwise data + rowwise_data = None + columnwise_data = None + + if self.has_data(): + if nvfp4: + rowwise_tensor_shape = self.quantizers[i].convert_shape_for_fp4(tensor_shape) + else: + rowwise_tensor_shape = tensor_shape + rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape) + + if self.has_columnwise_data(): + columnwise_tensor_shape = self.quantizers[i].get_columnwise_shape(tensor_shape) + if nvfp4: + columnwise_tensor_shape = self.quantizers[i].convert_shape_for_fp4( + columnwise_tensor_shape + ) + columnwise_data = self.columnwise_data[data_start:data_end].view( + columnwise_tensor_shape + ) + + # MXFP8 format + if recipe.mxfp8(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Calculate expected scale shape for MXFP8 + scale_shape = self.quantizers[i].get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + cscale_shape = self.quantizers[i].get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + if self.quantizers[i].internal: + mxfp8_tensor_class = MXFP8TensorStorage + else: + mxfp8_tensor_class = MXFP8Tensor + tensor = mxfp8_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.quantizers[i].dtype, + quantizer=self.quantizers[i], + ) + result.append(tensor) + + # Delayed scaling or current scaling (both use Float8TensorStorage) + elif recipe.delayed() or recipe.float8_current_scaling(): + # Scale inverse - one per tensor + scale_inv = None + if self.scale_inv is not None: + scale_inv = self.scale_inv[i : i + 1] + + if self.quantizers[i].internal: + float8_tensor_class = Float8TensorStorage + else: + float8_tensor_class = Float8Tensor + + tensor = float8_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + data=rowwise_data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.quantizers[i].dtype, + quantizer=self.quantizers[i], + data_transpose=columnwise_data, + ) + result.append(tensor) + + # Float8 block scaling + elif recipe.float8_block_scaling(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Get scale shape from quantizer + scale_shape = self.quantizers[i].get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + # Get columnwise scale shape from quantizer + cscale_shape = self.quantizers[i].get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Compute is_2D_scaled and data_format from quantizer attributes + is_2D_scaled = self.quantizers[i].block_scaling_dim == 2 + data_format = ( + Float8BlockScaleTensorFormat.COMPACT + if self.quantizers[i].all_gather_usage + else Float8BlockScaleTensorFormat.GEMM_READY + ) + + if self.quantizers[i].internal: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage + else: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensor + + tensor = float8_blockwise_q_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.quantizers[i].dtype, + quantizer=self.quantizers[i], + is_2D_scaled=is_2D_scaled, + data_format=data_format, + ) + result.append(tensor) + + # NVFP4 format + elif recipe.nvfp4(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + amax_rowwise = None + amax_columnwise = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Get scale shape from quantizer + scale_shape = self.quantizers[i].get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + # Get columnwise scale shape from quantizer + cscale_shape = self.quantizers[i].get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Extract amax - one per tensor + if self.amax is not None: + amax_rowwise = self.amax[i : i + 1] + + if self.columnwise_amax is not None: + amax_columnwise = self.columnwise_amax[i : i + 1] + + if self.quantizers[i].internal: + nvfp4_tensor_class = NVFP4TensorStorage + else: + nvfp4_tensor_class = NVFP4Tensor + + tensor = nvfp4_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + fp4_dtype=self.quantizers[i].dtype, + quantizer=self.quantizers[i], + ) + result.append(tensor) + + else: + raise ValueError(f"Unsupported quantization recipe: {recipe}") + + return result + + @staticmethod + def create_and_quantize( + tensors: int, + quantizers: None | List[Quantizer], + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize given tensors into quantized tensors with underlying + storage allocated in a GroupedTensor. + """ + + num_tensors = len(tensors) + + if quantizers is not None: + assert num_tensors == len(quantizers), "Number of tensors and quantizers must match" + + grouped_tensor = GroupedTensor.make_grouped_tensor( + num_tensors=len(tensors), + shape=[t.shape for t in tensors], + quantizers=quantizers, + device=device, + dtype=dtype, + ) + + grouped_tensor.quantize(tensors, noop_flag=noop_flag) + + return grouped_tensor + + def quantize( + self, + tensors: List[torch.Tensor], + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize the GroupedTensor inplace. + """ + + quantized_tensors = self.split_into_quantized_tensors() + for i in range(self.num_tensors): + self.quantizers[i].update_quantized( + tensors[i], quantized_tensors[i], noop_flag=noop_flag + ) + return quantized_tensors