Skip to content

Commit d3d0f8f

Browse files
committed
propagate attn_implementation
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent e761815 commit d3d0f8f

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,9 @@ def _retry(**override):
426426
_download_model_weights(hf_config, pretrained_model_name_or_path)
427427
logger.info(f"Using custom model implementation for {architectures[0]}")
428428
kwargs.pop("trust_remote_code", None)
429+
# TODO: restore weights after initialization.
429430
with local_torch_dtype(torch_dtype, ModelRegistry.model_arch_name_to_cls[architectures[0]].__name__):
430-
return ModelRegistry.model_arch_name_to_cls[architectures[0]](
431-
hf_config,
432-
*model_args,
433-
**kwargs,
434-
)
431+
return ModelRegistry.model_arch_name_to_cls[architectures[0]](hf_config)
435432

436433
# 3. fallback to parent class
437434
model = None
@@ -564,7 +561,10 @@ def _retry(**override):
564561

565562
# handle model_id passed as config
566563
if isinstance(config, str):
567-
config = AutoConfig.from_pretrained(config, trust_remote_code=kwargs.get("trust_remote_code", False))
564+
config = AutoConfig.from_pretrained(
565+
config, trust_remote_code=kwargs.get("trust_remote_code", False),
566+
attn_implementation=attn_implementation,
567+
)
568568
# 1. if force_hf is True, we will use the parent class to load and return the model as is
569569
if force_hf:
570570
return cls._from_config_parent_class(
@@ -578,7 +578,8 @@ def _retry(**override):
578578
# 2. If we have a custom model implementation available, we prioritize that over HF
579579
architectures = get_architectures(config)
580580
if len(architectures) > 0 and architectures[0] in ModelRegistry.model_arch_name_to_cls:
581-
return ModelRegistry.model_arch_name_to_cls[architectures[0]](config, *model_args, **kwargs)
581+
with local_torch_dtype(torch_dtype, ModelRegistry.model_arch_name_to_cls[architectures[0]].__name__):
582+
return ModelRegistry.model_arch_name_to_cls[architectures[0]](config)
582583

583584
# 3. fallback to parent class
584585
model = None

0 commit comments

Comments
 (0)