|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import logging |
16 | | -from unittest.mock import MagicMock, patch |
| 16 | +import importlib |
| 17 | +import torch |
| 18 | +import torch.nn as nn |
| 19 | +from contextlib import AbstractContextManager |
| 20 | +from types import SimpleNamespace |
| 21 | +from unittest.mock import MagicMock, call, patch |
| 22 | + |
| 23 | +from nemo_automodel.components.config.loader import ConfigNode |
17 | 24 | from nemo_automodel.recipes.llm.train_ft import ( |
18 | 25 | TrainFinetuneRecipeForNextTokenPrediction, |
19 | | - build_validation_dataloader, |
20 | 26 | build_dataloader, |
21 | 27 | build_model_and_optimizer, |
| 28 | + build_validation_dataloader, |
22 | 29 | ) |
23 | | -from nemo_automodel.components.config.loader import ConfigNode |
24 | | -from unittest.mock import patch |
25 | | -import importlib |
26 | | -import torch |
27 | | -import torch.nn as nn |
28 | 30 | from torch.utils.data import IterableDataset |
29 | | -from types import SimpleNamespace |
30 | | -from contextlib import AbstractContextManager |
31 | | -from unittest.mock import MagicMock |
32 | 31 |
|
33 | 32 |
|
34 | 33 | class DummyIterableDataset(IterableDataset): # noqa: D401 |
@@ -567,3 +566,37 @@ def test_nvtx_false_skips_patching(monkeypatch): |
567 | 566 |
|
568 | 567 | assert trainer.enable_nvtx is False |
569 | 568 | patch_mock.assert_not_called() |
| 569 | + |
| 570 | + |
| 571 | +def test_nvtx_true_pipeline_patches_all_parts(monkeypatch): |
| 572 | + cfg = _minimal_cfg_with_nvtx(nvtx_value=True) |
| 573 | + patch_mock = MagicMock() |
| 574 | + _patch_setup_minimals(monkeypatch, patch_mock) |
| 575 | + |
| 576 | + class DummyAutoPipeline(SimpleNamespace): |
| 577 | + pass |
| 578 | + |
| 579 | + # Make isinstance(model, AutoPipeline) succeed with our dummy |
| 580 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.AutoPipeline", DummyAutoPipeline) |
| 581 | + |
| 582 | + parts = [DummyModel(), DummyModel()] |
| 583 | + |
| 584 | + def _build_model_and_optimizer_stub(*args, **kwargs): |
| 585 | + ap = DummyAutoPipeline(parts=parts, info=SimpleNamespace(has_last_stage=False, has_first_stage=False, schedule=None)) |
| 586 | + dummy_opt = SimpleNamespace(param_groups=[{"lr": 0.01}], step=lambda: None, zero_grad=lambda: None) |
| 587 | + return ap, ["w"], [dummy_opt], "loss_fn", {"trainable_params": 2, "total_params": 2} |
| 588 | + |
| 589 | + # Override the default stub to return a pipeline-wrapped model |
| 590 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_model_and_optimizer", _build_model_and_optimizer_stub) |
| 591 | + |
| 592 | + trainer = TrainFinetuneRecipeForNextTokenPrediction(cfg) |
| 593 | + trainer.setup() |
| 594 | + |
| 595 | + assert trainer.enable_nvtx is True |
| 596 | + patch_mock.assert_has_calls( |
| 597 | + [ |
| 598 | + call(parts[0], name="PipelineStage_0"), |
| 599 | + call(parts[1], name="PipelineStage_1"), |
| 600 | + ], |
| 601 | + any_order=False, |
| 602 | + ) |
0 commit comments