Skip to content

Commit 87959ff

Browse files
committed
test: add unittest
Signed-off-by: ooooo <[email protected]>
1 parent 7111497 commit 87959ff

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def compute_trust_remote_code_from_model(cfg_model):
412412
cfg_model (ConfigNode): Model configuration.
413413
414414
Returns:
415-
Whether to trust remote code.
415+
bool: Whether to trust remote code.
416416
"""
417417
if hasattr(cfg_model, "trust_remote_code"):
418418
return getattr(cfg_model, "trust_remote_code")

tests/unit_tests/recipes/test_train_ft.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
build_dataloader,
2929
build_model_and_optimizer,
3030
build_validation_dataloader,
31+
compute_trust_remote_code_from_model,
3132
)
3233
from torch.utils.data import IterableDataset
3334

@@ -632,3 +633,36 @@ def _build_model_and_optimizer_stub(*args, **kwargs):
632633
(parts[0], "PipelineStage_0"),
633634
(parts[1], "PipelineStage_1"),
634635
]
636+
637+
638+
def test_compute_trust_remote_code_prefers_cfg_flag():
639+
cfg_model = ConfigNode({"trust_remote_code": False, "pretrained_model_name_or_path": "ignored"})
640+
641+
with patch("nemo_automodel.recipes.llm.train_ft.resolve_trust_remote_code") as mock_resolve:
642+
result = compute_trust_remote_code_from_model(cfg_model)
643+
644+
assert result is False
645+
mock_resolve.assert_not_called()
646+
647+
648+
def test_compute_trust_remote_code_prefers_nested_config():
649+
cfg_model = ConfigNode({"config": {"trust_remote_code": True}})
650+
651+
with patch("nemo_automodel.recipes.llm.train_ft.resolve_trust_remote_code") as mock_resolve:
652+
result = compute_trust_remote_code_from_model(cfg_model)
653+
654+
assert result is True
655+
mock_resolve.assert_not_called()
656+
657+
658+
def test_compute_trust_remote_code_falls_back_to_resolve():
659+
cfg_model = ConfigNode({"pretrained_model_name_or_path": "nvidia/foo"})
660+
661+
with patch(
662+
"nemo_automodel.recipes.llm.train_ft.resolve_trust_remote_code",
663+
return_value=True,
664+
) as mock_resolve:
665+
result = compute_trust_remote_code_from_model(cfg_model)
666+
667+
assert result is True
668+
mock_resolve.assert_called_once_with("nvidia/foo")

0 commit comments

Comments
 (0)