44import  torch 
55import  torch .nn  as  nn 
66from  hivemind .utils .logging  import  get_logger 
7+ from  transformers .cache_utils  import  Cache 
78from  transformers .modeling_outputs  import  BaseModelOutputWithPastAndCrossAttentions 
89from  transformers .models .bloom  import  BloomForCausalLM , BloomForSequenceClassification , BloomModel , BloomPreTrainedModel 
910
@@ -92,12 +93,16 @@ def forward(
9293        if  use_prompts :
9394            hidden_states  =  hidden_states [:, self .pre_seq_len  :]
9495
96+         if  past_key_values  is  None :
97+             past_key_values  =  RemotePastKeyValues ()
98+         past_key_values .update_seen (hidden_states .size (1 ))
99+ 
95100        # Add last hidden state 
96101        hidden_states  =  self .ln_f (hidden_states )
97102        hidden_states  =  hidden_states .view (output_shape )
98103        return  BaseModelOutputWithPastAndCrossAttentions (
99104            last_hidden_state = hidden_states ,
100-             past_key_values = RemotePastKeyValues () ,
105+             past_key_values = past_key_values ,
101106            hidden_states = None ,
102107            attentions = None ,
103108        )
@@ -107,6 +112,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
107112    _keys_to_ignore_on_load_missing  =  DistributedBloomModel ._keys_to_ignore_on_load_missing 
108113    _keys_to_ignore_on_load_missing  +=  [r"^lm_head\." ]  # Missing since they are shared with input embeddings 
109114    _keys_to_ignore_on_load_unexpected  =  DistributedBloomModel ._keys_to_ignore_on_load_unexpected 
115+     _supports_cache_class  =  True 
110116
111117    config_class  =  DistributedBloomConfig 
112118
@@ -118,6 +124,58 @@ def __init__(self, config: DistributedBloomConfig):
118124        # Initialize weights and apply final processing 
119125        self .post_init ()
120126
127+     def  prepare_inputs_for_generation (
128+         self , input_ids , past_key_values = None , attention_mask = None , inputs_embeds = None , ** kwargs 
129+     ) ->  dict :
130+         # Omit tokens covered by past_key_values 
131+         if  past_key_values  is  not   None :
132+             if  isinstance (past_key_values , Cache ):
133+                 cache_length  =  past_key_values .get_seq_length ()
134+                 past_length  =  past_key_values .seen_tokens 
135+                 max_cache_length  =  past_key_values .get_max_length ()
136+             else :
137+                 cache_length  =  past_length  =  past_key_values [0 ][0 ].shape [2 ]
138+                 max_cache_length  =  None 
139+ 
140+             if  attention_mask  is  not   None  and  attention_mask .shape [1 ] >  input_ids .shape [1 ]:
141+                 input_ids  =  input_ids [:, - (attention_mask .shape [1 ] -  past_length ) :]
142+             elif  past_length  <  input_ids .shape [1 ]:
143+                 input_ids  =  input_ids [:, past_length :]
144+ 
145+             if  (
146+                 max_cache_length  is  not   None 
147+                 and  attention_mask  is  not   None 
148+                 and  cache_length  +  input_ids .shape [1 ] >  max_cache_length 
149+             ):
150+                 attention_mask  =  attention_mask [:, - max_cache_length :]
151+ 
152+         position_ids  =  kwargs .get ("position_ids" , None )
153+         if  attention_mask  is  not   None  and  position_ids  is  None :
154+             # create position_ids on the fly for batch generation 
155+             position_ids  =  attention_mask .long ().cumsum (- 1 ) -  1 
156+             position_ids .masked_fill_ (attention_mask  ==  0 , 1 )
157+             if  past_key_values :
158+                 position_ids  =  position_ids [:, - input_ids .shape [1 ] :]
159+ 
160+         # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 
161+         if  inputs_embeds  is  not   None  and  past_key_values  is  None :
162+             model_inputs  =  {"inputs_embeds" : inputs_embeds }
163+         else :
164+             model_inputs  =  {"input_ids" : input_ids }
165+ 
166+         model_inputs .update (
167+             {
168+                 "position_ids" : position_ids ,
169+                 "past_key_values" : past_key_values ,
170+                 "use_cache" : kwargs .get ("use_cache" ),
171+                 "attention_mask" : attention_mask ,
172+             }
173+         )
174+         return  model_inputs 
175+ 
176+     def  _temporary_reorder_cache (self , past_key_values , beam_idx ):
177+         return  past_key_values 
178+ 
121179    def  get_output_embeddings (self ):
122180        return  self .lm_head 
123181
0 commit comments