Skip to content

Commit f37d114

Browse files
committed
add test
Signed-off-by: HuiyingLi <[email protected]>
1 parent 938967f commit f37d114

File tree

1 file changed

+43
-10
lines changed

1 file changed

+43
-10
lines changed

tests/unit_tests/recipes/test_train_ft.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,21 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from unittest.mock import MagicMock, patch
16+
import importlib
17+
import torch
18+
import torch.nn as nn
19+
from contextlib import AbstractContextManager
20+
from types import SimpleNamespace
21+
from unittest.mock import MagicMock, call, patch
22+
23+
from nemo_automodel.components.config.loader import ConfigNode
1724
from nemo_automodel.recipes.llm.train_ft import (
1825
TrainFinetuneRecipeForNextTokenPrediction,
19-
build_validation_dataloader,
2026
build_dataloader,
2127
build_model_and_optimizer,
28+
build_validation_dataloader,
2229
)
23-
from nemo_automodel.components.config.loader import ConfigNode
24-
from unittest.mock import patch
25-
import importlib
26-
import torch
27-
import torch.nn as nn
2830
from torch.utils.data import IterableDataset
29-
from types import SimpleNamespace
30-
from contextlib import AbstractContextManager
31-
from unittest.mock import MagicMock
3231

3332

3433
class DummyIterableDataset(IterableDataset): # noqa: D401
@@ -567,3 +566,37 @@ def test_nvtx_false_skips_patching(monkeypatch):
567566

568567
assert trainer.enable_nvtx is False
569568
patch_mock.assert_not_called()
569+
570+
571+
def test_nvtx_true_pipeline_patches_all_parts(monkeypatch):
572+
cfg = _minimal_cfg_with_nvtx(nvtx_value=True)
573+
patch_mock = MagicMock()
574+
_patch_setup_minimals(monkeypatch, patch_mock)
575+
576+
class DummyAutoPipeline(SimpleNamespace):
577+
pass
578+
579+
# Make isinstance(model, AutoPipeline) succeed with our dummy
580+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.AutoPipeline", DummyAutoPipeline)
581+
582+
parts = [DummyModel(), DummyModel()]
583+
584+
def _build_model_and_optimizer_stub(*args, **kwargs):
585+
ap = DummyAutoPipeline(parts=parts, info=SimpleNamespace(has_last_stage=False, has_first_stage=False, schedule=None))
586+
dummy_opt = SimpleNamespace(param_groups=[{"lr": 0.01}], step=lambda: None, zero_grad=lambda: None)
587+
return ap, ["w"], [dummy_opt], "loss_fn", {"trainable_params": 2, "total_params": 2}
588+
589+
# Override the default stub to return a pipeline-wrapped model
590+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_model_and_optimizer", _build_model_and_optimizer_stub)
591+
592+
trainer = TrainFinetuneRecipeForNextTokenPrediction(cfg)
593+
trainer.setup()
594+
595+
assert trainer.enable_nvtx is True
596+
patch_mock.assert_has_calls(
597+
[
598+
call(parts[0], name="PipelineStage_0"),
599+
call(parts[1], name="PipelineStage_1"),
600+
],
601+
any_order=False,
602+
)

0 commit comments

Comments
 (0)