1616 make_tensor_with_pad_align )
1717
1818_SAMPLING_EPS = 1e-5
19+ pin_memory = is_pin_memory_available ()
20+ is_hpu = current_platform .is_hpu ()
1921
2022
2123@dataclass
@@ -286,7 +288,7 @@ def _prepare_seq_groups(
286288
287289 if seq_group_metadata .is_prompt :
288290 if sampling_params .seed is not None :
289- if current_platform . is_hpu () :
291+ if is_hpu :
290292 import habana_frameworks .torch .hpu .random as htrandom
291293 generator = \
292294 htrandom .default_generators [
@@ -420,8 +422,10 @@ def from_sampling_metadata(
420422 vocab_size : int ,
421423 device : torch .device ,
422424 dtype : torch .dtype ,
425+ prompt_tokens_cache : torch .tensor ,
426+ past_seq_ids : set ,
423427 ) -> tuple ["SamplingTensors" , bool , bool , bool , Optional [int ],
424- Optional [float ]]:
428+ Optional [float ], Optional [ torch . tensor ] ]:
425429 prompt_tokens : list [array ] = []
426430 output_tokens : list [array ] = []
427431 top_ks : list [int ] = []
@@ -434,6 +438,7 @@ def from_sampling_metadata(
434438 do_penalties = False
435439 do_top_p_top_k = False
436440 do_min_p = False
441+ current_seq_ids = set ()
437442
438443 assert sampling_metadata .seq_groups is not None
439444 for seq_group in sampling_metadata .seq_groups :
@@ -508,6 +513,9 @@ def from_sampling_metadata(
508513 seq_data = seq_group .seq_data [seq_id ]
509514 prompt_tokens .append (seq_data .prompt_token_ids_array )
510515 output_tokens .append (seq_data .output_token_ids_array )
516+ current_seq_ids .update (seq_ids )
517+ if current_seq_ids != past_seq_ids :
518+ prompt_tokens_cache = None
511519
512520 top_k_scalar = top_ks [0 ] if do_top_p_top_k and all (
513521 k == top_ks [0 ] for k in top_ks ) else None
@@ -527,9 +535,10 @@ def from_sampling_metadata(
527535 vocab_size ,
528536 device ,
529537 dtype ,
538+ prompt_tokens_cache ,
530539 )
531540 return (sampling_tensors , do_penalties , do_top_p_top_k , do_min_p ,
532- top_k_scalar , top_p_scalar )
541+ top_k_scalar , top_p_scalar , current_seq_ids )
533542
534543 @classmethod
535544 def from_lists (
@@ -546,23 +555,28 @@ def from_lists(
546555 vocab_size : int ,
547556 device : torch .device ,
548557 dtype : torch .dtype ,
558+ prompt_tokens_cache : torch .tensor ,
549559 ) -> "SamplingTensors" :
550560 # Note that the performance will be very bad without
551561 # pinned memory.
552- pin_memory = is_pin_memory_available ()
553562
554563 do_penalties = prompt_tokens or output_tokens
555564
556565 if do_penalties :
557- if current_platform .is_hpu ():
558- prompt_t = make_tensor_with_pad_align (
559- prompt_tokens ,
560- vocab_size ,
561- device = "cpu" ,
562- dtype = torch .int64 ,
563- pin_memory = pin_memory ,
564- max_len_align = 1024 ,
565- )
566+ if is_hpu :
567+ if (prompt_tokens_cache is not None and
568+ prompt_tokens_cache .device == device ):
569+ # Reuse cached prompt_tokens already on HPU
570+ prompt_t = prompt_tokens_cache
571+ else :
572+ prompt_t = make_tensor_with_pad_align (
573+ prompt_tokens ,
574+ vocab_size ,
575+ device = "cpu" ,
576+ dtype = torch .int64 ,
577+ pin_memory = pin_memory ,
578+ max_len_align = 1024 ,
579+ )
566580 output_t = make_tensor_with_pad_align (
567581 output_tokens ,
568582 vocab_size ,
@@ -647,6 +661,6 @@ def from_lists(
647661 non_blocking = True ),
648662 repetition_penalties = repetition_penalties_t .to (device = device ,
649663 non_blocking = True ),
650- prompt_tokens = prompt_t .to (device = device , non_blocking = True ),
664+ prompt_tokens = prompt_t .to (device = device , non_blocking = True ) if prompt_t . device != device else prompt_t ,
651665 output_tokens = output_t .to (device = device , non_blocking = True ),
652666 )
0 commit comments