From d2d3f35ea559984f8ef8dd85778c2cc230c25261 Mon Sep 17 00:00:00 2001 From: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> Date: Mon, 3 Nov 2025 09:29:38 -0800 Subject: [PATCH] save initial changes to test_disaggregated.py --- cpp/tensorrt_llm/kernels/mlaKernels.cu | 7 +- cpp/tensorrt_llm/kernels/mlaKernels.h | 4 + .../trtllmGenKernels/gemm/KernelRunner.cpp | 6 ++ .../slurm/benchmark/disaggr_torch.slurm | 32 ++++---- .../slurm/benchmark/gen_server_config.py | 18 +++-- .../_torch/attention_backend/trtllm.py | 37 +++++++++- .../_torch/custom_ops/cpp_custom_ops.py | 45 ++++++++++++ .../_torch/distributed/communicator.py | 29 +++++++- .../_torch/models/modeling_deepseekv3.py | 73 +++++++++++++++++-- tensorrt_llm/_torch/modules/attention.py | 9 ++- tensorrt_llm/_torch/modules/gated_mlp.py | 5 +- tensorrt_llm/_torch/pyexecutor/llm_request.py | 1 + .../_torch/pyexecutor/model_engine.py | 34 ++++++--- tensorrt_llm/_torch/pyexecutor/py_executor.py | 3 +- .../_torch/pyexecutor/resource_manager.py | 16 ++++ tensorrt_llm/llmapi/disagg_utils.py | 2 +- tensorrt_llm/mapping.py | 5 ++ tensorrt_llm/serve/openai_disagg_server.py | 10 ++- ...tp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml | 36 +++++++++ ...ntp1cp2_deepseek_v3_lite_fp8_tllm_gen.yaml | 40 ++++++++++ .../defs/disaggregated/test_disaggregated.py | 50 +++++++++++++ 21 files changed, 414 insertions(+), 48 deletions(-) create mode 100644 tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml create mode 100644 tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_fp8_tllm_gen.yaml diff --git a/cpp/tensorrt_llm/kernels/mlaKernels.cu b/cpp/tensorrt_llm/kernels/mlaKernels.cu index 2897352639e..94c85bc7b2a 100644 --- a/cpp/tensorrt_llm/kernels/mlaKernels.cu +++ b/cpp/tensorrt_llm/kernels/mlaKernels.cu @@ -351,7 +351,8 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe, int* seqQOffset, uint32_t* fmha_tile_counter, int32_t const* kv_cache_lengths, int* seqKVOffsets, int q_pe_ld, int q_pe_stride, KvCacheDataType cache_type, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o, float const* quant_scale_q, float const* quant_scale_kv, float const* dequant_scale_q, - float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets) + float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets, + bool const* helix_is_inactive_rank) { // Constants. @@ -514,7 +515,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe, auto local_token_idx = global_token_idx % seq_len; bool valid_token = global_token_idx < total_s_len; - if (valid_token) + if (valid_token && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx])) { if (head_dim_vec_idx == 0) { @@ -1047,7 +1048,7 @@ void invokeMLARopeGeneration(MlaParams& params, KVCacheBuffer kv_cache_buffer params.seqQOffset, params.fmha_tile_counter, params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld, params.q_pe_stride, params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.quant_scale_q, params.quant_scale_kv, params.dequant_scale_q, params.dequant_scale_kv, - params.host_bmm1_scale, params.helix_position_offsets); + params.host_bmm1_scale, params.helix_position_offsets, params.helix_is_inactive_rank); } template diff --git a/cpp/tensorrt_llm/kernels/mlaKernels.h b/cpp/tensorrt_llm/kernels/mlaKernels.h index 1775b992cc7..ce6f4b1bfa0 100644 --- a/cpp/tensorrt_llm/kernels/mlaKernels.h +++ b/cpp/tensorrt_llm/kernels/mlaKernels.h @@ -107,6 +107,10 @@ struct MlaParams // for Helix parallelism: the rotary position offsets [b] int32_t const* helix_position_offsets{nullptr}; + + // for Helix parallelism: whether the current rank is inactive, shape [b] + // (the current query tokens are not appended to this rank's KV cache) + bool const* helix_is_inactive_rank{nullptr}; }; template diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp b/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp index 726a2aea7ea..479225dca69 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp @@ -213,6 +213,12 @@ void TrtllmGenGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k) return optionsA.mNumSlicesForSplitK > optionsB.mNumSlicesForSplitK; } + // then by tileN, if N is large enough + if (gemmData.mProblemDimensions.mN > 256 && optionsA.mTileN != optionsB.mTileN) + { + return optionsA.mTileN > optionsB.mTileN; + } + return true; }); diff --git a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm index 21b8eb2ff51..9b6d913a99a 100644 --- a/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm +++ b/examples/disaggregated/slurm/benchmark/disaggr_torch.slurm @@ -178,20 +178,24 @@ for i in $(seq 0 $((num_gen_servers - 1))); do &> ${full_logdir}/output_gen_${i}.log & done -# start the ctx workers -echo "Starting ctx workers..." -for i in $(seq 0 $((num_ctx_servers - 1))); do - srun -l -N ${ctx_nodes_num_in_single_server} \ - --ntasks=$((ctx_world_size)) \ - --ntasks-per-node=${gpus_per_node} \ - --container-image=${container_image} \ - --container-name=${container_name} \ - --container-mounts=${container_mount} \ - --mpi=pmix \ - bash ${work_dir}/start_worker.sh \ - "CTX" ${i} ${model_path} "8336" "${benchmark_mode}" "${concurrency_list}" "${enable_pdl}" "${numa_bind}" "${full_logdir}" "${nsys_on}" "${ctx_config_path}" \ - &> ${full_logdir}/output_ctx_${i}.log & -done +# start the ctx workers (skip if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set). +if [ "${TRTLLM_DISAGG_BENCHMARK_GEN_ONLY:-0}" != "1" ]; then + echo "Starting ctx workers..." + for i in $(seq 0 $((num_ctx_servers - 1))); do + srun -l -N ${ctx_nodes_num_in_single_server} \ + --ntasks=${ctx_tp_size} \ + --ntasks-per-node=${gpus_per_node} \ + --container-image=${container_image} \ + --container-name=${container_name} \ + --container-mounts=${container_mount} \ + --mpi=pmix \ + bash ${work_dir}/start_worker.sh \ + "CTX" ${i} ${model_path} "8336" "${benchmark_mode}" "${concurrency_list}" "${enable_pdl}" "${numa_bind}" "${full_logdir}" "${nsys_on}" "${ctx_config_path}" \ + &> ${full_logdir}/output_ctx_${i}.log & + done +else + echo "Skipping context workers startup (TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set)" +fi # start the server echo "Starting server..." diff --git a/examples/disaggregated/slurm/benchmark/gen_server_config.py b/examples/disaggregated/slurm/benchmark/gen_server_config.py index c427f5d42b4..1901f118431 100644 --- a/examples/disaggregated/slurm/benchmark/gen_server_config.py +++ b/examples/disaggregated/slurm/benchmark/gen_server_config.py @@ -39,12 +39,17 @@ time.sleep(10) print(f"Waiting for hostnames folder {hostnames_folder} to be found") hostnames = os.listdir(hostnames_folder) - # check length of hostnames is equal to num_ctx_servers + num_gen_servers, if not, sleep 10 seconds and check again - while len(hostnames) != args.num_ctx_servers + args.num_gen_servers: + + # Skip context servers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set + gen_only = os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1" + expected_hostnames = args.num_gen_servers if gen_only else args.num_ctx_servers + args.num_gen_servers + + # check length of hostnames is equal to expected count, if not, sleep 10 seconds and check again + while len(hostnames) != expected_hostnames: time.sleep(10) hostnames = os.listdir(hostnames_folder) print( - f"Waiting for hostnames to be found in {hostnames_folder}, current length: {len(hostnames)}, expected length: {args.num_ctx_servers + args.num_gen_servers}" + f"Waiting for hostnames to be found in {hostnames_folder}, current length: {len(hostnames)}, expected length: {expected_hostnames}" ) print(f"All hostnames found in {hostnames_folder}") @@ -69,13 +74,16 @@ hostname = socket.gethostname() print(f"Current hostname: {hostname}") + # Skip context servers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set + gen_only = os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1" + server_config = { 'hostname': hostname, 'port': args.server_port, 'backend': 'pytorch', 'context_servers': { - 'num_instances': args.num_ctx_servers, - 'urls': [f'{host}:{args.worker_port}' for host in ctx_hostnames] + 'num_instances': 0 if gen_only else args.num_ctx_servers, + 'urls': [] if gen_only else [f'{host}:{args.worker_port}' for host in ctx_hostnames] }, 'generation_servers': { 'num_instances': args.num_gen_servers, diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 45b4d4131af..5c1984c2c14 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -185,6 +185,7 @@ def plan( mrope_config: Optional[dict] = None, softmax_stats_tensor: Optional[torch.Tensor] = None, helix_position_offsets: Optional[torch.Tensor] = None, + helix_is_inactive_rank: Optional[torch.Tensor] = None, is_spec_decoding_enabled: bool = False, use_spec_decoding: bool = False, is_spec_dec_tree: bool = False, @@ -235,6 +236,7 @@ def plan( mrope_config (dict): The dictionary containing the mRope configuration. softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum) helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU. + helix_is_inactive_rank (torch.Tensor): The tensor to store the inactive helix rank flags, with shape (num_requests) on GPU. attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU. chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens. sparse_kv_indices (torch.Tensor): The sparse indices for the KV cache, with shape of (num_heads_kv, num_sparse_tokens) on GPU. @@ -278,6 +280,10 @@ def plan( self.block_ids_per_seq = block_ids_per_seq self.softmax_stats_tensor = softmax_stats_tensor self.helix_position_offsets = helix_position_offsets + self.helix_is_inactive_rank = helix_is_inactive_rank + if self.helix_is_inactive_rank is not None: + self.helix_is_inactive_rank = torch.tensor( + self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True) self.attention_sinks = attention_sinks self.sparse_kv_indices = sparse_kv_indices self.sparse_kv_offsets = sparse_kv_offsets @@ -447,7 +453,7 @@ def run( self.spec_decoding_generation_lengths, self.spec_decoding_position_offsets, self.spec_decoding_packed_mask ] - mla_tensor_params = [self.helix_position_offsets] + mla_tensor_params = [self.helix_position_offsets, self.helix_is_inactive_rank] thop.attention( q, @@ -596,6 +602,13 @@ class TrtllmAttentionMetadata(AttentionMetadata): spec_decoding_packed_mask: Optional[torch.Tensor] = None spec_decoding_generation_lengths: Optional[torch.Tensor] = None + # Whether the current rank is inactive for helix parallelism. + # In helix parallelism, only the active rank appends KV cache for the query token + # and attends to the previously cached tokens as well as the query token. Inactive + # ranks do not append KV cache for the query token and attend to the previously + # cached tokens only. + helix_is_inactive_rank: Optional[torch.Tensor] = None + @property def max_seq_len(self) -> int: """ @@ -817,8 +830,24 @@ def prepare(self) -> None: if self.enable_flash_mla: self.prepare_flash_mla() - # number of tokens needed in the kv cache for each sequence after the next pass - kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv + # kv_lens is the number of tokens needed in the kv cache for each sequence after the next pass. + if self.helix_is_inactive_rank is not None and len( + self.helix_is_inactive_rank): + # If helix is inactive, attend to the previously cached tokens only. + # This gets further complicated with multiple requests as each request might + # have a different active helix rank. + assert cached_token_lens is not None, "cached_token_lens should be set for helix" + kv_lens = cached_token_lens + helix_is_inactive_rank_cpu = torch.tensor( + self.helix_is_inactive_rank, + dtype=torch.bool, + device='cpu', + ) + active_rank = ~helix_is_inactive_rank_cpu + kv_lens[active_rank] += self.seq_lens_kv[active_rank] + else: + kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv + # self.kv_lens is the valid kv cache length, while the self.kv_lens_cuda is # the sequence length including the cached tokens and the input tokens. self.kv_lens[:self.num_seqs].copy_( @@ -1270,6 +1299,7 @@ def forward( softmax_stats_tensor: Optional[torch.Tensor] = None, helix_position_offsets: Optional[torch.Tensor] = None, enable_attn_nvfp4_output: bool = True, + helix_position_offsets: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, output_sf: Optional[torch.Tensor] = None, attention_sinks: Optional[torch.Tensor] = None, @@ -1348,6 +1378,7 @@ def forward( mrope_config=mrope_config, softmax_stats_tensor=softmax_stats_tensor, helix_position_offsets=helix_position_offsets, + helix_is_inactive_rank=metadata.helix_is_inactive_rank, is_spec_decoding_enabled=metadata.is_spec_decoding_enabled, use_spec_decoding=metadata.use_spec_decoding, is_spec_dec_tree=metadata.is_spec_dec_tree, diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 8f4c23cfa68..7710ac67b27 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -510,6 +510,51 @@ def reducescatter(input, sizes, group): def _(input, sizes, group, process_group): return reducescatter(input, sizes, group) + @torch.library.register_fake("trtllm::fp4_block_scale_moe_runner") + def _( + routing_logits, + routing_bias, + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm2_weights, + gemm2_weights_scale, + output1_scale_scalar, + output1_scale_gate_scalar, + output2_scale_scalar, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + tile_tokens_dim, + routing_method_type, + do_finalize, + ) -> List[torch.Tensor]: + num_tokens = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] * 2 + if do_finalize: + return [ + hidden_states.new_empty((num_tokens, hidden_size), + dtype=torch.bfloat16) + ] + + expanded_row_count = num_tokens * top_k + max_padding_required = (tile_tokens_dim - 1) * num_experts + max_num_padded_tokens = fp4_utils.pad_up( + expanded_row_count + max_padding_required, tile_tokens_dim) + wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16 + return [ + hidden_states.new_empty((max_num_padded_tokens, hidden_size), + dtype=torch.bfloat16), + hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype), + hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32) + ] + @torch.library.register_fake("trtllm::block_scale_interleave") def _(sf: torch.Tensor): rows = sf.shape[-2] diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 07f2b4227f9..2a885769166 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -341,9 +341,36 @@ class MPIDist(Distributed): def __init__(self, mapping: Mapping): super().__init__(mapping) + self.create_cp_comm() + + # For helix, repurpose CP ranks to TP so that the right comms are created. + mapping_with_helix = None + if self.mapping.has_cp_helix(): + logger.info(f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.") + mapping_with_helix = copy.deepcopy(self.mapping) + mapping_without_helix = Mapping( + world_size=self.mapping.world_size, + rank=self.mapping.rank, + gpus_per_node=self.mapping.gpus_per_node, + cp_size=1, + cp_config={}, + tp_size=self.mapping.tp_size * self.mapping.cp_size, + pp_size=self.mapping.pp_size, + moe_cluster_size=self.mapping.moe_cluster_size, + moe_tp_size=self.mapping.moe_tp_size, + moe_ep_size=self.mapping.moe_ep_size, + attn_tp_size=self.mapping.attn_tp_size, + attn_cp_size=self.mapping.attn_cp_size, + enable_attention_dp=self.mapping.enable_attention_dp, + enable_lm_head_tp_in_adp=self.mapping.enable_lm_head_tp_in_adp) + self.mapping = mapping_without_helix self.create_tp_comm() self.create_pp_comm() - self.create_cp_comm() + + # Restore the original mapping if it was rearranged for helix. + if mapping_with_helix is not None: + logger.info(f"[MPIDist::__init__] Restoring original mapping.") + self.mapping = mapping_with_helix def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024): comm = mpi_comm() diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 4d3d50abb33..7d9dac33049 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -258,6 +258,14 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, [local_num_heads, local_qk_nope_head_dim, -1]) v_b_proj = v_b_proj.view([local_num_heads, local_v_head_dim, -1]) + # TODO: Verify if we really need this given the repurposing of CP ranks to TP. + if cp_size > 1: + local_cp_heads = local_num_heads // cp_size + k_b_proj = k_b_proj[cp_rank * local_cp_heads:(cp_rank + 1) * + local_cp_heads] + v_b_proj = v_b_proj[cp_rank * local_cp_heads:(cp_rank + 1) * + local_cp_heads] + return k_b_proj, v_b_proj is_lite = self.config.q_lora_rank is None @@ -268,6 +276,8 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, tp_rank = self.model_config.mapping.tp_rank tp_size = self.model_config.mapping.tp_size + cp_rank = self.model_config.mapping.cp_rank + cp_size = self.model_config.mapping.cp_size params_map = {'gate_up_proj': ['gate_proj', 'up_proj']} all_named_modules = dict(self.model.named_modules()) @@ -509,6 +519,7 @@ def __init__( model_config: ModelConfig[PretrainedConfig], layer_idx: Optional[int] = None, aux_stream: Optional[torch.cuda.Stream] = None, + mapping_with_cp: Optional[Mapping] = None, ): config = model_config.pretrained_config predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1 @@ -531,7 +542,10 @@ def __init__( layer_idx=layer_idx, dtype=config.torch_dtype, config=model_config, - aux_stream=aux_stream) + aux_stream=aux_stream, + mapping_with_cp=mapping_with_cp) + # @B: Does this layer need to know about mapping_with_cp? + # Likely no because no use of mapping. self.kv_a_proj_with_mqa = DeepseekV3Linear( config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim + @@ -933,7 +947,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int, aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], - is_separate_draft_engine: bool = False): + is_separate_draft_engine: bool = False, + mapping_with_cp: Optional[Mapping] = None): super().__init__() self.model_config = model_config self.config = model_config.pretrained_config @@ -955,7 +970,8 @@ def __init__(self, self.self_attn = DeepseekV3Attention( model_config, layer_idx=layer_idx_for_attention, - aux_stream=aux_stream_dict[AuxStreamType.Attention]) + aux_stream=aux_stream_dict[AuxStreamType.Attention], + mapping_with_cp=mapping_with_cp) self.enable_attention_dp = mapping.enable_attention_dp self.mlp_tp_size = mapping.tp_size @@ -1412,7 +1428,7 @@ def norm_hidden(): class DeepseekV3Model(DecoderModel): - def __init__(self, model_config: ModelConfig[PretrainedConfig]): + def __init__(self, model_config: ModelConfig[PretrainedConfig], mapping_with_cp: Optional[Mapping] = None): super().__init__(model_config) config = model_config.pretrained_config self.vocab_size = config.vocab_size @@ -1433,7 +1449,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): self.layers = nn.ModuleList([ DeepseekV3DecoderLayer(model_config, layer_idx, - self.aux_stream_dict) + self.aux_stream_dict, + mapping_with_cp=mapping_with_cp) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(hidden_size=config.hidden_size, @@ -1478,6 +1495,39 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, PretrainedConfig]): def __init__(self, model_config: ModelConfig[PretrainedConfig]): + + ############################################################################### + self.mapping_with_cp = None + # Note: Currently the usage of mapping is all over the place making its usage brittle + # in this file. As a temporary WAR, we hold on to an original copy of mapping when CP + # is in action. This shall be passed on to attention which is the only layer that's + # affected by CP. For other layers, CP ranks are repurposed to TP. This shall be undone + # at the end of __init__. + if model_config.mapping.cp_size > 1: + logger.info( + f"[DeepseekV3ForCausalLM::__init__] Repurposing KVP ranks to TP while keeping other details the same." + ) + self.mapping_with_cp = copy.deepcopy(model_config.mapping) + + original_tp_size = self.mapping_with_cp.tp_size + original_cp_size = self.mapping_with_cp.cp_size + + # Repurpose KVP ranks to TP while keeping other details the same. + model_config._frozen = False + model_config.mapping = Mapping( + world_size=model_config.mapping.world_size, + rank=model_config.mapping.rank, + gpus_per_node=model_config.mapping.gpus_per_node, + cp_size=1, + cp_config={}, + tp_size=original_tp_size * original_cp_size, + pp_size=model_config.mapping.pp_size, + moe_ep_size=model_config.mapping.moe_ep_size, + auto_parallel=model_config.mapping.auto_parallel, + enable_attention_dp=model_config.mapping.enable_attention_dp) + model_config._frozen = True + ############################################################################### + # Rename some keys of quant_config_dict to support legacy checkpoints if model_config.quant_config_dict is not None: model_config = copy.deepcopy(model_config) @@ -1491,7 +1541,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): model_config.quant_config_dict = quant_config_dict model_config._frozen = True - super().__init__(model=DeepseekV3Model(model_config), + super().__init__(model=DeepseekV3Model(model_config, mapping_with_cp=self.mapping_with_cp), model_config=model_config) self.model_nextn = 0 @@ -1525,6 +1575,17 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): self.epilogue.extend(self.draft_model.mtp_layers) self.epilogue.append(self.spec_worker) + ############################################################################### + # Undo any manipulations done to mapping. + if self.mapping_with_cp is not None: + logger.info( + f"[DeepseekV3ForCausalLM::__init__] Restoring original mapping." + ) + model_config._frozen = False + model_config.mapping = self.mapping_with_cp + model_config._frozen = True + ############################################################################### + def forward( self, attn_metadata: AttentionMetadata, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 06e7f3e231a..35d84b54865 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -687,6 +687,7 @@ def __init__( dense_bias: Optional[bool] = None, config: Optional[ModelConfig] = None, enable_unit_test: bool = False, + mapping_with_cp: Optional[Mapping] = None, ): """ Initialize the MLA module. @@ -758,7 +759,11 @@ def __init__( # tensor parallel config = config or ModelConfig() - self.mapping = config.mapping + if mapping_with_cp is not None: + logger.info("[MLA::__init__] OVERRIDING MAPPING WITH CP DETECTED.") + self.mapping = mapping_with_cp + else: + self.mapping = config.mapping tp_size = self.mapping.tp_size pp_size = self.mapping.pp_size cp_size = self.mapping.cp_size @@ -766,6 +771,8 @@ def __init__( tp_size = 1 if self.mapping.has_cp_ulysses(): raise NotImplementedError("MLA doesn't support CP Ulyssees yet") + if self.mapping.cp_size > 1: + assert self.mapping.cp_config['cp_type'] == CpType.HELIX, "MLA only supports CP Helix parallelism for now." mapping = Mapping( world_size=tp_size * pp_size * cp_size, diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index 90af4440c36..ee38f96d7d5 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -47,12 +47,15 @@ def __init__( tp_size = overridden_tp_size # "Misuse" pp_size here to perform all-reduce within smaller groups pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size + # TODO: Figure if this change is actually needed given the repurposing of CP ranks to TP. mapping = Mapping( - world_size=tp_size * pp_size, + world_size=tp_size * pp_size * self.mapping.cp_size, rank=self.mapping.rank, gpus_per_node=self.mapping.gpus_per_node, tp_size=tp_size, pp_size=pp_size, + cp_size=self.mapping.cp_size, + cp_config=self.mapping.cp_config, ) else: mapping = config.mapping diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index fb720838126..37bd17c3f97 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -472,6 +472,7 @@ def __init__( self.py_prompt_len = self.prompt_len self.py_orig_prompt_len = self.orig_prompt_len self.py_max_new_tokens = self.max_new_tokens + self.py_helix_is_inactive_rank = False self.py_min_length = self.sampling_config.min_length self.py_batch_idx = None self.py_draft_pages_allocated = 0 diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e6da9fc216a..8e5b2cd6389 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -517,9 +517,9 @@ def warmup(self, resource_manager: ResourceManager) -> None: # TODO: current warmup_request is not suitable for context parallelism. cp_type = self.mapping.cp_config.get('cp_type', None) if cp_type is not None: - logger.info("[ModelEngine::warmup] Skipping warmup for cp_type: ", - cp_type.name) - return + if cp_type in [CpType.ULYSSES, CpType.STAR]: + logger.info("[ModelEngine::warmup] Early return since cp_type is ", cp_type) + return self._run_torch_compile_warmup(resource_manager) self._run_autotuner_warmup(resource_manager) @@ -992,10 +992,15 @@ def _init_max_seq_len(self): # NOTE: py_executor_creator makes sure that the executor uses this # smaller value as its max_seq_len too. logger.warning( - f"Specified {self.max_seq_len=} is larger than what the model can support " - f"({inferred_max_seq_len}). Setting max_seq_len to {inferred_max_seq_len}. " + f"\n*******************************************************\n" + f"Specified {self.max_seq_len=} is larger than what the model can support\n" + f"({inferred_max_seq_len}). NOT Setting max_seq_len to {inferred_max_seq_len}. " + f"ARE YOU SURE ABOUT THIS?\n" + f"*******************************************************\n" ) - self.max_seq_len = inferred_max_seq_len + # TODO: Undo this change before merging. + # self.max_seq_len = inferred_max_seq_len + pass def _infer_max_seq_len_from_config(self) -> int: @@ -1450,6 +1455,7 @@ def _prepare_tp_inputs( # update batch index request.py_batch_idx = request.py_seq_slot + helix_is_inactive_rank = [] if self.mapping.cp_size > 1 else None for request in generation_requests: request_ids.append(request.py_request_id) beam_width = request.sampling_config.beam_width @@ -1474,12 +1480,19 @@ def _prepare_tp_inputs( past_seen_token_num = request.max_beam_num_tokens position_id = past_seen_token_num if self.mapping.has_cp_helix(): - # Do an allgather among CP ranks to get the complete sequence length seen by all CP ranks. - past_seen_token_nums = self.dist.cp_allgather( - past_seen_token_num) - position_id = sum(past_seen_token_nums) + # Warmup doesn't have `total_input_len_cp` set because merge_helix_requests is not called. + if not self.is_warmup and not request.is_cuda_graph_dummy: + position_id = request.total_input_len_cp + request.py_decoding_iter - 1 + # Assuming last CP rank is the active rank. + if self.mapping.cp_rank == self.mapping.cp_size - 1: + past_seen_token_num = request.orig_prompt_len + request.py_decoding_iter - 1 + else: + # past_seen_token_num doesn't grow on inactive ranks. + past_seen_token_num = request.orig_prompt_len position_ids.append(position_id) num_cached_tokens_per_seq.append(past_seen_token_num) + if self.mapping.has_cp_helix(): + helix_is_inactive_rank.append(request.py_helix_is_inactive_rank) request.cached_tokens = num_cached_tokens_per_seq[-1] prompt_lengths.append(request.py_prompt_len) draft_lens.append(0) @@ -1680,6 +1693,7 @@ def previous_seq_slots_device(): attn_metadata.request_ids = request_ids attn_metadata.prompt_lens = prompt_lengths + attn_metadata.helix_is_inactive_rank = helix_is_inactive_rank attn_metadata.num_contexts = len(scheduled_requests.context_requests) # Use num_chunked_ctx_requests to record the number of extend context requests, # so that we can update the kv_lens_cuda correctly in _preprocess_inputs. diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 248599835d7..13c50b18d6f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1871,7 +1871,8 @@ def _update_request_states_star_attention( @nvtx_range("_update_request_states") def _update_request_states(self, scheduled_requests: ScheduledRequests): cp_config = self.dist.cp_config - if 'cp_type' in cp_config: + # Note: Helix Parallelism uses the same logic as tp parallelism here. + if 'cp_type' in cp_config and cp_config['cp_type'] != CpType.HELIX: cp_type = cp_config['cp_type'] if cp_type == CpType.STAR: self._update_request_states_star_attention(scheduled_requests) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 764bd3937d2..c1140b0da4d 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -177,6 +177,10 @@ def __init__( indexer_k_cache_index_head_dim: int = 0, **kwargs, ) -> None: + # Couple of places where we assume tokens_per_block is 32: Let's assert here for now. + # 1) block assignment in merge_helix_requests + # 2) computation of cache_transceiver_config.max_tokens_in_buffer. + assert tokens_per_block == 32, "tokens_per_block must be 32 for helix benchmarking." self.mapping = mapping self.dtype = dtype self.kv_cache_type = kv_cache_type @@ -438,6 +442,18 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req, block_ids) for req in generation_batch: + # Skip allocating KV cache at decode for inactive helix ranks. + ################################################################## + # TODO: For now, we hardcode that last rank is active. + if self.mapping.has_cp_helix(): + if self.mapping.cp_rank != self.mapping.cp_size - 1: + req.py_helix_is_inactive_rank = True + ################################################################## + if req.py_helix_is_inactive_rank: + print(f"[ResourceManager::prepare_resources][rank {self.mapping.rank}] Skipping KV allocation for request {req.py_request_id}.") + continue + print(f"[ResourceManager::prepare_resources][rank {self.mapping.rank}] Adding KV allocation for request {req.py_request_id}.") + self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) diff --git a/tensorrt_llm/llmapi/disagg_utils.py b/tensorrt_llm/llmapi/disagg_utils.py index 1ef5f413973..b38fd514bc5 100644 --- a/tensorrt_llm/llmapi/disagg_utils.py +++ b/tensorrt_llm/llmapi/disagg_utils.py @@ -198,7 +198,7 @@ def extract_ctx_gen_cfgs(type: Literal['ctx', 'gen'], # Compute the number of ranks per instance instance_num_ranks = kwargs.get('tensor_parallel_size', 1) * kwargs.get( - 'pipeline_parallel_size', 1) + 'pipeline_parallel_size', 1) * kwargs.get('context_parallel_size', 1) cfgs = [] for hostname, port in zip(hostnames, ports): diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index 6d38d948a21..0243d62fe09 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -63,6 +63,11 @@ def __init__( cp_type = CpType.ULYSSES if cp_config is None else cp_config.get( "cp_type", CpType.ULYSSES) + ################################################################# + # TODO: Remove this hardcoding. + if cp_size > 1: + assert cp_type == CpType.HELIX + ################################################################# moe_world_size = tp_size if cp_type == CpType.ULYSSES else tp_size * cp_size if moe_tp_size == -1 and moe_ep_size == -1: diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 1473a1cf29c..3bdf735ee85 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -632,7 +632,9 @@ async def get_unready_servers(servers: list[str]) -> list[str]: async def check_all_servers_ready(): iter = 0 - unready_servers = await get_unready_servers(ctx_servers + gen_servers) + # Skip context servers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set + servers_to_check = gen_servers if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1" else ctx_servers + gen_servers + unready_servers = await get_unready_servers(servers_to_check) while len(unready_servers) > 0: wait_time = 3 logger.info( @@ -645,7 +647,11 @@ async def check_all_servers_ready(): await asyncio.wait_for(check_all_servers_ready(), timeout=server_start_timeout_secs) except asyncio.CancelledError: raise TimeoutError("Timeout waiting for context and generation servers to be ready") - logger.info("Context and generation servers are ready") + + if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1": + logger.info("Generation servers are ready (context servers skipped)") + else: + logger.info("Context and generation servers are ready") async def is_ready(self) -> bool: if self.disagg_cluster_manager: diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml new file mode 100644 index 00000000000..270bd70994a --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml @@ -0,0 +1,36 @@ +hostname: localhost +port: 8000 +model: DeepSeek-V3-Lite/bf16 +free_gpu_memory_fraction: 0.25 +backend: "pytorch" +disable_overlap_scheduler: True +cuda_graph_config: null +context_servers: + num_instances: 1 + enable_chunked_prefill: False + kv_cache_config: + enable_block_reuse: False + enable_partial_reuse: False + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "UCX" + max_tokens_in_buffer: 4096 + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + context_parallel_size: 2 + cp_config: + cp_type: helix + enable_chunked_prefill: False + kv_cache_config: + enable_block_reuse: False + enable_partial_reuse: False + cache_transceiver_config: + backend: "UCX" + max_tokens_in_buffer: 4096 + urls: + - "localhost:8002" \ No newline at end of file diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_fp8_tllm_gen.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_fp8_tllm_gen.yaml new file mode 100644 index 00000000000..fc38b1359e5 --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_fp8_tllm_gen.yaml @@ -0,0 +1,40 @@ +hostname: localhost +port: 8000 +model: DeepSeek-V3-Lite/fp8 +free_gpu_memory_fraction: 0.25 +backend: "pytorch" +disable_overlap_scheduler: True +cuda_graph_config: null +context_servers: + num_instances: 1 + enable_chunked_prefill: False + kv_cache_config: + enable_block_reuse: False + enable_partial_reuse: False + moe_config: + backend: DEEPGEMM + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "UCX" + max_tokens_in_buffer: 4096 + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + context_parallel_size: 2 + cp_config: + cp_type: helix + enable_chunked_prefill: False + kv_cache_config: + enable_block_reuse: False + enable_partial_reuse: False + cache_transceiver_config: + backend: "UCX" + max_tokens_in_buffer: 4096 + moe_config: + backend: DEEPGEMM + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 720da1acbdc..c707fc399f0 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -261,6 +261,14 @@ def get_test_config(test_desc, example_dir, test_root): (4, f"{test_configs_root}/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_ctxpp2_gentp2.yaml" ), + "deepseek_v3_lite_fp8_tllm_gen_helix": + (4, + f"{test_configs_root}/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_fp8_tllm_gen.yaml" + ), + "deepseek_v3_lite_bf16_tllm_gen_helix": + (4, + f"{test_configs_root}/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml" + ), } if test_desc not in config_map: @@ -1493,6 +1501,48 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_two_mtp( cwd=llm_venv.get_working_directory()) +@pytest.mark.skip_less_device(4) +@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'], + indirect=True) +def test_disaggregated_deepseek_v3_lite_fp8_tllm_gen_helix( + disaggregated_test_root, disaggregated_example_root, llm_venv, + deepseek_v3_model_root): + src_dst_dict = { + deepseek_v3_model_root: + f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/fp8", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + run_disaggregated_test(disaggregated_example_root, + "deepseek_v3_lite_fp8_tllm_gen_helix", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory()) + + +@pytest.mark.skip_less_device(4) +@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'], + indirect=True) +def test_disaggregated_deepseek_v3_lite_bf16_tllm_gen_helix( + disaggregated_test_root, disaggregated_example_root, llm_venv, + deepseek_v3_model_root): + src_dst_dict = { + deepseek_v3_model_root: + f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + run_disaggregated_test(disaggregated_example_root, + "deepseek_v3_lite_bf16_tllm_gen_helix", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory()) + + @pytest.fixture(scope="module") def benchmark_root(): llm_root = os.getenv("LLM_ROOT")