From 932b3e84e51581508d80859ff89b95ad9c0ca883 Mon Sep 17 00:00:00 2001 From: benzh Date: Fri, 5 Dec 2025 08:05:34 +0000 Subject: [PATCH 1/4] add fp4 gemm + allreduce Signed-off-by: benzh --- .../allreduce_gemm_impl_sm100.h | 2 +- cpp/tensorrt_llm/thop/CMakeLists.txt | 3 +- .../thop/fusedGemmAllreduceOp.cpp | 300 ++++++++++++++++++ .../_torch/custom_ops/torch_custom_ops.py | 96 ++++++ tensorrt_llm/_torch/models/modeling_llama.py | 35 +- tensorrt_llm/_torch/modules/linear.py | 75 ++++- .../unittest/_torch/multi_gpu/test_linear.py | 158 ++++++++- 7 files changed, 651 insertions(+), 18 deletions(-) create mode 100644 cpp/tensorrt_llm/thop/fusedGemmAllreduceOp.cpp diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm100.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm100.h index 8ea96d0b6af..9ff51af6c77 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm100.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm100.h @@ -141,7 +141,7 @@ class GemmAllReduceImplTwoshot_Sm100 : public GemmAllReduceImplInterface // Epilogue //////////////// using FusionCallbacks = cutlass::epilogue::fusion::LinearCombination; - using TileBarrierType = cutlass::MulticastSystemBarrier; + using TileBarrierType = cutlass::MulticastSystemBarrier; using EpilogueScheduleType = typename MmaAdapter::EpilogueSchedule; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using FusionOp diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index 20fcc35b829..4c9fafd226c 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -104,7 +104,8 @@ add_library( loraOp.cpp finegrained_mixed_dtype_gemm_thop.cpp tinygemm2.cpp - dsv3RopeOp.cpp) + dsv3RopeOp.cpp + fusedGemmAllreduceOp.cpp) set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries( th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES} diff --git a/cpp/tensorrt_llm/thop/fusedGemmAllreduceOp.cpp b/cpp/tensorrt_llm/thop/fusedGemmAllreduceOp.cpp new file mode 100644 index 00000000000..d810ea039ee --- /dev/null +++ b/cpp/tensorrt_llm/thop/fusedGemmAllreduceOp.cpp @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass_extensions/gemm_configs.h" + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h" +#include "tensorrt_llm/runtime/ipcNvlsMemory.h" +#include "tensorrt_llm/thop/thUtils.h" + +#include + +#include +#include + +#include +#include +#include +#include + +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner; +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface; +using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes; +using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface; + +namespace +{ +struct AllocationKey +{ + int64_t device_index; + std::set group; + + bool operator==(AllocationKey const& other) const + { + return device_index == other.device_index && group == other.group; + } + + std::string toString() const + { + std::stringstream ss; + ss << "AllocationKey(device: " << device_index << ", group: ["; + for (int rank : group) + { + ss << rank << ", "; + } + ss << "])"; + return ss.str(); + } +}; + +struct AllocationKeyHash +{ + size_t operator()(AllocationKey const& key) const + { + size_t seed = 0; + + // Hash the device index + hash_combine(seed, key.device_index); + + // Hash the set elements + for (auto const& elem : key.group) + { + hash_combine(seed, elem); + } + + return seed; + } + +private: + template + static void hash_combine(size_t& seed, T const& val) + { + seed ^= std::hash()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } +}; + +class IpcNvlsHandleWrapper +{ +public: + IpcNvlsHandleWrapper(size_t size, std::set groups) + : mSize(size) + { + mHandle = tensorrt_llm::runtime::ipcNvlsAllocate(size, groups); + } + + tensorrt_llm::runtime::IpcNvlsHandle* getHandle() const + { + return mHandle; + } + + size_t getSize() const + { + return mSize; + } + + ~IpcNvlsHandleWrapper() + { + tensorrt_llm::runtime::ipcNvlsFree(mHandle); + } + +private: + size_t mSize; + tensorrt_llm::runtime::IpcNvlsHandle* mHandle; +}; + +std::once_flag init_flag; + +size_t getPreferredWorkspaceSize() +{ + // 128MB + static size_t preferredWorkspaceSize = 134217728; + std::call_once(init_flag, + [&]() + { + char const* envWorkspaceSize = std::getenv("TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE"); + size_t workspaceSize = 0; + if (envWorkspaceSize != nullptr) + { + workspaceSize = std::atoi(envWorkspaceSize); + } + preferredWorkspaceSize = std::max(preferredWorkspaceSize, workspaceSize); + }); + return preferredWorkspaceSize; +} + +class GemmAllreduceNvlsMemoryManager +{ +public: + GemmAllreduceNvlsMemoryManager() + { + TLLM_LOG_INFO("GemmAllreduceNvlsMemoryManager constructor"); + } + + ~GemmAllreduceNvlsMemoryManager() + { + TLLM_LOG_INFO("GemmAllreduceNvlsMemoryManager destructor"); + } + + std::pair getWorkspace( + GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs const& problem, + AllocationKey const& key) + { + int M = std::get<0>(problem.problem_size); + int N = std::get<1>(problem.problem_size); + size_t requiredSize = M * N * 2; + size_t preferredWorkspaceSize = getPreferredWorkspaceSize(); + if (requiredSize > preferredWorkspaceSize) + { + std::stringstream ss; + ss << "Please set TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE to at least " << requiredSize << " bytes"; + C10_THROW_ERROR(ErrorAlwaysShowCppStacktrace, ss.str().c_str()); + } + + auto handle = mHandles[key]; + if (handle == nullptr) + { + TLLM_LOG_INFO("Creating allreduce workspace for %s", key.toString().c_str()); + handle = std::make_shared(preferredWorkspaceSize, key.group); + GemmAllReduceImplInterface::ProblemArgs tmpArgs; + int maxN = 16384; + int maxM = preferredWorkspaceSize / (maxN * 2); + tmpArgs.argProblemShape(maxM, maxN, 512, 1) + .argRanks(problem.rank, problem.ranks) + .argLaunchConfig(runner->getSupportedLaunchConfigs()[0]); + auto workspace = runner->getPersistentWorkspace(tmpArgs); + workspace->allocate(); + mWorkspaces[key] = workspace; + mHandles[key] = handle; + } + return std::make_pair(mWorkspaces[key].get(), mHandles[key]->getHandle()); + } + +private: + std::unordered_map, AllocationKeyHash> mWorkspaces; + std::unordered_map, AllocationKeyHash> mHandles; +}; + +GemmAllreduceNvlsMemoryManager* getGemmAllreduceNvlsMemoryManager() +{ + static GemmAllreduceNvlsMemoryManager gNvlsMemoryManager; + return &gNvlsMemoryManager; +} + +at::Tensor runGemmImpl(GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs& problem, + at::ScalarType outputDtype, c10::cuda::CUDAStream stream) +{ + AllocationKey key{stream.device_index(), problem.ranks}; + auto [workspace, handle] = getGemmAllreduceNvlsMemoryManager()->getWorkspace(runner, problem, key); + problem.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data()); + problem.argWorkspace(workspace); + runner->run(problem, stream); + size_t dSize + = std::get<0>(problem.problem_size) * std::get<1>(problem.problem_size) * c10::elementSize(outputDtype); + auto D = at::detail::empty_cuda({std::get<0>(problem.problem_size), std::get<1>(problem.problem_size)}, outputDtype, + stream.device(), std::nullopt); + TLLM_CUDA_CHECK(cudaMemcpyAsync( + D.data_ptr(), reinterpret_cast(handle->uc_ptr), dSize, cudaMemcpyDeviceToDevice, stream)); + return D; +} +} // namespace + +namespace torch_ext +{ + +class Fp4GemmAllreduceRunner : public torch::CustomClassHolder +{ +public: + explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List group) + : mOutputDtype(outputDtype) + , mRank(rank) + { + for (int64_t rank : group) + { + mGroup.insert(static_cast(rank)); + } + + if (outputDtype == at::ScalarType::Half) + { + using Traits = GemmTypes; + mRunner = std::make_shared>(); + } + else if (outputDtype == at::ScalarType::BFloat16) + { + using Traits = GemmTypes; + mRunner = std::make_shared>(); + } + else + { + C10_THROW_ERROR(NotImplementedError, "Unsupported input or output dtype"); + } + + mConfigs = mRunner->getSupportedLaunchConfigs(); + } + + at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, + at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const + { + if (configIdx < 0) + configIdx = 0; + + TORCH_CHECK(configIdx < int64_t(mConfigs.size()), "configIdx out of bounds"); + const int64_t M = mat1.size(0); + const int64_t N = mat2.size(0); + const int64_t K = mat1.size(1) * 2; + + GemmAllReduceImplInterface::ProblemArgs problemArgs; + problemArgs.argProblemShape(M, N, K, 1); + problemArgs.argA(mat1.data_ptr()); + problemArgs.argB(mat2.data_ptr()); + problemArgs.argAScale(mat1Scale.data_ptr()); + problemArgs.argBScale(mat2Scale.data_ptr()); + problemArgs.argC(nullptr); + problemArgs.argAlphaPtr(reinterpret_cast(alpha.const_data_ptr())); + problemArgs.argBeta(0.f); + problemArgs.argRanks(mRank, mGroup); + problemArgs.argLaunchConfig(mConfigs[configIdx]); + + auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device()); + return runGemmImpl(mRunner.get(), problemArgs, mOutputDtype, stream); + } + + int64_t getNumConfigs() const + { + return static_cast(mConfigs.size()); + } + +private: + at::ScalarType mOutputDtype; + int mRank; + std::set mGroup; + std::shared_ptr mRunner{nullptr}; + std::vector mConfigs; +}; + +} // namespace torch_ext + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.class_("Fp4GemmAllreduceRunner") + .def(torch::init>()) + .def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm) + .def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs); +} diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 2ee8d29ccca..6b5088f3dba 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1869,3 +1869,99 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None: stream = get_stream(stream_id) assert stream is not None tensor.record_stream(stream) + + +class Fp4GemmAllreduceRunner(TunableRunner): + runner_dict = dict() + tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), + constraint_specs=(ConstraintSpec( + 2, 0, fp4_scale_infer_shape), )) + + def __init__( + self, + output_dtype: torch.dtype, + tp_rank: int, + tp_group: List[int], + ): + self.output_dtype = output_dtype + self.tp_rank = tp_rank + self.tp_group_str = '-'.join(str(g) for g in tp_group) + instance_key = (output_dtype, self.tp_group_str) + if instance_key not in Fp4GemmAllreduceRunner.runner_dict: + Fp4GemmAllreduceRunner.runner_dict[ + instance_key] = torch.classes.trtllm.Fp4GemmAllreduceRunner( + output_dtype, tp_rank, tp_group) + self.fp4_gemm_all_reduce_runner = Fp4GemmAllreduceRunner.runner_dict[ + instance_key] + + def unique_id(self): + return (self.output_dtype, self.tp_group_str) + + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: + return list(range(self.fp4_gemm_all_reduce_runner.get_num_configs())) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = 0, + ) -> torch.Tensor: + mat1, mat2, mat1_scale, mat2_scale, global_scale = inputs + return self.fp4_gemm_all_reduce_runner.run_gemm( + mat1, + mat2, + mat1_scale, + mat2_scale, + global_scale, + tactic, + ) + + +@torch.library.custom_op("trtllm::nvfp4_gemm_allreduce", mutates_args=()) +def nvfp4_gemm_allreduce( + act_fp4: torch.Tensor, + weight: torch.Tensor, + act_sf: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, + output_dtype: torch.dtype, + tp_rank: int, + tp_group: List[int], +) -> torch.Tensor: + AutoTuner.get() + + # Use Cutlass runner with predefined configs + nvfp4_gemm_allreduce_runner = Fp4GemmAllreduceRunner( + output_dtype, tp_rank, tp_group) + + # TODO: Enable auto-tuning + # runner_type = type(nvfp4_gemm_allreduce_runner).__name__ + # _, best_tactic = tuner.choose_one( + # f"trtllm::nvfp4_gemm_allreduce::{runner_type}", + # [nvfp4_gemm_allreduce_runner], + # nvfp4_gemm_allreduce_runner.tuning_config, + # [act_fp4, weight, act_sf, weight_scale, alpha], + # ) + + best_tactic = -1 + + return nvfp4_gemm_allreduce_runner( + inputs=[act_fp4, weight, act_sf, weight_scale, alpha], + tactic=best_tactic) + + +@nvfp4_gemm_allreduce.register_fake +def _( + act_fp4: torch.Tensor, + weight: torch.Tensor, + act_sf: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, + output_dtype: torch.dtype, + tp_rank: int, + tp_group: List[int], +) -> torch.Tensor: + return act_fp4.new_empty((act_fp4.size(0), weight.size(0)), + dtype=output_dtype) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 464a446cb37..8fdd765cd31 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -14,7 +14,7 @@ AllReduceParams, MoEAllReduce) from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper -from tensorrt_llm._utils import get_sm_version +from tensorrt_llm._utils import get_sm_version, mpi_disabled from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.inputs.multimodal import MultimodalParams from tensorrt_llm.logger import logger @@ -673,9 +673,38 @@ def __init__( # Disable fusion for small models due to accuracy issues self.enable_fusion &= config.hidden_size > 4096 - self.PRE_MLP_FUSION = self.mapping.has_tp( + use_fused_gemm_allreduce = True + use_fused_gemm_allreduce &= (not mpi_disabled()) + use_fused_gemm_allreduce &= (self.mapping.tp_size > 1) + use_fused_gemm_allreduce &= (config.torch_dtype + in (torch.float16, torch.bfloat16)) + use_fused_gemm_allreduce &= (self.is_nvfp4 is not None + and self.is_nvfp4) + + num_heads = config.num_attention_heads + head_dim = getattr(config, 'head_dim', None) + if not isinstance(head_dim, int): + head_dim = config.hidden_size // num_heads + + in_features = num_heads * head_dim + out_features = config.hidden_size + in_features_div_by = 128 + attn_fused_gemm_allreduce = use_fused_gemm_allreduce and in_features % in_features_div_by == 0 and in_features >= 1024 + attn_fused_gemm_allreduce &= (out_features % 64 == 0 + and out_features >= 1024) + + self.PRE_MLP_FUSION = not attn_fused_gemm_allreduce and self.mapping.has_tp( ) and not self.enable_attention_dp and self.enable_fusion - self.POST_MLP_FUSION = self.mapping.has_tp() and self.enable_fusion + + in_features = config.intermediate_size + out_features = config.hidden_size + in_features_div_by = 128 * self.mapping.tp_size + mlp_fused_gemm_allreduce = use_fused_gemm_allreduce and in_features % in_features_div_by == 0 and in_features >= 1024 + mlp_fused_gemm_allreduce &= (out_features % 64 == 0 + and out_features >= 1024) + + self.POST_MLP_FUSION = not mlp_fused_gemm_allreduce and self.mapping.has_tp( + ) and self.enable_fusion if self.is_nvfp4: self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4 diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 44daa25eb3c..8197b0cc2f4 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -14,7 +14,7 @@ import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils from tensorrt_llm._torch.peft.lora.layer import LoraLayer -from tensorrt_llm._utils import is_device_integrated +from tensorrt_llm._utils import is_device_integrated, mpi_disabled from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, AllReduceStrategy) from tensorrt_llm.logger import logger @@ -307,6 +307,11 @@ def apply(self, module: Linear, input: torch.Tensor, bias: Optional[torch.Tensor], *args, **kwargs): raise NotImplementedError + def apply_linear_allreduce(self, module: Linear, input: torch.Tensor, + bias: Optional[torch.Tensor], tp_rank: int, + tp_group: List[int], *args, **kwargs): + raise NotImplementedError + def load_weights(self, module: Linear, weights: List[Dict], @@ -908,8 +913,7 @@ def create_weights(self, module: Linear, in_features: int, else: module.register_parameter("bias", None) - def apply(self, module: Linear, input: torch.Tensor, - bias: Optional[torch.Tensor]): + def _input_prepare(self, module: Linear, input: torch.Tensor): if isinstance(input, Fp4QuantizedTensor): # Input is already quantized - this should not happen if pre_quant_scale exists # because we disable FP4 output for attention output when pre_quant_scale is present @@ -935,7 +939,11 @@ def apply(self, module: Linear, input: torch.Tensor, act_fp4, act_sf = torch.ops.trtllm.fp4_quantize( input, module.input_scale, module.scaling_vector_size, False) + return act_fp4, act_sf + def apply(self, module: Linear, input: torch.Tensor, + bias: Optional[torch.Tensor]): + act_fp4, act_sf = self._input_prepare(module, input) # Use unified interface - supports CUTLASS, cuBLASLt, CuteDSL # Convert list to comma-separated string for torch.compile compatibility allowed_backends_str = ','.join(module.nvfp4_allowed_backends) @@ -956,6 +964,21 @@ def apply(self, module: Linear, input: torch.Tensor, output = output + bias return output + def apply_linear_allreduce(self, module: Linear, input: torch.Tensor, + bias: Optional[torch.Tensor], tp_rank: int, + tp_group: List[int]): + act_fp4, act_sf = self._input_prepare(module, input) + output = torch.ops.trtllm.nvfp4_gemm_allreduce( + act_fp4, module.weight, act_sf, module.weight_scale, module.alpha, + module.dtype, tp_rank, tp_group) + # Take the dim of out_features if padded. Make sure the output is contiguous + if output.shape[-1] > module.out_features: + output = output[..., :module.out_features].contiguous() + + if bias is not None: + output = output + bias + return output + def load_kv_scales(self, weights: List[Dict]): k_scale, v_scale = [], [] for w in weights: @@ -2133,6 +2156,20 @@ def __init__( self.use_custom_cublas_mm = use_custom_cublas_mm self.lora = lora + use_fused_gemm_allreduce = True + use_fused_gemm_allreduce &= (not mpi_disabled()) + use_fused_gemm_allreduce &= self.dtype in (torch.float16, + torch.bfloat16) + use_fused_gemm_allreduce &= (self.in_features % 128 == 0) + use_fused_gemm_allreduce &= (self.tp_mode is not None + and self.tp_mode == TensorParallelMode.ROW) + use_fused_gemm_allreduce &= (self.tp_size > 1 and self.reduce_output) + use_fused_gemm_allreduce &= (self.out_features % 64 == 0) + use_fused_gemm_allreduce &= ( + self.quant_config is not None + and self.quant_config.layer_quant_mode.has_nvfp4()) + self.use_fused_gemm_allreduce = use_fused_gemm_allreduce + self.enable_cuda_core = False if torch.cuda.is_available(): capability = torch.cuda.get_device_capability( @@ -2224,13 +2261,20 @@ def apply_linear(self, lora_params: Optional[dict] | None = None, layer_idx: Optional[int] | None = None): output = self.quant_method.apply(self, input, bias) - if self.lora is not None and bool(lora_params): lora_result = self.lora(input, lora_params, layer_idx) if lora_result is not None: output = output + lora_result return output + def apply_linear_allreduce(self, + input, + bias, + layer_idx: Optional[int] | None = None): + output = self.quant_method.apply_linear_allreduce( + self, input, bias, self.tp_rank, self.mapping.tp_group) + return output + def _maybe_fuse_bias_into_allreduce( self, bias: Optional[torch.Tensor], @@ -2257,16 +2301,23 @@ def forward( layer_idx: Optional[int] = None, ) -> torch.Tensor: if self.tp_mode == TensorParallelMode.ROW: + use_fused_gemm_allreduce = self.use_fused_gemm_allreduce and lora_params is None + if use_fused_gemm_allreduce and all_reduce_params is not None: + use_fused_gemm_allreduce = all_reduce_params.enable_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE + bias = None if (self.tp_rank > 0) else self.bias if self.reduce_output: - fuse_bias = self._maybe_fuse_bias_into_allreduce( - bias, all_reduce_params) - bias = None if fuse_bias else bias - output = self.apply_linear(input, bias, lora_params, layer_idx) - output = self.all_reduce( - output, - all_reduce_params=all_reduce_params, - ) + if use_fused_gemm_allreduce: + output = self.apply_linear_allreduce( + input, self.bias, layer_idx) + else: + fuse_bias = self._maybe_fuse_bias_into_allreduce( + bias, all_reduce_params) + bias = None if fuse_bias else bias + output = self.apply_linear(input, bias, lora_params, + layer_idx) + output = self.all_reduce( + output, all_reduce_params=all_reduce_params) else: output = self.apply_linear(input, bias, lora_params, layer_idx) elif self.tp_mode == TensorParallelMode.COLUMN: diff --git a/tests/unittest/_torch/multi_gpu/test_linear.py b/tests/unittest/_torch/multi_gpu/test_linear.py index d78dc4defa2..839ae6eb6d9 100644 --- a/tests/unittest/_torch/multi_gpu/test_linear.py +++ b/tests/unittest/_torch/multi_gpu/test_linear.py @@ -2,16 +2,26 @@ import sys import traceback +try: + pass +except ImportError: + pass + import cloudpickle import pytest import torch from mpi4py import MPI from torch import nn +from utils.util import skip_pre_blackwell import tensorrt_llm +import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils +from tensorrt_llm._torch.autotuner import autotune from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams from tensorrt_llm.mapping import Mapping +from tensorrt_llm.math_utils import pad_up +from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig cloudpickle.register_pickle_by_value(sys.modules[__name__]) MPI.pickle.__init__( @@ -161,7 +171,6 @@ def row_linear_forward(x, hidden_size, dtype, tensor_parallel_size, ) l0.load_weights([dict(weight=weights[0])]) l0.cuda() - xs = torch.chunk(x, 2, dim=-1) l0 = torch.compile(l0, fullgraph=True) output = l0.forward(xs[tensor_parallel_rank]) @@ -333,3 +342,150 @@ def test_row_linear_norm_fusion(seq_len, hidden_size, mpi_pool_executor): [l0_weight], hidden_size, dtype)] * 2)) for r in results: assert r is True + + +def check_accuracy(a, b, atol, rtol, percent): + assert a.shape == b.shape + assert a.dtype == b.dtype + a = a.to(torch.float32) + b = b.to(torch.float32) + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if not (mismatch_percent < 1 - percent): + raise Exception("Mismatch percentage is %f for rtol %f" % + (mismatch_percent, rtol)) + + +@torch.inference_mode +def fp4_row_linear_allreduce(tp_size, local_rank, seq_len, output_size, + hidden_size, dtype, output_ref, x_sf_global, + w_sf_global, x_fp4s, w_fp4, x_sf_blocks, + w_sf_block_unswizzled): + output_ref = output_ref.cuda() + x_sf_global = x_sf_global.cuda() + w_sf_global = w_sf_global.cuda() + x_fp4 = x_fp4s[local_rank].cuda() + w_fp4 = w_fp4.cuda() + x_sf_block = x_sf_blocks[local_rank].cuda() + w_sf_block_unswizzled = w_sf_block_unswizzled.cuda() + + qc = QuantConfig(quant_algo=QuantAlgo.NVFP4) + l0 = Linear( + in_features=hidden_size, + out_features=output_size, + bias=False, + dtype=dtype, + quant_config=qc, + mapping=Mapping( + world_size=tp_size, + tp_size=tp_size, + rank=local_rank, + ), + tensor_parallel_mode=TensorParallelMode.ROW, + ) + + l0.load_weights([{ + 'input_scale': + 1.0 / x_sf_global.cpu(), + 'weight': + w_fp4.cpu(), + 'weight_scale': + w_sf_block_unswizzled.view(torch.float8_e4m3fn), + 'weight_scale_2': + 1.0 / w_sf_global.cpu() + }]) + + l0.cuda() + # TODO: parameters['weight']' size mismatch at index 0 + # l0 = torch.compile(l0) + with torch.inference_mode(), autotune(): + output = l0.forward((x_fp4, x_sf_block)) + + torch.cuda.synchronize() + check_accuracy(output, output_ref, atol=0.05, rtol=0.05, percent=0.99) + + +def fp4_row_linear_allreduce_run_single_rank(func, tp_size, seq_len, + output_size, hidden_size, dtype, + output_ref, x_sf_global, + w_sf_global, x_fp4s, w_fp4, + x_sf_blocks, + w_sf_block_unswizzled): + local_rank = tensorrt_llm.mpi_rank() + torch.cuda.set_device(local_rank) + + try: + func(tp_size, local_rank, seq_len, output_size, hidden_size, dtype, + output_ref, x_sf_global, w_sf_global, x_fp4s, w_fp4, x_sf_blocks, + w_sf_block_unswizzled) + except Exception: + traceback.print_exc() + raise + return True + + +@skip_pre_blackwell +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason='needs 2 GPUs to run this test') +@pytest.mark.parametrize("seq_len", [256], ids=lambda x: f"seqlen:{x}") +@pytest.mark.parametrize("output_size", [32, 64], ids=lambda x: f"output:{x}") +@pytest.mark.parametrize("hidden_size", [128, 256], ids=lambda x: f"hidden:{x}") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], + ids=lambda x: f"dtype:{x}") +@pytest.mark.parametrize("mpi_pool_executor", [2], + indirect=True, + ids=lambda x: f"tp_size:{x}") +def test_fp4_row_linear_allreduce(seq_len, output_size, hidden_size, dtype, + mpi_pool_executor): + torch.manual_seed(42) + tp_size = mpi_pool_executor.num_workers + + x = torch.randn((seq_len, hidden_size), dtype=dtype).cuda() + w = torch.randn((output_size, hidden_size), dtype=dtype).cuda() + + scaling_vector_size = 16 + x_sf_global = (448 * 6) / x.abs().max().float() + x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize(x, x_sf_global, + scaling_vector_size, + False) + w_sf_global = (448 * 6) / w.abs().max().float() + w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global, + scaling_vector_size, + False) + w_sf_block_unswizzled = (torch.ops.trtllm.block_scale_interleave_reverse( + w_sf_block.cpu().view(pad_up(output_size, 128), -1))) + + with torch.inference_mode(): + alpha_ref = 1.0 / (w_sf_global * x_sf_global) + output_ref = torch.ops.trtllm.fp4_gemm( + x_fp4, w_fp4, x_sf_block, w_sf_block, alpha_ref, + fp4_utils.FP4GemmType.W4A4_NVFP4_NVFP4, dtype) + + torch.cuda.synchronize() + + xs = [x.contiguous().cuda() for x in torch.chunk(x, tp_size, dim=-1)] + x_fp4s = [] + x_sf_blocks = [] + for i in range(tp_size): + _fp4, _sf_block = torch.ops.trtllm.fp4_quantize(xs[i], x_sf_global, + scaling_vector_size, + False) + x_fp4s.append(_fp4.cpu()) + x_sf_blocks.append(_sf_block.cpu()) + + output_ref = output_ref.cpu() + x_sf_global = x_sf_global.cpu() + w_sf_global = w_sf_global.cpu() + w_fp4 = w_fp4.cpu() + w_sf_block_unswizzled = w_sf_block_unswizzled.cpu() + + results = mpi_pool_executor.map( + fp4_row_linear_allreduce_run_single_rank, + *zip(*[(fp4_row_linear_allreduce, tp_size, seq_len, output_size, + hidden_size, dtype, output_ref, x_sf_global, w_sf_global, + x_fp4s, w_fp4, x_sf_blocks, w_sf_block_unswizzled)] * tp_size)) + + for r in results: + assert r is True From 1246ddb7ace27acb386e8b267ea745d7c4dbc07e Mon Sep 17 00:00:00 2001 From: benzh-2025 Date: Wed, 24 Dec 2025 05:07:53 +0000 Subject: [PATCH 2/4] refactor logical check for gemm+allreduce fusion Signed-off-by: benzh-2025 --- tensorrt_llm/_torch/models/modeling_llama.py | 37 +++++++++++-------- tensorrt_llm/_torch/modules/linear.py | 24 ++++++------ .../unittest/_torch/multi_gpu/test_linear.py | 6 +-- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 8fdd765cd31..fc42318661f 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -673,13 +673,17 @@ def __init__( # Disable fusion for small models due to accuracy issues self.enable_fusion &= config.hidden_size > 4096 - use_fused_gemm_allreduce = True - use_fused_gemm_allreduce &= (not mpi_disabled()) - use_fused_gemm_allreduce &= (self.mapping.tp_size > 1) - use_fused_gemm_allreduce &= (config.torch_dtype - in (torch.float16, torch.bfloat16)) - use_fused_gemm_allreduce &= (self.is_nvfp4 is not None - and self.is_nvfp4) + mpi_enabled = not mpi_disabled() + dtype_supported = config.torch_dtype in (torch.float16, torch.bfloat16) + tp_valid = self.mapping.tp_size > 1 + quant_valid = self.is_nvfp4 is not None and self.is_nvfp4 + use_fused_gemm_allreduce = all( + [mpi_enabled, dtype_supported, tp_valid, quant_valid]) + + def check_in_out_features(in_features, out_features): + in_feature_valid = in_features % 128 == 0 and in_features >= 1024 + out_feature_valid = out_features % 64 == 0 and out_features >= 1024 + return all([in_feature_valid, out_feature_valid]) num_heads = config.num_attention_heads head_dim = getattr(config, 'head_dim', None) @@ -688,21 +692,22 @@ def __init__( in_features = num_heads * head_dim out_features = config.hidden_size - in_features_div_by = 128 - attn_fused_gemm_allreduce = use_fused_gemm_allreduce and in_features % in_features_div_by == 0 and in_features >= 1024 - attn_fused_gemm_allreduce &= (out_features % 64 == 0 - and out_features >= 1024) + in_out_features_valid = check_in_out_features(in_features, out_features) + attn_fused_gemm_allreduce = all( + [use_fused_gemm_allreduce, in_out_features_valid]) self.PRE_MLP_FUSION = not attn_fused_gemm_allreduce and self.mapping.has_tp( ) and not self.enable_attention_dp and self.enable_fusion in_features = config.intermediate_size out_features = config.hidden_size - in_features_div_by = 128 * self.mapping.tp_size - mlp_fused_gemm_allreduce = use_fused_gemm_allreduce and in_features % in_features_div_by == 0 and in_features >= 1024 - mlp_fused_gemm_allreduce &= (out_features % 64 == 0 - and out_features >= 1024) - + in_features_aligned_with_tp = in_features % self.mapping.tp_size == 0 + in_out_features_valid = check_in_out_features( + in_features // self.mapping.tp_size, out_features) + mlp_fused_gemm_allreduce = all([ + use_fused_gemm_allreduce, in_features_aligned_with_tp, + in_out_features_valid + ]) self.POST_MLP_FUSION = not mlp_fused_gemm_allreduce and self.mapping.has_tp( ) and self.enable_fusion diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 8197b0cc2f4..43ef695ee33 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -2156,19 +2156,17 @@ def __init__( self.use_custom_cublas_mm = use_custom_cublas_mm self.lora = lora - use_fused_gemm_allreduce = True - use_fused_gemm_allreduce &= (not mpi_disabled()) - use_fused_gemm_allreduce &= self.dtype in (torch.float16, - torch.bfloat16) - use_fused_gemm_allreduce &= (self.in_features % 128 == 0) - use_fused_gemm_allreduce &= (self.tp_mode is not None - and self.tp_mode == TensorParallelMode.ROW) - use_fused_gemm_allreduce &= (self.tp_size > 1 and self.reduce_output) - use_fused_gemm_allreduce &= (self.out_features % 64 == 0) - use_fused_gemm_allreduce &= ( - self.quant_config is not None - and self.quant_config.layer_quant_mode.has_nvfp4()) - self.use_fused_gemm_allreduce = use_fused_gemm_allreduce + mpi_enabled = not mpi_disabled() + dtype_supported = self.dtype in (torch.float16, torch.bfloat16) + in_features_aligned = self.in_features % 128 == 0 + out_features_aligned = self.out_features % 64 == 0 + tp_valid = self.tp_mode is not None and self.tp_mode == TensorParallelMode.ROW and self.tp_size > 1 + quant_valid = self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4( + ) + self.use_fused_gemm_allreduce = all([ + self.reduce_output, mpi_enabled, dtype_supported, + in_features_aligned, out_features_aligned, tp_valid, quant_valid + ]) self.enable_cuda_core = False if torch.cuda.is_available(): diff --git a/tests/unittest/_torch/multi_gpu/test_linear.py b/tests/unittest/_torch/multi_gpu/test_linear.py index 839ae6eb6d9..3593af46c73 100644 --- a/tests/unittest/_torch/multi_gpu/test_linear.py +++ b/tests/unittest/_torch/multi_gpu/test_linear.py @@ -420,8 +420,8 @@ def fp4_row_linear_allreduce_run_single_rank(func, tp_size, seq_len, func(tp_size, local_rank, seq_len, output_size, hidden_size, dtype, output_ref, x_sf_global, w_sf_global, x_fp4s, w_fp4, x_sf_blocks, w_sf_block_unswizzled) - except Exception: - traceback.print_exc() + except Exception as e: + print(f"Error: {e}") raise return True @@ -429,7 +429,7 @@ def fp4_row_linear_allreduce_run_single_rank(func, tp_size, seq_len, @skip_pre_blackwell @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='needs 2 GPUs to run this test') -@pytest.mark.parametrize("seq_len", [256], ids=lambda x: f"seqlen:{x}") +@pytest.mark.parametrize("seq_len", [256, 400], ids=lambda x: f"seqlen:{x}") @pytest.mark.parametrize("output_size", [32, 64], ids=lambda x: f"output:{x}") @pytest.mark.parametrize("hidden_size", [128, 256], ids=lambda x: f"hidden:{x}") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], From c076853f58f5713efb020d081c946311ca71aeed Mon Sep 17 00:00:00 2001 From: benzh-2025 Date: Tue, 30 Dec 2025 06:44:23 +0000 Subject: [PATCH 3/4] support gemm+allreduce only on arch >= blackwell Signed-off-by: benzh-2025 --- tensorrt_llm/_torch/models/modeling_llama.py | 15 +++++++++++++-- tensorrt_llm/_torch/modules/linear.py | 12 +++++++++++- tests/unittest/_torch/multi_gpu/test_linear.py | 5 ----- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index fc42318661f..d4c63efe6fb 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -677,8 +677,19 @@ def __init__( dtype_supported = config.torch_dtype in (torch.float16, torch.bfloat16) tp_valid = self.mapping.tp_size > 1 quant_valid = self.is_nvfp4 is not None and self.is_nvfp4 - use_fused_gemm_allreduce = all( - [mpi_enabled, dtype_supported, tp_valid, quant_valid]) + + device_supported = False + if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability( + torch.device('cuda:0')) + sm_version = capability[0] * 10 + capability[1] + if sm_version >= 100: + device_supported = True + + use_fused_gemm_allreduce = all([ + mpi_enabled, dtype_supported, tp_valid, quant_valid, + device_supported + ]) def check_in_out_features(in_features, out_features): in_feature_valid = in_features % 128 == 0 and in_features >= 1024 diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 43ef695ee33..2a413e301aa 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -2163,9 +2163,19 @@ def __init__( tp_valid = self.tp_mode is not None and self.tp_mode == TensorParallelMode.ROW and self.tp_size > 1 quant_valid = self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4( ) + + device_supported = False + if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability( + torch.device('cuda:0')) + sm_version = capability[0] * 10 + capability[1] + if sm_version >= 100: + device_supported = True + self.use_fused_gemm_allreduce = all([ self.reduce_output, mpi_enabled, dtype_supported, - in_features_aligned, out_features_aligned, tp_valid, quant_valid + in_features_aligned, out_features_aligned, tp_valid, quant_valid, + device_supported ]) self.enable_cuda_core = False diff --git a/tests/unittest/_torch/multi_gpu/test_linear.py b/tests/unittest/_torch/multi_gpu/test_linear.py index 3593af46c73..11466818cf2 100644 --- a/tests/unittest/_torch/multi_gpu/test_linear.py +++ b/tests/unittest/_torch/multi_gpu/test_linear.py @@ -2,11 +2,6 @@ import sys import traceback -try: - pass -except ImportError: - pass - import cloudpickle import pytest import torch From 7596882deda9f29b6142a0f762ae5e80eb2ce62e Mon Sep 17 00:00:00 2001 From: benzh-2025 Date: Wed, 7 Jan 2026 03:07:38 +0000 Subject: [PATCH 4/4] add sm103 support Signed-off-by: benzh-2025 --- .../cutlass_kernels/allreduce_gemm/allreduce_gemm_runner.cu | 1 + tensorrt_llm/_torch/models/modeling_llama.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_runner.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_runner.cu index 2bca57c2297..26dc3116f61 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_runner.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_runner.cu @@ -191,6 +191,7 @@ GemmAllReduceImplRunner::GemmAllReduceImplRunner() break; // Blackwell case 100: + case 103: registry_builder.addSm100(); break; diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index d4c63efe6fb..78d4e63ba1c 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -673,6 +673,8 @@ def __init__( # Disable fusion for small models due to accuracy issues self.enable_fusion &= config.hidden_size > 4096 + enable_gemm_allreduce_fusion = (os.environ.get( + "TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "0") == "1") mpi_enabled = not mpi_disabled() dtype_supported = config.torch_dtype in (torch.float16, torch.bfloat16) tp_valid = self.mapping.tp_size > 1 @@ -687,8 +689,8 @@ def __init__( device_supported = True use_fused_gemm_allreduce = all([ - mpi_enabled, dtype_supported, tp_valid, quant_valid, - device_supported + enable_gemm_allreduce_fusion, mpi_enabled, dtype_supported, + tp_valid, quant_valid, device_supported ]) def check_in_out_features(in_features, out_features):