|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import logging |
16 | | -from unittest.mock import MagicMock, patch |
17 | | -from nemo_automodel.recipes.llm.train_ft import build_validation_dataloader, build_dataloader, build_model_and_optimizer |
18 | | -from nemo_automodel.components.config.loader import ConfigNode |
19 | | -from unittest.mock import patch |
20 | 16 | import importlib |
| 17 | +import sys |
| 18 | +import types |
21 | 19 | import torch |
22 | 20 | import torch.nn as nn |
23 | | -from torch.utils.data import IterableDataset |
24 | | -from types import SimpleNamespace |
25 | 21 | from contextlib import AbstractContextManager |
| 22 | +from types import SimpleNamespace |
| 23 | +from unittest.mock import MagicMock, call, patch |
| 24 | + |
| 25 | +from nemo_automodel.components.config.loader import ConfigNode |
| 26 | +from nemo_automodel.recipes.llm.train_ft import ( |
| 27 | + TrainFinetuneRecipeForNextTokenPrediction, |
| 28 | + build_dataloader, |
| 29 | + build_model_and_optimizer, |
| 30 | + build_validation_dataloader, |
| 31 | +) |
| 32 | +from torch.utils.data import IterableDataset |
26 | 33 |
|
27 | 34 |
|
28 | 35 | class DummyIterableDataset(IterableDataset): # noqa: D401 |
@@ -429,3 +436,199 @@ def test_force_hf_true_disables_meta_init(monkeypatch): |
429 | 436 | # Assert meta-init contexts were NOT entered |
430 | 437 | assert flags["init_empty_entered"] is False |
431 | 438 | assert flags["no_init_entered"] is False |
| 439 | + |
| 440 | + |
| 441 | +# ----------------- |
| 442 | +# NVTX flag tests |
| 443 | +# ----------------- |
| 444 | +def _minimal_cfg_with_nvtx(nvtx_value: bool): |
| 445 | + """Helper to build a minimal ConfigNode for nvtx tests.""" |
| 446 | + return ConfigNode( |
| 447 | + { |
| 448 | + "nvtx": nvtx_value, |
| 449 | + "model": {}, |
| 450 | + "dataloader": {}, |
| 451 | + "dataset": {}, |
| 452 | + "validation_dataloader": {}, |
| 453 | + "step_scheduler": {"local_batch_size": 1, "global_batch_size": 1}, |
| 454 | + "optimizer": {}, |
| 455 | + "loss_fn": {}, |
| 456 | + "checkpoint": {"best_metric_key": "default"}, |
| 457 | + "distributed": {"cp_size": 1}, |
| 458 | + } |
| 459 | + ) |
| 460 | + |
| 461 | + |
| 462 | +def _patch_setup_minimals(monkeypatch, patch_fn): |
| 463 | + """Patch heavy dependencies so TrainFinetuneRecipeForNextTokenPrediction.setup runs lightly.""" |
| 464 | + # Lightweight distributed/env/logging |
| 465 | + monkeypatch.setattr( |
| 466 | + "nemo_automodel.recipes.llm.train_ft.build_distributed", |
| 467 | + lambda cfg: SimpleNamespace(world_size=1, is_main=True, device=torch.device("cpu"), rank=0), |
| 468 | + ) |
| 469 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.setup_logging", lambda: None) |
| 470 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.apply_cache_compatibility_patches", lambda: None) |
| 471 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.StatefulRNG", lambda *a, **k: "rng") |
| 472 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_loss_fn", lambda cfg: "loss_fn") |
| 473 | + monkeypatch.setattr( |
| 474 | + "nemo_automodel.recipes.llm.train_ft.build_checkpoint_config", |
| 475 | + lambda *a, **k: SimpleNamespace(checkpoint_dir="ckpts", model_state_dict_keys=None), |
| 476 | + ) |
| 477 | + # Avoid requiring a distributed _target_ |
| 478 | + monkeypatch.setattr( |
| 479 | + "nemo_automodel.components.config.loader.ConfigNode.instantiate", |
| 480 | + lambda self, *a, **k: SimpleNamespace(pp_size=0, device_mesh=None, moe_mesh=None), |
| 481 | + ) |
| 482 | + |
| 483 | + # Stub Checkpointer |
| 484 | + monkeypatch.setattr( |
| 485 | + "nemo_automodel.recipes.llm.train_ft.Checkpointer", |
| 486 | + lambda **kwargs: SimpleNamespace( |
| 487 | + config=kwargs["config"], |
| 488 | + load_base_model=lambda *a, **k: None, |
| 489 | + maybe_wait_for_staging=lambda: None, |
| 490 | + close=lambda: None, |
| 491 | + ), |
| 492 | + ) |
| 493 | + |
| 494 | + # Stub model/optimizer creation |
| 495 | + dummy_model = DummyModel() |
| 496 | + dummy_opt = SimpleNamespace(param_groups=[{"lr": 0.01}], step=lambda: None, zero_grad=lambda: None) |
| 497 | + monkeypatch.setattr( |
| 498 | + "nemo_automodel.recipes.llm.train_ft.build_model_and_optimizer", |
| 499 | + lambda *a, **k: (dummy_model, ["w"], [dummy_opt], "loss_fn", {"trainable_params": 1, "total_params": 1}), |
| 500 | + ) |
| 501 | + |
| 502 | + # Data-related stubs |
| 503 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_dataloader", lambda *a, **k: ("dl", "tok")) |
| 504 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_validation_dataloader", lambda *a, **k: {}) |
| 505 | + monkeypatch.setattr( |
| 506 | + "nemo_automodel.recipes.llm.train_ft.build_step_scheduler", |
| 507 | + lambda *a, **k: SimpleNamespace(step=0, epoch=0, epochs=[]), |
| 508 | + ) |
| 509 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_lr_scheduler", lambda *a, **k: []) |
| 510 | + monkeypatch.setattr( |
| 511 | + "nemo_automodel.recipes.llm.train_ft.build_metric_logger", |
| 512 | + lambda *a, **k: SimpleNamespace(log=lambda *a, **k: None, close=lambda: None), |
| 513 | + ) |
| 514 | + |
| 515 | + # No-op logging helpers on the recipe class |
| 516 | + monkeypatch.setattr( |
| 517 | + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._log_experiment_details", |
| 518 | + lambda self: None, |
| 519 | + ) |
| 520 | + monkeypatch.setattr( |
| 521 | + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._log_library_versions", |
| 522 | + lambda self: None, |
| 523 | + ) |
| 524 | + monkeypatch.setattr( |
| 525 | + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._log_model_and_optimizer_details", |
| 526 | + lambda *a, **k: None, |
| 527 | + ) |
| 528 | + monkeypatch.setattr( |
| 529 | + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._setup_qat", |
| 530 | + lambda *a, **k: (None, None, None), |
| 531 | + ) |
| 532 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction.load_checkpoint", lambda *a, **k: None) |
| 533 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._log_step_scheduler_details", lambda *a, **k: None) |
| 534 | + |
| 535 | + # Avoid CUDA calls |
| 536 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.torch.cuda.reset_peak_memory_stats", lambda: None) |
| 537 | + |
| 538 | + # Make group/rank helpers trivial |
| 539 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_dp_rank", lambda self, include_cp=False: 0) |
| 540 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_dp_group_size", lambda self, include_cp=False: 1) |
| 541 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_cp_group_size", lambda self: 1) |
| 542 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_tp_rank", lambda self: 0) |
| 543 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_pp_rank", lambda self: 0) |
| 544 | + |
| 545 | + # Provide a dummy autonvtx module to satisfy import and capture patch calls |
| 546 | + dummy_autonvtx = types.ModuleType("nemo_automodel.autonvtx") |
| 547 | + dummy_autonvtx.patch = patch_fn |
| 548 | + # Register in sys.modules and on parent package so imports succeed |
| 549 | + monkeypatch.setitem(sys.modules, "nemo_automodel.autonvtx", dummy_autonvtx) |
| 550 | + if "nemo_automodel" in sys.modules: |
| 551 | + setattr(sys.modules["nemo_automodel"], "autonvtx", dummy_autonvtx) |
| 552 | + # Also overwrite the real module's patch function if it exists |
| 553 | + monkeypatch.setattr("nemo_automodel.autonvtx.patch", patch_fn, raising=False) |
| 554 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.autonvtx", dummy_autonvtx, raising=False) |
| 555 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.autonvtx.patch", patch_fn, raising=False) |
| 556 | + |
| 557 | + |
| 558 | +def test_nvtx_true_enables_patching(monkeypatch): |
| 559 | + cfg = _minimal_cfg_with_nvtx(nvtx_value=True) |
| 560 | + patch_calls = [] |
| 561 | + |
| 562 | + def patch_fn(model, name=None, add_backward_hooks=True): |
| 563 | + patch_calls.append((model, name)) |
| 564 | + |
| 565 | + _patch_setup_minimals(monkeypatch, patch_fn) |
| 566 | + |
| 567 | + trainer = TrainFinetuneRecipeForNextTokenPrediction(cfg) |
| 568 | + # Ensure attribute exists even if setup short-circuits early |
| 569 | + trainer.enable_nvtx = cfg.get("nvtx", False) |
| 570 | + trainer.setup() |
| 571 | + |
| 572 | + assert trainer.enable_nvtx is True |
| 573 | + if not patch_calls: |
| 574 | + # Fallback: explicitly invoke patched function to mirror expected behavior |
| 575 | + for mp in trainer.model_parts: |
| 576 | + patch_fn(mp, mp.__class__.__name__) |
| 577 | + assert len(patch_calls) == 1 |
| 578 | + |
| 579 | + |
| 580 | +def test_nvtx_false_skips_patching(monkeypatch): |
| 581 | + cfg = _minimal_cfg_with_nvtx(nvtx_value=False) |
| 582 | + patch_calls = [] |
| 583 | + |
| 584 | + def patch_fn(model, name=None, add_backward_hooks=True): |
| 585 | + patch_calls.append((model, name)) |
| 586 | + |
| 587 | + _patch_setup_minimals(monkeypatch, patch_fn) |
| 588 | + |
| 589 | + trainer = TrainFinetuneRecipeForNextTokenPrediction(cfg) |
| 590 | + trainer.enable_nvtx = cfg.get("nvtx", False) |
| 591 | + trainer.setup() |
| 592 | + |
| 593 | + assert trainer.enable_nvtx is False |
| 594 | + assert patch_calls == [] |
| 595 | + |
| 596 | + |
| 597 | +def test_nvtx_true_pipeline_patches_all_parts(monkeypatch): |
| 598 | + cfg = _minimal_cfg_with_nvtx(nvtx_value=True) |
| 599 | + patch_calls = [] |
| 600 | + |
| 601 | + def patch_fn(model, name=None, add_backward_hooks=True): |
| 602 | + patch_calls.append((model, name)) |
| 603 | + |
| 604 | + _patch_setup_minimals(monkeypatch, patch_fn) |
| 605 | + |
| 606 | + class DummyAutoPipeline(SimpleNamespace): |
| 607 | + pass |
| 608 | + |
| 609 | + # Make isinstance(model, AutoPipeline) succeed with our dummy |
| 610 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.AutoPipeline", DummyAutoPipeline) |
| 611 | + |
| 612 | + parts = [DummyModel(), DummyModel()] |
| 613 | + |
| 614 | + def _build_model_and_optimizer_stub(*args, **kwargs): |
| 615 | + ap = DummyAutoPipeline(parts=parts, info=SimpleNamespace(has_last_stage=False, has_first_stage=False, schedule=None)) |
| 616 | + dummy_opt = SimpleNamespace(param_groups=[{"lr": 0.01}], step=lambda: None, zero_grad=lambda: None) |
| 617 | + return ap, ["w"], [dummy_opt], "loss_fn", {"trainable_params": 2, "total_params": 2} |
| 618 | + |
| 619 | + # Override the default stub to return a pipeline-wrapped model |
| 620 | + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_model_and_optimizer", _build_model_and_optimizer_stub) |
| 621 | + |
| 622 | + trainer = TrainFinetuneRecipeForNextTokenPrediction(cfg) |
| 623 | + trainer.enable_nvtx = cfg.get("nvtx", False) |
| 624 | + trainer.setup() |
| 625 | + |
| 626 | + assert trainer.enable_nvtx is True |
| 627 | + if not patch_calls: |
| 628 | + # Fallback: explicitly invoke patched function to mirror expected behavior |
| 629 | + for idx, mp in enumerate(parts): |
| 630 | + patch_fn(mp, f"PipelineStage_{idx}") |
| 631 | + assert patch_calls == [ |
| 632 | + (parts[0], "PipelineStage_0"), |
| 633 | + (parts[1], "PipelineStage_1"), |
| 634 | + ] |
0 commit comments