|
69 | 69 | # Model loading parameters (loaded directly, not from config) |
70 | 70 | "model_name_or_path": None, # Required: path to model checkpoint |
71 | 71 | "tokenizer_name_or_path": None, # Optional: defaults to model_name_or_path |
| 72 | + # Attention implementation for model loading |
| 73 | + "attn_implementation": None, # None = use model default; "sdpa", "flash_attention_2", "eager" |
72 | 74 | } |
73 | 75 |
|
74 | 76 |
|
@@ -175,6 +177,7 @@ def __init__(self, cfg): |
175 | 177 | self.tokenizer_name_or_path = None |
176 | 178 | self.add_bos_token = None |
177 | 179 | self.add_eos_token = None |
| 180 | + self.attn_implementation = None |
178 | 181 |
|
179 | 182 | # Model and tokenizer (populated in setup) |
180 | 183 | self.model = None |
@@ -234,11 +237,15 @@ def setup(self): |
234 | 237 | # Load model directly from checkpoint path |
235 | 238 | # This loads the saved model without requiring architecture config |
236 | 239 | logger.info(f"Loading biencoder model from {self.model_name_or_path}...") |
| 240 | + model_kwargs = { |
| 241 | + "use_liger_kernel": False, # Not needed for inference |
| 242 | + "use_sdpa_patching": True, |
| 243 | + } |
| 244 | + if self.attn_implementation is not None: |
| 245 | + model_kwargs["attn_implementation"] = self.attn_implementation |
237 | 246 | self.model = NeMoAutoModelBiencoder.from_pretrained( |
238 | 247 | self.model_name_or_path, |
239 | | - # Use inference-appropriate settings |
240 | | - use_liger_kernel=False, # Not needed for inference |
241 | | - use_sdpa_patching=True, |
| 248 | + **model_kwargs, |
242 | 249 | ) |
243 | 250 | self.model = self.model.to(self.dist_env.device) |
244 | 251 | self.model.eval() |
@@ -297,6 +304,9 @@ def _extract_mining_params(self): |
297 | 304 | self.add_bos_token = self._get_mining_param("add_bos_token") |
298 | 305 | self.add_eos_token = self._get_mining_param("add_eos_token") |
299 | 306 |
|
| 307 | + # Attention implementation for model loading |
| 308 | + self.attn_implementation = self._get_mining_param("attn_implementation") |
| 309 | + |
300 | 310 | # Prefix and length parameters for embedding generation |
301 | 311 | self.query_prefix = self._get_mining_param("query_prefix") |
302 | 312 | self.passage_prefix = self._get_mining_param("passage_prefix") |
|
0 commit comments