1919#include < cuda_bf16.h>
2020#include < cuda_fp16.h>
2121
22- namespace tensorrt_llm ::kernels::moe_a2a
22+ namespace tensorrt_llm ::kernels::mnnvl_throughput
2323{
2424
2525// Configuration constants
@@ -91,7 +91,7 @@ struct MoeA2ADispatchParams
9191
9292 // Token configuration
9393 int local_num_tokens; // Number of tokens on this rank
94- int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation
94+ int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to runtime_max_tokens_per_rank
9595 int top_k; // Number of experts per token
9696
9797 // Expert routing information
@@ -101,23 +101,22 @@ struct MoeA2ADispatchParams
101101 int num_payloads; // Number of different payload types
102102 PayloadDescriptor payloads[kMaxPayloads ]; // Array of payload descriptors
103103
104- // Receive buffers and synchronization
105- void * recv_buffers[kMaxRanks ][kMaxPayloads ]; // Per-rank receive buffers for each payload
104+ // Local aux data
105+ uint32_t * flag_val; // The value of the flag for this round (stored on the local rank)
106+ int * local_token_counter; // Atomic counter for completed tokens on this rank
107+ int * send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank
108+ int * topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank
109+ // per k, -1 for duplicates
110+ int * topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index
111+ // per k, -1 for duplicates
106112
107- // Synchronization
113+ // Distributed aux data and recv buffers
114+ int * recv_counters[kMaxRanks ]; // tracks tokens received from each source rank. Each rank has [ep_size] counters
108115 uint32_t * completion_flags[kMaxRanks ]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
109116 // rank has signaled the target rank
110- uint32_t * flag_val; // The value of the flag for this round (stored on the local rank)
111-
112- // Communication tracking
113- int * send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank
114- int * recv_counters[kMaxRanks ]; // tracks tokens received from each source rank. Each rank has [ep_size] counters
115- int * local_token_counter; // Atomic counter for completed tokens on this rank
116-
117- // Top-K compact routing info per local token (size: [local_num_tokens, top_k])
118- int * topk_target_ranks; // target rank per k, -1 for duplicates
119- int * topk_send_indices; // dst index per k, -1 for duplicates
117+ void * recv_buffers[kMaxRanks ][kMaxPayloads ]; // Per-rank receive buffers for each payload
120118
119+ // CUDA stream
121120 cudaStream_t stream;
122121};
123122
@@ -137,30 +136,33 @@ struct MoeA2ACombineParams
137136
138137 // Token configuration
139138 int local_num_tokens; // Number of tokens on this rank
140- int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation
139+ int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to runtime_max_tokens_per_rank
141140 int top_k; // Number of experts per token
142141
143- // Expert routing information
144- int const * recv_counters; // [ep_size] number of valid tokens per source rank for this target
145-
146- // Top-K compact routing info per local token (size: [local_num_tokens, top_k])
147- int const * topk_target_ranks; // target rank per k, -1 for duplicates
148- int const * topk_send_indices; // dst index per k, -1 for duplicates
142+ // Prepare-only field: original payload tensor pointer used to stage into workspace
143+ void const * prepare_payload;
149144
150- // Single payload information
151- void const * recv_buffers[kMaxRanks ]; // Per-rank receive buffers (only for single payload)
152- void * output_data; // Output buffer [local_num_tokens, elements_per_token]
153- int elements_per_token; // Number of elements per token
154- nvinfer1::DataType dtype; // Data type for proper summation
145+ // Output tensor
146+ void * output_data; // Output buffer [local_num_tokens, elements_per_token]
147+ // Payload information
148+ int elements_per_token; // Number of elements per token
149+ nvinfer1::DataType dtype; // Data type for proper summation
150+
151+ // Local aux data
152+ uint32_t * flag_val; // The value of the flag for this round (stored on the local rank)
153+ int * topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank
154+ // per k, -1 for duplicates
155+ int * topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index
156+ // per k, -1 for duplicates
157+ int const * recv_counters; // [ep_size] number of valid tokens per source rank for this target
155158
156- // Synchronization
159+ // Distributed aux data and recv buffers
157160 uint32_t * completion_flags[kMaxRanks ]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
158161 // rank has signaled the target rank
159- uint32_t * flag_val ; // The value of the flag for this round (stored on the local rank )
162+ void const * recv_buffers[ kMaxRanks ] ; // Per-rank receive buffers (only for single payload )
160163
164+ // CUDA stream
161165 cudaStream_t stream;
162- // Prepare-only field: original payload tensor pointer used to stage into workspace
163- void const * prepare_payload;
164166};
165167
166168// Combine kernels
@@ -175,4 +177,4 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params);
175177void moe_a2a_sanitize_expert_ids_launch (int32_t * expert_ids, int32_t const * recv_counters, int32_t invalid_id,
176178 int ep_size, int max_tokens_per_rank, int top_k, cudaStream_t stream);
177179
178- } // namespace tensorrt_llm::kernels::moe_a2a
180+ } // namespace tensorrt_llm::kernels::mnnvl_throughput
0 commit comments