Skip to content

Commit 909d641

Browse files
committed
add success into GenerationResult
1 parent dfb3dd1 commit 909d641

File tree

4 files changed

+373
-7
lines changed

4 files changed

+373
-7
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,16 @@ def __init__(self,
163163
return_log_probs: bool = False,
164164
return_context_logits: bool = False,
165165
return_generation_logits: bool = False,
166-
exclude_last_generation_logits: bool = False):
166+
exclude_last_generation_logits: bool = False,
167+
success: bool = False):
167168
self._streaming = streaming
168169
self._context_logits = LogitsStorage(
169170
prompt_len, use_device_memory) if return_context_logits else None
170171
self._generation_logits = LogitsStorage(
171172
max_new_tokens, use_device_memory, exclude_last_generation_logits
172173
) if return_generation_logits else None
173174
self._log_probs = LogProbStorage() if return_log_probs else None
175+
self._success = success
174176

175177
def append_context_logits(self, context_logits: torch.Tensor):
176178
if self._context_logits:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ..speculative.drafter import Drafter
3535
from .kv_cache_transceiver import KvCacheTransceiver
3636
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
37-
LlmResponse, LlmResult, executor_request_to_llm_request)
37+
LlmResponse, LlmResult, executor_request_to_llm_request, PyResult)
3838
from .model_engine import ModelEngine
3939
from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler
4040
from .scheduler import RequestScheduler, ScheduledRequests
@@ -1086,11 +1086,11 @@ def _prepare_draft_requests(self):
10861086

10871087
def _sleep(self, sleep_request):
10881088
self.is_sleep_request = False
1089-
self._enqueue_responses({sleep_request.id: LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=None, is_final=True), client_id=sleep_request.id)})
1089+
self._enqueue_responses({sleep_request.id: LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=sleep_request.id)})
10901090

10911091
def _wakeup(self, wakeup_request):
10921092
self.is_wakeup_request = False
1093-
self._enqueue_responses({wakeup_request.id: LlmResponse(request_id=wakeup_request.id, result=LlmResult(result=None, py_result=None, is_final=True), client_id=wakeup_request.id)})
1093+
self._enqueue_responses({wakeup_request.id: LlmResponse(request_id=wakeup_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=wakeup_request.id)})
10941094

10951095
def _update_weight(self, update_weight_request):
10961096
self.is_update_weight_request = False
@@ -1119,13 +1119,13 @@ def _update_weight(self, update_weight_request):
11191119
self.model_engine.model.load_weights(weights)
11201120

11211121
torch.cuda.synchronize()
1122-
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=None, is_final=True), client_id=update_weight_request.id)
1122+
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=update_weight_request.id)
11231123
self._enqueue_responses({update_weight_request.id: update_weight_response})
11241124
except Exception as e:
11251125
print(
11261126
f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: {e}"
11271127
)
1128-
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=None, is_final=True), client_id=update_weight_request.id)
1128+
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=False), is_final=True), client_id=update_weight_request.id)
11291129
self._enqueue_responses({update_weight_request.id: update_weight_response})
11301130

11311131
def _executor_loop_overlap(self):

tensorrt_llm/executor/result.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def __init__(self,
146146
self.disaggregated_params = None
147147
self.decoding_iter = 0
148148
self._done = False
149+
self._success = False
149150

150151
if has_event_loop():
151152
self.aqueue = AsyncQueue()
@@ -303,6 +304,7 @@ def _handle_response(self,
303304
response_result.deserialize()
304305

305306
self._done = response_result.is_final
307+
self._success = True # TODO: replace with response_result._py_result._success
306308
context_phase_params = response_result.context_phase_params
307309
self.decoding_iter = response_result.decoding_iter
308310
if context_phase_params is not None:
@@ -332,10 +334,13 @@ def _handle_response(self,
332334
handler := self._background_error_handler()):
333335
handler()
334336
elif is_update_weights_response(response):
337+
self._success = response.result._py_result._success
335338
self._done = True
336339
elif is_sleep_response(response):
340+
self._success = response.result._py_result._success
337341
self._done = True
338342
elif is_wakeup_response(response):
343+
self._success = response.result._py_result._success
339344
self._done = True
340345
elif isinstance(response, ErrorResponse):
341346
if self._background_error_handler is not None and (
@@ -463,6 +468,10 @@ def aborted(self) -> bool:
463468
def finished(self) -> bool:
464469
return self._done
465470

471+
@property
472+
def success(self) -> bool:
473+
return self._success
474+
466475
def clear_logprob_params(self) -> None:
467476
# Remove temporary attribute used in executor
468477
# for a cleaner external-facing output.
@@ -533,7 +542,7 @@ def _exception(self, timeout: Optional[float] = None):
533542

534543
def _repr_fields(self):
535544
return [
536-
'request_id', 'prompt_token_ids', 'outputs', 'finished',
545+
'request_id', 'prompt_token_ids', 'outputs', 'finished', 'success',
537546
"context_logits"
538547
]
539548

0 commit comments

Comments
 (0)