4242 LlamaForSequenceClassification ,
4343 LlamaModel ,
4444)
45- from transformers .utils import auto_docstring , can_return_tuple , logging
45+ from transformers .processing_utils import Unpack
46+ from transformers .utils import TransformersKwargs , auto_docstring , logging
47+ from transformers .utils .generic import check_model_inputs
48+
49+ try :
50+ from nemo_automodel .components .models .biencoder .state_dict_adapter import BiencoderStateDictAdapter
51+ except ImportError :
52+ BiencoderStateDictAdapter = object
4653
4754logger = logging .get_logger (__name__ )
4855
@@ -170,7 +177,7 @@ def _update_causal_mask(
170177 return attention_mask
171178 return None
172179
173- @can_return_tuple
180+ @check_model_inputs
174181 @auto_docstring
175182 def forward (
176183 self ,
@@ -179,40 +186,22 @@ def forward(
179186 position_ids : Optional [torch .LongTensor ] = None ,
180187 past_key_values : Optional [Cache ] = None ,
181188 inputs_embeds : Optional [torch .FloatTensor ] = None ,
182- use_cache : Optional [bool ] = None ,
183- output_attentions : Optional [bool ] = None ,
184- output_hidden_states : Optional [bool ] = None ,
185189 cache_position : Optional [torch .LongTensor ] = None ,
186- ** flash_attn_kwargs ,
190+ use_cache : Optional [bool ] = None ,
191+ ** kwargs : Unpack [TransformersKwargs ],
187192 ) -> BaseModelOutputWithPast :
188- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
189- output_hidden_states = (
190- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
191- )
192- use_cache = use_cache if use_cache is not None else self .config .use_cache
193-
194193 if (input_ids is None ) ^ (inputs_embeds is not None ):
195194 raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
196195
197- if self .gradient_checkpointing and self .training and use_cache :
198- logger .warning_once (
199- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
200- )
201- use_cache = False
202-
203- # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
204- if not isinstance (past_key_values , (type (None ), Cache )):
205- raise ValueError ("The `past_key_values` should be either a `Cache` object or `None`." )
206-
207196 if inputs_embeds is None :
208- inputs_embeds = self .embed_tokens (input_ids )
197+ inputs_embeds : torch . Tensor = self .embed_tokens (input_ids )
209198
210199 if use_cache and past_key_values is None :
211- past_key_values = DynamicCache ()
200+ past_key_values = DynamicCache (config = self . config )
212201
213202 if cache_position is None :
214203 past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
215- cache_position = torch .arange (
204+ cache_position : torch . Tensor = torch .arange (
216205 past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
217206 )
218207
@@ -222,46 +211,23 @@ def forward(
222211 causal_mask = self ._update_causal_mask (attention_mask = attention_mask )
223212
224213 hidden_states = inputs_embeds
225-
226- # create position embeddings to be shared across the decoder layers
227214 position_embeddings = self .rotary_emb (hidden_states , position_ids )
228215
229- # decoder layers
230- all_hidden_states = () if output_hidden_states else None
231- all_self_attns = () if output_attentions else None
232-
233216 for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
234- if output_hidden_states :
235- all_hidden_states += (hidden_states ,)
236-
237- layer_outputs = decoder_layer (
217+ hidden_states = decoder_layer (
238218 hidden_states ,
239219 attention_mask = causal_mask ,
240220 position_ids = position_ids ,
241- past_key_value = past_key_values ,
242- output_attentions = output_attentions ,
243- use_cache = use_cache ,
221+ past_key_values = past_key_values ,
244222 cache_position = cache_position ,
245223 position_embeddings = position_embeddings ,
246- ** flash_attn_kwargs ,
224+ ** kwargs ,
247225 )
248226
249- hidden_states = layer_outputs [0 ]
250-
251- if output_attentions :
252- all_self_attns += (layer_outputs [1 ],)
253-
254227 hidden_states = self .norm (hidden_states )
255-
256- # add hidden states from the last decoder layer
257- if output_hidden_states :
258- all_hidden_states += (hidden_states ,)
259-
260228 return BaseModelOutputWithPast (
261229 last_hidden_state = hidden_states ,
262- past_key_values = past_key_values if use_cache else None ,
263- hidden_states = all_hidden_states ,
264- attentions = all_self_attns ,
230+ past_key_values = past_key_values ,
265231 )
266232
267233
@@ -432,6 +398,15 @@ def __init__(
432398 self .config = self .lm_q .config
433399 self .trainer = None
434400
401+ # For HuggingFace consolidated checkpoint compatibility
402+ self .name_or_path = os .path .abspath (__file__ )
403+ self .state_dict_adapter = BiencoderStateDictAdapter ()
404+ self .config .architectures = ["LlamaBidirectionalModel" ]
405+ self .config .auto_map = {
406+ "AutoModel" : "llama_bidirectional_model.LlamaBidirectionalModel" ,
407+ "AutoConfig" : "llama_bidirectional_model.LlamaBidirectionalConfig" ,
408+ }
409+
435410 def forward (self , query : Dict [str , Tensor ] = None , passage : Dict [str , Tensor ] = None ):
436411 """Forward pass for training."""
437412
0 commit comments