Skip to content

Commit 6f3922f

Browse files
feat: Low Precision Allreduce for PCIe based GPU (NVIDIA#4344)
This PR adds a customized allreduce to TensorRT-LLM. The new allreduce is used for communication on PCIe-based GPUs via low-precision quantization, which can accelerate the PCIe allreduce process. Signed-off-by: Hui Kang <hkang@nvidia.com> Co-authored-by: Hui Kang <hkang@nvidia.com>
1 parent c8e062b commit 6f3922f

File tree

12 files changed

+2301
-18
lines changed

12 files changed

+2301
-18
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.cu

Lines changed: 1635 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "tensorrt_llm/common/assert.h"
20+
#include "tensorrt_llm/common/cudaUtils.h"
21+
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
22+
#include <NvInferRuntime.h>
23+
#include <cuda_bf16.h>
24+
#include <cuda_fp16.h>
25+
#include <vector>
26+
27+
namespace tensorrt_llm::kernels
28+
{
29+
30+
constexpr int LP_ALLREDUCE_MAX_BLOCKS = 8;
31+
constexpr int LP_ALLREDUCE_WARPSIZE = 32;
32+
constexpr int LP_ALLREDUCE_DEFAULT_BLOCK_SIZE = 512;
33+
constexpr int LP_ALLREDUCE_WARP_NUM_PER_BLOCK = 16;
34+
constexpr int LP_ALLREDUCE_BYTES_PER_LOAD = 16;
35+
constexpr int LP_ALLREDUCE_NUMA_NUM = 2;
36+
constexpr int LP_ALLREDUCE_MAX_RANKS_PER_NUMA = 4;
37+
constexpr int LP_ALLREDUCE_BUFFER_DUPLICATE = 16;
38+
constexpr int LP_ALLREDUCE_BUFFER_CHUNKS = 8;
39+
constexpr int LP_ALLREDUCE_HIER_STAGE_NUM = 3;
40+
constexpr int LP_ALLREDUCE_RANKS_PER_NUMA = 4;
41+
constexpr int LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE = 32 * 1024 * 1024;
42+
constexpr int LP_ALLREDUCE_MIN_ELTS_THRESHOLD = 8 * 1024 * 1024;
43+
constexpr int LP_ALLREDUCE_MAX_TP_SIZE = 8;
44+
constexpr int LP_ALLREDUCE_MAX_RANKS_PER_NODE = 16;
45+
46+
struct StaticLowPrecisionBuffers
47+
{
48+
void* peer_comm_buffer_ptrs[LP_ALLREDUCE_MAX_TP_SIZE * 2];
49+
uint64_t* peer_barrier_ptrs_in[LP_ALLREDUCE_MAX_TP_SIZE];
50+
uint64_t* peer_barrier_ptrs_out[LP_ALLREDUCE_MAX_TP_SIZE];
51+
int64_t* flag_ptr;
52+
bool initialized = false;
53+
size_t tpSize = 0;
54+
};
55+
56+
void initialize_static_lowprecision_buffers(int64_t* buffer, size_t tpSize);
57+
58+
std::vector<size_t> splitNumber(size_t number);
59+
60+
struct LowPrecisionAllReduceParams
61+
{
62+
size_t elts_total;
63+
size_t elts_per_rank;
64+
size_t elts_per_block;
65+
size_t rank_offset;
66+
int32_t ranks_per_node, rank, local_rank;
67+
uint64_t barrier_flag;
68+
uint64_t* peer_barrier_ptrs_in[LP_ALLREDUCE_MAX_RANKS_PER_NODE];
69+
uint64_t* peer_barrier_ptrs_out[LP_ALLREDUCE_MAX_RANKS_PER_NODE];
70+
void* peer_comm_buffer_ptrs[LP_ALLREDUCE_MAX_RANKS_PER_NODE];
71+
void* local_output_buffer_ptr;
72+
void const* local_input_buffer_ptr;
73+
74+
// for low precision
75+
size_t buffer_elts_per_rank;
76+
size_t buffer_offset;
77+
78+
// for low precision hier
79+
uint32_t num_rounds = 0;
80+
uint32_t num_rounds_fence = 0;
81+
uint32_t block_num = 0;
82+
int32_t numa_rank = -1;
83+
84+
void* inputs_inside_numa[4];
85+
86+
void* rs_buffers[LP_ALLREDUCE_MAX_BLOCKS];
87+
void* ar_buffers[LP_ALLREDUCE_MAX_BLOCKS];
88+
void* ar_peer_buffers_cross_numa[LP_ALLREDUCE_MAX_BLOCKS];
89+
void* ag_peer_buffers_inside_numa[LP_ALLREDUCE_MAX_BLOCKS * 4];
90+
91+
// for low precision hier handshake rs stage
92+
uint64_t* rs_send_flags[LP_ALLREDUCE_MAX_BLOCKS];
93+
uint64_t* rs_ack_flags[LP_ALLREDUCE_MAX_BLOCKS]; // 2*flags
94+
uint64_t* rs_notify_local_flags[LP_ALLREDUCE_MAX_BLOCKS];
95+
uint64_t* rs_notify_remote_flags[LP_ALLREDUCE_MAX_BLOCKS];
96+
97+
// for low precision hier handshake ar stage
98+
uint64_t* ar_send_flags[LP_ALLREDUCE_MAX_BLOCKS];
99+
uint64_t* ar_ack_peer_rs_flags[LP_ALLREDUCE_MAX_BLOCKS];
100+
uint64_t* ar_ack_flags[LP_ALLREDUCE_MAX_BLOCKS];
101+
uint64_t* ar_notify_rs_local_flags[LP_ALLREDUCE_MAX_BLOCKS];
102+
uint64_t* ar_notify_rs_remote_flags[LP_ALLREDUCE_MAX_BLOCKS];
103+
uint64_t* ar_notify_ag_flags[LP_ALLREDUCE_MAX_BLOCKS];
104+
105+
// for low precision hier handshake ag stage
106+
uint64_t* ag_send_flags[LP_ALLREDUCE_MAX_BLOCKS];
107+
uint64_t* ag_ack_peer_inside_numa_flags[LP_ALLREDUCE_MAX_BLOCKS]; // 3*flags , 3 is other rank inside numa
108+
uint64_t* ag_notify_peer_inside_numa_flags[LP_ALLREDUCE_MAX_BLOCKS * 4]; // 3*flags , 3 is other rank inside numa
109+
110+
static LowPrecisionAllReduceParams deserialize(
111+
size_t tpSize, size_t tpRank, nvinfer1::DataType dataType, int token_num, int hidden_size);
112+
static LowPrecisionAllReduceParams deserialize_hier(
113+
size_t tpSize, size_t tpRank, nvinfer1::DataType dataType, int token_num, int hidden_size);
114+
};
115+
116+
bool lowPrecisionConfigurationSupported(size_t msg_size, size_t n_ranks);
117+
118+
void customLowPrecisionAllReduce(
119+
kernels::LowPrecisionAllReduceParams& params, nvinfer1::DataType dataType, cudaStream_t stream);
120+
121+
int32_t max_workspace_size_lowprecision(int32_t tp_size);
122+
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/kernels/customAllReduceKernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ enum class AllReduceStrategyType : int8_t
5151
AUTO = 3,
5252
ONESHOT = 4,
5353
TWOSHOT = 5,
54+
LOWPRECISION = 6,
5455
};
5556

5657
enum class AllReduceStrategyConfig : int8_t

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "bindings.h"
1919
#include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
2020
#include "tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h"
21+
#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h"
2122
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
2223
#include "tensorrt_llm/kernels/delayStream.h"
2324
#include "tensorrt_llm/runtime/cudaEvent.h"
@@ -393,6 +394,10 @@ void initBindings(pybind11::module_& m)
393394
tensorrt_llm::kernels::invokeDelayStreamKernel(delay_micro_secs, stream);
394395
},
395396
"Delay kernel launch on the default stream");
397+
m.def(
398+
"max_workspace_size_lowprecision",
399+
[](int32_t tp_size) { return tensorrt_llm::kernels::max_workspace_size_lowprecision(tp_size); },
400+
"Calculate the maximum workspace size needed for low precision all-reduce operations");
396401

397402
py::enum_<tensorrt_llm::kernels::AllReduceFusionOp>(m, "AllReduceFusionOp")
398403
.value("NONE", tensorrt_llm::kernels::AllReduceFusionOp::NONE)

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "tensorrt_llm/common/dataType.h"
2121
#include "tensorrt_llm/common/opUtils.h"
2222
#include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
23+
#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h"
2324
#include "tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h"
2425
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
2526
#include "tensorrt_llm/kernels/internal_cutlass_kernels/include/fp4_gemm.h"
@@ -177,6 +178,8 @@ class AllreduceOp
177178
case AllReduceStrategyType::ONESHOT:
178179
case AllReduceStrategyType::TWOSHOT:
179180
return runFusionAllReduce(input, residual, norm_weight, scale, bias, workspace, runtime_strategy);
181+
case AllReduceStrategyType::LOWPRECISION:
182+
return runLowPrecisionAllReduce(input, residual, norm_weight, scale, bias);
180183
default: TORCH_CHECK(false, "Invalid runtime strategy"); return {};
181184
}
182185
}
@@ -296,6 +299,73 @@ class AllreduceOp
296299
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
297300
}
298301

302+
std::vector<torch::Tensor> runLowPrecisionAllReduce(torch::Tensor const& input,
303+
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
304+
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
305+
{
306+
#ifdef ENABLE_FP8
307+
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
308+
int size = input.numel();
309+
int hidden_size = input.size(-1);
310+
311+
auto const tp_size = mGroup.size();
312+
auto const cur_rank = COMM_SESSION.getRank();
313+
int tp_rank = 0;
314+
315+
for (auto const& currentRank : mGroup)
316+
{
317+
if (cur_rank == currentRank)
318+
break;
319+
++tp_rank;
320+
}
321+
322+
int bytes_per_element = input.element_size();
323+
324+
int token_num = size / hidden_size;
325+
326+
auto parts = tensorrt_llm::kernels::splitNumber(size);
327+
328+
torch::Tensor reduce_output = torch::empty_like(input);
329+
330+
size_t global_offset = 0;
331+
for (size_t i = 0; i < parts.size(); ++i)
332+
{
333+
size_t tmp_size = parts[i];
334+
tensorrt_llm::kernels::LowPrecisionAllReduceParams tmp_param;
335+
if (tp_size <= 4)
336+
{
337+
tmp_param = tensorrt_llm::kernels::LowPrecisionAllReduceParams::deserialize(
338+
tp_size, tp_rank, mType, token_num, hidden_size);
339+
}
340+
else
341+
{
342+
tmp_param = tensorrt_llm::kernels::LowPrecisionAllReduceParams::deserialize_hier(
343+
tp_size, tp_rank, mType, token_num, hidden_size);
344+
}
345+
346+
tmp_param.local_input_buffer_ptr = reinterpret_cast<void const*>(
347+
reinterpret_cast<char const*>(input.data_ptr()) + global_offset * bytes_per_element);
348+
tmp_param.local_output_buffer_ptr = reinterpret_cast<void*>(
349+
reinterpret_cast<char*>(reduce_output.mutable_data_ptr()) + global_offset * bytes_per_element);
350+
tmp_param.elts_total = tmp_size;
351+
tensorrt_llm::kernels::customLowPrecisionAllReduce(tmp_param, mType, stream);
352+
353+
global_offset += tmp_size;
354+
}
355+
356+
if (mOp == AllReduceFusionOp::NONE)
357+
{
358+
return {reduce_output};
359+
}
360+
361+
// Treat any other patterns as fallback cases.
362+
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
363+
364+
#else
365+
C10_THROW_ERROR(NotImplementedError, "Can't use LOWPRECISION without compile with ENABLE FP8.");
366+
#endif
367+
}
368+
299369
std::vector<torch::Tensor> runFusionAllReduce(torch::Tensor const& input,
300370
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
301371
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias,
@@ -594,6 +664,11 @@ class AllreduceOp
594664
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UB", rank);
595665
break;
596666
}
667+
case AllReduceStrategyType::LOWPRECISION:
668+
{
669+
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank);
670+
break;
671+
}
597672
default: break;
598673
}
599674
}
@@ -766,7 +841,21 @@ class AllreduceOp
766841
AllReduceStrategyType selectImplementation(
767842
size_t seq_len, size_t message_size, int world_size, nvinfer1::DataType type) noexcept
768843
{
844+
845+
if (isUsingLowPrecision(message_size))
846+
{
847+
return AllReduceStrategyType::LOWPRECISION;
848+
}
849+
else
850+
{
851+
if (mStrategy == AllReduceStrategyType::LOWPRECISION)
852+
{
853+
mStrategy = AllReduceStrategyType::AUTO;
854+
}
855+
}
856+
769857
// Check that heuristic is only applied when AUTO is set.
858+
// Use Auto select
770859
bool const is_auto = (mStrategy == AllReduceStrategyType::AUTO);
771860
auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize(type);
772861
auto const max_workspace_size
@@ -847,6 +936,24 @@ class AllreduceOp
847936
return strategy;
848937
}
849938

939+
bool isUsingLowPrecision(size_t message_size) const noexcept
940+
{
941+
static char* force_low_precision_allreduce_strategy_char
942+
= std::getenv("FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY");
943+
bool force_low_precision = (force_low_precision_allreduce_strategy_char != nullptr)
944+
|| (mStrategy == AllReduceStrategyType::LOWPRECISION);
945+
946+
#ifdef ENABLE_FP8
947+
// Use LowPrecision if PCIe and p2p support and message size is larger than 2MB
948+
constexpr int LowPrecisionMinMessageSize = 2 * 1024 * 1024;
949+
return force_low_precision && !mIsNVLINKSupported && mIsP2PSupported
950+
&& message_size >= LowPrecisionMinMessageSize;
951+
#else
952+
// Low precision is not available when FP8 is not enabled
953+
return false;
954+
#endif
955+
}
956+
850957
private:
851958
std::set<int> mGroup;
852959
bool mIsNVLINKSupported;
@@ -966,10 +1073,22 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
9661073
"int rank,"
9671074
"int nranks,"
9681075
"float eps) -> Tensor[]");
1076+
m.def("initialize_static_lowprecision_buffers(Tensor workspace, int tp_size) -> Tensor[]");
9691077
}
9701078

9711079
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
9721080
{
9731081
m.impl("allreduce", &torch_ext::allreduce);
9741082
m.impl("moe_allreduce", &torch_ext::moe_allreduce);
9751083
}
1084+
1085+
TORCH_LIBRARY_IMPL(trtllm, CPU, m)
1086+
{
1087+
m.impl("initialize_static_lowprecision_buffers",
1088+
[](at::Tensor const& workspace, int64_t tp_size)
1089+
{
1090+
tensorrt_llm::kernels::initialize_static_lowprecision_buffers(
1091+
reinterpret_cast<int64_t*>(workspace.data_ptr()), (int) tp_size);
1092+
return std::vector<at::Tensor>{};
1093+
});
1094+
}

cpp/tensorrt_llm/thop/thUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,5 @@ int nextPowerOfTwo(int v);
8989
std::optional<float> getFloatEnv(char const* name);
9090

9191
cudaDataType_t convert_torch_dtype(torch::ScalarType dtype);
92+
9293
} // namespace torch_ext
261 KB
Loading
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Low-Precision-AllReduce
2+
3+
```{note}
4+
Note:
5+
This feature is optimized for PCIe-based GPU topologies and may affect model accuracy. Please evaluate precision impact for your specific workload.
6+
```
7+
8+
9+
TRT-LLM supports `low-precision-allreduce`, a communication optimization that accelerates AllReduce operations in PCIe-based GPU environments. This feature quantizes FP16/BF16 data to FP8 during network transmission, reducing communication volume and improving performance.
10+
11+
## Algorithm
12+
13+
The Low-Precision-AllReduce algorithm works by:
14+
1. Quantizing input FP16/BF16 tensors to FP8 format before network transmission
15+
16+
17+
**Quantization details**: We use a "per-warp" quantization approach where each CUDA warp (32 threads) processes a batch of data. In each warp, 31 threads quantize FP16/BF16 values to FP8 e4m3 format (16 bytes per thread), while the last thread transmits a scalar value. This results in each warp collectively quantizing 496 elements plus one scalar at a time.
18+
19+
2. Transmitting the quantized data through the network
20+
3. Dequantizing received data back to the original precision
21+
4. Performing the reduction operation
22+
23+
In 8-GPU scenarios, this approach shifts the communication bottleneck from cross-NUMA QPI to the PCIe switch, resulting in better overall performance.
24+
25+
## Topology Requirements
26+
27+
![8x L20/L40s Node Architecture](images/8x_l20_L40S_node_architecture.png)
28+
29+
Low-Precision-AllReduce is specifically designed for the topology shown above, where:
30+
- Each node contains 2 NUMA domains
31+
- Each NUMA domain has 4 GPUs connected via PCIe switch
32+
- GPUs within the same NUMA node communicate via the PCIe switch
33+
34+
**Important:** This optimization will not accelerate performance in different topologies (e.g., where each GPU is in a separate NUMA domain).
35+
36+
## Usage
37+
38+
The Low-Precision-AllReduce algorithm can be enabled in two ways:
39+
40+
1. **Direct specification** in your code:
41+
```
42+
AllReduce allreduce(mapping=mapping, strategy=AllReduceStrategy.LOWPRECISION);
43+
```
44+
2. **Environment variable control** with AUTO strategy:
45+
```
46+
// In your code
47+
AllReduce allreduce(mapping=mapping, strategy=AllReduceStrategy.AUTO);
48+
// Set environment variable before running
49+
export FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY=1
50+
```
51+
52+
## Performance and Accuracy Considerations
53+
54+
Low-Precision-AllReduce reduces communication volume by using FP8 data format for transmission. This optimization:
55+
- Improves performance for large message sizes in PCIe-based topologies
56+
- May slightly reduce numerical precision
57+
- Automatically falls back to other strategies when no performance benefit is expected (e.g., with NVLink or small messages)
58+
59+
Users should evaluate the precision impact on their specific models and workloads.
60+
61+
## Environment Variables
62+
63+
- `FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY`: When set to `1`, forces the use of low-precision algorithm with AUTO strategy. If the algorithm determines it cannot provide performance benefits, it will automatically fall back to other strategies.
64+
65+
**Note**: When compiling TensorRT-LLM without enabling the `ENABLE_FP8` option, setting Low Precision allreduce will not take effect.

0 commit comments

Comments
 (0)