Skip to content

Commit ab56f2f

Browse files
feat: add nsys model layer name scope and benchmark support (with nsys) in app (#951)
Signed-off-by: Zhiyu Li <[email protected]>
1 parent ce09c48 commit ab56f2f

File tree

5 files changed

+154
-8
lines changed

5 files changed

+154
-8
lines changed

nemo_automodel/_cli/app.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,24 @@
4242
# ├── ...
4343
# └── qwen2_5_vl_3b_rdr.yaml
4444

45+
COMMAND_ALIASES = {"finetune": "train_ft", "pretrain": "train_ft", "benchmark": "benchmark"}
46+
47+
48+
def get_recipe_script_path(command: str, domain: str, repo_root: str | Path) -> str:
49+
"""
50+
Get the script path for a given command and domain.
51+
52+
Args:
53+
command: The command name (e.g., 'finetune', 'benchmark', 'pretrain')
54+
domain: The domain (e.g., 'llm', 'vlm')
55+
repo_root: The repository root path
56+
57+
Returns:
58+
str: Full path to the recipe script
59+
"""
60+
recipe_name = COMMAND_ALIASES.get(command, command)
61+
return f"{repo_root}/nemo_automodel/recipes/{domain}/{recipe_name}.py"
62+
4563

4664
def load_function(file_path: str | Path, func_name: str):
4765
"""
@@ -135,16 +153,36 @@ def launch_with_slurm(args, job_conf_path, job_dir, slurm_config, extra_args=Non
135153
if slurm_config.get("job_name", "") == "":
136154
slurm_config["job_name"] = f"{args.domain}_{args.command}"
137155

156+
# Get the recipe script path
157+
script_path = get_recipe_script_path(args.command, args.domain, repo_root)
158+
159+
# Build nsys profile command if enabled
160+
if slurm_config.get("nsys_enabled", False):
161+
profile_cmd = (
162+
f"nsys profile -s none "
163+
f"--trace=cuda,cudnn,nvtx "
164+
f"--cudabacktrace=all "
165+
f"--cuda-graph-trace=node "
166+
f"--python-backtrace=cuda "
167+
f"--wait all "
168+
f"-o {job_dir}/automodel_profile_%p.nsys-rep "
169+
f"--force-overwrite true "
170+
f"--capture-range=cudaProfilerApi "
171+
f"--capture-range-end=stop "
172+
)
173+
else:
174+
profile_cmd = ""
175+
138176
# create the command
139177
command_parts = [
140178
f"PYTHONPATH={repo_root}:$PYTHONPATH",
141179
# Use torchrun to launch multiple processes instead
142-
"uv sync --inexact --frozen $(cat /opt/uv_args.txt) && uv run --no-sync torchrun ",
180+
f"uv sync --inexact --frozen $(cat /opt/uv_args.txt) && {profile_cmd}uv run --no-sync torchrun ",
143181
f"--nproc_per_node={slurm_config['ntasks_per_node']} ",
144182
f"--nnodes={slurm_config['nodes']} ",
145183
"--rdzv_backend=c10d ",
146184
f"--rdzv_endpoint=${{MASTER_ADDR}}:${{MASTER_PORT}}", # noqa: F541
147-
f"{repo_root}/examples/{args.domain}_{args.command}/{args.command}.py",
185+
script_path,
148186
"-c",
149187
f"{job_conf_path}",
150188
]
@@ -174,8 +212,8 @@ def build_parser() -> argparse.ArgumentParser:
174212
parser.add_argument(
175213
"command",
176214
metavar="<command>",
177-
choices=["finetune", "pretrain", "kd"],
178-
help="Command within the domain (e.g., finetune, pretrain, kd, etc)",
215+
choices=["finetune", "pretrain", "kd", "benchmark"],
216+
help="Command within the domain (e.g., finetune, pretrain, kd, benchmark, etc)",
179217
)
180218
parser.add_argument(
181219
"domain",
@@ -229,12 +267,9 @@ def run_interactive(args):
229267
from torch.distributed.run import determine_local_world_size, get_args_parser
230268
from torch.distributed.run import run as thrun
231269

232-
COMMAND_ALIASES = {"finetune": "train_ft", "pretrain": "train_ft"}
233-
# remap commands: finetune -> train_ft
234-
command = COMMAND_ALIASES.get(args.command, args.command)
235270
config_path = args.config.resolve()
236271
repo_root = get_repo_root()
237-
script_path = repo_root / "nemo_automodel" / "recipes" / args.domain / f"{command}.py"
272+
script_path = Path(get_recipe_script_path(args.command, args.domain, repo_root))
238273

239274
# launch job on this node
240275
num_devices = determine_local_world_size(nproc_per_node="gpu")
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from threading import local
17+
18+
import torch
19+
20+
# inspired by https://github.com/zasdfgbnm/autonvtx
21+
22+
# Thread-local storage to track active NVTX ranges and prevent recursion
23+
_thread_local = local()
24+
25+
26+
def _get_active_ranges():
27+
"""Get the set of currently active NVTX ranges for this thread."""
28+
if not hasattr(_thread_local, "active_ranges"):
29+
_thread_local.active_ranges = set()
30+
return _thread_local.active_ranges
31+
32+
33+
def _add_nvtx_hooks(model, name, add_backward_hooks=True):
34+
"""Add NVTX range hooks to a model's forward and optionally backward passes."""
35+
if hasattr(model, "_nvtx_patched"):
36+
return
37+
38+
def push_fwd(module, *args, **kwargs):
39+
if name in _get_active_ranges():
40+
module._nvtx_skipped = True
41+
return
42+
module._nvtx_skipped = False
43+
_get_active_ranges().add(name)
44+
torch.cuda.nvtx.range_push(name)
45+
46+
def pop_fwd(module, *args, **kwargs):
47+
if getattr(module, "_nvtx_skipped", False):
48+
return
49+
torch.cuda.nvtx.range_pop()
50+
_get_active_ranges().discard(name)
51+
52+
model.register_forward_pre_hook(push_fwd)
53+
model.register_forward_hook(pop_fwd)
54+
55+
if add_backward_hooks:
56+
57+
def push_bwd(module, grad_input):
58+
if name in _get_active_ranges():
59+
module._nvtx_skipped = True
60+
return
61+
module._nvtx_skipped = False
62+
_get_active_ranges().add(name)
63+
torch.cuda.nvtx.range_push(name)
64+
65+
def pop_bwd(module, grad_input, grad_output):
66+
if getattr(module, "_nvtx_skipped", False):
67+
return
68+
torch.cuda.nvtx.range_pop()
69+
_get_active_ranges().discard(name)
70+
71+
model.register_full_backward_pre_hook(push_bwd)
72+
model.register_full_backward_hook(pop_bwd)
73+
74+
model._nvtx_patched = True
75+
76+
77+
def patch(model, name=None, add_backward_hooks=True):
78+
"""
79+
Recursively patch a model with NVTX profiling annotations.
80+
81+
Prevents duplicate scopes when activation checkpointing reruns forward passes.
82+
"""
83+
if hasattr(model, "_nvtx_patched"):
84+
return model
85+
86+
name = type(model).__name__ if name is None else f"{name}: {type(model).__name__}"
87+
_add_nvtx_hooks(model, name, add_backward_hooks=add_backward_hooks)
88+
89+
# Recursively patch all children
90+
for child_name, child in model.named_children():
91+
patch(child, child_name, add_backward_hooks)
92+
93+
return model
94+
95+
96+
# Export the functions properly
97+
__all__ = ["patch"]

nemo_automodel/components/_peft/lora.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def patch_linear_module(
268268
lora_A_init_method="xavier",
269269
lora_dtype=None,
270270
use_triton=True,
271+
layer_name=None,
271272
):
272273
"""
273274
Monkey-patches a nn.Linear (orig_linear param) to be a LinearLoRA.
@@ -321,6 +322,8 @@ def patch_linear_module(
321322
orig_linear.super_fwd = orig_linear.forward
322323

323324
orig_linear.__class__ = new_cls
325+
if layer_name is not None:
326+
orig_linear._layer_name = layer_name
324327
return orig_linear
325328

326329

@@ -382,6 +385,7 @@ def apply_lora_to_linear_modules(
382385
lora_A_init_method=peft_config.lora_A_init,
383386
lora_dtype=lora_dtype,
384387
use_triton=peft_config.use_triton,
388+
layer_name=name,
385389
)
386390

387391
return num_modules_matched

nemo_automodel/components/launcher/slurm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class SlurmConfig:
7171
# User command
7272
command: str = field(default="", metadata=dict(help="Shell command(s) to run inside container"))
7373
chdir: str = field(default=None, metadata=dict(help="Working directory of the job"))
74+
nsys_enabled: bool = field(default=False, metadata=dict(help="Enable nsys profiling"))
7475

7576
def __post_init__(self):
7677
if isinstance(self.extra_mounts, list):

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,16 @@ def setup(self):
10061006
if isinstance(model, AutoPipeline):
10071007
self.model_parts = model.parts
10081008
self.pp = model
1009+
import nemo_automodel.autonvtx as autonvtx
1010+
1011+
# Patch each pipeline stage with NVTX profiling
1012+
for i, part in enumerate(self.model_parts):
1013+
autonvtx.patch(part, name=f"PipelineStage_{i}")
10091014
else:
1015+
import nemo_automodel.autonvtx as autonvtx
1016+
1017+
# Patch model with NVTX profiling
1018+
autonvtx.patch(model, name=model.__class__.__name__)
10101019
self.model_parts = [model]
10111020
self.pp = None
10121021

0 commit comments

Comments
 (0)