Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <cstdint>
#include <type_traits>

namespace tensorrt_llm::kernels::mnnvl_throughput
namespace tensorrt_llm::kernels::moe_comm
{

#define ENABLE_DEBUG_PRINT 0
Expand Down Expand Up @@ -964,4 +964,4 @@ void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv
expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id);
}

} // namespace tensorrt_llm::kernels::mnnvl_throughput
} // namespace tensorrt_llm::kernels::moe_comm
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>

namespace tensorrt_llm::kernels::mnnvl_throughput
namespace tensorrt_llm::kernels::moe_comm
{

// Configuration constants
Expand Down Expand Up @@ -177,4 +177,4 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params);
void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, int32_t invalid_id,
int ep_size, int max_tokens_per_rank, int top_k, cudaStream_t stream);

} // namespace tensorrt_llm::kernels::mnnvl_throughput
} // namespace tensorrt_llm::kernels::moe_comm
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tensorrt_llm::nanobind::thop
void initBindings(nb::module_& m)
{
// Export MoE A2A constants
for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs())
for (auto const& kv : torch_ext::moe_comm::getMoeA2AMetaInfoIndexPairs())
{
m.attr(kv.first) = kv.second;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/pybind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tensorrt_llm::pybind::thop
void initBindings(pybind11::module_& m)
{
// Export MoE A2A constants
for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs())
for (auto const& kv : torch_ext::moe_comm::getMoeA2AMetaInfoIndexPairs())
{
m.attr(kv.first) = py::int_(kv.second);
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/thop/moeAlltoAllMeta.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

namespace torch_ext
{
namespace mnnvl_throughput
namespace moe_comm
{

// Enum for indexing into moe_a2a_metainfo tensor
Expand Down Expand Up @@ -61,5 +61,5 @@ inline std::vector<std::pair<char const*, int64_t>> getMoeA2AMetaInfoIndexPairs(
};
}

} // namespace mnnvl_throughput
} // namespace moe_comm
} // namespace torch_ext
38 changes: 19 additions & 19 deletions cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
namespace torch_ext
{

namespace mnnvl_throughput
namespace moe_comm
{

// TODO: Is Alignment necessary?
Expand Down Expand Up @@ -78,13 +78,13 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
// topk_target_ranks: [maxNumTokens, kMaxTopK]
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
offsets[TOPK_TARGET_RANKS_OFFSET_INDEX] = offset;
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK)
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::moe_comm::kMaxTopK)
* SIZEOF_INT32;

// topk_send_indices: [maxNumTokens, kMaxTopK]
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
offsets[TOPK_SEND_INDICES_OFFSET_INDEX] = offset;
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK)
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::moe_comm::kMaxTopK)
* SIZEOF_INT32;

// payload data
Expand Down Expand Up @@ -165,11 +165,11 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
std::vector<torch::Tensor> const& inputPayloads, torch::Tensor const& workspace, torch::Tensor const& metainfo,
int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, int64_t numExperts)
{
using tensorrt_llm::kernels::mnnvl_throughput::PayloadDescriptor;
using tensorrt_llm::kernels::mnnvl_throughput::MoeA2ADispatchParams;
using tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_dispatch_launch;
using tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK;
using tensorrt_llm::kernels::mnnvl_throughput::kMaxPayloads;
using tensorrt_llm::kernels::moe_comm::PayloadDescriptor;
using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams;
using tensorrt_llm::kernels::moe_comm::moe_a2a_dispatch_launch;
using tensorrt_llm::kernels::moe_comm::kMaxTopK;
using tensorrt_llm::kernels::moe_comm::kMaxPayloads;

// Validate inputs
CHECK_INPUT(tokenSelectedExperts, torch::kInt32);
Expand Down Expand Up @@ -344,9 +344,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK,
int64_t combinePayloadOffset, bool payloadInWorkspace)
{
using tensorrt_llm::kernels::mnnvl_throughput::MoeA2ACombineParams;
using tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_combine_launch;
using tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK;
using tensorrt_llm::kernels::moe_comm::MoeA2ACombineParams;
using tensorrt_llm::kernels::moe_comm::moe_a2a_combine_launch;
using tensorrt_llm::kernels::moe_comm::kMaxTopK;

// Validate inputs
CHECK_TH_CUDA(payload);
Expand Down Expand Up @@ -474,8 +474,8 @@ void moeA2ASanitizeExpertIdsOp(torch::Tensor& expert_ids, torch::Tensor& workspa
uint8_t* rankWorkSpacePtr = workspace.data_ptr<uint8_t>() + epRank * workspace.stride(0);
int* recv_counters = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]);

tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_sanitize_expert_ids_launch(expert_ids.data_ptr<int32_t>(),
recv_counters, static_cast<int32_t>(invalid_expert_id), ep_size, runtime_max_tokens_per_rank, top_k,
tensorrt_llm::kernels::moe_comm::moe_a2a_sanitize_expert_ids_launch(expert_ids.data_ptr<int32_t>(), recv_counters,
static_cast<int32_t>(invalid_expert_id), ep_size, runtime_max_tokens_per_rank, top_k,
at::cuda::getCurrentCUDAStream());
}

Expand Down Expand Up @@ -508,7 +508,7 @@ torch::Tensor moeA2AGetCombinePayloadTensorOp(torch::Tensor const& workspace, in
return t;
}

} // namespace mnnvl_throughput
} // namespace moe_comm

} // namespace torch_ext

Expand Down Expand Up @@ -540,9 +540,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module)

TORCH_LIBRARY_IMPL(trtllm, CUDA, module)
{
module.impl("moe_a2a_dispatch", &torch_ext::mnnvl_throughput::moeA2ADispatchOp);
module.impl("moe_a2a_combine", &torch_ext::mnnvl_throughput::moeA2ACombineOp);
module.impl("moe_a2a_initialize", &torch_ext::mnnvl_throughput::moeA2AInitializeOp);
module.impl("moe_a2a_sanitize_expert_ids", &torch_ext::mnnvl_throughput::moeA2ASanitizeExpertIdsOp);
module.impl("moe_a2a_get_combine_payload_tensor", &torch_ext::mnnvl_throughput::moeA2AGetCombinePayloadTensorOp);
module.impl("moe_a2a_dispatch", &torch_ext::moe_comm::moeA2ADispatchOp);
module.impl("moe_a2a_combine", &torch_ext::moe_comm::moeA2ACombineOp);
module.impl("moe_a2a_initialize", &torch_ext::moe_comm::moeA2AInitializeOp);
module.impl("moe_a2a_sanitize_expert_ids", &torch_ext::moe_comm::moeA2ASanitizeExpertIdsOp);
module.impl("moe_a2a_get_combine_payload_tensor", &torch_ext::moe_comm::moeA2AGetCombinePayloadTensorOp);
}
20 changes: 10 additions & 10 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ def __init__(
self.use_low_precision_combine = model_config.use_low_precision_moe_combine

if self.alltoall_method_type == AlltoallMethodType.MNNVL:
if self.moe_alltoall_backend == "mnnvllatency":
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
elif self.moe_alltoall_backend == "mnnvlthroughput":
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
workspace_mb = int(
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
self.moe_a2a = MoeAlltoAll(
Expand Down Expand Up @@ -253,9 +253,9 @@ def enable_alltoall(self):

@cached_property
def moe_alltoall_backend(self):
# "mnnvlthroughput" (default) or "mnnvllatency"
# "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED"
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
"mnnvlthroughput").strip().lower()
"NVLINK_ONE_SIDED").strip().upper()

def _supports_load_balancer(self) -> bool:
"""CutlassFusedMoE supports load balancer."""
Expand Down Expand Up @@ -328,7 +328,7 @@ def forward_chunk(

if self.layer_load_balancer:
self._load_balancer_done_wait_gpu_stage(is_first_call)
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency"
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
self._load_balancer_update_statistic(
token_selected_experts,
is_first_call,
Expand Down Expand Up @@ -439,7 +439,7 @@ def forward_chunk(
token_final_scales = torch.ones_like(token_selected_slots,
dtype=torch.float32)

if self.moe_alltoall_backend == "mnnvllatency":
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
if is_last_call:
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
Expand Down Expand Up @@ -472,7 +472,7 @@ def forward_chunk(
token_selected_slots, alltoall_info.recv_rank_count_cumsum,
runtime_max_tokens_per_rank, top_k, self.num_slots,
self.ep_size)
elif self.moe_alltoall_backend == "mnnvlthroughput":
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
# Python MoeAlltoAll path
if x_sf is not None:
x_sf = x_sf.view(x_row,
Expand Down Expand Up @@ -532,7 +532,7 @@ def forward_chunk(

# Optionally provide an output tensor to fused_moe so it writes directly to our buffer
moe_output: Optional[torch.Tensor] = None
if self.enable_alltoall and self.moe_alltoall_backend == "mnnvlthroughput":
if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
# Retrieve a workspace-backed output tensor sized by runtime tokens
runtime_max_tokens_per_rank = max(
all_rank_num_tokens) if all_rank_num_tokens else x.shape[0]
Expand Down Expand Up @@ -583,7 +583,7 @@ def forward_chunk(

# Combine results if using alltoall
if self.enable_alltoall:
if self.moe_alltoall_backend == "mnnvllatency":
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
if alltoall_info is not None:
top_k = self.routing_method.experts_per_token
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
Expand All @@ -596,7 +596,7 @@ def forward_chunk(
use_low_precision_combine=self.
use_low_precision_combine,
token_count=token_count)
elif self.moe_alltoall_backend == "mnnvlthroughput":
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
output_hidden_size = final_hidden_states.shape[-1]
runtime_max_tokens_per_rank = max(
all_rank_num_tokens) if all_rank_num_tokens else token_count
Expand Down
20 changes: 10 additions & 10 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def __init__(
self.use_low_precision_combine = model_config.use_low_precision_moe_combine

if self.alltoall_method_type == AlltoallMethodType.MNNVL:
if self.moe_alltoall_backend == "mnnvllatency":
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
elif self.moe_alltoall_backend == "mnnvlthroughput":
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
workspace_mb = int(
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
self.moe_a2a = MoeAlltoAll(
Expand Down Expand Up @@ -198,9 +198,9 @@ def enable_alltoall(self):

@cached_property
def moe_alltoall_backend(self):
# "mnnvlthroughput" (default) or "mnnvllatency"
# "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED"
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
"mnnvlthroughput").strip().lower()
"NVLINK_ONE_SIDED").strip().upper()

def _check_configs(self):
assert self.has_deepseek_fp8_block_scales \
Expand Down Expand Up @@ -362,7 +362,7 @@ def forward_impl(

self._load_balancer_done_wait_gpu_stage(is_first_call)

ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency"
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
self._load_balancer_update_statistic(
token_selected_experts,
is_first_call,
Expand Down Expand Up @@ -394,7 +394,7 @@ def forward_impl(
else:
token_final_scales = token_final_scales.to(torch.float32)

if self.moe_alltoall_backend == "mnnvllatency":
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
if is_last_call:
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
Expand Down Expand Up @@ -444,7 +444,7 @@ def forward_impl(

if token_final_scales is not None:
token_final_scales = token_final_scales.to(torch.bfloat16)
elif self.moe_alltoall_backend == "mnnvlthroughput":
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
if x_sf is not None:
x_sf = x_sf.view(x_row,
ceil_div(x_col, self.scaling_vector_size))
Expand Down Expand Up @@ -510,7 +510,7 @@ def forward_impl(
moe_output: Optional[torch.Tensor] = None
use_workspace_output = False
# TODO: use_workspace_output only supports w4a8_mxfp4_mxfp8 (gpt-oss) for now
if self.enable_alltoall and self.moe_alltoall_backend == "mnnvlthroughput" and self.has_w4a8_mxfp4_mxfp8:
if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED" and self.has_w4a8_mxfp4_mxfp8:
moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace(
runtime_max_tokens_per_rank, self.hidden_size, torch.bfloat16)
use_workspace_output = True
Expand Down Expand Up @@ -774,7 +774,7 @@ def forward_impl(

# Combine results if using alltoall
if self.enable_alltoall:
if self.moe_alltoall_backend == "mnnvllatency":
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
if alltoall_info is not None:
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
final_hidden_states,
Expand All @@ -787,7 +787,7 @@ def forward_impl(
use_low_precision_combine,
token_count=token_count,
)
elif self.moe_alltoall_backend == "mnnvlthroughput":
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
# If use_workspace_output=True, the MoE result is already in workspace
# Otherwise, we need to reshape and pass it
if use_workspace_output:
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ def enable_alltoall(self):

@cached_property
def moe_alltoall_backend(self):
# "mnnvllatency" (default) or "mnnvlthroughput"
# "NVLINK_TWO_SIDED" (default) or "NVLINK_ONE_SIDED"
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
"mnnvllatency").strip().lower()
"NVLINK_TWO_SIDED").strip().upper()

def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
num_rows = sum(all_rank_num_tokens)
Expand Down Expand Up @@ -436,7 +436,7 @@ def forward_chunk(

if self.layer_load_balancer:
self._load_balancer_done_wait_gpu_stage(is_first_call)
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency"
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
self._load_balancer_update_statistic(token_selected_experts,
is_first_call, is_last_call,
ignore_allreduce)
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/modules/fused_moe/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def _load_balancer_update_statistic(self,
token_selected_experts: The selected experts of all tokens, has shape of [tokenCount * topK]
is_first_call: Whether this is the first call for the same weights
is_last_call: Whether this is the last call for the same weights
ignore_allreduce: Whether to ignore allreduce, if True, only update local statistics, need call _load_balancer_get_local_statistic_tensor to get the local statistic tensor and then do external allgather and then call _load_balancer_update_statistic_with_gathered_statistic to update the global statistics. MnnvlLatency supports this.
ignore_allreduce: Whether to ignore allreduce, if True, only update local statistics, need call _load_balancer_get_local_statistic_tensor to get the local statistic tensor and then do external allgather and then call _load_balancer_update_statistic_with_gathered_statistic to update the global statistics. NVLINK_TWO_SIDED supports this.
"""
if self._using_dynamic_load_balancer():
if ignore_allreduce:
Expand Down
Loading