Skip to content

Commit 079ef8a

Browse files
jellysnacksyuoni
andauthored
[None][feat] Graceful Error Handling for Guided Decoder (NVIDIA#9078)
Signed-off-by: jellysnack <oleg.jellysnack@gmail.com> Signed-off-by: jellysnack <158609015+jellysnack@users.noreply.github.com> Co-authored-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent 85406f9 commit 079ef8a

File tree

2 files changed

+127
-63
lines changed

2 files changed

+127
-63
lines changed

tensorrt_llm/_torch/pyexecutor/guided_decoder.py

Lines changed: 77 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -204,73 +204,84 @@ def __init__(self,
204204
def bitmask_size(self) -> int:
205205
return math.ceil(self.vocab_size_padded / 32)
206206

207-
def _build(self, requests: GuidedRequests) -> None:
207+
def _build(self, requests: GuidedRequests) -> List[Tuple[int, str]]:
208208
"""Build the bitmask for requests with guided decoding enabled.
209209
210210
Specifically, this method:
211211
- build and advance the grammar matcher for context and generation requests, respectively;
212212
- call the grammar matcher to fill the bitmask on CPU;
213213
- asynchronously copy the bitmask to GPU.
214214
"""
215+
failed_requests = []
215216
self.token_mask_host[:requests.num_bitmask_tokens].fill_(0)
216217

217218
for req, offset in requests.valid_requests_with_offsets():
218219
slot = req.seq_slot
219-
self.num_advanced_tokens[slot] = 0
220-
self.num_guided_tokens[slot] = 0
220+
try:
221+
self.num_advanced_tokens[slot] = 0
222+
self.num_guided_tokens[slot] = 0
221223

222-
matcher_init: bool = req.require_matcher_init()
223-
matcher_advance: bool = req.require_matcher_advance()
224-
if not (matcher_init or matcher_advance):
225-
continue
226-
227-
if matcher_init:
228-
matcher = self.grammar_matcher_factory.create(
229-
req.guided_decoding_params)
230-
self.grammar_matchers[slot] = matcher
231-
232-
if matcher_advance:
233-
matcher = self.grammar_matchers[slot]
234-
# The last new token must be acceptable unless the matcher is terminated:
235-
# 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration.
236-
# 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration.
237-
if matcher.is_terminated() or self.is_draft_terminated[slot]:
224+
matcher_init: bool = req.require_matcher_init()
225+
matcher_advance: bool = req.require_matcher_advance()
226+
if not (matcher_init or matcher_advance):
238227
continue
239-
accepted = matcher.accept_token(req.new_token)
240-
if not accepted:
241-
if req.is_draft:
242-
self.is_draft_terminated[slot] = True
243-
logger.debug(
244-
f"Draft request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
245-
)
228+
229+
if matcher_init:
230+
matcher = self.grammar_matcher_factory.create(
231+
req.guided_decoding_params)
232+
self.grammar_matchers[slot] = matcher
233+
234+
if matcher_advance:
235+
matcher = self.grammar_matchers[slot]
236+
# The last new token must be acceptable unless the matcher is terminated or None:
237+
# 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration.
238+
# 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration.
239+
# 3. The matcher can be None if there was an error during its creation.
240+
if matcher is None or matcher.is_terminated(
241+
) or self.is_draft_terminated[slot]:
246242
continue
247-
# TODO: Make this an error response.
248-
raise ValueError(
249-
f"Request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
250-
)
251-
252-
self.num_advanced_tokens[slot] += 1
253-
if not matcher.is_terminated():
254-
matcher.fill_next_token_bitmask(self.bitmask_host, offset)
255-
self.token_mask_host[offset] = 1
256-
self.num_guided_tokens[slot] += 1
257-
# Process draft tokens
258-
for i, tid in enumerate(req.draft_tokens, 1):
259-
accepted = matcher.accept_token(tid)
243+
accepted = matcher.accept_token(req.new_token)
260244
if not accepted:
261-
break
262-
self.num_advanced_tokens[slot] += 1
263-
if matcher.is_terminated():
264-
break
265-
matcher.fill_next_token_bitmask(self.bitmask_host,
266-
offset + i)
267-
self.token_mask_host[offset + i] = 1
245+
if req.is_draft:
246+
self.is_draft_terminated[slot] = True
247+
logger.debug(
248+
f"Draft request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
249+
)
250+
continue
251+
raise ValueError(
252+
f"Request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
253+
)
254+
255+
self.num_advanced_tokens[slot] += 1
256+
if not matcher.is_terminated():
257+
matcher.fill_next_token_bitmask(self.bitmask_host, offset)
258+
self.token_mask_host[offset] = 1
268259
self.num_guided_tokens[slot] += 1
260+
# Process draft tokens
261+
for i, tid in enumerate(req.draft_tokens, 1):
262+
accepted = matcher.accept_token(tid)
263+
if not accepted:
264+
break
265+
self.num_advanced_tokens[slot] += 1
266+
if matcher.is_terminated():
267+
break
268+
matcher.fill_next_token_bitmask(self.bitmask_host,
269+
offset + i)
270+
self.token_mask_host[offset + i] = 1
271+
self.num_guided_tokens[slot] += 1
272+
273+
if req.is_draft:
274+
assert len(req.draft_tokens) == 0
275+
self.num_advanced_draft_tokens[
276+
slot] += self.num_advanced_tokens[slot]
277+
except Exception as e:
278+
error_msg = f"Guided decoding error: {str(e)}"
279+
failed_requests.append((req.request_id, error_msg))
280+
logger.error(
281+
f"Request {req.request_id} at slot {slot} failed during guided decoding: {error_msg}"
282+
)
269283

270-
if req.is_draft:
271-
assert len(req.draft_tokens) == 0
272-
self.num_advanced_draft_tokens[
273-
slot] += self.num_advanced_tokens[slot]
284+
return failed_requests
274285

275286
def _copy_bitmask(self,
276287
requests: GuidedRequests,
@@ -306,8 +317,8 @@ def add_batch(self, scheduled_requests: ScheduledRequests) -> None:
306317
scheduled_requests, self.max_num_draft_tokens)
307318

308319
@nvtx_range("GuideDecoder.build")
309-
def build(self) -> None:
310-
self._build(self.requests)
320+
def build(self) -> List[Tuple[int, str]]:
321+
return self._build(self.requests)
311322

312323
@nvtx_range("GuideDecoder.copy_bitmask")
313324
def copy_bitmask(self, num_bitmask_tokens: Optional[int] = None) -> None:
@@ -325,8 +336,8 @@ def apply_bitmask(self,
325336

326337
def execute(self,
327338
logits: torch.Tensor,
328-
d2t: Optional[torch.Tensor] = None) -> None:
329-
self.build()
339+
d2t: Optional[torch.Tensor] = None) -> List[Tuple[int, str]]:
340+
failed_requests = self.build()
330341

331342
with torch.cuda.stream(self.stream):
332343
torch.cuda.current_stream().wait_event(self.token_event)
@@ -337,6 +348,8 @@ def execute(self,
337348
self.apply_bitmask(logits, d2t=d2t)
338349
self.token_event.record()
339350

351+
return failed_requests
352+
340353
def _rollback_rejected_tokens(self, requests: GuidedRequests) -> None:
341354
"""Rollback the grammar matcher for rejected tokens.
342355
@@ -460,23 +473,25 @@ def fetch_batch(self) -> None:
460473
)
461474

462475
@hostfunc
463-
def build(self) -> None:
464-
self._build(self.requests_hostfunc)
476+
def build(self) -> List[Tuple[int, str]]:
477+
return self._build(self.requests_hostfunc)
465478

466479
def execute(self,
467480
logits: torch.Tensor,
468-
d2t: Optional[torch.Tensor] = None) -> None:
481+
d2t: Optional[torch.Tensor] = None) -> List[Tuple[int, str]]:
469482
with torch.cuda.stream(self.stream):
470483
torch.cuda.current_stream().wait_event(self.token_event)
471484
self.fetch_batch()
472485
self.init_disagg_gen_requests()
473-
self.build()
486+
failed_requests = self.build()
474487
self.copy_bitmask()
475488
self.bitmask_event.record()
476489

477490
torch.cuda.current_stream().wait_event(self.bitmask_event)
478491
self.apply_bitmask(logits, d2t=d2t)
479492

493+
return failed_requests
494+
480495
@hostfunc
481496
def rollback_rejected_tokens(self) -> None:
482497
self._rollback_rejected_tokens(self.requests_hostfunc)
@@ -532,13 +547,13 @@ def fetch_draft_batch(self, draft_step: int = 0) -> None:
532547
def execute_draft_batch(self,
533548
logits: torch.Tensor,
534549
d2t: Optional[torch.Tensor] = None,
535-
draft_step: int = 0) -> None:
550+
draft_step: int = 0) -> List[Tuple[int, str]]:
536551
with torch.cuda.stream(self.stream):
537552
torch.cuda.current_stream().wait_event(self.token_event)
538553
self.fetch_draft_batch(draft_step=draft_step)
539554
if draft_step == 0:
540555
self.rollback_rejected_tokens()
541-
self.build()
556+
failed_requests = self.build()
542557
if draft_step == self.max_num_draft_tokens - 1:
543558
self.rollback_draft_tokens()
544559
# Overwrite num_bitmask_tokens since the request might not be updated on CUDA stream yet.
@@ -550,3 +565,5 @@ def execute_draft_batch(self,
550565
self.apply_bitmask(logits,
551566
d2t=d2t,
552567
num_bitmask_tokens=len(self.requests))
568+
569+
return failed_requests

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -984,14 +984,22 @@ def _executor_loop_pp(self):
984984

985985
batch_outputs = self._forward_step(scheduled_batch)
986986

987+
guided_decoder_failed_requests = None
987988
if self.guided_decoder is not None:
988989
self.guided_decoder.add_batch(scheduled_batch)
989-
self.guided_decoder.execute(
990+
guided_decoder_failed_requests = self.guided_decoder.execute(
990991
batch_outputs['logits'])
991992

992993
sample_state = self._sample_async(
993994
scheduled_batch, batch_outputs)
994995
assert sample_state is not None, "Sampling failed"
996+
997+
# Handle guided decoder errors after _sample_async to avoid state conflicts.
998+
# If called before, failed requests would be marked as GENERATION_COMPLETE,
999+
# causing _sample_async to fail when accessing context_chunk_size property.
1000+
self._handle_guided_decoder_errors(
1001+
scheduled_batch, guided_decoder_failed_requests)
1002+
9951003
self._update_request_states(scheduled_batch)
9961004

9971005
if self.enable_iter_perf_stats:
@@ -1306,11 +1314,21 @@ def _executor_loop(self):
13061314
self.guided_decoder.rollback_draft_tokens()
13071315

13081316
batch_outputs = self._forward_step(scheduled_batch)
1317+
1318+
guided_decoder_failed_requests = None
13091319
if self.guided_decoder is not None:
1310-
self.guided_decoder.execute(batch_outputs['logits'])
1320+
guided_decoder_failed_requests = self.guided_decoder.execute(
1321+
batch_outputs['logits'])
13111322

13121323
sample_state = self._sample_async(scheduled_batch,
13131324
batch_outputs)
1325+
1326+
# Handle guided decoder errors after _sample_async to avoid state conflicts.
1327+
# If called before, failed requests would be marked as GENERATION_COMPLETE,
1328+
# causing _sample_async to fail when accessing context_chunk_size property.
1329+
self._handle_guided_decoder_errors(
1330+
scheduled_batch, guided_decoder_failed_requests)
1331+
13141332
if self.drafter is not None:
13151333
self.drafter.run_drafter_post(scheduled_batch,
13161334
self.resource_manager,
@@ -1562,15 +1580,23 @@ def _executor_loop_overlap(self):
15621580
self.drafter.cleanup_previous_draft_resources()
15631581

15641582
if can_queue:
1583+
guided_decoder_failed_requests = None
15651584
if self.guided_decoder is not None:
15661585
# add_batch must be called again to have updated new tokens.
15671586
self.guided_decoder.add_batch(scheduled_batch)
1568-
self.guided_decoder.execute(batch_outputs['logits'])
1587+
guided_decoder_failed_requests = self.guided_decoder.execute(
1588+
batch_outputs['logits'])
15691589

15701590
sample_state = self._sample_async(scheduled_batch,
15711591
batch_outputs)
15721592
assert sample_state is not None, "Sampling failed"
15731593

1594+
# Handle guided decoder errors after _sample_async to avoid state conflicts.
1595+
# If called before, failed requests would be marked as GENERATION_COMPLETE,
1596+
# causing _sample_async to fail when accessing context_chunk_size property.
1597+
self._handle_guided_decoder_errors(
1598+
scheduled_batch, guided_decoder_failed_requests)
1599+
15741600
self._update_request_states(scheduled_batch)
15751601

15761602
ctx_transmission_reqs = self._send_disagg_ctx_cache(
@@ -2694,6 +2720,27 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,
26942720
def reset_prefix_cache(self):
26952721
self.kv_cache_manager.reset_reuse_state()
26962722

2723+
def _handle_guided_decoder_errors(
2724+
self, scheduled_batch: ScheduledRequests,
2725+
failed_requests: Optional[List[Tuple[int, str]]]):
2726+
"""Handle errors that occurred during guided decoding.
2727+
2728+
Args:
2729+
scheduled_batch: The current batch of scheduled requests
2730+
failed_requests: List of (request_id, error_message) tuples for failed requests,
2731+
or None if no failures occurred
2732+
"""
2733+
if not failed_requests:
2734+
return
2735+
2736+
failed_req_id_to_err = {req_id: err for req_id, err in failed_requests}
2737+
2738+
for request in scheduled_batch.all_requests():
2739+
if request.py_request_id not in failed_req_id_to_err:
2740+
continue
2741+
error_msg = failed_req_id_to_err[request.py_request_id]
2742+
self._handle_errors(error_msg, requests=[request])
2743+
26972744

26982745
class DisaggPPTerminationHandler:
26992746
"""Handles termination synchronization across pipeline parallel ranks under disaggregated serving.

0 commit comments

Comments
 (0)