Skip to content

Commit 7111497

Browse files
committed
fix: respect cfg.model.trust_remote_code when building AutoConfig
Signed-off-by: ooooo <[email protected]>
1 parent 1939fc4 commit 7111497

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,24 @@ def build_loss_fn(cfg_loss):
405405
return cfg_loss.instantiate()
406406

407407

408-
def _build_tokenizer(cfg_model, cfg_ds):
409-
def compute_trust_remote_code():
410-
if hasattr(cfg_model, "trust_remote_code"):
411-
return getattr(cfg_model, "trust_remote_code")
412-
return resolve_trust_remote_code(_get_model_name(cfg_model))
408+
def compute_trust_remote_code_from_model(cfg_model):
409+
"""Compute the value of trust_remote_code based on the model configuration.
410+
411+
Args:
412+
cfg_model (ConfigNode): Model configuration.
413413
414-
trust_remote_code = compute_trust_remote_code()
414+
Returns:
415+
Whether to trust remote code.
416+
"""
417+
if hasattr(cfg_model, "trust_remote_code"):
418+
return getattr(cfg_model, "trust_remote_code")
419+
elif hasattr(cfg_model, "config") and hasattr(cfg_model.config, "trust_remote_code"):
420+
return getattr(cfg_model.config, "trust_remote_code")
421+
return resolve_trust_remote_code(_get_model_name(cfg_model))
422+
423+
424+
def _build_tokenizer(cfg_model, cfg_ds):
425+
trust_remote_code = compute_trust_remote_code_from_model(cfg_model)
415426
# if tokenizer is not provided, use the model config to instantiate it
416427
if "tokenizer" not in cfg_ds and _get_model_name(cfg_model) is not None:
417428
logging.info("Using model config to instantiate tokenizer")
@@ -592,7 +603,9 @@ def build_dataloader(
592603
if pp_enabled:
593604
from nemo_automodel.components.datasets.utils import add_causal_masks_to_batch
594605

595-
hf_model_config = AutoConfig.from_pretrained(_get_model_name(cfg_model))
606+
hf_model_config = AutoConfig.from_pretrained(
607+
_get_model_name(cfg_model), trust_remote_code=compute_trust_remote_code_from_model(cfg_model)
608+
)
596609

597610
if "collate_fn" in dl_kwargs:
598611
# Case 1: PP enabled + collate_fn exists -> chain them

0 commit comments

Comments
 (0)