Skip to content

Commit 8fd8c96

Browse files
authored
feat: Fix and enhances for Nsight system profiling (#865)
Signed-off-by: Guyue Huang <[email protected]>
1 parent 2b87def commit 8fd8c96

File tree

7 files changed

+89
-4
lines changed

7 files changed

+89
-4
lines changed

docs/nsys-profiling.md

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ NeMo RL supports Nsight profiling for Ray workers through environment variable p
1717
Set the `NRL_NSYS_WORKER_PATTERNS` environment variable with a comma-separated list of patterns to match worker names:
1818

1919
```bash
20-
export NRL_NSYS_WORKER_PATTERNS="*policy*,*other-worker*"
20+
export NRL_NSYS_WORKER_PATTERNS="*policy*,*vllm*"
2121
```
2222

2323
Set the `NRL_NSYS_PROFILE_STEP_RANGE` environment variable to control which training steps the profiler captures. Its
@@ -40,7 +40,7 @@ export NRL_NSYS_PROFILE_STEP_RANGE=3:5
4040

4141
The supported worker types are:
4242
- **DTensorPolicyWorker**: Pattern matched against `"dtensor_policy_worker"`
43-
- **MegatronPolicyWorker**: Pattern matched against `"megatron_policy_worker"`
43+
- **VllmGenerationWorker**: Pattern matched against `"vllm_generation_worker"`
4444

4545
## Example Usage
4646

@@ -49,10 +49,16 @@ The supported worker types are:
4949
NRL_NSYS_PROFILE_STEP_RANGE=2:3 NRL_NSYS_WORKER_PATTERNS="*policy*" uv run examples/run_grpo_math.py grpo.max_num_steps=5
5050
```
5151

52+
### Profile Multiple Worker Types
53+
54+
```bash
55+
NRL_NSYS_PROFILE_STEP_RANGE=1:2 NRL_NSYS_WORKER_PATTERNS="*policy*,*vllm*" uv run examples/run_grpo_math.py grpo.max_num_steps=5
56+
```
57+
5258
### Profile Workers with Exact Names
5359

5460
```bash
55-
NRL_NSYS_PROFILE_STEP_RANGE=3:10 NRL_NSYS_WORKER_PATTERNS="dtensor_policy_worker" uv run examples/run_grpo_math.py grpo.max_num_steps=5
61+
NRL_NSYS_PROFILE_STEP_RANGE=3:10 NRL_NSYS_WORKER_PATTERNS="dtensor_policy_worker,vllm_generation_worker" uv run examples/run_grpo_math.py grpo.max_num_steps=5
5662
```
5763

5864
### Profile Megatron Workers
@@ -63,7 +69,7 @@ To profile a Megatron worker, you should set `LD_LIBRARY_PATH` as follows, other
6369

6470
```bash
6571
LD_LIBRARY_PATH="/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/cuda/lib64:/usr/local/cuda/lib:/usr/local/nvidia/lib64:/usr/local/nvidia/lib:/usr/lib/x86_64-linux-gnu" \
66-
NRL_NSYS_PROFILE_STEP_RANGE=2:3 NRL_NSYS_WORKER_PATTERNS="megatron_policy_worker" uv run examples/run_grpo_math.py --config examples/configs/grpo_math_1B_megatron.yaml grpo.max_num_steps=5
72+
NRL_NSYS_PROFILE_STEP_RANGE=2:3 NRL_NSYS_WORKER_PATTERNS="megatron_policy_worker,vllm_generation_worker" uv run examples/run_grpo_math.py --config examples/configs/grpo_math_1B_megatron.yaml grpo.max_num_steps=5
6773
```
6874

6975
## Profile Output
@@ -78,7 +84,10 @@ When profiling is enabled, it generates the following logs and files:
7884
2. **Profile Files**: Each profiled worker generates a `.nsys-rep` file with naming pattern:
7985
```
8086
dtensor_policy_worker_<NRL_NSYS_PROFILE_STEP_RANGE>_<PID>.nsys-rep
87+
vllm_generation_worker_<NRL_NSYS_PROFILE_STEP_RANGE>_<PID>.nsys-rep
88+
worker_process_<PID>.nsys-rep
8189
```
90+
If you are not using model parallelism in Vllm, you should directly refer to `vllm_generation_worker_<NRL_NSYS_PROFILE_STEP_RANGE>_<PID>.nsys-rep` for nsight reports; If you are using model parallelism, the `vllm_generation_worker_<NRL_NSYS_PROFILE_STEP_RANGE>_<PID>.nsys-rep` will be empty, and the `worker_process_<PID>.nsys-rep` are nsight profiles from vllm's ray distributed executors (refer to https://github.com/vllm-project/vllm/blob/7e3a8dc90670fd312ce1e0d4eba9bf11c571e3ad/vllm/executor/ray_distributed_executor.py#L136 for more information).
8291

8392
3. **File Location**: Profile files are saved in `/tmp/ray/session*/logs/nsight/` directory on each worker node. Ensure you check both `ls /tmp/ray/session_[0-9]*/logs/nsight` and `ls /tmp/ray/session_latest/logs/nsight` for the profiles, since the "latest" pointer may be stale.
8493

nemo_rl/distributed/worker_group_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def get_nsight_config_if_pattern_matches(worker_name: str) -> dict[str, Any]:
5757
# Profile will only start/stop when torch.cuda.profiler.start()/stop() is called
5858
"capture-range": "cudaProfilerApi",
5959
"capture-range-end": "stop",
60+
"cuda-graph-trace": "node",
6061
}
6162
}
6263

nemo_rl/models/generation/vllm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from nemo_rl.models.huggingface.common import ModelFlag
5555
from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled
56+
from nemo_rl.utils.nsys import wrap_with_nvtx_name
5657

5758

5859
class VllmSpecificArgs(TypedDict):
@@ -323,6 +324,18 @@ def _patch_vllm_init_workers_ray():
323324
if ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(self.model_name):
324325
load_format = "auto"
325326

327+
if (
328+
len(get_nsight_config_if_pattern_matches("vllm_generation_worker")) > 0
329+
and vllm_kwargs["distributed_executor_backend"] == "ray"
330+
):
331+
logger.warning(
332+
"Nsight profiling is enabled for vllm generation worker through the vllm ray distributed executor. "
333+
"The nsight command-line args and output file names are automatically picked by the ray distributed "
334+
"executor. Refer to https://github.com/vllm-project/vllm/blob/7e3a8dc90670fd312ce1e0d4eba9bf11c571e3ad/vllm/executor/ray_distributed_executor.py#L136 "
335+
"for more information."
336+
)
337+
vllm_kwargs["ray_workers_use_nsight"] = True
338+
326339
llm_kwargs = dict(
327340
model=self.model_name,
328341
load_format=load_format,
@@ -436,6 +449,7 @@ def _build_sampling_params(
436449
include_stop_str_in_output=True,
437450
)
438451

452+
@wrap_with_nvtx_name("vllm_genertion_worker/generate")
439453
def generate(
440454
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
441455
) -> BatchedDataDict[GenerationOutputSpec]:
@@ -799,6 +813,7 @@ async def process_single_sample(sample_idx):
799813
await asyncio.gather(*sample_tasks, return_exceptions=True)
800814
raise e
801815

816+
@wrap_with_nvtx_name("vllm_genertion_worker/generate_text")
802817
def generate_text(
803818
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
804819
) -> BatchedDataDict[GenerationOutputSpec]:
@@ -1033,6 +1048,7 @@ async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> Non
10331048
"""Async version of prepare_refit_info."""
10341049
await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,))
10351050

1051+
@wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_ipc_handles")
10361052
def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
10371053
"""Update weights from IPC handles by delegating to the vLLM Worker implementation.
10381054
@@ -1144,6 +1160,7 @@ async def update_weights_from_ipc_handles_async(
11441160
traceback.print_exc()
11451161
return False
11461162

1163+
@wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_collective")
11471164
def update_weights_from_collective(self) -> bool:
11481165
"""Update the model weights from collective communication."""
11491166
try:
@@ -1317,10 +1334,14 @@ async def wake_up_async(self, **kwargs):
13171334
def start_gpu_profiling(self) -> None:
13181335
"""Start GPU profiling."""
13191336
torch.cuda.profiler.start()
1337+
if self.llm is not None:
1338+
self.llm.collective_rpc("start_gpu_profiling", args=tuple())
13201339

13211340
def stop_gpu_profiling(self) -> None:
13221341
"""Stop GPU profiling."""
13231342
torch.cuda.profiler.stop()
1343+
if self.llm is not None:
1344+
self.llm.collective_rpc("stop_gpu_profiling", args=tuple())
13241345

13251346

13261347
class VllmGeneration(GenerationInterface):

nemo_rl/models/generation/vllm_backend.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import torch
1818
from torch.multiprocessing.reductions import rebuild_cuda_tensor
1919

20+
from nemo_rl.utils.nsys import wrap_with_nvtx_name
21+
2022
try:
2123
import vllm # noqa: F401
2224
except ImportError:
@@ -66,6 +68,9 @@ def prepare_refit_info(
6668
"""
6769
self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored
6870

71+
@wrap_with_nvtx_name(
72+
"vllm_internal_worker_extension/update_weights_from_global_ipc_handles"
73+
)
6974
def update_weights_from_global_ipc_handles(self, global_device_ipc_handles):
7075
"""Update weights from global IPC handles.
7176
@@ -79,6 +84,9 @@ def update_weights_from_global_ipc_handles(self, global_device_ipc_handles):
7984
local_device_ipc_handles = global_device_ipc_handles[device_uuid]
8085
return self.update_weights_from_local_ipc_handles(local_device_ipc_handles)
8186

87+
@wrap_with_nvtx_name(
88+
"vllm_internal_worker_extension/update_weights_from_local_ipc_handles"
89+
)
8290
def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
8391
"""Update weights from local IPC handles.
8492
@@ -155,6 +163,9 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
155163
)
156164
return False
157165

166+
@wrap_with_nvtx_name(
167+
"vllm_internal_worker_extension/update_weights_from_collective"
168+
)
158169
def update_weights_from_collective(self) -> bool:
159170
"""Update the model weights from collective communication."""
160171
assert self.state_dict_info is not None, (
@@ -174,3 +185,11 @@ def update_weights_from_collective(self) -> bool:
174185
return False
175186

176187
return True
188+
189+
def start_gpu_profiling(self) -> None:
190+
"""Start GPU profiling."""
191+
torch.cuda.profiler.start()
192+
193+
def stop_gpu_profiling(self) -> None:
194+
"""Stop GPU profiling."""
195+
torch.cuda.profiler.stop()

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
load_checkpoint,
7878
save_checkpoint,
7979
)
80+
from nemo_rl.utils.nsys import wrap_with_nvtx_name
8081

8182

8283
@contextmanager
@@ -513,6 +514,7 @@ def get_gpu_info(self) -> dict[str, Any]:
513514
"""Return information about the GPU being used by this worker."""
514515
return get_gpu_info(self.model)
515516

517+
@wrap_with_nvtx_name("dtensor_policy_worker/train")
516518
def train(
517519
self,
518520
data: BatchedDataDict[Any],
@@ -855,6 +857,7 @@ def train(
855857

856858
return metrics
857859

860+
@wrap_with_nvtx_name("dtensor_policy_worker/get_logprobs")
858861
def get_logprobs(
859862
self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
860863
) -> BatchedDataDict[LogprobOutputSpec]:
@@ -1137,6 +1140,7 @@ def use_reference_model(self) -> Generator[None, None, None]:
11371140
val = to_local_if_dtensor(v)
11381141
val.copy_(curr_state_dict[k])
11391142

1143+
@wrap_with_nvtx_name("dtensor_policy_worker/get_reference_policy_logprobs")
11401144
def get_reference_policy_logprobs(
11411145
self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
11421146
) -> BatchedDataDict[ReferenceLogprobOutputSpec]:
@@ -1234,6 +1238,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
12341238
return self.refit_param_info, total_available_bytes
12351239

12361240
@torch.no_grad()
1241+
@wrap_with_nvtx_name("dtensor_policy_worker/get_weights_ipc_handles")
12371242
def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]:
12381243
assert self._held_sharded_state_dict_reference is not None, (
12391244
"prepare_weights_for_ipc must be called before get_weights_ipc_handles"
@@ -1296,6 +1301,7 @@ def broadcast_weights_for_collective(self) -> None:
12961301
if self.cpu_offload:
12971302
self.model = self.move_to_cpu(self.model)
12981303

1304+
@wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_lp_inference")
12991305
def prepare_for_lp_inference(self) -> None:
13001306
if not self.cpu_offload:
13011307
self.move_to_cuda(self.model)
@@ -1305,6 +1311,7 @@ def prepare_for_lp_inference(self) -> None:
13051311
self.model.eval()
13061312
self.offload_before_refit()
13071313

1314+
@wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_training")
13081315
def prepare_for_training(self, *args, **kwargs) -> None:
13091316
# onload models and optimizer state to cuda
13101317
if not self.cpu_offload:
@@ -1329,6 +1336,7 @@ def prepare_for_training(self, *args, **kwargs) -> None:
13291336
torch.cuda.empty_cache()
13301337

13311338
@torch.no_grad()
1339+
@wrap_with_nvtx_name("dtensor_policy_worker/offload_before_refit")
13321340
def offload_before_refit(self) -> None:
13331341
"""Offload the optimizer to the CPU."""
13341342
torch.randn(1).cuda() # wake up torch allocator
@@ -1342,6 +1350,7 @@ def offload_before_refit(self) -> None:
13421350
torch.cuda.empty_cache()
13431351

13441352
@torch.no_grad()
1353+
@wrap_with_nvtx_name("dtensor_policy_worker/offload_after_refit")
13451354
def offload_after_refit(self) -> None:
13461355
# Offload as much as possible on the CPU
13471356
self.model = self.move_to_cpu(self.model)

nemo_rl/models/policy/megatron_policy_worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
get_megatron_checkpoint_dir,
125125
get_runtime_env_for_policy_worker,
126126
)
127+
from nemo_rl.utils.nsys import wrap_with_nvtx_name
127128

128129
TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)
129130

@@ -765,6 +766,7 @@ def disable_forward_pre_hook(self, param_sync=True):
765766
assert isinstance(self.model, DistributedDataParallel)
766767
self.model.disable_forward_pre_hook(param_sync=param_sync)
767768

769+
@wrap_with_nvtx_name("megatron_policy_worker/train")
768770
def train(
769771
self,
770772
data: BatchedDataDict,
@@ -1010,6 +1012,7 @@ def train(
10101012
}
10111013
return metrics
10121014

1015+
@wrap_with_nvtx_name("megatron_policy_worker/get_logprobs")
10131016
def get_logprobs(
10141017
self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
10151018
) -> BatchedDataDict[LogprobOutputSpec]:
@@ -1240,6 +1243,7 @@ def use_reference_model(self):
12401243
self.enable_forward_pre_hook()
12411244

12421245
# Temporary fix, 'data' is a kwarg due to some sort of ray bug
1246+
@wrap_with_nvtx_name("megatron_policy_worker/get_reference_policy_logprobs")
12431247
def get_reference_policy_logprobs(
12441248
self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
12451249
) -> BatchedDataDict[ReferenceLogprobOutputSpec]:
@@ -1262,6 +1266,7 @@ def get_reference_policy_logprobs(
12621266
return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu()
12631267
return return_data
12641268

1269+
@wrap_with_nvtx_name("megatron_policy_worker/generate")
12651270
def generate(
12661271
self, *, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
12671272
) -> BatchedDataDict[GenerationOutputSpec]:
@@ -1405,6 +1410,7 @@ def report_device_id(self) -> str:
14051410
return get_device_uuid(device_idx)
14061411

14071412
@torch.no_grad()
1413+
@wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info")
14081414
def prepare_refit_info(self) -> None:
14091415
# Get parameter info for refit
14101416
# param_info: list of ((name, shape, dtype), size_in_bytes) tuples
@@ -1439,6 +1445,7 @@ def prepare_refit_info(self) -> None:
14391445

14401446
return refit_param_info_hf
14411447

1448+
@wrap_with_nvtx_name("megatron_policy_worker/prepare_weights_for_ipc")
14421449
def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
14431450
"""Prepare Megatron model weights for IPC transfer to vLLM.
14441451
@@ -1460,6 +1467,7 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
14601467

14611468
# Temporary fix, 'keys' is a kwarg due to some sort of ray bug
14621469
@torch.no_grad()
1470+
@wrap_with_nvtx_name("megatron_policy_worker/get_weights_ipc_handles")
14631471
def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]:
14641472
"""Get IPC handles for the requested Megatron model weights.
14651473
@@ -1592,6 +1600,7 @@ def prepare_for_training(self, *args, **kwargs):
15921600

15931601
torch.cuda.empty_cache()
15941602

1603+
@wrap_with_nvtx_name("megatron_policy_worker/offload_before_refit")
15951604
def offload_before_refit(self):
15961605
"""Offload the optimizer and buffers to the CPU."""
15971606
no_grad = torch.no_grad()
@@ -1630,6 +1639,7 @@ def offload_before_refit(self):
16301639
)
16311640
no_grad.__exit__(None, None, None)
16321641

1642+
@wrap_with_nvtx_name("megatron_policy_worker/offload_after_refit")
16331643
def offload_after_refit(self):
16341644
no_grad = torch.no_grad()
16351645
no_grad.__enter__()

nemo_rl/utils/nsys.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Protocol
1717

1818
import rich
19+
import torch
1920

2021
NRL_NSYS_WORKER_PATTERNS = os.environ.get("NRL_NSYS_WORKER_PATTERNS", "")
2122
NRL_NSYS_PROFILE_STEP_RANGE = os.environ.get("NRL_NSYS_PROFILE_STEP_RANGE", "")
@@ -76,3 +77,18 @@ def stop_profiler_on_exit():
7677
)
7778
policy.stop_gpu_profiling()
7879
policy.__NRL_PROFILE_STARTED = False
80+
81+
82+
def wrap_with_nvtx_name(name: str):
83+
"""A decorator to wrap a function with an NVTX range with the given name."""
84+
85+
def decorator(func):
86+
def wrapper(*args, **kwargs):
87+
torch.cuda.nvtx.range_push(name)
88+
ret = func(*args, **kwargs)
89+
torch.cuda.nvtx.range_pop()
90+
return ret
91+
92+
return wrapper
93+
94+
return decorator

0 commit comments

Comments
 (0)