Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cpp/tensorrt_llm/kernels/mlaKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
int* seqQOffset, uint32_t* fmha_tile_counter, int32_t const* kv_cache_lengths, int* seqKVOffsets, int q_pe_ld,
int q_pe_stride, KvCacheDataType cache_type, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o,
float const* quant_scale_q, float const* quant_scale_kv, float const* dequant_scale_q,
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets)
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets,
bool const* helix_is_inactive_rank)
{

// Constants.
Expand Down Expand Up @@ -514,7 +515,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
auto local_token_idx = global_token_idx % seq_len;
bool valid_token = global_token_idx < total_s_len;

if (valid_token)
if (valid_token && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx]))
{
if (head_dim_vec_idx == 0)
{
Expand Down Expand Up @@ -1047,7 +1048,7 @@ void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer
params.seqQOffset, params.fmha_tile_counter, params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld,
params.q_pe_stride, params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o,
params.quant_scale_q, params.quant_scale_kv, params.dequant_scale_q, params.dequant_scale_kv,
params.host_bmm1_scale, params.helix_position_offsets);
params.host_bmm1_scale, params.helix_position_offsets, params.helix_is_inactive_rank);
}

template <typename T, typename TCache>
Expand Down
4 changes: 4 additions & 0 deletions cpp/tensorrt_llm/kernels/mlaKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ struct MlaParams

// for Helix parallelism: the rotary position offsets [b]
int32_t const* helix_position_offsets{nullptr};

// for Helix parallelism: whether the current rank is inactive, shape [b]
// (the current query tokens are not appended to this rank's KV cache)
bool const* helix_is_inactive_rank{nullptr};
};

template <typename T, typename KVCacheBuffer>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,12 @@ void TrtllmGenGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k)
return optionsA.mNumSlicesForSplitK > optionsB.mNumSlicesForSplitK;
}

// then by tileN, if N is large enough
if (gemmData.mProblemDimensions.mN > 256 && optionsA.mTileN != optionsB.mTileN)
{
return optionsA.mTileN > optionsB.mTileN;
}

return true;
});

Expand Down
32 changes: 18 additions & 14 deletions examples/disaggregated/slurm/benchmark/disaggr_torch.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,24 @@ for i in $(seq 0 $((num_gen_servers - 1))); do
&> ${full_logdir}/output_gen_${i}.log &
done

# start the ctx workers
echo "Starting ctx workers..."
for i in $(seq 0 $((num_ctx_servers - 1))); do
srun -l -N ${ctx_nodes_num_in_single_server} \
--ntasks=$((ctx_world_size)) \
--ntasks-per-node=${gpus_per_node} \
--container-image=${container_image} \
--container-name=${container_name} \
--container-mounts=${container_mount} \
--mpi=pmix \
bash ${work_dir}/start_worker.sh \
"CTX" ${i} ${model_path} "8336" "${benchmark_mode}" "${concurrency_list}" "${enable_pdl}" "${numa_bind}" "${full_logdir}" "${nsys_on}" "${ctx_config_path}" \
&> ${full_logdir}/output_ctx_${i}.log &
done
# start the ctx workers (skip if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set).
if [ "${TRTLLM_DISAGG_BENCHMARK_GEN_ONLY:-0}" != "1" ]; then
echo "Starting ctx workers..."
for i in $(seq 0 $((num_ctx_servers - 1))); do
srun -l -N ${ctx_nodes_num_in_single_server} \
--ntasks=${ctx_tp_size} \
--ntasks-per-node=${gpus_per_node} \
--container-image=${container_image} \
--container-name=${container_name} \
--container-mounts=${container_mount} \
--mpi=pmix \
bash ${work_dir}/start_worker.sh \
"CTX" ${i} ${model_path} "8336" "${benchmark_mode}" "${concurrency_list}" "${enable_pdl}" "${numa_bind}" "${full_logdir}" "${nsys_on}" "${ctx_config_path}" \
&> ${full_logdir}/output_ctx_${i}.log &
done
else
echo "Skipping context workers startup (TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set)"
fi

# start the server
echo "Starting server..."
Expand Down
18 changes: 13 additions & 5 deletions examples/disaggregated/slurm/benchmark/gen_server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@
time.sleep(10)
print(f"Waiting for hostnames folder {hostnames_folder} to be found")
hostnames = os.listdir(hostnames_folder)
# check length of hostnames is equal to num_ctx_servers + num_gen_servers, if not, sleep 10 seconds and check again
while len(hostnames) != args.num_ctx_servers + args.num_gen_servers:

# Skip context servers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set
gen_only = os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1"
expected_hostnames = args.num_gen_servers if gen_only else args.num_ctx_servers + args.num_gen_servers

# check length of hostnames is equal to expected count, if not, sleep 10 seconds and check again
while len(hostnames) != expected_hostnames:
time.sleep(10)
hostnames = os.listdir(hostnames_folder)
print(
f"Waiting for hostnames to be found in {hostnames_folder}, current length: {len(hostnames)}, expected length: {args.num_ctx_servers + args.num_gen_servers}"
f"Waiting for hostnames to be found in {hostnames_folder}, current length: {len(hostnames)}, expected length: {expected_hostnames}"
)
print(f"All hostnames found in {hostnames_folder}")

Expand All @@ -69,13 +74,16 @@
hostname = socket.gethostname()
print(f"Current hostname: {hostname}")

# Skip context servers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set
gen_only = os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1"

server_config = {
'hostname': hostname,
'port': args.server_port,
'backend': 'pytorch',
'context_servers': {
'num_instances': args.num_ctx_servers,
'urls': [f'{host}:{args.worker_port}' for host in ctx_hostnames]
'num_instances': 0 if gen_only else args.num_ctx_servers,
'urls': [] if gen_only else [f'{host}:{args.worker_port}' for host in ctx_hostnames]
},
'generation_servers': {
'num_instances': args.num_gen_servers,
Expand Down
37 changes: 34 additions & 3 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def plan(
mrope_config: Optional[dict] = None,
softmax_stats_tensor: Optional[torch.Tensor] = None,
helix_position_offsets: Optional[torch.Tensor] = None,
helix_is_inactive_rank: Optional[torch.Tensor] = None,
is_spec_decoding_enabled: bool = False,
use_spec_decoding: bool = False,
is_spec_dec_tree: bool = False,
Expand Down Expand Up @@ -235,6 +236,7 @@ def plan(
mrope_config (dict): The dictionary containing the mRope configuration.
softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum)
helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU.
helix_is_inactive_rank (torch.Tensor): The tensor to store the inactive helix rank flags, with shape (num_requests) on GPU.
attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU.
chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens.
sparse_kv_indices (torch.Tensor): The sparse indices for the KV cache, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
Expand Down Expand Up @@ -278,6 +280,10 @@ def plan(
self.block_ids_per_seq = block_ids_per_seq
self.softmax_stats_tensor = softmax_stats_tensor
self.helix_position_offsets = helix_position_offsets
self.helix_is_inactive_rank = helix_is_inactive_rank
if self.helix_is_inactive_rank is not None:
self.helix_is_inactive_rank = torch.tensor(
self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True)
self.attention_sinks = attention_sinks
self.sparse_kv_indices = sparse_kv_indices
self.sparse_kv_offsets = sparse_kv_offsets
Expand Down Expand Up @@ -447,7 +453,7 @@ def run(
self.spec_decoding_generation_lengths,
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
]
mla_tensor_params = [self.helix_position_offsets]
mla_tensor_params = [self.helix_position_offsets, self.helix_is_inactive_rank]

thop.attention(
q,
Expand Down Expand Up @@ -596,6 +602,13 @@ class TrtllmAttentionMetadata(AttentionMetadata):
spec_decoding_packed_mask: Optional[torch.Tensor] = None
spec_decoding_generation_lengths: Optional[torch.Tensor] = None

# Whether the current rank is inactive for helix parallelism.
# In helix parallelism, only the active rank appends KV cache for the query token
# and attends to the previously cached tokens as well as the query token. Inactive
# ranks do not append KV cache for the query token and attend to the previously
# cached tokens only.
helix_is_inactive_rank: Optional[torch.Tensor] = None

@property
def max_seq_len(self) -> int:
"""
Expand Down Expand Up @@ -817,8 +830,24 @@ def prepare(self) -> None:

if self.enable_flash_mla:
self.prepare_flash_mla()
# number of tokens needed in the kv cache for each sequence after the next pass
kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv
# kv_lens is the number of tokens needed in the kv cache for each sequence after the next pass.
if self.helix_is_inactive_rank is not None and len(
self.helix_is_inactive_rank):
# If helix is inactive, attend to the previously cached tokens only.
# This gets further complicated with multiple requests as each request might
# have a different active helix rank.
assert cached_token_lens is not None, "cached_token_lens should be set for helix"
kv_lens = cached_token_lens
helix_is_inactive_rank_cpu = torch.tensor(
self.helix_is_inactive_rank,
dtype=torch.bool,
device='cpu',
)
active_rank = ~helix_is_inactive_rank_cpu
kv_lens[active_rank] += self.seq_lens_kv[active_rank]
else:
kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv

# self.kv_lens is the valid kv cache length, while the self.kv_lens_cuda is
# the sequence length including the cached tokens and the input tokens.
self.kv_lens[:self.num_seqs].copy_(
Expand Down Expand Up @@ -1270,6 +1299,7 @@ def forward(
softmax_stats_tensor: Optional[torch.Tensor] = None,
helix_position_offsets: Optional[torch.Tensor] = None,
enable_attn_nvfp4_output: bool = True,
helix_position_offsets: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
output_sf: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1348,6 +1378,7 @@ def forward(
mrope_config=mrope_config,
softmax_stats_tensor=softmax_stats_tensor,
helix_position_offsets=helix_position_offsets,
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
is_spec_decoding_enabled=metadata.is_spec_decoding_enabled,
use_spec_decoding=metadata.use_spec_decoding,
is_spec_dec_tree=metadata.is_spec_dec_tree,
Expand Down
45 changes: 45 additions & 0 deletions tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,51 @@ def reducescatter(input, sizes, group):
def _(input, sizes, group, process_group):
return reducescatter(input, sizes, group)

@torch.library.register_fake("trtllm::fp4_block_scale_moe_runner")
def _(
routing_logits,
routing_bias,
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
output1_scale_scalar,
output1_scale_gate_scalar,
output2_scale_scalar,
num_experts,
top_k,
n_group,
topk_group,
intermediate_size,
local_expert_offset,
local_num_experts,
routed_scaling_factor,
tile_tokens_dim,
routing_method_type,
do_finalize,
) -> List[torch.Tensor]:
num_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[1] * 2
if do_finalize:
return [
hidden_states.new_empty((num_tokens, hidden_size),
dtype=torch.bfloat16)
]

expanded_row_count = num_tokens * top_k
max_padding_required = (tile_tokens_dim - 1) * num_experts
max_num_padded_tokens = fp4_utils.pad_up(
expanded_row_count + max_padding_required, tile_tokens_dim)
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
return [
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
dtype=torch.bfloat16),
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
]

@torch.library.register_fake("trtllm::block_scale_interleave")
def _(sf: torch.Tensor):
rows = sf.shape[-2]
Expand Down
29 changes: 28 additions & 1 deletion tensorrt_llm/_torch/distributed/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,36 @@ class MPIDist(Distributed):

def __init__(self, mapping: Mapping):
super().__init__(mapping)
self.create_cp_comm()

# For helix, repurpose CP ranks to TP so that the right comms are created.
mapping_with_helix = None
if self.mapping.has_cp_helix():
logger.info(f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.")
mapping_with_helix = copy.deepcopy(self.mapping)
mapping_without_helix = Mapping(
world_size=self.mapping.world_size,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
cp_size=1,
cp_config={},
tp_size=self.mapping.tp_size * self.mapping.cp_size,
pp_size=self.mapping.pp_size,
moe_cluster_size=self.mapping.moe_cluster_size,
moe_tp_size=self.mapping.moe_tp_size,
moe_ep_size=self.mapping.moe_ep_size,
attn_tp_size=self.mapping.attn_tp_size,
attn_cp_size=self.mapping.attn_cp_size,
enable_attention_dp=self.mapping.enable_attention_dp,
enable_lm_head_tp_in_adp=self.mapping.enable_lm_head_tp_in_adp)
self.mapping = mapping_without_helix
self.create_tp_comm()
self.create_pp_comm()
self.create_cp_comm()

# Restore the original mapping if it was rearranged for helix.
if mapping_with_helix is not None:
logger.info(f"[MPIDist::__init__] Restoring original mapping.")
self.mapping = mapping_with_helix

def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
comm = mpi_comm()
Expand Down
Loading
Loading