1717import warnings
1818from pathlib import Path
1919from tempfile import TemporaryDirectory
20- from typing import Dict , Optional , Tuple , Union
20+ from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple , Union
2121
2222import numpy as np
2323import openvino
2828from transformers import AutoModelForCausalLM , PretrainedConfig
2929from transformers .file_utils import add_start_docstrings , add_start_docstrings_to_model_forward
3030from transformers .generation import GenerationMixin
31+ from transformers .generation .configuration_utils import GenerationConfig , GenerationMode
32+ from transformers .generation .logits_process import LogitsProcessorList
33+ from transformers .generation .stopping_criteria import StoppingCriteriaList
34+ from transformers .generation .utils import GenerateOutput
3135from transformers .modeling_outputs import CausalLMOutputWithPast
3236
3337from optimum .utils .normalized_config import NormalizedConfigManager
4145from .utils import ONNX_WEIGHTS_NAME , OV_XML_FILE_NAME , STR_TO_OV_TYPE
4246
4347
48+ if TYPE_CHECKING :
49+ from transformers .modeling_utils import PreTrainedModel
50+ from transformers .streamers import BaseStreamer
51+
52+
4453logger = logging .getLogger (__name__ )
4554
4655core = Core ()
@@ -122,6 +131,8 @@ def __init__(
122131 self ._pkv_precision = Type .f32
123132 self .next_beam_idx = None
124133 self ._past_length = 0
134+ self ._first_iter_beam_search = False
135+ self ._second_iter_beam_search = False
125136 self .update_pkv_precision ()
126137 if self .is_dynamic :
127138 self .model = self ._reshape (self .model , - 1 , - 1 )
@@ -375,7 +386,11 @@ def prepare_inputs(
375386 inputs = {}
376387 if not self .stateful :
377388 if past_key_values is not None :
378- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
389+ if (
390+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
391+ or self .config .model_type == "falcon"
392+ and self .config .new_decoder_architecture
393+ ):
379394 if self ._pkv_precision == Type .bf16 :
380395 # numpy does not support bf16, pretending f16, should change to bf16
381396 past_key_values = tuple (
@@ -418,7 +433,6 @@ def prepare_inputs(
418433 self .next_beam_idx = np .arange (batch_size , dtype = int )
419434 self ._past_length = 0
420435 past_len = self ._get_past_length (past_key_values )
421-
422436 inputs ["input_ids" ] = np .array (input_ids )
423437 # Add the attention_mask inputs when needed
424438 if "attention_mask" in self .input_names or "position_ids" in self .input_names :
@@ -468,6 +482,8 @@ def forward(
468482 ** kwargs ,
469483 )
470484
485+ if self ._first_iter_beam_search :
486+ inputs , duplication_indices = self ._deduplicate_inputs (inputs )
471487 # Run inference
472488 self .request .start_async (inputs , share_inputs = True )
473489 self .request .wait ()
@@ -483,14 +499,22 @@ def forward(
483499 if self .use_cache :
484500 # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
485501 past_key_values = tuple (self .request .get_tensor (key ).data for key in self .key_value_output_names )
486- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
502+ if (
503+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
504+ or self .config .model_type == "falcon"
505+ and self .config .new_decoder_architecture
506+ ):
487507 # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
488508 past_key_values = tuple (
489509 past_key_values [i : i + self .num_pkv ] for i in range (0 , len (past_key_values ), self .num_pkv )
490510 )
491511 else :
492512 past_key_values = None
493513
514+ if self ._first_iter_beam_search :
515+ logits , past_key_values = self ._expand_outputs_for_generation (duplication_indices , logits , past_key_values )
516+ self ._first_iter_beam_search = False
517+
494518 return CausalLMOutputWithPast (logits = logits , past_key_values = past_key_values )
495519
496520 # Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
@@ -520,20 +544,124 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
520544 if past_key_values :
521545 position_ids = position_ids [:, - input_ids .shape [1 ] :]
522546
523- return {
547+ model_inputs = {
524548 "input_ids" : input_ids ,
525549 "past_key_values" : past_key_values ,
526550 "use_cache" : use_cache ,
527551 "position_ids" : position_ids ,
528552 "attention_mask" : attention_mask ,
529553 }
530554
555+ return model_inputs
556+
557+ def _expand_outputs_for_generation (self , indicies , logits : torch .Tensor , past_key_values : Tuple ):
558+ batch_size = logits .shape [0 ]
559+ if indicies .shape [0 ] != 1 :
560+ logits = logits [indicies ]
561+ if past_key_values and not self .stateful :
562+ if (
563+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
564+ or self .config .model_type == "falcon"
565+ and self .config .new_decoder_architecture
566+ ):
567+ past_key_values = tuple (
568+ tuple (
569+ past_state [indicies ]
570+ if not self .config .model_type == "chatglm"
571+ else past_state [:, indicies , ...]
572+ for past_state in layer_past
573+ )
574+ for layer_past in past_key_values
575+ )
576+ else :
577+ past_key_values = tuple ([past_state [indicies ] for past_state in past_key_values ])
578+ if self .stateful :
579+ self .next_beam_idx = (
580+ self .next_beam_idx [indicies ]
581+ if self .next_beam_idx is not None
582+ else np .arange (batch_size , dtype = int )[indicies ]
583+ )
584+ self ._second_iter_beam_search = True
585+ return logits , past_key_values
586+
587+ def _deduplicate_inputs (self , model_inputs : Dict ):
588+ input_ids = model_inputs ["input_ids" ]
589+ upd_model_inputs = {}
590+ unique_input_ids , indicies , reverse_indicies = np .unique (
591+ input_ids , axis = 0 , return_index = True , return_inverse = True
592+ )
593+ for input_name , input_tensor in model_inputs .items ():
594+ if input_name not in ["input_ids" , "beam_idx" ]:
595+ if not isinstance (input_tensor , Tensor ):
596+ upd_model_inputs [input_name ] = input_tensor [indicies ]
597+ else :
598+ shape = input_tensor .shape
599+ dtype = input_tensor .element_type
600+ upd_batch_size = indicies .shape [0 ]
601+ if self .config .model_type == "bloom" :
602+ upd_batch_size *= self .config .num_attention_heads
603+ shape [0 if not self .config .model_type == "chatglm" else 1 ] = upd_batch_size
604+ upd_model_inputs [input_name ] = Tensor (dtype , shape )
605+ upd_model_inputs ["input_ids" ] = unique_input_ids
606+ if "beam_idx" in model_inputs :
607+ beam_range = (
608+ unique_input_ids .shape [0 ]
609+ if self .config .model_type != "bloom"
610+ else unique_input_ids .shape [0 ] * self .config .num_attention_heads
611+ )
612+ beam_idx = np .arange (beam_range , dtype = int )
613+ upd_model_inputs ["beam_idx" ] = beam_idx
614+ return upd_model_inputs , reverse_indicies
615+
616+ @torch .no_grad ()
617+ def generate (
618+ self ,
619+ inputs : Optional [torch .Tensor ] = None ,
620+ generation_config : Optional [GenerationConfig ] = None ,
621+ logits_processor : Optional [LogitsProcessorList ] = None ,
622+ stopping_criteria : Optional [StoppingCriteriaList ] = None ,
623+ prefix_allowed_tokens_fn : Optional [Callable [[int , torch .Tensor ], List [int ]]] = None ,
624+ synced_gpus : Optional [bool ] = None ,
625+ assistant_model : Optional ["PreTrainedModel" ] = None ,
626+ streamer : Optional ["BaseStreamer" ] = None ,
627+ negative_prompt_ids : Optional [torch .Tensor ] = None ,
628+ negative_prompt_attention_mask : Optional [torch .Tensor ] = None ,
629+ ** kwargs ,
630+ ) -> Union [GenerateOutput , torch .LongTensor ]:
631+ _generation_config , _ = self ._prepare_generation_config (generation_config , ** kwargs )
632+ generation_mode = _generation_config .get_generation_mode (assistant_model )
633+
634+ is_beam_search = generation_mode in [
635+ GenerationMode .BEAM_SEARCH ,
636+ GenerationMode .BEAM_SAMPLE ,
637+ GenerationMode .GROUP_BEAM_SEARCH ,
638+ GenerationMode .CONSTRAINED_BEAM_SEARCH ,
639+ ]
640+ if is_beam_search :
641+ self ._first_iter_beam_search = True
642+ result = super ().generate (
643+ inputs ,
644+ generation_config ,
645+ logits_processor ,
646+ stopping_criteria ,
647+ prefix_allowed_tokens_fn ,
648+ synced_gpus ,
649+ assistant_model ,
650+ streamer ,
651+ negative_prompt_ids ,
652+ negative_prompt_attention_mask ,
653+ ** kwargs ,
654+ )
655+ return result
656+
531657 def _get_past_length (self , past_key_values = None ):
532658 if past_key_values is None :
533659 return 0
534660 if self .stateful :
535661 return self ._past_length
536- if self .config .model_type in MULTI_QUERY_ATTN_MODELS :
662+ if self .config .model_type in MULTI_QUERY_ATTN_MODELS and not (
663+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
664+ ):
537665 return past_key_values [0 ].shape [- 2 ]
538666 seq_length_dim = - 2
539667 if self .config .model_type == "chatglm" :
@@ -558,12 +686,20 @@ def _reorder_cache(
558686 if self .stateful :
559687 # TODO: Apply it differently based on model type
560688 # TODO: At least for bloom we need to replicate values for each attention head
561- self .next_beam_idx = np .array (beam_idx ) # save beam_idx to be used as an input in the next iteration
689+ self .next_beam_idx = (
690+ np .array (beam_idx ) if not self ._second_iter_beam_search else self .next_beam_idx
691+ ) # save beam_idx to be used as an input in the next iteration
692+ self ._second_iter_beam_search = False
562693 return past_key_values
563694 else :
564- return tuple (
565- tuple (np .take (past_state , beam_idx , 0 ) for past_state in layer_past ) for layer_past in past_key_values
566- )
695+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS and not (
696+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
697+ ):
698+ return tuple (
699+ tuple (np .take (past_state , beam_idx , 0 ) for past_state in layer_past )
700+ for layer_past in past_key_values
701+ )
702+ return tuple (np .take (past_state , beam_idx , 0 ) for past_state in past_key_values )
567703
568704 def can_generate (self ):
569705 """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
@@ -684,11 +820,12 @@ def _reorder_cache(
684820 This is required to match `past_key_values` with the correct beam_idx at every generation step.
685821 """
686822 if self .stateful :
687- beam_idx = np .array (beam_idx )
688823 batch_size = beam_idx .shape [0 ]
824+ beam_idx = np .array (beam_idx ) if not self ._second_iter_beam_search else self .next_beam_idx
689825 indices = np .array (range (batch_size * self .config .num_attention_heads ))
690826 indices = indices .reshape ([batch_size , self .config .num_attention_heads ])
691827 self .next_beam_idx = np .take (indices , beam_idx , 0 ).flatten ()
828+ self ._second_iter_beam_search = False
692829 return past_key_values
693830 else :
694831 standardized_past = self ._convert_to_standard_cache (past_key_values , batch_size = len (beam_idx ))
@@ -738,14 +875,34 @@ def _convert_to_standard_cache(
738875 for layer_past in past_key_value
739876 )
740877
878+ def _expand_outputs_for_generation (self , indicies , logits : torch .Tensor , past_key_values : Tuple ):
879+ batch_size = logits .shape [0 ]
880+ if indicies .shape [0 ] != 1 :
881+ logits = logits [indicies ]
882+ if past_key_values and not self .stateful :
883+ pkv_standard = self ._convert_to_standard_cache (past_key_values , batch_size )
884+ pkv = tuple (tuple (past_state [indicies ] for past_state in layer_past ) for layer_past in pkv_standard )
885+ past_key_values = self ._convert_to_bloom_cache (pkv )
886+
887+ if self .stateful :
888+ self .next_beam_idx = (
889+ self .next_beam_idx [indicies ]
890+ if self .next_beam_idx is not None
891+ else np .arange (batch_size , dtype = int )[indicies ]
892+ )
893+ self ._second_iter_beam_search = True
894+ return logits , past_key_values
895+
741896
742897class OVGPTBigCodeForCausalLM (OVModelForCausalLM ):
743898 # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
744899 def _reorder_cache (
745900 self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
746901 ) -> Tuple [Tuple [torch .Tensor ]]:
747902 if self .stateful :
748- self .next_beam_idx = np .array (beam_idx ) # save beam_idx to be used as an input in the next iteration
903+ # save beam_idx to be used as an input in the next iteration
904+ self .next_beam_idx = np .array (beam_idx ) if not self ._second_iter_beam_search else self .next_beam_idx
905+ self ._second_iter_beam_search = False
749906 return past_key_values
750907 else :
751908 return tuple (np .take (layer_past , beam_idx , 0 ) for layer_past in past_key_values )
0 commit comments