Skip to content

Commit 6aa875f

Browse files
bobbolifredricz-20070104
authored andcommitted
[None][feat] Integrate MnnvlThroughput into TRTLLM MoE. (NVIDIA#8728)
Signed-off-by: Bo Li <[email protected]> Signed-off-by: FredricZ-2007 <[email protected]>
1 parent 30892da commit 6aa875f

File tree

12 files changed

+736
-648
lines changed

12 files changed

+736
-648
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include <cstdint>
2424
#include <type_traits>
2525

26-
namespace tensorrt_llm::kernels::moe_a2a
26+
namespace tensorrt_llm::kernels::mnnvl_throughput
2727
{
2828

2929
#define ENABLE_DEBUG_PRINT 0
@@ -506,7 +506,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
506506
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);
507507

508508
// Prepare kernel pointers struct
509-
DispatchKernelPointers kernel_ptrs = {}; // Zero-initialize
509+
DispatchKernelPointers kernel_ptrs = {};
510510

511511
// Fill source data pointers and payload sizes
512512
for (int i = 0; i < params.num_payloads; i++)
@@ -958,4 +958,4 @@ void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv
958958
expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id);
959959
}
960960

961-
} // namespace tensorrt_llm::kernels::moe_a2a
961+
} // namespace tensorrt_llm::kernels::mnnvl_throughput

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <cuda_bf16.h>
2020
#include <cuda_fp16.h>
2121

22-
namespace tensorrt_llm::kernels::moe_a2a
22+
namespace tensorrt_llm::kernels::mnnvl_throughput
2323
{
2424

2525
// Configuration constants
@@ -91,7 +91,7 @@ struct MoeA2ADispatchParams
9191

9292
// Token configuration
9393
int local_num_tokens; // Number of tokens on this rank
94-
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation
94+
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to runtime_max_tokens_per_rank
9595
int top_k; // Number of experts per token
9696

9797
// Expert routing information
@@ -101,23 +101,22 @@ struct MoeA2ADispatchParams
101101
int num_payloads; // Number of different payload types
102102
PayloadDescriptor payloads[kMaxPayloads]; // Array of payload descriptors
103103

104-
// Receive buffers and synchronization
105-
void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload
104+
// Local aux data
105+
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
106+
int* local_token_counter; // Atomic counter for completed tokens on this rank
107+
int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank
108+
int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank
109+
// per k, -1 for duplicates
110+
int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index
111+
// per k, -1 for duplicates
106112

107-
// Synchronization
113+
// Distributed aux data and recv buffers
114+
int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has [ep_size] counters
108115
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
109116
// rank has signaled the target rank
110-
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
111-
112-
// Communication tracking
113-
int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank
114-
int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has [ep_size] counters
115-
int* local_token_counter; // Atomic counter for completed tokens on this rank
116-
117-
// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
118-
int* topk_target_ranks; // target rank per k, -1 for duplicates
119-
int* topk_send_indices; // dst index per k, -1 for duplicates
117+
void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload
120118

119+
// CUDA stream
121120
cudaStream_t stream;
122121
};
123122

@@ -137,30 +136,33 @@ struct MoeA2ACombineParams
137136

138137
// Token configuration
139138
int local_num_tokens; // Number of tokens on this rank
140-
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation
139+
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to runtime_max_tokens_per_rank
141140
int top_k; // Number of experts per token
142141

143-
// Expert routing information
144-
int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target
145-
146-
// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
147-
int const* topk_target_ranks; // target rank per k, -1 for duplicates
148-
int const* topk_send_indices; // dst index per k, -1 for duplicates
142+
// Prepare-only field: original payload tensor pointer used to stage into workspace
143+
void const* prepare_payload;
149144

150-
// Single payload information
151-
void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload)
152-
void* output_data; // Output buffer [local_num_tokens, elements_per_token]
153-
int elements_per_token; // Number of elements per token
154-
nvinfer1::DataType dtype; // Data type for proper summation
145+
// Output tensor
146+
void* output_data; // Output buffer [local_num_tokens, elements_per_token]
147+
// Payload information
148+
int elements_per_token; // Number of elements per token
149+
nvinfer1::DataType dtype; // Data type for proper summation
150+
151+
// Local aux data
152+
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
153+
int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank
154+
// per k, -1 for duplicates
155+
int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index
156+
// per k, -1 for duplicates
157+
int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target
155158

156-
// Synchronization
159+
// Distributed aux data and recv buffers
157160
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
158161
// rank has signaled the target rank
159-
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
162+
void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload)
160163

164+
// CUDA stream
161165
cudaStream_t stream;
162-
// Prepare-only field: original payload tensor pointer used to stage into workspace
163-
void const* prepare_payload;
164166
};
165167

166168
// Combine kernels
@@ -175,4 +177,4 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params);
175177
void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, int32_t invalid_id,
176178
int ep_size, int max_tokens_per_rank, int top_k, cudaStream_t stream);
177179

178-
} // namespace tensorrt_llm::kernels::moe_a2a
180+
} // namespace tensorrt_llm::kernels::mnnvl_throughput

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace tensorrt_llm::nanobind::thop
3030
void initBindings(nb::module_& m)
3131
{
3232
// Export MoE A2A constants
33-
for (auto const& kv : torch_ext::getMoeA2AMetaInfoIndexPairs())
33+
for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs())
3434
{
3535
m.attr(kv.first) = kv.second;
3636
}

cpp/tensorrt_llm/pybind/thop/bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace tensorrt_llm::pybind::thop
3030
void initBindings(pybind11::module_& m)
3131
{
3232
// Export MoE A2A constants
33-
for (auto const& kv : torch_ext::getMoeA2AMetaInfoIndexPairs())
33+
for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs())
3434
{
3535
m.attr(kv.first) = py::int_(kv.second);
3636
}

cpp/tensorrt_llm/thop/moeAlltoAllMeta.h

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616

1717
#pragma once
1818

19+
#include <array>
1920
#include <cstdint>
2021
#include <utility>
2122
#include <vector>
2223

2324
namespace torch_ext
2425
{
26+
namespace mnnvl_throughput
27+
{
2528

2629
// Enum for indexing into moe_a2a_metainfo tensor
27-
enum MoeA2AMetaInfoIndex
30+
enum MoeA2AMetaInfoIndex : int64_t
2831
{
2932
FLAG_VAL_OFFSET_INDEX = 0,
3033
LOCAL_TOKEN_COUNTER_OFFSET_INDEX = 1,
@@ -34,21 +37,29 @@ enum MoeA2AMetaInfoIndex
3437
DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = 4,
3538
// Combine completion flags offset
3639
COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = 5,
37-
PAYLOAD_DATA_OFFSET_INDEX = 6,
38-
NUM_METAINFO_FIELDS = 7
40+
TOPK_TARGET_RANKS_OFFSET_INDEX = 6,
41+
TOPK_SEND_INDICES_OFFSET_INDEX = 7,
42+
PAYLOAD_DATA_OFFSET_INDEX = 8,
43+
NUM_METAINFO_FIELDS = 9
3944
};
4045

46+
using MoeA2ADataOffsets = std::array<int64_t, NUM_METAINFO_FIELDS>;
47+
4148
inline std::vector<std::pair<char const*, int64_t>> getMoeA2AMetaInfoIndexPairs()
4249
{
4350
return {
44-
{"MOE_A2A_FLAG_VAL_OFFSET_INDEX", static_cast<int64_t>(FLAG_VAL_OFFSET_INDEX)},
45-
{"MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX", static_cast<int64_t>(LOCAL_TOKEN_COUNTER_OFFSET_INDEX)},
46-
{"MOE_A2A_SEND_COUNTERS_OFFSET_INDEX", static_cast<int64_t>(SEND_COUNTERS_OFFSET_INDEX)},
47-
{"MOE_A2A_RECV_COUNTERS_OFFSET_INDEX", static_cast<int64_t>(RECV_COUNTERS_OFFSET_INDEX)},
48-
{"MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX",
49-
static_cast<int64_t>(DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX)},
50-
{"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", static_cast<int64_t>(COMBINE_COMPLETION_FLAGS_OFFSET_INDEX)},
51-
{"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", static_cast<int64_t>(PAYLOAD_DATA_OFFSET_INDEX)},
51+
{"MOE_A2A_FLAG_VAL_OFFSET_INDEX", FLAG_VAL_OFFSET_INDEX},
52+
{"MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX", LOCAL_TOKEN_COUNTER_OFFSET_INDEX},
53+
{"MOE_A2A_SEND_COUNTERS_OFFSET_INDEX", SEND_COUNTERS_OFFSET_INDEX},
54+
{"MOE_A2A_RECV_COUNTERS_OFFSET_INDEX", RECV_COUNTERS_OFFSET_INDEX},
55+
{"MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX", DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX},
56+
{"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", COMBINE_COMPLETION_FLAGS_OFFSET_INDEX},
57+
{"MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX", TOPK_TARGET_RANKS_OFFSET_INDEX},
58+
{"MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX", TOPK_SEND_INDICES_OFFSET_INDEX},
59+
{"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", PAYLOAD_DATA_OFFSET_INDEX},
60+
{"MOE_A2A_NUM_METAINFO_FIELDS", NUM_METAINFO_FIELDS},
5261
};
5362
}
63+
64+
} // namespace mnnvl_throughput
5465
} // namespace torch_ext

0 commit comments

Comments
 (0)