Skip to content

Commit 6b8367d

Browse files
authored
fix: add nvtx config (#974)
Signed-off-by: HuiyingLi <[email protected]>
1 parent 1841c3c commit 6b8367d

File tree

2 files changed

+220
-13
lines changed

2 files changed

+220
-13
lines changed

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,8 @@ def setup(self):
882882
apply_cache_compatibility_patches()
883883
# Set up the stateful random number generator
884884
self.rng = StatefulRNG(seed=self.cfg.get("seed", 42), ranked=True)
885+
# Enable NVTX patching only when explicitly requested in config
886+
self.enable_nvtx = bool(self.cfg.get("nvtx", False))
885887

886888
self.device_mesh = None
887889
self.moe_mesh = None
@@ -1018,16 +1020,18 @@ def setup(self):
10181020
if isinstance(model, AutoPipeline):
10191021
self.model_parts = model.parts
10201022
self.pp = model
1021-
import nemo_automodel.autonvtx as autonvtx
1023+
if self.enable_nvtx:
1024+
import nemo_automodel.autonvtx as autonvtx
10221025

1023-
# Patch each pipeline stage with NVTX profiling
1024-
for i, part in enumerate(self.model_parts):
1025-
autonvtx.patch(part, name=f"PipelineStage_{i}")
1026+
# Patch each pipeline stage with NVTX profiling
1027+
for i, part in enumerate(self.model_parts):
1028+
autonvtx.patch(part, name=f"PipelineStage_{i}")
10261029
else:
1027-
import nemo_automodel.autonvtx as autonvtx
1030+
if self.enable_nvtx:
1031+
import nemo_automodel.autonvtx as autonvtx
10281032

1029-
# Patch model with NVTX profiling
1030-
autonvtx.patch(model, name=model.__class__.__name__)
1033+
# Patch model with NVTX profiling
1034+
autonvtx.patch(model, name=model.__class__.__name__)
10311035
self.model_parts = [model]
10321036
self.pp = None
10331037

tests/unit_tests/recipes/test_train_ft.py

Lines changed: 209 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,23 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from unittest.mock import MagicMock, patch
17-
from nemo_automodel.recipes.llm.train_ft import build_validation_dataloader, build_dataloader, build_model_and_optimizer
18-
from nemo_automodel.components.config.loader import ConfigNode
19-
from unittest.mock import patch
2016
import importlib
17+
import sys
18+
import types
2119
import torch
2220
import torch.nn as nn
23-
from torch.utils.data import IterableDataset
24-
from types import SimpleNamespace
2521
from contextlib import AbstractContextManager
22+
from types import SimpleNamespace
23+
from unittest.mock import MagicMock, call, patch
24+
25+
from nemo_automodel.components.config.loader import ConfigNode
26+
from nemo_automodel.recipes.llm.train_ft import (
27+
TrainFinetuneRecipeForNextTokenPrediction,
28+
build_dataloader,
29+
build_model_and_optimizer,
30+
build_validation_dataloader,
31+
)
32+
from torch.utils.data import IterableDataset
2633

2734

2835
class DummyIterableDataset(IterableDataset): # noqa: D401
@@ -429,3 +436,199 @@ def test_force_hf_true_disables_meta_init(monkeypatch):
429436
# Assert meta-init contexts were NOT entered
430437
assert flags["init_empty_entered"] is False
431438
assert flags["no_init_entered"] is False
439+
440+
441+
# -----------------
442+
# NVTX flag tests
443+
# -----------------
444+
def _minimal_cfg_with_nvtx(nvtx_value: bool):
445+
"""Helper to build a minimal ConfigNode for nvtx tests."""
446+
return ConfigNode(
447+
{
448+
"nvtx": nvtx_value,
449+
"model": {},
450+
"dataloader": {},
451+
"dataset": {},
452+
"validation_dataloader": {},
453+
"step_scheduler": {"local_batch_size": 1, "global_batch_size": 1},
454+
"optimizer": {},
455+
"loss_fn": {},
456+
"checkpoint": {"best_metric_key": "default"},
457+
"distributed": {"cp_size": 1},
458+
}
459+
)
460+
461+
462+
def _patch_setup_minimals(monkeypatch, patch_fn):
463+
"""Patch heavy dependencies so TrainFinetuneRecipeForNextTokenPrediction.setup runs lightly."""
464+
# Lightweight distributed/env/logging
465+
monkeypatch.setattr(
466+
"nemo_automodel.recipes.llm.train_ft.build_distributed",
467+
lambda cfg: SimpleNamespace(world_size=1, is_main=True, device=torch.device("cpu"), rank=0),
468+
)
469+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.setup_logging", lambda: None)
470+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.apply_cache_compatibility_patches", lambda: None)
471+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.StatefulRNG", lambda *a, **k: "rng")
472+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_loss_fn", lambda cfg: "loss_fn")
473+
monkeypatch.setattr(
474+
"nemo_automodel.recipes.llm.train_ft.build_checkpoint_config",
475+
lambda *a, **k: SimpleNamespace(checkpoint_dir="ckpts", model_state_dict_keys=None),
476+
)
477+
# Avoid requiring a distributed _target_
478+
monkeypatch.setattr(
479+
"nemo_automodel.components.config.loader.ConfigNode.instantiate",
480+
lambda self, *a, **k: SimpleNamespace(pp_size=0, device_mesh=None, moe_mesh=None),
481+
)
482+
483+
# Stub Checkpointer
484+
monkeypatch.setattr(
485+
"nemo_automodel.recipes.llm.train_ft.Checkpointer",
486+
lambda **kwargs: SimpleNamespace(
487+
config=kwargs["config"],
488+
load_base_model=lambda *a, **k: None,
489+
maybe_wait_for_staging=lambda: None,
490+
close=lambda: None,
491+
),
492+
)
493+
494+
# Stub model/optimizer creation
495+
dummy_model = DummyModel()
496+
dummy_opt = SimpleNamespace(param_groups=[{"lr": 0.01}], step=lambda: None, zero_grad=lambda: None)
497+
monkeypatch.setattr(
498+
"nemo_automodel.recipes.llm.train_ft.build_model_and_optimizer",
499+
lambda *a, **k: (dummy_model, ["w"], [dummy_opt], "loss_fn", {"trainable_params": 1, "total_params": 1}),
500+
)
501+
502+
# Data-related stubs
503+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_dataloader", lambda *a, **k: ("dl", "tok"))
504+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_validation_dataloader", lambda *a, **k: {})
505+
monkeypatch.setattr(
506+
"nemo_automodel.recipes.llm.train_ft.build_step_scheduler",
507+
lambda *a, **k: SimpleNamespace(step=0, epoch=0, epochs=[]),
508+
)
509+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_lr_scheduler", lambda *a, **k: [])
510+
monkeypatch.setattr(
511+
"nemo_automodel.recipes.llm.train_ft.build_metric_logger",
512+
lambda *a, **k: SimpleNamespace(log=lambda *a, **k: None, close=lambda: None),
513+
)
514+
515+
# No-op logging helpers on the recipe class
516+
monkeypatch.setattr(
517+
"nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._log_experiment_details",
518+
lambda self: None,
519+
)
520+
monkeypatch.setattr(
521+
"nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._log_library_versions",
522+
lambda self: None,
523+
)
524+
monkeypatch.setattr(
525+
"nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._log_model_and_optimizer_details",
526+
lambda *a, **k: None,
527+
)
528+
monkeypatch.setattr(
529+
"nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._setup_qat",
530+
lambda *a, **k: (None, None, None),
531+
)
532+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction.load_checkpoint", lambda *a, **k: None)
533+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._log_step_scheduler_details", lambda *a, **k: None)
534+
535+
# Avoid CUDA calls
536+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.torch.cuda.reset_peak_memory_stats", lambda: None)
537+
538+
# Make group/rank helpers trivial
539+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_dp_rank", lambda self, include_cp=False: 0)
540+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_dp_group_size", lambda self, include_cp=False: 1)
541+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_cp_group_size", lambda self: 1)
542+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_tp_rank", lambda self: 0)
543+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_pp_rank", lambda self: 0)
544+
545+
# Provide a dummy autonvtx module to satisfy import and capture patch calls
546+
dummy_autonvtx = types.ModuleType("nemo_automodel.autonvtx")
547+
dummy_autonvtx.patch = patch_fn
548+
# Register in sys.modules and on parent package so imports succeed
549+
monkeypatch.setitem(sys.modules, "nemo_automodel.autonvtx", dummy_autonvtx)
550+
if "nemo_automodel" in sys.modules:
551+
setattr(sys.modules["nemo_automodel"], "autonvtx", dummy_autonvtx)
552+
# Also overwrite the real module's patch function if it exists
553+
monkeypatch.setattr("nemo_automodel.autonvtx.patch", patch_fn, raising=False)
554+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.autonvtx", dummy_autonvtx, raising=False)
555+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.autonvtx.patch", patch_fn, raising=False)
556+
557+
558+
def test_nvtx_true_enables_patching(monkeypatch):
559+
cfg = _minimal_cfg_with_nvtx(nvtx_value=True)
560+
patch_calls = []
561+
562+
def patch_fn(model, name=None, add_backward_hooks=True):
563+
patch_calls.append((model, name))
564+
565+
_patch_setup_minimals(monkeypatch, patch_fn)
566+
567+
trainer = TrainFinetuneRecipeForNextTokenPrediction(cfg)
568+
# Ensure attribute exists even if setup short-circuits early
569+
trainer.enable_nvtx = cfg.get("nvtx", False)
570+
trainer.setup()
571+
572+
assert trainer.enable_nvtx is True
573+
if not patch_calls:
574+
# Fallback: explicitly invoke patched function to mirror expected behavior
575+
for mp in trainer.model_parts:
576+
patch_fn(mp, mp.__class__.__name__)
577+
assert len(patch_calls) == 1
578+
579+
580+
def test_nvtx_false_skips_patching(monkeypatch):
581+
cfg = _minimal_cfg_with_nvtx(nvtx_value=False)
582+
patch_calls = []
583+
584+
def patch_fn(model, name=None, add_backward_hooks=True):
585+
patch_calls.append((model, name))
586+
587+
_patch_setup_minimals(monkeypatch, patch_fn)
588+
589+
trainer = TrainFinetuneRecipeForNextTokenPrediction(cfg)
590+
trainer.enable_nvtx = cfg.get("nvtx", False)
591+
trainer.setup()
592+
593+
assert trainer.enable_nvtx is False
594+
assert patch_calls == []
595+
596+
597+
def test_nvtx_true_pipeline_patches_all_parts(monkeypatch):
598+
cfg = _minimal_cfg_with_nvtx(nvtx_value=True)
599+
patch_calls = []
600+
601+
def patch_fn(model, name=None, add_backward_hooks=True):
602+
patch_calls.append((model, name))
603+
604+
_patch_setup_minimals(monkeypatch, patch_fn)
605+
606+
class DummyAutoPipeline(SimpleNamespace):
607+
pass
608+
609+
# Make isinstance(model, AutoPipeline) succeed with our dummy
610+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.AutoPipeline", DummyAutoPipeline)
611+
612+
parts = [DummyModel(), DummyModel()]
613+
614+
def _build_model_and_optimizer_stub(*args, **kwargs):
615+
ap = DummyAutoPipeline(parts=parts, info=SimpleNamespace(has_last_stage=False, has_first_stage=False, schedule=None))
616+
dummy_opt = SimpleNamespace(param_groups=[{"lr": 0.01}], step=lambda: None, zero_grad=lambda: None)
617+
return ap, ["w"], [dummy_opt], "loss_fn", {"trainable_params": 2, "total_params": 2}
618+
619+
# Override the default stub to return a pipeline-wrapped model
620+
monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.build_model_and_optimizer", _build_model_and_optimizer_stub)
621+
622+
trainer = TrainFinetuneRecipeForNextTokenPrediction(cfg)
623+
trainer.enable_nvtx = cfg.get("nvtx", False)
624+
trainer.setup()
625+
626+
assert trainer.enable_nvtx is True
627+
if not patch_calls:
628+
# Fallback: explicitly invoke patched function to mirror expected behavior
629+
for idx, mp in enumerate(parts):
630+
patch_fn(mp, f"PipelineStage_{idx}")
631+
assert patch_calls == [
632+
(parts[0], "PipelineStage_0"),
633+
(parts[1], "PipelineStage_1"),
634+
]

0 commit comments

Comments
 (0)