Skip to content

Commit 5f27227

Browse files
authored
fix: Biencoder consolidated checkpoint and transformers issue (#936)
1 parent 99d214e commit 5f27227

File tree

6 files changed

+374
-70
lines changed

6 files changed

+374
-70
lines changed

nemo_automodel/components/checkpoint/addons.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,12 @@ def _maybe_save_custom_model_code(original_model_path: str | None, hf_metadata_d
253253
"""
254254
if original_model_path is None:
255255
return
256-
if not os.path.isdir(original_model_path):
256+
if os.path.isfile(original_model_path):
257+
pattern = original_model_path
258+
elif os.path.isdir(original_model_path):
259+
pattern = os.path.join(original_model_path, "**", "*.py")
260+
else:
257261
return
258-
pattern = os.path.join(original_model_path, "**", "*.py")
259262
for src_path in glob.glob(pattern, recursive=True):
260263
# Skip any .hidden paths
261264
rel_path = os.path.relpath(src_path, original_model_path)

nemo_automodel/components/models/biencoder/llama_bidirectional_model.py

Lines changed: 27 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,14 @@
4242
LlamaForSequenceClassification,
4343
LlamaModel,
4444
)
45-
from transformers.utils import auto_docstring, can_return_tuple, logging
45+
from transformers.processing_utils import Unpack
46+
from transformers.utils import TransformersKwargs, auto_docstring, logging
47+
from transformers.utils.generic import check_model_inputs
48+
49+
try:
50+
from nemo_automodel.components.models.biencoder.state_dict_adapter import BiencoderStateDictAdapter
51+
except ImportError:
52+
BiencoderStateDictAdapter = object
4653

4754
logger = logging.get_logger(__name__)
4855

@@ -170,7 +177,7 @@ def _update_causal_mask(
170177
return attention_mask
171178
return None
172179

173-
@can_return_tuple
180+
@check_model_inputs
174181
@auto_docstring
175182
def forward(
176183
self,
@@ -179,40 +186,22 @@ def forward(
179186
position_ids: Optional[torch.LongTensor] = None,
180187
past_key_values: Optional[Cache] = None,
181188
inputs_embeds: Optional[torch.FloatTensor] = None,
182-
use_cache: Optional[bool] = None,
183-
output_attentions: Optional[bool] = None,
184-
output_hidden_states: Optional[bool] = None,
185189
cache_position: Optional[torch.LongTensor] = None,
186-
**flash_attn_kwargs,
190+
use_cache: Optional[bool] = None,
191+
**kwargs: Unpack[TransformersKwargs],
187192
) -> BaseModelOutputWithPast:
188-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
189-
output_hidden_states = (
190-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
191-
)
192-
use_cache = use_cache if use_cache is not None else self.config.use_cache
193-
194193
if (input_ids is None) ^ (inputs_embeds is not None):
195194
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
196195

197-
if self.gradient_checkpointing and self.training and use_cache:
198-
logger.warning_once(
199-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
200-
)
201-
use_cache = False
202-
203-
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
204-
if not isinstance(past_key_values, (type(None), Cache)):
205-
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
206-
207196
if inputs_embeds is None:
208-
inputs_embeds = self.embed_tokens(input_ids)
197+
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
209198

210199
if use_cache and past_key_values is None:
211-
past_key_values = DynamicCache()
200+
past_key_values = DynamicCache(config=self.config)
212201

213202
if cache_position is None:
214203
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
215-
cache_position = torch.arange(
204+
cache_position: torch.Tensor = torch.arange(
216205
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
217206
)
218207

@@ -222,46 +211,23 @@ def forward(
222211
causal_mask = self._update_causal_mask(attention_mask=attention_mask)
223212

224213
hidden_states = inputs_embeds
225-
226-
# create position embeddings to be shared across the decoder layers
227214
position_embeddings = self.rotary_emb(hidden_states, position_ids)
228215

229-
# decoder layers
230-
all_hidden_states = () if output_hidden_states else None
231-
all_self_attns = () if output_attentions else None
232-
233216
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
234-
if output_hidden_states:
235-
all_hidden_states += (hidden_states,)
236-
237-
layer_outputs = decoder_layer(
217+
hidden_states = decoder_layer(
238218
hidden_states,
239219
attention_mask=causal_mask,
240220
position_ids=position_ids,
241-
past_key_value=past_key_values,
242-
output_attentions=output_attentions,
243-
use_cache=use_cache,
221+
past_key_values=past_key_values,
244222
cache_position=cache_position,
245223
position_embeddings=position_embeddings,
246-
**flash_attn_kwargs,
224+
**kwargs,
247225
)
248226

249-
hidden_states = layer_outputs[0]
250-
251-
if output_attentions:
252-
all_self_attns += (layer_outputs[1],)
253-
254227
hidden_states = self.norm(hidden_states)
255-
256-
# add hidden states from the last decoder layer
257-
if output_hidden_states:
258-
all_hidden_states += (hidden_states,)
259-
260228
return BaseModelOutputWithPast(
261229
last_hidden_state=hidden_states,
262-
past_key_values=past_key_values if use_cache else None,
263-
hidden_states=all_hidden_states,
264-
attentions=all_self_attns,
230+
past_key_values=past_key_values,
265231
)
266232

267233

@@ -432,6 +398,15 @@ def __init__(
432398
self.config = self.lm_q.config
433399
self.trainer = None
434400

401+
# For HuggingFace consolidated checkpoint compatibility
402+
self.name_or_path = os.path.abspath(__file__)
403+
self.state_dict_adapter = BiencoderStateDictAdapter()
404+
self.config.architectures = ["LlamaBidirectionalModel"]
405+
self.config.auto_map = {
406+
"AutoModel": "llama_bidirectional_model.LlamaBidirectionalModel",
407+
"AutoConfig": "llama_bidirectional_model.LlamaBidirectionalConfig",
408+
}
409+
435410
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None):
436411
"""Forward pass for training."""
437412

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Optional
16+
17+
from torch.distributed.device_mesh import DeviceMesh
18+
19+
from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter
20+
21+
22+
class BiencoderStateDictAdapter(StateDictAdapter):
23+
"""Adapter for converting BiencoderModel state dict to single encoder format.
24+
25+
This adapter extracts only the query encoder (lm_q) state dict and converts
26+
the "lm_q." prefix to "model." prefix, making it compatible with standard
27+
HuggingFace model format.
28+
"""
29+
30+
def __init__(self):
31+
"""Initialize the adapter."""
32+
self._uses_model_prefix = True
33+
34+
def to_hf(self, state_dict: dict[str, Any], **kwargs) -> dict[str, Any]:
35+
"""Convert from biencoder state dict to HuggingFace format.
36+
37+
Filters to only lm_q keys and converts "lm_q." prefix to "model." prefix.
38+
39+
Args:
40+
state_dict: The biencoder model state dict
41+
42+
Returns:
43+
The converted HuggingFace format state dict with only query encoder
44+
"""
45+
hf_state_dict = {}
46+
47+
for key, value in state_dict.items():
48+
if key.startswith("lm_q."):
49+
new_key = key.replace("lm_q.", "model.")
50+
hf_state_dict[new_key] = value
51+
52+
return hf_state_dict
53+
54+
def from_hf(
55+
self,
56+
hf_state_dict: dict[str, Any],
57+
device_mesh: Optional["DeviceMesh"] = None,
58+
**kwargs,
59+
) -> dict[str, Any]:
60+
"""Convert HuggingFace state dict to biencoder format.
61+
62+
Converts "model." prefix to "lm_q." prefix for loading into biencoder.
63+
64+
Args:
65+
hf_state_dict: The HuggingFace format state dict
66+
device_mesh: Optional device mesh (not used in this adapter)
67+
68+
Returns:
69+
The converted biencoder format state dict
70+
"""
71+
biencoder_state_dict = {}
72+
73+
for key, value in hf_state_dict.items():
74+
if key.startswith("model."):
75+
new_key_q = key.replace("model.", "lm_q.")
76+
biencoder_state_dict[new_key_q] = value
77+
new_key_p = key.replace("model.", "lm_p.")
78+
biencoder_state_dict[new_key_p] = value
79+
80+
return biencoder_state_dict
81+
82+
def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]:
83+
"""Convert a single tensor from biencoder format to HuggingFace format.
84+
85+
Args:
86+
fqn: Fully qualified name of the tensor in biencoder format
87+
tensor: The tensor to convert
88+
**kwargs: Additional arguments (unused)
89+
90+
Returns:
91+
List of (fqn, tensor) tuples in HuggingFace format.
92+
Returns empty list if tensor is not part of lm_q.
93+
"""
94+
if fqn.startswith("lm_q."):
95+
new_fqn = fqn.replace("lm_q.", "model.")
96+
return [(new_fqn, tensor)]
97+
98+
# Skip tensors that are not part of lm_q
99+
return []

nemo_automodel/recipes/biencoder/train_biencoder.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,8 @@ def setup(self):
433433
self.model_parts = [model]
434434
self.pp = None
435435

436+
self.checkpointer.config.model_state_dict_keys = ["model." + k for k in model.lm_q.state_dict().keys()]
437+
436438
# Build optimizer
437439
logger.info("Building optimizer...")
438440
trainable_params = list(filter(lambda x: x.requires_grad, self.model_parts[0].parameters()))
@@ -518,17 +520,24 @@ def run_train_validation_loop(self):
518520
# Log metrics
519521
self.log_train_metrics(train_log_data)
520522

521-
# Save checkpoint every ckpt_every_steps
522-
if self.step_scheduler.is_ckpt_step:
523-
self.save_checkpoint(epoch, self.step_scheduler.step)
524-
525523
# Run validation every val_every_steps
524+
val_loss = None
526525
if self.step_scheduler.is_val_step and self.val_dataloader is not None:
527526
val_log_data = self._run_validation_epoch(self.val_dataloader)
528527
self.log_val_metrics(val_log_data)
528+
val_loss = {"val_loss": val_log_data.metrics["val_loss"]}
529529
for mp in self.model_parts:
530530
mp.train()
531531

532+
# Save checkpoint every ckpt_every_steps
533+
if self.step_scheduler.is_ckpt_step:
534+
self.save_checkpoint(
535+
epoch,
536+
self.step_scheduler.step,
537+
train_loss=train_log_data.metrics["loss"],
538+
val_loss=val_loss,
539+
)
540+
532541
# Close JSONL loggers after training loop completes
533542
self.metric_logger_train.close()
534543
self.metric_logger_valid.close()
@@ -611,7 +620,7 @@ def _run_train_optim_step(self, batches, max_grad_norm=None):
611620
scheduler.step(1)
612621

613622
# Compute average loss across gradient accumulation and DP ranks
614-
reporting_loss = torch.sum(torch.stack(loss_buffer))
623+
reporting_loss = torch.mean(torch.stack(loss_buffer))
615624
if torch.distributed.is_initialized():
616625
reporting_loss = self._dp_allreduce(reporting_loss, include_cp=True)
617626
# Divide by DP group size to get average across all ranks

0 commit comments

Comments
 (0)