@@ -204,73 +204,84 @@ def __init__(self,
204204 def bitmask_size (self ) -> int :
205205 return math .ceil (self .vocab_size_padded / 32 )
206206
207- def _build (self , requests : GuidedRequests ) -> None :
207+ def _build (self , requests : GuidedRequests ) -> List [ Tuple [ int , str ]] :
208208 """Build the bitmask for requests with guided decoding enabled.
209209
210210 Specifically, this method:
211211 - build and advance the grammar matcher for context and generation requests, respectively;
212212 - call the grammar matcher to fill the bitmask on CPU;
213213 - asynchronously copy the bitmask to GPU.
214214 """
215+ failed_requests = []
215216 self .token_mask_host [:requests .num_bitmask_tokens ].fill_ (0 )
216217
217218 for req , offset in requests .valid_requests_with_offsets ():
218219 slot = req .seq_slot
219- self .num_advanced_tokens [slot ] = 0
220- self .num_guided_tokens [slot ] = 0
220+ try :
221+ self .num_advanced_tokens [slot ] = 0
222+ self .num_guided_tokens [slot ] = 0
221223
222- matcher_init : bool = req .require_matcher_init ()
223- matcher_advance : bool = req .require_matcher_advance ()
224- if not (matcher_init or matcher_advance ):
225- continue
226-
227- if matcher_init :
228- matcher = self .grammar_matcher_factory .create (
229- req .guided_decoding_params )
230- self .grammar_matchers [slot ] = matcher
231-
232- if matcher_advance :
233- matcher = self .grammar_matchers [slot ]
234- # The last new token must be acceptable unless the matcher is terminated:
235- # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration.
236- # 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration.
237- if matcher .is_terminated () or self .is_draft_terminated [slot ]:
224+ matcher_init : bool = req .require_matcher_init ()
225+ matcher_advance : bool = req .require_matcher_advance ()
226+ if not (matcher_init or matcher_advance ):
238227 continue
239- accepted = matcher .accept_token (req .new_token )
240- if not accepted :
241- if req .is_draft :
242- self .is_draft_terminated [slot ] = True
243- logger .debug (
244- f"Draft request { req .request_id } at slot { slot } failed to accept last new token: { req .new_token } ."
245- )
228+
229+ if matcher_init :
230+ matcher = self .grammar_matcher_factory .create (
231+ req .guided_decoding_params )
232+ self .grammar_matchers [slot ] = matcher
233+
234+ if matcher_advance :
235+ matcher = self .grammar_matchers [slot ]
236+ # The last new token must be acceptable unless the matcher is terminated or None:
237+ # 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration.
238+ # 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration.
239+ # 3. The matcher can be None if there was an error during its creation.
240+ if matcher is None or matcher .is_terminated (
241+ ) or self .is_draft_terminated [slot ]:
246242 continue
247- # TODO: Make this an error response.
248- raise ValueError (
249- f"Request { req .request_id } at slot { slot } failed to accept last new token: { req .new_token } ."
250- )
251-
252- self .num_advanced_tokens [slot ] += 1
253- if not matcher .is_terminated ():
254- matcher .fill_next_token_bitmask (self .bitmask_host , offset )
255- self .token_mask_host [offset ] = 1
256- self .num_guided_tokens [slot ] += 1
257- # Process draft tokens
258- for i , tid in enumerate (req .draft_tokens , 1 ):
259- accepted = matcher .accept_token (tid )
243+ accepted = matcher .accept_token (req .new_token )
260244 if not accepted :
261- break
262- self .num_advanced_tokens [slot ] += 1
263- if matcher .is_terminated ():
264- break
265- matcher .fill_next_token_bitmask (self .bitmask_host ,
266- offset + i )
267- self .token_mask_host [offset + i ] = 1
245+ if req .is_draft :
246+ self .is_draft_terminated [slot ] = True
247+ logger .debug (
248+ f"Draft request { req .request_id } at slot { slot } failed to accept last new token: { req .new_token } ."
249+ )
250+ continue
251+ raise ValueError (
252+ f"Request { req .request_id } at slot { slot } failed to accept last new token: { req .new_token } ."
253+ )
254+
255+ self .num_advanced_tokens [slot ] += 1
256+ if not matcher .is_terminated ():
257+ matcher .fill_next_token_bitmask (self .bitmask_host , offset )
258+ self .token_mask_host [offset ] = 1
268259 self .num_guided_tokens [slot ] += 1
260+ # Process draft tokens
261+ for i , tid in enumerate (req .draft_tokens , 1 ):
262+ accepted = matcher .accept_token (tid )
263+ if not accepted :
264+ break
265+ self .num_advanced_tokens [slot ] += 1
266+ if matcher .is_terminated ():
267+ break
268+ matcher .fill_next_token_bitmask (self .bitmask_host ,
269+ offset + i )
270+ self .token_mask_host [offset + i ] = 1
271+ self .num_guided_tokens [slot ] += 1
272+
273+ if req .is_draft :
274+ assert len (req .draft_tokens ) == 0
275+ self .num_advanced_draft_tokens [
276+ slot ] += self .num_advanced_tokens [slot ]
277+ except Exception as e :
278+ error_msg = f"Guided decoding error: { str (e )} "
279+ failed_requests .append ((req .request_id , error_msg ))
280+ logger .error (
281+ f"Request { req .request_id } at slot { slot } failed during guided decoding: { error_msg } "
282+ )
269283
270- if req .is_draft :
271- assert len (req .draft_tokens ) == 0
272- self .num_advanced_draft_tokens [
273- slot ] += self .num_advanced_tokens [slot ]
284+ return failed_requests
274285
275286 def _copy_bitmask (self ,
276287 requests : GuidedRequests ,
@@ -306,8 +317,8 @@ def add_batch(self, scheduled_requests: ScheduledRequests) -> None:
306317 scheduled_requests , self .max_num_draft_tokens )
307318
308319 @nvtx_range ("GuideDecoder.build" )
309- def build (self ) -> None :
310- self ._build (self .requests )
320+ def build (self ) -> List [ Tuple [ int , str ]] :
321+ return self ._build (self .requests )
311322
312323 @nvtx_range ("GuideDecoder.copy_bitmask" )
313324 def copy_bitmask (self , num_bitmask_tokens : Optional [int ] = None ) -> None :
@@ -325,8 +336,8 @@ def apply_bitmask(self,
325336
326337 def execute (self ,
327338 logits : torch .Tensor ,
328- d2t : Optional [torch .Tensor ] = None ) -> None :
329- self .build ()
339+ d2t : Optional [torch .Tensor ] = None ) -> List [ Tuple [ int , str ]] :
340+ failed_requests = self .build ()
330341
331342 with torch .cuda .stream (self .stream ):
332343 torch .cuda .current_stream ().wait_event (self .token_event )
@@ -337,6 +348,8 @@ def execute(self,
337348 self .apply_bitmask (logits , d2t = d2t )
338349 self .token_event .record ()
339350
351+ return failed_requests
352+
340353 def _rollback_rejected_tokens (self , requests : GuidedRequests ) -> None :
341354 """Rollback the grammar matcher for rejected tokens.
342355
@@ -460,23 +473,25 @@ def fetch_batch(self) -> None:
460473 )
461474
462475 @hostfunc
463- def build (self ) -> None :
464- self ._build (self .requests_hostfunc )
476+ def build (self ) -> List [ Tuple [ int , str ]] :
477+ return self ._build (self .requests_hostfunc )
465478
466479 def execute (self ,
467480 logits : torch .Tensor ,
468- d2t : Optional [torch .Tensor ] = None ) -> None :
481+ d2t : Optional [torch .Tensor ] = None ) -> List [ Tuple [ int , str ]] :
469482 with torch .cuda .stream (self .stream ):
470483 torch .cuda .current_stream ().wait_event (self .token_event )
471484 self .fetch_batch ()
472485 self .init_disagg_gen_requests ()
473- self .build ()
486+ failed_requests = self .build ()
474487 self .copy_bitmask ()
475488 self .bitmask_event .record ()
476489
477490 torch .cuda .current_stream ().wait_event (self .bitmask_event )
478491 self .apply_bitmask (logits , d2t = d2t )
479492
493+ return failed_requests
494+
480495 @hostfunc
481496 def rollback_rejected_tokens (self ) -> None :
482497 self ._rollback_rejected_tokens (self .requests_hostfunc )
@@ -532,13 +547,13 @@ def fetch_draft_batch(self, draft_step: int = 0) -> None:
532547 def execute_draft_batch (self ,
533548 logits : torch .Tensor ,
534549 d2t : Optional [torch .Tensor ] = None ,
535- draft_step : int = 0 ) -> None :
550+ draft_step : int = 0 ) -> List [ Tuple [ int , str ]] :
536551 with torch .cuda .stream (self .stream ):
537552 torch .cuda .current_stream ().wait_event (self .token_event )
538553 self .fetch_draft_batch (draft_step = draft_step )
539554 if draft_step == 0 :
540555 self .rollback_rejected_tokens ()
541- self .build ()
556+ failed_requests = self .build ()
542557 if draft_step == self .max_num_draft_tokens - 1 :
543558 self .rollback_draft_tokens ()
544559 # Overwrite num_bitmask_tokens since the request might not be updated on CUDA stream yet.
@@ -550,3 +565,5 @@ def execute_draft_batch(self,
550565 self .apply_bitmask (logits ,
551566 d2t = d2t ,
552567 num_bitmask_tokens = len (self .requests ))
568+
569+ return failed_requests
0 commit comments