Skip to content

Commit ab1fa49

Browse files
Add attn_implementation option to mining recipe
Signed-off-by: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com>
1 parent bf08d6d commit ab1fa49

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

nemo_automodel/recipes/biencoder/mine_hard_negatives.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
# Model loading parameters (loaded directly, not from config)
7070
"model_name_or_path": None, # Required: path to model checkpoint
7171
"tokenizer_name_or_path": None, # Optional: defaults to model_name_or_path
72+
# Attention implementation for model loading
73+
"attn_implementation": None, # None = use model default; "sdpa", "flash_attention_2", "eager"
7274
}
7375

7476

@@ -175,6 +177,7 @@ def __init__(self, cfg):
175177
self.tokenizer_name_or_path = None
176178
self.add_bos_token = None
177179
self.add_eos_token = None
180+
self.attn_implementation = None
178181

179182
# Model and tokenizer (populated in setup)
180183
self.model = None
@@ -234,11 +237,15 @@ def setup(self):
234237
# Load model directly from checkpoint path
235238
# This loads the saved model without requiring architecture config
236239
logger.info(f"Loading biencoder model from {self.model_name_or_path}...")
240+
model_kwargs = {
241+
"use_liger_kernel": False, # Not needed for inference
242+
"use_sdpa_patching": True,
243+
}
244+
if self.attn_implementation is not None:
245+
model_kwargs["attn_implementation"] = self.attn_implementation
237246
self.model = NeMoAutoModelBiencoder.from_pretrained(
238247
self.model_name_or_path,
239-
# Use inference-appropriate settings
240-
use_liger_kernel=False, # Not needed for inference
241-
use_sdpa_patching=True,
248+
**model_kwargs,
242249
)
243250
self.model = self.model.to(self.dist_env.device)
244251
self.model.eval()
@@ -297,6 +304,9 @@ def _extract_mining_params(self):
297304
self.add_bos_token = self._get_mining_param("add_bos_token")
298305
self.add_eos_token = self._get_mining_param("add_eos_token")
299306

307+
# Attention implementation for model loading
308+
self.attn_implementation = self._get_mining_param("attn_implementation")
309+
300310
# Prefix and length parameters for embedding generation
301311
self.query_prefix = self._get_mining_param("query_prefix")
302312
self.passage_prefix = self._get_mining_param("passage_prefix")

0 commit comments

Comments
 (0)