|
28 | 28 | namespace torch_ext |
29 | 29 | { |
30 | 30 |
|
31 | | -namespace mnnvl_throughput |
| 31 | +namespace moe_comm |
32 | 32 | { |
33 | 33 |
|
34 | 34 | // TODO: Is Alignment necessary? |
@@ -78,13 +78,13 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens) |
78 | 78 | // topk_target_ranks: [maxNumTokens, kMaxTopK] |
79 | 79 | offset = alignOffset(offset, CACHELINE_ALIGNMENT); |
80 | 80 | offsets[TOPK_TARGET_RANKS_OFFSET_INDEX] = offset; |
81 | | - offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK) |
| 81 | + offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::moe_comm::kMaxTopK) |
82 | 82 | * SIZEOF_INT32; |
83 | 83 |
|
84 | 84 | // topk_send_indices: [maxNumTokens, kMaxTopK] |
85 | 85 | offset = alignOffset(offset, CACHELINE_ALIGNMENT); |
86 | 86 | offsets[TOPK_SEND_INDICES_OFFSET_INDEX] = offset; |
87 | | - offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK) |
| 87 | + offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::moe_comm::kMaxTopK) |
88 | 88 | * SIZEOF_INT32; |
89 | 89 |
|
90 | 90 | // payload data |
@@ -165,11 +165,11 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c |
165 | 165 | std::vector<torch::Tensor> const& inputPayloads, torch::Tensor const& workspace, torch::Tensor const& metainfo, |
166 | 166 | int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, int64_t numExperts) |
167 | 167 | { |
168 | | - using tensorrt_llm::kernels::mnnvl_throughput::PayloadDescriptor; |
169 | | - using tensorrt_llm::kernels::mnnvl_throughput::MoeA2ADispatchParams; |
170 | | - using tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_dispatch_launch; |
171 | | - using tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK; |
172 | | - using tensorrt_llm::kernels::mnnvl_throughput::kMaxPayloads; |
| 168 | + using tensorrt_llm::kernels::moe_comm::PayloadDescriptor; |
| 169 | + using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams; |
| 170 | + using tensorrt_llm::kernels::moe_comm::moe_a2a_dispatch_launch; |
| 171 | + using tensorrt_llm::kernels::moe_comm::kMaxTopK; |
| 172 | + using tensorrt_llm::kernels::moe_comm::kMaxPayloads; |
173 | 173 |
|
174 | 174 | // Validate inputs |
175 | 175 | CHECK_INPUT(tokenSelectedExperts, torch::kInt32); |
@@ -344,9 +344,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke |
344 | 344 | torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, |
345 | 345 | int64_t combinePayloadOffset, bool payloadInWorkspace) |
346 | 346 | { |
347 | | - using tensorrt_llm::kernels::mnnvl_throughput::MoeA2ACombineParams; |
348 | | - using tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_combine_launch; |
349 | | - using tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK; |
| 347 | + using tensorrt_llm::kernels::moe_comm::MoeA2ACombineParams; |
| 348 | + using tensorrt_llm::kernels::moe_comm::moe_a2a_combine_launch; |
| 349 | + using tensorrt_llm::kernels::moe_comm::kMaxTopK; |
350 | 350 |
|
351 | 351 | // Validate inputs |
352 | 352 | CHECK_TH_CUDA(payload); |
@@ -474,8 +474,8 @@ void moeA2ASanitizeExpertIdsOp(torch::Tensor& expert_ids, torch::Tensor& workspa |
474 | 474 | uint8_t* rankWorkSpacePtr = workspace.data_ptr<uint8_t>() + epRank * workspace.stride(0); |
475 | 475 | int* recv_counters = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]); |
476 | 476 |
|
477 | | - tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_sanitize_expert_ids_launch(expert_ids.data_ptr<int32_t>(), |
478 | | - recv_counters, static_cast<int32_t>(invalid_expert_id), ep_size, runtime_max_tokens_per_rank, top_k, |
| 477 | + tensorrt_llm::kernels::moe_comm::moe_a2a_sanitize_expert_ids_launch(expert_ids.data_ptr<int32_t>(), recv_counters, |
| 478 | + static_cast<int32_t>(invalid_expert_id), ep_size, runtime_max_tokens_per_rank, top_k, |
479 | 479 | at::cuda::getCurrentCUDAStream()); |
480 | 480 | } |
481 | 481 |
|
@@ -508,7 +508,7 @@ torch::Tensor moeA2AGetCombinePayloadTensorOp(torch::Tensor const& workspace, in |
508 | 508 | return t; |
509 | 509 | } |
510 | 510 |
|
511 | | -} // namespace mnnvl_throughput |
| 511 | +} // namespace moe_comm |
512 | 512 |
|
513 | 513 | } // namespace torch_ext |
514 | 514 |
|
@@ -540,9 +540,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module) |
540 | 540 |
|
541 | 541 | TORCH_LIBRARY_IMPL(trtllm, CUDA, module) |
542 | 542 | { |
543 | | - module.impl("moe_a2a_dispatch", &torch_ext::mnnvl_throughput::moeA2ADispatchOp); |
544 | | - module.impl("moe_a2a_combine", &torch_ext::mnnvl_throughput::moeA2ACombineOp); |
545 | | - module.impl("moe_a2a_initialize", &torch_ext::mnnvl_throughput::moeA2AInitializeOp); |
546 | | - module.impl("moe_a2a_sanitize_expert_ids", &torch_ext::mnnvl_throughput::moeA2ASanitizeExpertIdsOp); |
547 | | - module.impl("moe_a2a_get_combine_payload_tensor", &torch_ext::mnnvl_throughput::moeA2AGetCombinePayloadTensorOp); |
| 543 | + module.impl("moe_a2a_dispatch", &torch_ext::moe_comm::moeA2ADispatchOp); |
| 544 | + module.impl("moe_a2a_combine", &torch_ext::moe_comm::moeA2ACombineOp); |
| 545 | + module.impl("moe_a2a_initialize", &torch_ext::moe_comm::moeA2AInitializeOp); |
| 546 | + module.impl("moe_a2a_sanitize_expert_ids", &torch_ext::moe_comm::moeA2ASanitizeExpertIdsOp); |
| 547 | + module.impl("moe_a2a_get_combine_payload_tensor", &torch_ext::moe_comm::moeA2AGetCombinePayloadTensorOp); |
548 | 548 | } |
0 commit comments