@@ -333,6 +333,7 @@ def __init__(
333333 self .method = specdecode_config .method
334334 self .model_config = specdecode_config .model_config
335335 self .cache_config = specdecode_config .cache_config
336+ self .num_spec_tokens = specdecode_config .num_speculative_tokens
336337 self .backend_config = backend_config
337338 self .device = device
338339
@@ -365,12 +366,17 @@ def build_graph_runner(self):
365366 def build_cache_engine (self , cache_stream : torch .cuda .Stream ):
366367 """Build cache engine."""
367368 if self .cache_config is not None :
368- self .cache_engine = CacheEngine (self .cache_config , self .model_config , rank = 0 , tp_rank = 0 , world_size = 1 , cache_stream = cache_stream )
369+ self .cache_engine = CacheEngine (self .cache_config ,
370+ self .model_config ,
371+ rank = 0 ,
372+ tp_rank = 0 ,
373+ world_size = 1 ,
374+ cache_stream = cache_stream )
369375
370376 def _forward_impl (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
371377 """Forward impl."""
372378 cache_swapping (self .cache_engine , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
373- output = self .proposer .propose (inputs , cache_engine = self .cache_engine , stream = self .stream )
379+ output = self .proposer ._forward (inputs , cache_engine = self .cache_engine , stream = self .stream )
374380 return output
375381
376382 async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
@@ -385,32 +391,122 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_ou
385391 await asyncio .sleep (0 )
386392 return output
387393
394+ async def _async_model_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
395+ """Model forward.
396+
397+ Args:
398+ inputs (Dict): The input data comes from _make_inputs.
399+ swap_in_map (SwapMap): Cache maps to swap in.
400+ swap_out_map (SwapMap): Cache maps to swap out.
401+ """
402+ max_prefill_token_num = self .cache_config .max_prefill_token_num
403+ swap_done = False
404+
405+ async def __forward (inputs ):
406+ """forward."""
407+ nonlocal swap_done , swap_in_map , swap_out_map
408+ if swap_done :
409+ return await self .async_forward (inputs , swap_in_map = dict (), swap_out_map = dict ())
410+ else :
411+ swap_done = True
412+ return await self .async_forward (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
413+
414+ async def __long_context_single_forward (new_inputs ):
415+ """One large sequence."""
416+ model_metas = new_inputs [0 ].model_metas
417+ for inp in new_inputs :
418+ inp .model_metas = model_metas
419+ output = await __forward (inp )
420+ model_metas = output .get ('model_metas' )
421+ return output
422+
423+ # make long context inputs
424+ is_long_context = inputs .input_ids .numel () > max_prefill_token_num and not inputs .is_decoding
425+
426+ if is_long_context :
427+ seq_len = inputs .seq_length
428+ batch_size = seq_len .size (0 )
429+ assert batch_size == 1 , 'Do not support batched long context.'
430+ inputs_li = inputs .split (max_prefill_token_num )
431+ outputs = await __long_context_single_forward (inputs_li )
432+ else :
433+ outputs = await __forward (inputs )
434+
435+ loop_count = self .num_spec_tokens - 1
436+ draft_token_ids , model_metas , target_hidden_states = self .proposer .get_outputs (outputs , inputs )
437+ draft_tokens_li = [draft_token_ids ]
438+ if loop_count > 0 :
439+ inputs = self .proposer .update_inputs_decoding (inputs , draft_token_ids .transpose (0 , 1 ), target_hidden_states ,
440+ model_metas )
441+ for loop_idx in range (loop_count ):
442+ outputs = await self .async_forward (inputs , swap_in_map = dict (), swap_out_map = dict ())
443+ draft_token_ids , model_metas , target_hidden_states = self .proposer .get_outputs (outputs , inputs )
444+ draft_tokens_li .append (draft_token_ids )
445+ if loop_idx < loop_count - 1 :
446+ inputs .update (draft_token_ids .transpose (0 , 1 ))
447+ inputs .model_metas = model_metas
448+ inputs .target_hidden_states = target_hidden_states
449+ if inputs .target_position_ids is not None :
450+ inputs .target_position_ids += 1
451+
452+ return torch .cat (draft_tokens_li , dim = - 1 )
453+
388454 async def async_model_forward (self ,
389455 model_inputs : ModelInputs ,
390456 spec_inputs : SpecDecodeInputs ,
391457 swap_in_map : SwapMap = dict (),
392458 swap_out_map : SwapMap = dict ()):
393459 """Draft model forward."""
394- if model_inputs .spec_metadata .draft_token_ids is not None :
395- spec_metadata = model_inputs .spec_metadata
396- output_token_ids , num_rejected_tokens , last_token_ids = self .rejection_sampler (
397- spec_inputs .target_logits , spec_metadata .draft_token_ids , spec_inputs .bonus_token_ids ,
398- spec_metadata .num_draft_tokens , spec_metadata .max_spec_len )
399- spec_inputs .num_rejected_tokens = num_rejected_tokens
400- spec_inputs .reject_sample_tokens = output_token_ids
401- spec_inputs .next_token_ids = last_token_ids
402- else :
403- spec_inputs .next_token_ids = spec_inputs .bonus_token_ids
404- output_token_ids = spec_inputs .next_token_ids .unsqueeze (- 1 )
460+ with torch .cuda .stream (self .stream ):
461+ if model_inputs .spec_metadata .draft_token_ids is not None :
462+ spec_metadata = model_inputs .spec_metadata
463+ output_token_ids , num_rejected_tokens , last_token_ids = self .rejection_sampler (
464+ spec_inputs .target_logits , spec_metadata .draft_token_ids , spec_inputs .bonus_token_ids ,
465+ spec_metadata .num_draft_tokens , spec_metadata .max_spec_len )
466+ spec_inputs .num_rejected_tokens = num_rejected_tokens
467+ spec_inputs .reject_sample_tokens = output_token_ids
468+ spec_inputs .next_token_ids = last_token_ids
469+ else :
470+ spec_inputs .next_token_ids = spec_inputs .bonus_token_ids
471+ output_token_ids = spec_inputs .next_token_ids .unsqueeze (- 1 )
405472
406- with record_function ('draft_prepare_inputs' ):
407- draft_model_inputs = self .proposer .prepare_inputs (model_inputs , spec_inputs )
473+ with record_function ('draft_prepare_inputs' ):
474+ draft_model_inputs = self .proposer .prepare_inputs (model_inputs , spec_inputs )
408475
409- new_draft_tokens = await self .async_forward (draft_model_inputs ,
410- swap_in_map = swap_in_map ,
411- swap_out_map = swap_out_map )
412- outputs = dict (output_token_ids = output_token_ids , spec_token_ids = new_draft_tokens )
413- return outputs
476+ new_draft_tokens = await self ._async_model_forward (draft_model_inputs ,
477+ swap_in_map = swap_in_map ,
478+ swap_out_map = swap_out_map )
479+ outputs = dict (output_token_ids = output_token_ids , spec_token_ids = new_draft_tokens )
480+ return outputs
481+
482+ def warmup (self , max_batches : int , target_model_config : ModelConfig ):
483+ """warmup."""
484+ target_hidden_size = self .proposer .get_target_hidden_size (target_model_config )
485+
486+ # warmup prefill
487+ inputs = ModelInputs .make_dummy (max_batches ,
488+ is_decoding = False ,
489+ device = 'cuda' ,
490+ vocab_size = self .model_config .vocab_size )
491+ inputs .target_hidden_states = torch .randn ((1 , max_batches , target_hidden_size ),
492+ dtype = self .model_config .dtype ,
493+ device = 'cuda' )
494+ self ._forward_impl (inputs , swap_in_map = dict (), swap_out_map = dict ())
495+
496+ capture_batch_sizes = self .proposer .model .get_capture_batch_sizes ()
497+ capture_batch_sizes = sorted (capture_batch_sizes , reverse = True )
498+
499+ for batch_size in capture_batch_sizes :
500+ inputs = ModelInputs .make_dummy (
501+ batch_size ,
502+ is_decoding = True ,
503+ device = 'cuda' ,
504+ vocab_size = self .model_config .vocab_size ,
505+ )
506+ inputs .target_hidden_states = torch .randn ((1 , batch_size , self .model_config .hidden_size ),
507+ dtype = self .model_config .dtype ,
508+ device = 'cuda' )
509+ self ._forward_impl (inputs , swap_in_map = dict (), swap_out_map = dict ())
414510
415511
416512class BaseModelAgent :
@@ -525,8 +621,9 @@ def get_free_mem(self):
525621 def warmup (self ):
526622 """warmup."""
527623 # TODO: disable for now, do not remove the comments.
528- with self .all_context ():
624+ with self .all_context (), torch . cuda . stream ( self . stream ), torch . inference_mode () :
529625 max_batches = self .cache_config .max_batches
626+
530627 num_tokens = max_batches
531628
532629 # warmup prefill
@@ -546,6 +643,10 @@ def warmup(self):
546643 vocab_size = self .model_config .vocab_size )
547644 self ._forward_impl (inputs , swap_in_map = dict (), swap_out_map = dict ())
548645
646+ # warmup draft model
647+ if self .spec_agent is not None :
648+ self .spec_agent .warmup (max_batches , self .model_config )
649+
549650 async def _async_model_forward (
550651 self ,
551652 inputs : ModelInputs ,
@@ -639,8 +740,8 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int):
639740 return tmp_out
640741
641742 # make long context inputs
642- is_long_context = inputs .input_ids .numel (
643- ) > max_prefill_token_num and not inputs . is_decoding and inputs . seq_length [ 0 ] == 1
743+ is_long_context = inputs .input_ids .numel () > max_prefill_token_num and not inputs . is_decoding
744+
644745 max_seqlen = 0
645746 if is_long_context :
646747 seq_len = inputs .seq_length
@@ -1165,7 +1266,7 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map:
11651266 inputs ,
11661267 self .cache_engine ,
11671268 stream = self .stream ,
1168- output_position_ids = self . spec_agent is not None )
1269+ output_position_ids = False )
11691270 return output
11701271
11711272 async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
@@ -1194,6 +1295,10 @@ def reset_graph_runner(self):
11941295 if hasattr (self .patched_model , 'reset' ):
11951296 self .patched_model .reset ()
11961297
1298+ if self .spec_agent is not None :
1299+ if self .spec_agent .proposer .model is not None and hasattr (self .spec_agent .proposer .model , 'reset' ):
1300+ self .spec_agent .proposer .model .reset ()
1301+
11971302 @torch .inference_mode ()
11981303 def update_params (self , request : UpdateParamsRequest ):
11991304 """Update params."""
0 commit comments