diff --git a/benchmarks/linear/benchmark_linear.py b/benchmarks/linear/benchmark_linear.py new file mode 100644 index 0000000000..b293c44fc9 --- /dev/null +++ b/benchmarks/linear/benchmark_linear.py @@ -0,0 +1,330 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import torch +import torch.utils.benchmark as benchmark +import pandas as pd + +from transformer_engine.pytorch.module import Linear as TELinear +from transformer_engine.common.recipe import ( + Float8BlockScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) +from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager +from contextlib import nullcontext + +""" +# Profile BF16 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_linear_bf16 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_linear.py --profile --recipe bf16 + +# Profile FP8 sub-channel recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_linear_fp8_sub_channel \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_linear.py --profile --recipe fp8_sub_channel + +# Profile MXFP8 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_linear_mxfp8 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_linear.py --profile --recipe mxfp8 + +# Profile NVFP4 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_linear_nvfp4_rht_cast_fusion \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_linear.py --profile --recipe nvfp4 + +# Example to look at a single kernel target with NCU, like the fused hadamard amax kernel for NVFP4 recipe +ncu -f -o ./benchmarks/linear/ncu_b200_linear_nvfp4_rht_cast_fusion \ + --set=full \ + --kernel-name "row_col_rht_gemm_device" \ + -s 5 -c 5 \ + python benchmarks/linear/benchmark_linear.py --profile --recipe nvfp4 + +""" + +RECIPES = { + "bf16": None, + "fp8_sub_channel": Float8BlockScaling(), + "mxfp8": MXFP8BlockScaling(), + "nvfp4": NVFP4BlockScaling(), +} + +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() + + +def run_linear_multiple_steps(layer, x, mode, gradient, run_num_steps=1, recipe=None): + assert mode in ["fwd_only", "fwd_bwd"] + quantization_context = ( + autocast(enabled=True, recipe=recipe) if recipe is not None else nullcontext() + ) + + if mode == "fwd_only": + with torch.no_grad(), quantization_context: + for i in range(run_num_steps): + y_q = layer.forward( + x, + is_first_microbatch=(i == 0), + ) + return y_q + else: + # reset gradients + layer.zero_grad() + x.grad = None + + with quantization_context: + for i in range(run_num_steps): + label = f"step_{i}" + torch.cuda.nvtx.range_push(label) + y_q = layer.forward( + x, + is_first_microbatch=(i == 0), + ) + y_q.backward(gradient) + torch.cuda.nvtx.range_pop() + + grads_q = [] + grads_q.append(x.grad) + # remaining derivatives are in respect to model parameters + for p in layer.parameters(): + if p.requires_grad: + grads_q.append(p.grad) + + return y_q, grads_q + + +def benchmark_linear( + x, + w, + bias, + recipe_name, + mode, +): + params_dtype = torch.bfloat16 + recipe = RECIPES[recipe_name] + + in_features = x.shape[1] + out_features = w.shape[0] + gradient = torch.ones((x.shape[0], out_features), dtype=torch.bfloat16, device=x.device) + + layer = TELinear( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + ) + + layer = layer.to("cuda") + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + num_microbatches = 32 + + label = f"{recipe_name}_{'linear'}" + torch.cuda.nvtx.range_push(label) + timing = benchmark.Timer( + stmt="run_linear_multiple_steps(layer, x, mode, gradient, num_microbatches, recipe)", + globals={ + "run_linear_multiple_steps": run_linear_multiple_steps, + "layer": layer, + "x": x, + "mode": mode, + "gradient": gradient, + "num_microbatches": num_microbatches, + "recipe": recipe, + }, + num_threads=1, + ).blocked_autorange(min_run_time=10) + print(f"{recipe_name}: {timing} \n") + timing_ms = timing.median * 1000 / num_microbatches + + return timing_ms + + +def run_benchmark_linear(mkns, recipe_name, use_bias, fwd_only=False): + data = [] + assert not use_bias, "Bias is not supported in this benchmark script" + + print(f"========== Benchmarking {recipe_name} ==========") + for m, k, n in mkns: + device = "cuda" + x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) + w = torch.randn((n, k), dtype=torch.bfloat16, device=device) + bias = None + + # Run the benchmark + print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}") + print(f"fwd_only: {fwd_only}") + + linear_fwd_bwd_timing_ms = benchmark_linear( + x, + w, + bias, + recipe_name, + mode="fwd_only" if fwd_only else "fwd_bwd", + ) + + # Append the results + data.append( + [ + m, + k, + n, + recipe_name, + linear_fwd_bwd_timing_ms, + ] + ) + + timing_notation = "linear_fwd_time_ms" if fwd_only else "linear_fwd_bwd_time_ms" + + df = pd.DataFrame( + data=data, + columns=[ + "m", + "k", + "n", + "recipe", + timing_notation, + ], + ) + + print(df, "\n") + return df + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable profiling mode") + parser.add_argument( + "--output-dir", + type=str, + default="benchmark_output/", + help="output path for report", + ) + # arguments for recipe, options are fp8_sub_channel, mxfp8, bf16, all + parser.add_argument( + "--recipe", + type=str, + default="bf16", + help="Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all", + ) + parser.add_argument( + "--token-dim", + type=int, + default=None, + help="Token dimension to use, calculated by SEQ_LEN * MBS / TP_SIZE", + ) + parser.add_argument( + "--hidden-dim", + type=int, + default=None, + help="Hidden dimension to use", + ) + parser.add_argument( + "--output-dim", + type=int, + default=None, + help="Output dimension to use", + ) + parser.add_argument( + "--fwd-only", + action="store_true", + default=False, + help="Run forward pass only, default is both forward and backward passes", + ) + args = parser.parse_args() + + use_bias = False + + token_dim_list = [16384] + hidden_dim_list = [4096] + output_dim_list = [4096] + + if args.token_dim is not None: + token_dim_list = [args.token_dim] + + if args.hidden_dim is not None: + hidden_dim_list = [args.hidden_dim] + + if args.output_dim is not None: + output_dim_list = [args.output_dim] + + # MKN for linear + mkns = [] + for m in token_dim_list: + for k in hidden_dim_list: + for n in output_dim_list: + mkns.append((m, k, n)) + + # default recipes to run if not specified + recipe_list = ["bf16"] + + if args.recipe == "all": + recipe_list = ["bf16", "fp8_sub_channel", "mxfp8", "nvfp4"] + else: + recipe_list = [args.recipe] + + if args.profile: + hidden_dim_to_profile = 4096 if args.hidden_dim is None else args.hidden_dim + output_dim_to_profile = 4096 if args.output_dim is None else args.output_dim + token_dim_to_profile = 16384 if args.token_dim is None else args.token_dim + mkns = [(token_dim_to_profile, hidden_dim_to_profile, output_dim_to_profile)] + # in profile mode, only run one recipe specified in args.recipe + assert args.recipe != "all", ( + "In profile mode, only one recipe can be specified, please specify the recipe as" + " fp8_sub_channel, mxfp8, nvfp4, or bf16" + ) + recipe_list = [args.recipe] + torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + + # Initialize a dataframe to store the results + df_linears = pd.DataFrame() + + # Run the fp8 benchmarks + for recipe_name in recipe_list: + assert recipe_name in [ + "bf16", + "fp8_sub_channel", + "mxfp8", + "nvfp4", + ], "Recipe must be one of bf16, fp8_sub_channel, mxfp8, or nvfp4" + if recipe_name == "mxfp8" and not mxfp8_available: + print(f"MXFP8 is not available, skipping {recipe_name}") + continue + if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available: + print(f"FP8 block scaling is not available, skipping {recipe_name}") + continue + if recipe_name == "nvfp4" and not nvfp4_available: + print(f"NVFP4 is not available, skipping {recipe_name}") + continue + + df = run_benchmark_linear( + mkns, + recipe_name, + use_bias, + fwd_only=args.fwd_only, + ) + df_linears = pd.concat([df_linears, df]) + + print(df_linears) + + if args.profile: + torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 98be9a4f54..5826c4b95f 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -35,6 +35,7 @@ def check_quantization_nvfp4_versus_reference( M: int, N: int, contiguous: bool, + return_identity: bool, return_transpose: bool, use_cpp_allocator: bool, swizzled_scale: bool = False, @@ -61,7 +62,7 @@ def check_quantization_nvfp4_versus_reference( # Quantize nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, - rowwise=True, + rowwise=return_identity, columnwise=return_transpose, with_amax_reduction=False, amax_reduction_group=None, @@ -78,9 +79,11 @@ def check_quantization_nvfp4_versus_reference( x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) # Extract data from NVFP4Tensor - assert x_nvfp4_sut._rowwise_data is not None - qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) - assert x_nvfp4_sut._rowwise_scale_inv is not None + qx: torch.Tensor = ( + x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._rowwise_data is not None + else None + ) sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv qx_t = ( x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) @@ -91,13 +94,13 @@ def check_quantization_nvfp4_versus_reference( amax_rowwise = x_nvfp4_sut._amax_rowwise amax_colwise = x_nvfp4_sut._amax_columnwise - qx = unpack_fp4(qx) + qx = unpack_fp4(qx) if qx is not None else None qx_t = unpack_fp4(qx_t) if qx_t is not None else None # Reference quantization using NVFP4QuantizerRef with built-in RHT ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, - rowwise=True, + rowwise=return_identity, columnwise=return_transpose, pow_2_scales=False, eps=0.0, @@ -130,13 +133,14 @@ def check_quantization_nvfp4_versus_reference( sx_t_ref = None ref_amax_colwise_t = None - torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0) + if return_identity: + torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0) - torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) - # Compare only the valid portion of scale tensors (reference may not have padding) - ref_sx_shape = sx_ref.shape - sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] - torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + # Compare only the valid portion of scale tensors (reference may not have padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) if return_transpose: torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0) @@ -185,7 +189,7 @@ def check_quantization_nvfp4_versus_reference( ) @pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] ) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] @@ -197,15 +201,29 @@ def test_rht_with_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, N: int, - return_transpose: bool, + quantize_mode: str, use_cpp_allocator: bool, with_random_sign_mask: bool, ) -> None: + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, M=M, N=N, contiguous=True, + return_identity=return_identity, return_transpose=return_transpose, use_cpp_allocator=use_cpp_allocator, with_random_sign_mask=with_random_sign_mask, @@ -221,7 +239,7 @@ def test_rht_with_quantization_block_tiling_versus_reference( ) @pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] ) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] @@ -233,15 +251,29 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, N: int, - return_transpose: bool, + quantize_mode: str, use_cpp_allocator: bool, with_random_sign_mask: bool, ): + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, M=M, N=N, contiguous=False, + return_identity=return_identity, return_transpose=return_transpose, use_cpp_allocator=use_cpp_allocator, with_random_sign_mask=with_random_sign_mask, diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 5975efedaf..58ae83f168 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -174,6 +174,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources hadamard_transform/group_hadamard_transform.cu hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu + hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu multi_tensor/compute_scale.cu diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu new file mode 100644 index 0000000000..2eace1ce00 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -0,0 +1,1386 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "customized_pipeline.cuh" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/float8.h" +#include "cutlass/float_subbyte.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/platform/platform.h" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/print_error.hpp" + +// clang-format off + +namespace transformer_engine { +namespace detail { +namespace { + +using namespace cute; + +// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor +using cute::Tensor; + +struct CLCResponse { uint32_t data[4] = {0}; }; + + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverterBase( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + auto output_ptr = reinterpret_cast(&output); + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" + "}" + : "=h"(output_ptr[0]), "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), + "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return output; +} + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverter(cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = reinterpret_cast const *>(&rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template < + class ElementA, + class ElementB, + class ASmemLayout, + class BSmemLayout, + class ClusterShape, + int AccumulatorPipelineStageCount_, + int EpilogueUnrollFactor_, + int SchedulerPipelineStageCount_> +struct SharedStorage { + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr EpilogueUnrollFactor = EpilogueUnrollFactor_; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::detail::CustomizedPipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + struct TensorStorage : cute::aligned_struct<128, _1> { + // cute::array_aligned> smem_A; + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + alignas(16) CLCPipelineStorage clc; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) CLCResponse clc_response[SchedulerPipelineStageCount_]; + uint32_t tmem_base_ptr; +}; + +template +__launch_bounds__(512, 1) +__global__ static void row_col_rht_gemm_device( + MShape M, + NShape N, + KShape K, + ClusterShape cluster_shape, + ClusterTileShape cluster_tile, + TA const* A, + AStride dA, + ASmemLayout sAlayout, + CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const* B, + BStride dB, + BSmemLayout sBlayout, + CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + TD* D, + DStride dD, + DSmemLayout, + TSFD* SFD, + TSFDLayout sfd_layout, + TQA* QA, + QAStride dQA, + TSFA* SFA, + TSFALayout sfa_layout, + TiledMMA mma, + float const* a_global_amax, + float const* c_global_amax, + const size_t* rng_state) { + using namespace cute; + + // Abort immediately if compilation is not supported + constexpr bool is_blackwell_arch = ARCH_BLACKWELL_FAMILY; + if constexpr (!is_blackwell_arch) { + NVTE_DEVICE_ERROR( + "row_col_rht_gemm_device is only supported on Blackwell " + "with architecture-specific compilation. " + "Try recompiling with sm_100a or similar."); + return; + } + static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, + "row_col_rht_gemm_device must generate row-wise " + "and/or column-wise output."); +#if !defined(CUTLASS_ARCH_CLC_ENABLED) + CUTLASS_NOT_IMPLEMENTED(); + return; +#endif + + using X = Underscore; + // static constexpr bool kApplyStochasticRounding = true; + using ElementAccumulator = float; + static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v); + static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; + static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; + static constexpr bool kEnableRowQuant = kEnableRowQuant_; + static constexpr bool kUseFastMath = kUseFastMath_; + static int constexpr RhtTensorSize = 16; + static int constexpr kTmaRhtTensorTransactionBytes = cutlass::bits_to_bytes( + RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::detail::CustomizedPipelineTmaUmmaAsync< + MainloopPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + static_assert(ClusterShape{} == Shape<_1,_1,_1>{}, "ClusterShape must be Shape<_1,_1,_1>"); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static int constexpr VectorSize = RhtTensorSize; + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + auto cluster_size = size<0>(cluster_shape); + auto mainloop_tiler = Shape<_128,_16,_128>{}; + auto epilogue_tiler = Shape<_128,_128,_128>{}; + + static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile); + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + int const K_TILE_MAX = ceil_div(min(N, K), size<2>(epilogue_tiler)); + + struct TileScheduler { + struct WorkTileInfo { + uint32_t m_idx = 0; + uint32_t n_idx = 0; + uint32_t l_idx = 0; + bool is_valid_tile = false; + }; + uint32_t tiles_in_m = 0; + uint32_t tiles_in_n = 0; + + int k_tile_max = 0; + + int wave_cnt = 0; + WorkTileInfo work_tile_info; + WorkTileInfo next_work_tile_info; + CLCResponse* clc_response_ptr_; + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, CLCResponse* clc_response_ptr) + : tiles_in_m(tiles_m), + tiles_in_n(tiles_n), + + k_tile_max(kmax), + work_tile_info({blockIdx.x, blockIdx.y, blockIdx.z, blockIdx.x( + &clc_response_ptr[state.index()])); + asm volatile( + "{\n\t" + "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n\t" + "}\n" + : + : "r"(result_addr), "r"(mbarrier_addr)); + #else + CUTLASS_NOT_IMPLEMENTED(); + #endif + } + CUTLASS_DEVICE + static WorkTileInfo + work_tile_info_from_clc_response(uint32_t result_addr) { + WorkTileInfo work_tile_info; + uint32_t valid = 0; + #if defined(CUTLASS_ARCH_CLC_ENABLED) + asm volatile( + "{\n" + ".reg .pred p1;\n\t" + ".reg .b128 clc_result;\n\t" + "ld.shared.b128 clc_result, [%4];\n\t" + "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\n\t" + "selp.u32 %3, 1, 0, p1;\n\t" + "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {%0, %1, %2, _}, clc_result;\n\t" + "}\n" + : "=r"(work_tile_info.m_idx), "=r"(work_tile_info.n_idx), "=r"(work_tile_info.l_idx), "=r"(valid) + : "r"(result_addr) + : "memory" + ); + + cutlass::arch::fence_view_async_shared(); + #else + CUTLASS_NOT_IMPLEMENTED(); + #endif + work_tile_info.is_valid_tile = (valid == 1); + return work_tile_info; + } + }; + + + + // Allocate SMEMork + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); + uint32_t tiles_in_n = uint32_t(size(ceil_div(N, size<2>(epilogue_tiler)))); + TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, shared_storage.clc_response); + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto acc_shape_mma = make_shape(take<0,2>(mainloop_tiler), _1{}, _1{}); + auto acc_shape_epilogue = make_shape(take<0,2>(epilogue_tiler), _1{}, _1{}); + + auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); + auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); + + static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; + static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; + static int constexpr NumMmaThreadCount = kEnableRHTColQuant? 32: 0; + static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant? 1: 0; + static int constexpr NumSchedThreads = 32; + static int constexpr NumMainloopLoadThreads = 32; + static int constexpr NumEpilogueThreads = NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + NumMmaThreadCount + NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // warp assignment + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_sched_warp = (warp_idx == 2); + bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); + bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); + + if (is_epilogue_col_quant_warp && elect_one_sync()) { + cute::prefetch(raw_pointer_cast(c_global_amax)); + } + if (is_epilogue_row_quant_warp && elect_one_sync()) { + cute::prefetch(raw_pointer_cast(a_global_amax)); + } + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + mainloop_pipeline_params.num_consumers = NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount; + MainloopPipeline mainloop_pipeline( + shared_storage.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_col_quant_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; + accumulator_pipeline_params.initializing_warp = 1; + using IsInitAccumulatorPipeline = cute::conditional_t; + AccumulatorPipeline accumulator_pipeline( + shared_storage.accumulator, + accumulator_pipeline_params, + cluster_shape, + IsInitAccumulatorPipeline{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (is_sched_warp) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); + clc_pipeline_params.transaction_bytes = sizeof(CLCResponse); + clc_pipeline_params.initializing_warp = 3; + CLCPipeline clc_pipeline(shared_storage.clc, clc_pipeline_params, cluster_shape); + CLCPipelineState clc_pipeline_consumer_state; + CLCPipelineState clc_pipeline_producer_state = cutlass::make_producer_start_state(); + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (is_dma_warp) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (is_sched_warp) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 4; + + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + + if (is_dma_warp) { + cutlass::arch::warpgroup_reg_dealloc<32>(); + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); + + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{}); + Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = tma_partition( + tma_load_a, + get<2>(cta_coord_vmnk), + make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsA), + group_modes<0,3>(tCgA)); + + auto [tBgB, tBsB] = tma_partition( + tma_load_b, + get<1>(cta_coord_vmnk), + make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(tCsB), + group_modes<0,3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + if constexpr (kEnableRHTColQuant) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0)); + } + } + + do { + bool is_first_wave = scheduler.is_first_wave(); + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_,scheduler.tile_m(),_); + int k_tile = 0; + // Throttle CLC producer + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { + + int k_tile_idx_n = scheduler.tile_n_base() + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier( + mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + if (cute::elect_one_sync()) { + copy( + tma_load_a.with(*tma_barrier, tma_mcast_mask_a), + tAgA_mk(_,k_tile_idx_n), + tAsA(_,write_stage)); + } + } + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + cutlass::arch::warpgroup_reg_dealloc<32>(); + if constexpr (kEnableRHTColQuant) { + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + + auto barrier_token = mainloop_pipeline.consumer_try_wait( + mainloop_pipe_consumer_state, + skip_wait); + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_,_,_,read_stage); + auto tCrB_nk = tCrB(_,_,0,0); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) + { + int accumulator_k_block = accumulator_pipe_producer_state.index() * EpilogueUnrollFactor; + int tCrA_k_block = k_block * EpilogueUnrollFactor; + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < EpilogueUnrollFactor; i++) { + auto accumulators = bulk_tmem_mma(_,_,_,accumulator_k_block + i); + gemm(mma, tCrA_mk(_,_,tCrA_k_block + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state); + barrier_token = mainloop_pipeline.consumer_try_wait( + mainloop_pipe_consumer_state, + skip_wait); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } else if(is_sched_warp) { + cutlass::arch::warpgroup_reg_dealloc<32>(); + do { + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + clc_pipeline_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipeline_producer_state); + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } else if (is_epilogue_col_quant_warp) { + cutlass::arch::warpgroup_reg_alloc<192>(); + if constexpr (kEnableRHTColQuant) { + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + float const c_global_amax_val = *c_global_amax; + auto acc_epilogue_pipelined_shape = append(acc_shape_epilogue, Int{}); + auto bulk_tmem_epilogue_layout = make_layout( + acc_epilogue_pipelined_shape, + make_stride( + stride<0>(bulk_tmem_mma), + Int<0>{}, + Int<0>{}, + size<1>(epilogue_tiler))); + auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); + + // leveraging 256-bit writes to global memory + static int constexpr FragmentSize = 256 / sizeof_bits_v; + + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup; + + size_t rng_seed = 0; + size_t rng_offset = 0; + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + + Tensor mD = make_tensor( + cute::subbyte_iterator(D), + make_shape(M,N), + dD); // (M,N) + Tensor gD_mn = local_tile( + mD, + epilogue_tiler, + make_coord(_,_, _), + Step<_1,_1, X>{}); // (BLK_M,BLK_N) + Tensor pD = make_identity_tensor(mD.shape()); + Tensor pD_mn = local_tile( + pD, + epilogue_tiler, + make_coord(_,_, _), + Step<_1,_1, X>{}); // (BLK_M,BLK_N) + Tensor mSFD = make_tensor(make_gmem_ptr(SFD), sfd_layout); + Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + Tensor pSFD = make_identity_tensor(mSFD.shape()); + Tensor pSFD_mn = local_tile(pSFD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + Tensor gD_mn_view = tiled_divide(gD_mn, take<0,2>(epilogue_tiler)); + Tensor pD_mn_view = tiled_divide(pD_mn, take<0,2>(epilogue_tiler)); + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{})); + auto tiled_r2g = make_tiled_copy_D( + Copy_Atom{}, + tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(local_thread_idx); + auto thr_r2g = tiled_r2g.get_slice(local_thread_idx); + + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + float const fp4_max_inv = 1.0f / fp4_max; + float const global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float const global_decode_scale = 1.0f / global_encode_scale; + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + auto sfc_converter = cutlass::NumericConverter{}; + + do { + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ++k_tile) { + Tensor tDgD_mn = gD_mn_view(_,_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + Tensor tDgSFD_mn = gSFD_mn(_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + Tensor tDpD_mn = pD_mn_view(_,_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + Tensor tDpSFD_mn = pSFD_mn(_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto Acc = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index()); + Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDpD = thr_t2r.partition_D(tDpD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tTR_rAcc = make_tensor(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrD = make_tensor(shape(tDgD)); + Tensor tTR_rAcc_frag = recast>(coalesce(tTR_rAcc)); + Tensor tDrD_frag = recast>(coalesce(tDrD)); + + Tensor src = thr_r2g.retile_S(tDrD); + Tensor dst = thr_r2g.retile_D(tDgD); + Tensor pSrc = thr_r2g.retile_D(tDpD); + + Tensor tDgSFD_view = make_tensor( + tDgSFD_mn.data(), + make_layout( + make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); + Tensor tDpSFD_view = make_tensor( + tDpSFD_mn.data(), + make_layout( + make_shape(shape(tDpSFD_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tDpSFD_mn), Int<0>{}, Int<0>{}))); + Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); + Tensor tDrSFD = make_tensor(shape(tDgSFD)); + Tensor tDpSFD = filter(thr_t2r.partition_D(tDpSFD_view)); + static int constexpr NumVecs = size(tDgD) / VectorSize; + Tensor tD_rRowSFD_frg = recast>(tDrSFD); + + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // TMEM_LOAD + copy(tiled_t2r, tDtAcc, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + ++accumulator_pipe_consumer_state; + + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with + // unfused kernels + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); + } + + auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales = + cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); + } + auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); + + tD_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}(tD_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}( + qpvscale_ups, + global_decode_scale); + + cutlass::Array acc_scales; + if constexpr (kUseFastMath) { + // fast math: use reciprocal approximate to replace div + acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // regular path for slower math, use divide to replace div + acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + } + + uint4 random_uint4 = uint4{0, 0, 0, 0}; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + // "Prefetch" a stochastic rounding state for the first tile + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], + cutlass::platform::numeric_limits::max()); + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter(cutlass::multiplies>{}(compute_frgs[v], acc_scale), *reinterpret_cast*>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], + acc_scale)); + } + + } + + Tensor pred_pSrc = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(dst), _1{})), [&](auto coord){ + Tensor pSrc_view = group_modes<1,rank(pSrc)>(pSrc); + return elem_less(pSrc_view(_0{},coord), shape(mD)); + }); + copy_if(tiled_r2g, pred_pSrc, src, dst); + // 32bit vectorization copy 4 e4m3 SFD for per 64 or(16,4):(0, 1) element + + constexpr int vec_len = 32 / sizeof_bits_v; + Tensor tDrSFD_v = recast>(tDrSFD); + Tensor tDgSFD_v = recast>(tDgSFD); + copy_if( + [&](auto coord){ + Tensor tDpSFD_view = group_modes<1,rank(tDpSFD)>(tDpSFD); + return elem_less(tDpSFD_view(_0{}, coord * vec_len), shape(mSFD)); + }, + tDrSFD_v, tDgSFD_v); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + } else if (is_epilogue_row_quant_warp) { + cutlass::arch::warpgroup_reg_alloc<136>(); + if constexpr (kEnableRowQuant) { + using S2RVectorType = uint128_t; + float const a_global_amax_val = *a_global_amax; + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % 256; + size_t rng_seed = 0; + size_t rng_offset = 0; + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + Tensor mQA = make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, N), dQA)); + Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); + Tensor pQA = make_identity_tensor(mQA.shape()); + Tensor pQA_mn = local_tile(pQA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); + + Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); + Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); // (BLK_M,BLK_N) + Tensor pSFA = make_identity_tensor(mSFA.shape()); + Tensor pSFA_mn = local_tile(pSFA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); + Tensor sA = as_position_independent_swizzle_tensor( + group_modes<0,2>(coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) + using S2RWarpLayout = Layout>; + using WarpGroupLayout = Layout>; + using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); + using S2RValLayout = Layout, _1>>; + using S2RAtomA = Copy_Atom; + using R2GAtomQA = Copy_Atom; + + auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_QA = make_tiled_copy_D(R2GAtomQA{}, tiled_s2r); + + auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); + auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); + + Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) + + Tensor tQArA = make_tensor_like(make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) + // Tensor tQArA_PI = thr_s2r.partition_S(sA_PI); + Tensor tQAgQA = thr_r2g_QA.partition_D(gQA_mn); + Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); + Tensor tQApQA = thr_r2g_QA.partition_D(pQA_mn); + + Tensor tQAgSFA = thr_s2r.partition_D(gSFA_mn); + Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); + Tensor tQApSFA = thr_s2r.partition_D(pSFA_mn); + + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + float const fp4_max_inv = 1.0f / fp4_max; + float const global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float const global_decode_scale = 1.0f / global_encode_scale; + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + auto sfa_converter = cutlass::NumericConverter{}; + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ) { + auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQApSFA_mn = tQApSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQApQA_mn = tQApQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto barrier_token = mainloop_pipeline.consumer_try_wait( + mainloop_pipe_consumer_state); + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA); + cutlass::arch::fence_view_async_shared(); + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + // static int constexpr NumVecs = size(tQArA) / VectorSize; + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + auto compute_frgs = reinterpret_cast *>(tQArA.data()); + auto output_frgs = reinterpret_cast *>(raw_pointer_cast(tQArQA.data())); + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size(tQArA)/VectorSize; v++) { + auto compute_frgs_up = cutlass::NumericArrayConverter{}(compute_frgs[v]); + auto amax = amax_reduction(ElementAccumulator(0), compute_frgs_up); + // declare pvscales + ElementAccumulator pvscales; + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales = cutlass::multiplies{}(amax, global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales = cutlass::divides{}(amax, fp4_max); + pvscales = cutlass::multiplies{}(pvscales, global_encode_scale); + } + filter(tQArSFA)(v) = sfa_converter(pvscales); + auto qpvscale_ups = cutlass::NumericConverter{}(filter(tQArSFA)(v)); + auto qpvscale_scaled = cutlass::multiplies{}(qpvscale_ups, global_decode_scale); + ElementAccumulator acc_scales; + if constexpr (kUseFastMath) { + // fast math: use reciprocal approximate to replace div + acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // regular path for slower math, use divide to replace div + acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); + } + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales, + cutlass::platform::numeric_limits::max()); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter(cutlass::multiplies>{}(compute_frgs_up, acc_scale), *reinterpret_cast*>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs_up, + acc_scale)); + } + } + + Tensor pred_tQApQA = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(tQAgQA_mn), _1{})), [&](auto coord){ + Tensor tQApQA_view = group_modes<1,rank(tQApQA_mn)>(tQApQA_mn); + return elem_less(tQApQA_view(_0{}, coord), shape(mQA)); + }); + copy_if(tiled_r2g_QA, pred_tQApQA, tQArQA, tQAgQA_mn); + // 32bit vectorization copy 4 e4m3 SFA for per 64 or (16,4):(0, 1) element + constexpr int vec_len = 32 / sizeof_bits_v; + Tensor tQArSFA_v = recast>(filter(tQArSFA)); + Tensor tQAgSFA_v = recast>(filter(tQAgSFA_mn)); + copy_if( + [&](auto coord){ + Tensor tQApSFA_view = filter(tQApSFA_mn); + return elem_less(tQApSFA_view(_0{}, coord * vec_len), shape(mSFA)); + }, + tQArSFA_v, tQAgSFA_v); + } + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + scheduler.update_work_tile_info(); + }while (scheduler.is_valid()); + } + } else { + cutlass::arch::warpgroup_reg_dealloc<32>(); + } +} // NOLINT(readability/fn_size) + + +// this function computes RHT-GEMM for +// m = hidden_size, n = sequence_length +// A: m x n: col-major +// B: 16 x 16: row-major +// D: m x n: row-major +// SFD: m x (n/16): row-major +// QA: m x n: col-major +// SFA: m/16 x n: col-major +template +void row_col_rht_gemm_ntt_w_sfc( + int sequence_length, + int hidden_size, + TA const* A, + TB const* B, + TD* D, + TSFD* SFD, + TQA* QA, + TSFA* SFA, + float const* a_global_amax, + float const* d_global_amax, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 1024) { + using namespace cute; + static int constexpr SFVecSize = 16; + static int constexpr RhtTensorSize = 16; + + static_assert(RhtTensorSize == 16, "RhtTensorSize must be 16"); + using LinearSFALayout = decltype(make_layout(make_shape(make_shape(Int{}, 0), 0), make_stride(make_stride(_0{}, _1{}), 0))); + using LinearSFCLayout = decltype(make_layout(make_shape(0, make_shape(Int{}, 0)), make_stride(0, make_stride(_0{}, _1{})))); + + using SwizzledSFALayoutAtom = cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFALayout = decltype(tile_to_shape(SwizzledSFALayoutAtom{}, make_shape(hidden_size,sequence_length), Step<_1,_2>{})); + using SwizzledSFDLayout = decltype(tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(hidden_size,sequence_length), Step<_2,_1>{})); + + using SFALayout = cute::conditional_t; + using SFCLayout = cute::conditional_t; + SFALayout sfa_layout; + SFCLayout sfd_layout; + + if constexpr (kEnableSwizzleSFOutput) { + sfa_layout = tile_to_shape(SwizzledSFALayoutAtom{}, make_shape(hidden_size, sequence_length), Step<_1,_2>{}); + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(hidden_size, sequence_length), Step<_2,_1>{}); + } else { + sfa_layout = make_layout(make_shape(make_shape(Int{}, hidden_size/SFVecSize), sequence_length), make_stride(make_stride(_0{}, _1{}), hidden_size/SFVecSize)); + sfd_layout = make_layout(make_shape(hidden_size, make_shape(Int{}, sequence_length/SFVecSize)), make_stride(sequence_length/SFVecSize, make_stride(_0{}, _1{}))); + } + // Define shapes (dynamic) + auto M = hidden_size; + auto N = sequence_length; + Tensor tensorA = make_tensor(A, make_shape(hidden_size, sequence_length), LayoutLeft{}); + Tensor tensorB = make_tensor(B, make_shape(RhtTensorSize, RhtTensorSize), LayoutLeft{}); + Tensor tensorD = make_tensor(D, make_shape(hidden_size, sequence_length), LayoutRight{}); + Tensor tensorQA = make_tensor(QA, make_shape(hidden_size, sequence_length), LayoutLeft{}); + Tensor tensorSFD = make_tensor(SFD, sfd_layout); + Tensor tensorSFA = make_tensor(SFA, sfa_layout); + // Define strides (from tensors) + auto dA = stride(tensorA); // (dM,dK) + auto dB = stride(tensorB); // (dN,dK) + auto dD = stride(tensorD); // (dM,dN) + auto dQA = stride(tensorQA); // (dM,dK) + using ClusterShape = Shape< _1, _1, _1>; + auto cluster_shape = ClusterShape{}; + auto cluster_tile_shape = Shape<_128,Int,Int>{}; + auto cluster_tile_mainloop = Shape<_128,Int,_128>{}; + + // Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles + static int constexpr EpilogueUnrollFactor = + size<2>(cluster_tile_mainloop) / size<2>(cluster_tile_shape); + // Construct the MMA + auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS(cluster_tile_shape), size<1>(cluster_tile_shape), + UMMA::Major::MN, UMMA::Major::MN>{}, + Layout>{}); + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cluster_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cluster_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cluster_tile_shape), size<2>(cluster_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div(shape<0>(cluster_tile_shape), shape_div(shape<0>(cluster_tile_shape), size<0>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div(shape<1>(cluster_tile_shape), shape_div(shape<1>(cluster_tile_shape), size<1>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cluster_tile_shape)); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>()); + + auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + static uint32_t constexpr TotalTmemRows = 128; + static uint32_t constexpr Sm100TmemCapacityColumns = 512; + static uint32_t constexpr TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns; + static uint32_t constexpr AccumulatorPipelineStageCount = + TotalTmem / + (cute::size<0>(cluster_tile_shape) * cute::size<1>(cluster_tile_shape)); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int SchedulerPipelineStageCount = 6; + static int constexpr MainloopPipelineBytes = sizeof(typename cutlass::detail::CustomizedPipelineTmaUmmaAsync< + 1, + Shape<_1,_1,_1>, + Shape<_1, _1, _1>>::SharedStorage); + + static int constexpr ClcResponseBytes = sizeof(CLCResponse) * SchedulerPipelineStageCount; + static int constexpr CLCThrottlePipelineBytes = sizeof(typename cutlass::PipelineAsync::SharedStorage); + static int constexpr CLCPipelineBytes = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + static int constexpr TmemDeallocBytes = sizeof(cutlass::arch::ClusterBarrier); + static int constexpr BTensorBytes = cute::size(mma_shape_B) * sizeof(TB); + static int constexpr AccPipelineBytes = sizeof(typename cutlass::PipelineUmmaAsync>::SharedStorage); + static int constexpr TmemBasePtrsBytes = sizeof(uint32_t); + static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes + static int constexpr kBytesPerStage = + cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes; + static int constexpr kReservedBytes = ClcResponseBytes + CLCThrottlePipelineBytes + TmemBasePtrsBytes + + CLCPipelineBytes + TmemDeallocBytes+BTensorBytes + AccPipelineBytes; // Reserve for barriers and other uses + static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + auto sA = UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(mma_shape_A, sP), Step<_2,_1,_3>{}); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(mma_shape_B, _1{})); // (MMA,MMA_N,MMA_K, _1) + auto sD = Layout<_1>{}; // XXX Dummy + + auto tma_load_a = make_tma_copy_A_sm100( + SM90_TMA_LOAD{}, + tensorA, + sA(_,_,_,0), + cluster_tile_mainloop, + mma); + auto tma_load_b = make_tma_copy_B_sm100( + SM90_TMA_LOAD{}, + tensorB, + sB(_,_,_,0), + cluster_tile_shape, + mma); + + // Assert checks problem size should be multiple of 64 + assert(M % 64 == 0); + assert(N % 64 == 0); + + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile_shape)))); + uint32_t tiles_in_n = uint32_t(size(ceil_div(N, k_tile_size))); + uint32_t tiles = tiles_in_m * tiles_in_n; + + dim3 dimBlock(512); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(tiles_in_m, tiles_in_n, 1); + + int smem_size = sizeof( + SharedStorage< + TA, + TB, + decltype(sA), + decltype(sB), + ClusterShape, + AccumulatorPipelineStageCount, + EpilogueUnrollFactor, + SchedulerPipelineStageCount>); + + auto* kernel_ptr = &row_col_rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), + decltype(cluster_shape), decltype(cluster_tile_shape), + TA, decltype(dA), decltype(sA), decltype(tma_load_a), + TB, decltype(dB), decltype(sB), decltype(tma_load_b), + TD, decltype(dD), decltype(sD), + TSFD, decltype(sfd_layout), + TQA, decltype(dQA), + TSFA, decltype(sfa_layout), + decltype(mma), + AccumulatorPipelineStageCount, + SchedulerPipelineStageCount, + kEnableStochasticRounding, + kEnableRHTColQuant, + kEnableRowQuant, + kUseFastMath>; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; + cutlass::Status status = cutlass::launch_kernel_on_cluster( + params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, + tensorA.data(), dA, sA, tma_load_a, + tensorB.data(), dB, sB, tma_load_b, + tensorD.data(), dD, sD, + tensorSFD.data(), sfd_layout, + tensorQA.data(), dQA, + tensorSFA.data(), sfa_layout, + mma, a_global_amax, d_global_amax, rng_state); + + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed."); + +} + +} // namespace +} // namespace detail + +// clang-format on + +void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, + const Tensor &hadamard_matrix_, QuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_cast_fusion); + + // Check input and output tensors + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor &input = input_.data; + + // rowwise cast and columnwise cast has different output data pointers + bool has_rowwise_quant = false; + bool has_columnwise_quant = false; + void *rowwise_data_ptr = nullptr; + void *rowwise_scale_inv_ptr = nullptr; + void *rowwise_amax_ptr = nullptr; + void *columnwise_data_ptr = nullptr; + void *columnwise_scale_inv_ptr = nullptr; + void *columnwise_amax_ptr = nullptr; + + // examine the output tensor (single tensor for dense) + if (output_.data.dptr != nullptr) { + has_rowwise_quant = true; + rowwise_data_ptr = output_.data.dptr; + rowwise_scale_inv_ptr = output_.scale_inv.dptr; + rowwise_amax_ptr = output_.amax.dptr; + } + + if (output_.columnwise_data.dptr != nullptr) { + has_columnwise_quant = true; + columnwise_data_ptr = output_.columnwise_data.dptr; + columnwise_scale_inv_ptr = output_.columnwise_scale_inv.dptr; + columnwise_amax_ptr = output_.columnwise_amax.dptr; + } + + NVTE_CHECK(has_rowwise_quant || has_columnwise_quant, + "Output tensor must have rowwise or columnwise quant."); + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TD = cutlass::float_e2m1_t; + using TSFD = cutlass::float_ue4m3_t; + using TQA = TD; + using TSFA = TSFD; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Hadamard matrix must be BF16 tensor, but scaling mode is ", + to_string(hadamard_matrix_.scaling_mode), "."); + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + auto sm_count = transformer_engine::cuda::sm_count(); + + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + + int k_tile_size = 1024; + + // TODO: add support for swizzle sf output + const bool use_swizzle_sf_output = false; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kEnableStochasticRounding, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + has_columnwise_quant, kEnableRhtColQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + has_rowwise_quant, kEnableRowQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_swizzle_sf_output, kEnableSwizzleSFOutput, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config.use_fast_math, kUseFastMath, + + if constexpr (kEnableRhtColQuant || kEnableRowQuant) { + detail::row_col_rht_gemm_ntt_w_sfc< + kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant, + kEnableSwizzleSFOutput, TA, TB, TD, TSFD, TQA, TSFA, kUseFastMath>( + /*sequence_length=*/m, /*hidden_size=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*D=*/reinterpret_cast(columnwise_data_ptr), + /*SFD=*/reinterpret_cast(columnwise_scale_inv_ptr), + /*QA=*/reinterpret_cast(rowwise_data_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_ptr), + /*a_global_amax=*/reinterpret_cast(rowwise_amax_ptr), + /*d_global_amax=*/reinterpret_cast(columnwise_amax_ptr), + /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + } else { + NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", + kEnableRhtColQuant, ", kEnableRowQuant=", kEnableRowQuant, ")."); + } + + ););););); +} + +} // namespace transformer_engine + +void nvte_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform_cast_fusion); + using namespace transformer_engine; + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + hadamard_transform_cast_fusion(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, + stream); +} diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 13103cc388..27e35dca2c 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -49,6 +49,7 @@ void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int /*! \brief Perform the columnwise hadamard transform cast fusion. * * This function is experimental and the API is not stable. + * This function will later be deprecated and replaced by nvte_hadamard_transform_cast_fusion * * \param[in] input Input tensor to apply Hadamard transform. * \param[in,out] output Output tensor. @@ -61,6 +62,21 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE const NVTEQuantizationConfig quant_config, cudaStream_t stream); +/*! \brief Perform the regular rowwise cast and columnwise hadamard transform cast fusion. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] hadamard_matrix Hadamard matrix. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + /*! \brief Split a tensor along dimension 0 and compute RHT amaxes for each split. * * This function is experimental and the API is not stable. diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 1e1e3326c4..fe90474d14 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -335,6 +335,11 @@ class NVFP4Quantizer : public Quantizer { private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); + void quantize_with_rht_unfused_helper(const TensorWrapper& input, TensorWrapper& out, + TensorWrapper& rht_output_t_cpp, + QuantizationConfigWrapper& quant_config, + QuantizationConfigWrapper& quant_config_columnwise, + cudaStream_t stream); }; std::unique_ptr convert_quantizer(py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 4e5e5223f7..57742b356a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -840,6 +840,11 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // Enable NVFP4 kernels to use math operations that sacrifice // accuracy for performance. These optimizations are experimental // and inconsistently implemented. + // What math is accelerated? Only the high precision math, so numerical impact is minimal + // 1. replace x / y by x * (1/y) + // 2. replace 1 / x by reciporal_approximate_ftz(x) + // 3. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, + // this will essentially remove a round trip between FP32 to BF16 then FP32 const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math) { for (auto &config : quant_config_list) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index a73efc008a..d1c868bb08 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -7,6 +7,7 @@ #include #include "common.h" +#include "common/util/system.h" #include "pybind.h" #include "torch/torch.h" @@ -1443,6 +1444,82 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( return {std::move(out_cpp), std::move(tensor)}; } +void NVFP4Quantizer::quantize_with_rht_unfused_helper( + const TensorWrapper& input, TensorWrapper& out, TensorWrapper& rht_output_t_cpp, + QuantizationConfigWrapper& quant_config, QuantizationConfigWrapper& quant_config_columnwise, + cudaStream_t stream) { + // only triggered for irregular shapes where RHT cast fusion kernel is not eligible + if (rowwise_usage) { + // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise + TensorWrapper out_identity(out.scaling_mode()); + auto out_identity_data = out.get_rowwise_data(); + auto out_identity_scale_inv = out.get_rowwise_scale_inv(); + auto out_identity_amax = out.get_amax(); + out_identity.set_rowwise_data(out_identity_data.data_ptr, + static_cast(out_identity_data.dtype), + out_identity_data.shape); + out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); + out_identity.set_amax(out_identity_amax.data_ptr, static_cast(out_identity_amax.dtype), + out_identity_amax.shape); + + NVTE_SCOPED_GIL_RELEASE( + { nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); }); + } + + if (columnwise_usage) { + // Get the output columnwise data, scale_inv, and amax + auto out_columnwise_data = out.get_columnwise_data(); + auto out_columnwise_scale_inv = out.get_columnwise_scale_inv(); + // NOTE: should already be populated. + auto out_columnwise_amax = out.get_columnwise_amax(); + + // Create a wrapper for the columnwise output, as the rowwise output. + // The reason is due to the input `rht_output_t` is already in the transposed layout. + // Thus, we only need a rowwise quantization to generate the columnwise output. + TensorWrapper out_transpose(out.scaling_mode()); + // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail + // need to convert the shape to 2D here + auto colwise_data_shape = out_columnwise_data.shape; + std::vector colwise_data_shape_2d; + // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte + // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again + // so the multiple 2 get cancelled out + colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); + size_t last_dim = 1; + for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { + last_dim *= colwise_data_shape.data[i]; + } + colwise_data_shape_2d.push_back(last_dim); + + out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, + static_cast(out_columnwise_data.dtype), + colwise_data_shape_2d); + out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, + static_cast(out_columnwise_scale_inv.dtype), + out_columnwise_scale_inv.shape); + out_transpose.set_amax(out_columnwise_amax.data_ptr, + static_cast(out_columnwise_amax.dtype), + out_columnwise_amax.shape); + + // Invoking fallback RHT kernel unfused. + + NVTE_SCOPED_GIL_RELEASE({ + // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. + nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + + // Quantize kernel will treat everything as rowwise input/output, which is + // intended. + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config_columnwise, + stream); + }); + } +} + void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax) { @@ -1454,8 +1531,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou auto stream = at::cuda::getCurrentCUDAStream(); QuantizationConfigWrapper quant_config; + QuantizationConfigWrapper quant_config_columnwise; if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); + quant_config_columnwise.set_noop_tensor(noop_flag->data()); } quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); quant_config.set_stochastic_rounding(this->stochastic_rounding); @@ -1468,14 +1547,25 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } size_t cols = input.size(input.ndim() - 1); + // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT + bool eligible_for_rht_cast_fusion = + input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; + // Stochastic rounding // When both rowwise and columnwise quantization are used with RHT, // we need separate RNG states for each to ensure they use different random numbers. TensorWrapper te_rng_state; TensorWrapper te_rng_state_columnwise; - QuantizationConfigWrapper quant_config_columnwise; - const bool need_separate_columnwise_rng = - this->stochastic_rounding && this->with_rht && this->columnwise_usage; + + // Only need a separate rng state when: + // 1. Stochastic rounding is enabled + // 2. RHT is enabled + // 3. Columnwise usage is enabled + // 4. Rowwise and columnwise quantization are not fused, + // because within a single kernel we can generate two different random numbers for rowwise and columnwise + const bool need_separate_columnwise_rng = this->stochastic_rounding && this->with_rht && + this->columnwise_usage && + (!eligible_for_rht_cast_fusion); if (this->stochastic_rounding) { const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened @@ -1498,13 +1588,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou te_rng_state_columnwise = makeTransformerEngineTensor(rng_state_columnwise); quant_config_columnwise.set_stochastic_rounding(true); quant_config_columnwise.set_rng_state(te_rng_state_columnwise.data()); + quant_config_columnwise.set_nvfp4_2d_quantization(this->with_2d_quantization); } } - // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT - bool eligible_for_rht_cast_fusion = - input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; - // Compute amax. if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { @@ -1570,103 +1657,47 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou { this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); }); } - if (this->with_rht) { - if (rowwise_usage) { - // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise - TensorWrapper out_identity(out.scaling_mode()); - auto out_identity_data = out.get_rowwise_data(); - auto out_identity_scale_inv = out.get_rowwise_scale_inv(); - auto out_identity_amax = out.get_amax(); - out_identity.set_rowwise_data(out_identity_data.data_ptr, - static_cast(out_identity_data.dtype), - out_identity_data.shape); - out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, - static_cast(out_identity_scale_inv.dtype), - out_identity_scale_inv.shape); - out_identity.set_amax(out_identity_amax.data_ptr, static_cast(out_identity_amax.dtype), - out_identity_amax.shape); - - NVTE_SCOPED_GIL_RELEASE( - { nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); }); - } - - if (columnwise_usage) { - // Get the output columnwise data, scale_inv, and amax - auto out_columnwise_data = out.get_columnwise_data(); - auto out_columnwise_scale_inv = out.get_columnwise_scale_inv(); - // NOTE: should already be populated. - auto out_columnwise_amax = out.get_columnwise_amax(); - - // Create a wrapper for the columnwise output, as the rowwise output. - // The reason is due to the input `rht_output_t` is already in the transposed layout. - // Thus, we only need a rowwise quantization to generate the columnwise output. - TensorWrapper out_transpose(out.scaling_mode()); - // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail - // need to convert the shape to 2D here - auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; - // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte - // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again - // so the multiple 2 get cancelled out - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); - size_t last_dim = 1; - for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { - last_dim *= colwise_data_shape.data[i]; - } - colwise_data_shape_2d.push_back(last_dim); - - out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, - static_cast(out_columnwise_data.dtype), - colwise_data_shape_2d); - out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, - static_cast(out_columnwise_scale_inv.dtype), - out_columnwise_scale_inv.shape); - out_transpose.set_amax(out_columnwise_amax.data_ptr, - static_cast(out_columnwise_amax.dtype), - out_columnwise_amax.shape); + // Fast math toggle: RHT transform can be accelerated + // What math is accelerated? Only the high precision math, so numerical impact is minimal + // 1. replace x / y by x * (1/y) + // 2. replace 1 / x by reciporal_approximate_ftz(x) + // 3. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, + // this will essentially remove a round trip between FP32 to BF16 then FP32 + const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); + if (use_fast_math) { + quant_config.set_use_fast_math(true); + quant_config_columnwise.set_use_fast_math(true); + } + if (this->with_rht) { + if (eligible_for_rht_cast_fusion) { + // fusion kernel requires passing in RHT matrix directly for maximum performance + auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); + // Fusion kernel that does the following: + // 1. Rowwise quantization + // 2. RHT followed by columnwise quantization & transpose + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_cast_fusion(input.data(), out.data(), rht_matrix_nvte.data(), + quant_config, stream); + }); + } else { // Use separate RNG state for columnwise to ensure different random numbers than rowwise - auto& columnwise_quant_config = + // This is only necessary because it's the unfused path where rowwise and columnwise + // are separate kernel launches + auto& columnwise_quant_config_to_use = need_separate_columnwise_rng ? quant_config_columnwise : quant_config; - - if (!eligible_for_rht_cast_fusion) { - // Invoking fallback RHT kernel. - - // If using RHT, then amax will be computed in the RHT step - // If not using RHT, then amax will be computed based on input x - at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout - // This wrapper is going to be passed as input to the quantization kernel. - TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs - rht_output_t = - allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); - // NOTE (frsun): This is non-intuitive, we are writing the - // result of transposed RHT to the output of rowwise. - rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), - std::vector{cols, rows}); - - NVTE_SCOPED_GIL_RELEASE({ - // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. - nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, - this->rht_matrix_random_sign_mask_t, stream); - }); - - // Quantize kernel will treat everything as rowwise input/output, which is - // intended. - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), columnwise_quant_config, - stream); - }); - } else { - // RHT cast fusion kernel. - NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, - "RHT matrix is not set"); - auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); - NVTE_SCOPED_GIL_RELEASE({ - nvte_hadamard_transform_cast_fusion_columnwise(input.data(), out_transpose.data(), - rht_matrix_nvte.data(), - columnwise_quant_config, stream); - }); - } + // unfused path also needs memory allocation for intermediate buffer for RHT output + at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout + // This wrapper is going to be passed as input to the quantization kernel. + TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs + rht_output_t = + allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); + // NOTE (frsun): This is non-intuitive, we are writing the + // result of transposed RHT to the output of rowwise. + rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), + std::vector{cols, rows}); + this->quantize_with_rht_unfused_helper(input, out, rht_output_t_cpp, quant_config, + columnwise_quant_config_to_use, stream); } } else { NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); });