diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 3a454cb740f..052350ecd89 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1197,16 +1197,16 @@ def _prepare_tp_inputs( new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch] next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len] - # Requests with draft tokens are treated like extend requests. Dummy extend requests should be - # at the end of extend_requests. + # Requests with draft tokens are treated like extend requests. CUDA graph dummy extend + # requests should be at the end of extend_requests. extend_requests = [] - extend_dummy_requests = [] + extend_cuda_graph_dummy_requests = [] generation_requests = [] for request in scheduled_requests.generation_requests: if len(request.py_draft_tokens ) > 0 or next_draft_tokens_device is not None: - if request.is_dummy: - extend_dummy_requests.append(request) + if request.is_cuda_graph_dummy: + extend_cuda_graph_dummy_requests.append(request) else: extend_requests.append(request) else: @@ -1219,8 +1219,8 @@ def _prepare_tp_inputs( pin_memory=True) mrope_config['mrope_position_deltas'].append( mrope_position_deltas.to('cuda', non_blocking=True)) - extend_requests += extend_dummy_requests + extend_requests = extend_cuda_graph_dummy_requests + extend_requests if not self._disable_overlap_scheduler and self.is_spec_decode: spec_dec_mode = self.spec_config.spec_dec_mode assert spec_dec_mode.support_overlap_scheduler( @@ -1229,18 +1229,18 @@ def _prepare_tp_inputs( # will contain previous batch incices of generation requests previous_batch_indices = [] previous_pos_indices = [] + request_ids_with_previous_batch = [] + num_extend_reqs_wo_previous_batch = 0 for request in extend_requests: # the request has no previous tensor: # (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or # (2) a dummy request; or # (3) the first step in the generation server of disaggregated serving if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None: - # get token ids, including input token ids and draft token ids. For these dummy requests, - # no need to copy the token ids. - if not request.is_dummy: - input_ids.append(request.get_last_tokens(0)) - input_ids.extend(request.py_draft_tokens) - draft_tokens.extend(request.py_draft_tokens) + # get token ids, including input token ids and draft token ids + input_ids.append(request.get_last_tokens(0)) + input_ids.extend(request.py_draft_tokens) + draft_tokens.extend(request.py_draft_tokens) # get other ids and lengths num_draft_tokens = len(request.py_draft_tokens) past_seen_token_num = request.max_beam_num_tokens - 1 @@ -1268,6 +1268,7 @@ def _prepare_tp_inputs( # update batch index request.py_batch_idx = batch_idx batch_idx += 1 + num_extend_reqs_wo_previous_batch += 1 else: # update batch index previous_batch_idx = request.py_batch_idx @@ -1294,7 +1295,10 @@ def _prepare_tp_inputs( num_cached_tokens_per_seq.append(past_seen_token_num + self.max_draft_len + 1) prompt_lengths.append(request.py_prompt_len) - request_ids.append(request.py_request_id) + request_ids_with_previous_batch.append(request.py_request_id) + + # move requests with previous batch to the end of the list + request_ids.extend(request_ids_with_previous_batch) sequence_lengths.extend([1] * len(generation_requests)) gather_ids.extend( @@ -1329,7 +1333,6 @@ def _prepare_tp_inputs( num_tokens = len(input_ids) num_draft_tokens = len(draft_tokens) previous_batchs = len(previous_batch_indices) - num_requests = len(request_ids) total_num_tokens = len(position_ids) assert total_num_tokens <= self.max_num_tokens, ( "total_num_tokens should be less than or equal to max_num_tokens") @@ -1371,27 +1374,31 @@ def _prepare_tp_inputs( non_blocking=True) # prepare data for the preprocess inputs kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1 + pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * ( + 1 + self.max_draft_len) + pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens + pre_batch_start_idx = num_extend_reqs_wo_previous_batch + pre_batch_end_idx = pre_batch_start_idx + previous_batchs previous_pos_indices = torch.tensor(previous_pos_indices, dtype=torch.int, pin_memory=True) - self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_( - previous_pos_indices, non_blocking=True) + self.previous_pos_indices_cuda[ + pre_tokens_start_idx:pre_tokens_end_idx].copy_( + previous_pos_indices, non_blocking=True) self.previous_pos_id_offsets_cuda[ - 0:previous_batch_tokens].copy_( + pre_tokens_start_idx:pre_tokens_end_idx].copy_( new_tokens_lens_device[self.previous_pos_indices_cuda[ - 0:previous_batch_tokens]], + pre_tokens_start_idx:pre_tokens_end_idx]], + non_blocking=True) + self.previous_kv_lens_offsets_cuda[ + pre_batch_start_idx:pre_batch_end_idx].copy_( + kv_len_offsets_device[ + self.previous_batch_indices_cuda[:previous_batchs]], non_blocking=True) - self.previous_kv_lens_offsets_cuda[0:previous_batchs].copy_( - kv_len_offsets_device[ - self.previous_batch_indices_cuda[:previous_batchs]], - non_blocking=True) # for the requests that do not have previous batch, set the previous_pos_id_offsets and # previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs - self.previous_pos_id_offsets_cuda[ - previous_batch_tokens:num_requests * - (1 + self.max_draft_len)] *= 0 - self.previous_kv_lens_offsets_cuda[ - previous_batchs:num_requests] *= 0 + self.previous_pos_id_offsets_cuda[:pre_tokens_start_idx] *= 0 + self.previous_kv_lens_offsets_cuda[:pre_batch_start_idx] *= 0 else: # change the data to zeros to skip the value changes in _preprocess_inputs self.previous_pos_id_offsets_cuda *= 0 diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 2aa50df07f1..67e807ef4ba 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1491,7 +1491,7 @@ def _check_disagg_gen_transfer_status(self): @nvtx_range("_pad_attention_dp_dummy_request") def _pad_attention_dp_dummy_request(self): """ - Pad with a dummy request, if required, to ensure every attention_dp rank has at least one active request. + Pad with dummy requests, if required, to avoid empty attention_dp rank. """ if not self.enable_attention_dp: return @@ -1505,20 +1505,22 @@ def _pad_attention_dp_dummy_request(self): or req.is_disagg_generation_transmission_in_progress else 1 for req in self.active_requests ]) - - if self.expected_num_active_requests - num_active_request > 0 and num_active_request == 0: - llm_request = self.kv_cache_manager.add_dummy_requests( - request_ids=[0], + num_dummy_request = self.expected_num_active_requests - num_active_request + if num_dummy_request > 0: + llm_request_list = self.kv_cache_manager.add_dummy_requests( + request_ids=list(range(num_dummy_request)), is_gen=not self.has_context_request, prepare_resource=not self.has_context_request, max_num_draft_tokens=self.max_draft_tokens, - )[0] - llm_request.is_attention_dp_dummy = True + ) + for llm_request in llm_request_list: + llm_request.is_attention_dp_dummy = True spec_resource_manager = self.resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) if spec_resource_manager is not None: - spec_resource_manager.add_dummy_requests([0]) - self.active_requests.append(llm_request) + spec_resource_manager.add_dummy_requests( + list(range(num_dummy_request))) + self.active_requests += llm_request_list @nvtx_range("_prepare_disagg_gen_init") def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests): @@ -1635,13 +1637,12 @@ def forward(scheduled_requests, resource_manager, new_tensors_device, def _update_request_states_tp(self, scheduled_requests: ScheduledRequests): # handle potential attention dp dummy request - if self.active_requests and self.active_requests[ - -1].is_attention_dp_dummy: - request = self.active_requests[-1] - request.state = LlmRequestState.GENERATION_COMPLETE - self.inflight_req_ids.erase(request.py_request_id) - self._terminate_request(request) - self.active_requests.remove(request) + for request in self.active_requests[:]: + if request.is_attention_dp_dummy: + request.state = LlmRequestState.GENERATION_COMPLETE + self.inflight_req_ids.erase(request.py_request_id) + self._terminate_request(request) + self.active_requests.remove(request) for request in scheduled_requests.context_requests: if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 82ed94ab0fd..29a17593e97 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -267,6 +267,12 @@ def update_requests(self, state: SampleStateMTP) -> None: request.py_decoding_iter += 1 idx += 1 + # skip the results of cuda graph dummy requests + if idx == 0: + num_cuda_graph_dummy_requests = len(new_tokens_list) - len( + state.scheduled_requests.generation_requests) + idx += num_cuda_graph_dummy_requests + for request in state.scheduled_requests.generation_requests: assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler" assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler" diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 0e4cd289944..8c6e4377b90 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -635,7 +635,8 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph, @pytest.mark.skip_device_not_contain(["H100"]) @parametrize_with_ids("mtp_nextn", [0, 2]) def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn): - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) + # OOM on H100 with default free_gpu_memory_fraction=0.9 + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) mtp_config = None if mtp_nextn > 0: mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)