Skip to content

Commit 08ea80e

Browse files
vmoenscursoragent
andcommitted
[Example] Dreamer: distributed prof profiling integration
Integrate the prof distributed profiler into the Dreamer training loop and collector workers for coordinated cross-process profiling. - dreamer_utils.py: add create_prof_handle(), extend DreamerProfiler with prof_handle param, step(), shm_name property, finish() cleanup. Add prof_shm_name param to make_collector with PROF_SHM_NAME env var. - dreamer.py: create prof_handle early, pass shm_name to collector, wrap training phases with _prof_context (sample, world_model, actor, value, weight_update), call profiler.finish() at cleanup. - config.yaml: add profiling.distributed block, raise total_optim_steps to 70 for prof window. - _runner.py: worker reads PROF_SHM_NAME/PROF_ENABLED env vars and calls prof.prepare() to join profiling. Wraps rollout in prof context. Co-authored-by: Cursor <[email protected]> ghstack-source-id: 9fcc2e8 Pull-Request: #3461
1 parent a9d7b74 commit 08ea80e

File tree

4 files changed

+301
-104
lines changed

4 files changed

+301
-104
lines changed

sota-implementations/dreamer/config.yaml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ profiling:
9595
# Enable PyTorch profiling
9696
enabled: False
9797
# Total optim steps when profiling (overrides optimization.total_optim_steps)
98-
total_optim_steps: 50
98+
# When using distributed profiling (profiling.distributed.enabled=true), set this
99+
# higher than PROF_ITERATIONS end value to allow prof to complete its window.
100+
# e.g., if PROF_ITERATIONS=50-55, set total_optim_steps >= 60
101+
total_optim_steps: 70
99102
# Skip the first N optim steps (no profiling at all)
100103
skip_first: 1
101104
# Warmup steps (profiler runs but data discarded for warmup)
@@ -129,3 +132,17 @@ profiling:
129132
trace_file: collector_trace_{worker_idx}.json
130133
# Override init_random_frames when collector profiling is enabled (0 = skip random frames phase)
131134
init_random_frames_override: 0
135+
136+
# Distributed profiling with the prof library
137+
# This enables coordinated profiling across training and collector workers
138+
# using shared memory signaling. Set environment variables to control behavior:
139+
# PROF_ENABLED=1 - Enable distributed profiling
140+
# PROF_ITERATIONS=50-55 - Which training steps to profile
141+
# PROF_OUTPUT_DIR=./traces - Where to save trace files
142+
# PROF_MODE=lite - Only trace explicit regions (default)
143+
distributed:
144+
# Enable prof-based distributed profiling
145+
enabled: False
146+
# Backend for coordination: "shm" (shared memory) for subprocess collectors,
147+
# "ray" for Ray actors (requires Ray-based collectors)
148+
backend: shm

sota-implementations/dreamer/dreamer.py

Lines changed: 136 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from dreamer_utils import (
1818
_default_device,
19+
create_prof_handle,
1920
DreamerProfiler,
2021
dump_video,
2122
log_metrics,
@@ -32,6 +33,23 @@
3233
from torch.autograd.profiler import record_function
3334
from torch.nn.utils import clip_grad_norm_
3435
from torchrl._utils import compile_with_warmup, logger as torchrl_logger, timeit
36+
37+
# Distributed profiling with prof (only imported when enabled)
38+
try:
39+
import prof
40+
41+
_PROF_AVAILABLE = True
42+
except ImportError:
43+
_PROF_AVAILABLE = False
44+
45+
46+
def _prof_context(name: str):
47+
"""Return a prof.profile context manager if prof is available, else nullcontext."""
48+
if _PROF_AVAILABLE:
49+
return prof.profile(name)
50+
return contextlib.nullcontext()
51+
52+
3553
from torchrl.envs.llm.transforms import PolicyVersion
3654
from torchrl.envs.utils import ExplorationType, set_exploration_type
3755
from torchrl.objectives.dreamer import (
@@ -80,6 +98,10 @@ def main(cfg: DictConfig): # noqa: F821
8098
if hasattr(logger, "log_hparams"):
8199
logger.log_hparams(cfg)
82100

101+
# Create prof handle early for distributed profiling (before collector creation)
102+
# This allows the shm_name to be passed to collector workers
103+
prof_handle = create_prof_handle(cfg)
104+
83105
# make_environments returns (train_env_factory, test_env) for async collection
84106
train_env_factory, test_env = make_environments(
85107
cfg=cfg,
@@ -180,6 +202,7 @@ def main(cfg: DictConfig): # noqa: F821
180202
replay_buffer=replay_buffer,
181203
storage_transform=storage_transform,
182204
track_policy_version=policy_version,
205+
prof_shm_name=prof_handle.shm_name if prof_handle is not None else None,
183206
)
184207

185208
# Enable collector worker profiling if configured
@@ -304,7 +327,9 @@ def main(cfg: DictConfig): # noqa: F821
304327
t_log_start = time.time()
305328

306329
# Profiling setup (encapsulated in helper class)
307-
profiler = DreamerProfiler(cfg, device, pbar, compile_warmup=compile_warmup)
330+
profiler = DreamerProfiler(
331+
cfg, device, pbar, compile_warmup=compile_warmup, prof_handle=prof_handle
332+
)
308333

309334
# Start async collection - collector fills the buffer in background
310335
torchrl_logger.info("Starting async collection...")
@@ -363,97 +388,107 @@ def main(cfg: DictConfig): # noqa: F821
363388
)
364389

365390
# sample from replay buffer
366-
with timeit("train/sample"), record_function("## train/sample ##"):
367-
sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length)
368-
if profiling_enabled:
369-
torch.cuda.synchronize()
391+
with _prof_context("sample"):
392+
with timeit("train/sample"), record_function("## train/sample ##"):
393+
sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length)
394+
if profiling_enabled:
395+
torch.cuda.synchronize()
370396

371397
# update world model
372-
with timeit("train/world_model-forward"), record_function(
373-
"## world_model/forward ##"
374-
):
375-
# Mark step begin for CUDAGraph to prevent tensor overwrite issues
376-
torch.compiler.cudagraph_mark_step_begin()
377-
with torch.autocast(
378-
device_type=device.type,
379-
dtype=autocast_dtype,
380-
) if autocast_dtype else contextlib.nullcontext():
381-
assert (
382-
sampled_tensordict.device.type == "cuda"
383-
), "sampled_tensordict should be on CUDA"
384-
model_loss_td, sampled_tensordict = world_model_loss(sampled_tensordict)
385-
loss_world_model = (
386-
model_loss_td["loss_model_kl"]
387-
+ model_loss_td["loss_model_reco"]
388-
+ model_loss_td["loss_model_reward"]
389-
)
398+
with _prof_context("world_model"):
399+
with timeit("train/world_model-forward"), record_function(
400+
"## world_model/forward ##"
401+
):
402+
# Mark step begin for CUDAGraph to prevent tensor overwrite issues
403+
torch.compiler.cudagraph_mark_step_begin()
404+
with torch.autocast(
405+
device_type=device.type,
406+
dtype=autocast_dtype,
407+
) if autocast_dtype else contextlib.nullcontext():
408+
assert (
409+
sampled_tensordict.device.type == "cuda"
410+
), "sampled_tensordict should be on CUDA"
411+
model_loss_td, sampled_tensordict = world_model_loss(
412+
sampled_tensordict
413+
)
414+
loss_world_model = (
415+
model_loss_td["loss_model_kl"]
416+
+ model_loss_td["loss_model_reco"]
417+
+ model_loss_td["loss_model_reward"]
418+
)
390419

391-
with timeit("train/world_model-backward"), record_function(
392-
"## world_model/backward ##"
393-
):
394-
world_model_opt.zero_grad()
395-
if autocast_dtype:
396-
scaler1.scale(loss_world_model).backward()
397-
scaler1.unscale_(world_model_opt)
398-
else:
399-
loss_world_model.backward()
400-
torchrl_logger.debug("world_model_loss backward OK")
401-
world_model_grad = clip_grad_norm_(world_model.parameters(), grad_clip)
402-
if autocast_dtype:
403-
scaler1.step(world_model_opt)
404-
scaler1.update()
405-
else:
406-
world_model_opt.step()
420+
with timeit("train/world_model-backward"), record_function(
421+
"## world_model/backward ##"
422+
):
423+
world_model_opt.zero_grad()
424+
if autocast_dtype:
425+
scaler1.scale(loss_world_model).backward()
426+
scaler1.unscale_(world_model_opt)
427+
else:
428+
loss_world_model.backward()
429+
torchrl_logger.debug("world_model_loss backward OK")
430+
world_model_grad = clip_grad_norm_(world_model.parameters(), grad_clip)
431+
if autocast_dtype:
432+
scaler1.step(world_model_opt)
433+
scaler1.update()
434+
else:
435+
world_model_opt.step()
407436

408437
# update actor network
409-
with timeit("train/actor-forward"), record_function("## actor/forward ##"):
410-
# Mark step begin for CUDAGraph to prevent tensor overwrite issues
411-
torch.compiler.cudagraph_mark_step_begin()
412-
with torch.autocast(
413-
device_type=device.type, dtype=autocast_dtype
414-
) if autocast_dtype else contextlib.nullcontext():
415-
actor_loss_td, sampled_tensordict = actor_loss(
416-
sampled_tensordict.reshape(-1)
417-
)
438+
with _prof_context("actor"):
439+
with timeit("train/actor-forward"), record_function("## actor/forward ##"):
440+
# Mark step begin for CUDAGraph to prevent tensor overwrite issues
441+
torch.compiler.cudagraph_mark_step_begin()
442+
with torch.autocast(
443+
device_type=device.type, dtype=autocast_dtype
444+
) if autocast_dtype else contextlib.nullcontext():
445+
actor_loss_td, sampled_tensordict = actor_loss(
446+
sampled_tensordict.reshape(-1)
447+
)
418448

419-
with timeit("train/actor-backward"), record_function("## actor/backward ##"):
420-
actor_opt.zero_grad()
421-
if autocast_dtype:
422-
scaler2.scale(actor_loss_td["loss_actor"]).backward()
423-
scaler2.unscale_(actor_opt)
424-
else:
425-
actor_loss_td["loss_actor"].backward()
426-
torchrl_logger.debug("actor_loss backward OK")
427-
actor_model_grad = clip_grad_norm_(actor_model.parameters(), grad_clip)
428-
if autocast_dtype:
429-
scaler2.step(actor_opt)
430-
scaler2.update()
431-
else:
432-
actor_opt.step()
449+
with timeit("train/actor-backward"), record_function(
450+
"## actor/backward ##"
451+
):
452+
actor_opt.zero_grad()
453+
if autocast_dtype:
454+
scaler2.scale(actor_loss_td["loss_actor"]).backward()
455+
scaler2.unscale_(actor_opt)
456+
else:
457+
actor_loss_td["loss_actor"].backward()
458+
torchrl_logger.debug("actor_loss backward OK")
459+
actor_model_grad = clip_grad_norm_(actor_model.parameters(), grad_clip)
460+
if autocast_dtype:
461+
scaler2.step(actor_opt)
462+
scaler2.update()
463+
else:
464+
actor_opt.step()
433465

434466
# update value network
435-
with timeit("train/value-forward"), record_function("## value/forward ##"):
436-
# Mark step begin for CUDAGraph to prevent tensor overwrite issues
437-
torch.compiler.cudagraph_mark_step_begin()
438-
with torch.autocast(
439-
device_type=device.type, dtype=autocast_dtype
440-
) if autocast_dtype else contextlib.nullcontext():
441-
value_loss_td, sampled_tensordict = value_loss(sampled_tensordict)
442-
443-
with timeit("train/value-backward"), record_function("## value/backward ##"):
444-
value_opt.zero_grad()
445-
if autocast_dtype:
446-
scaler3.scale(value_loss_td["loss_value"]).backward()
447-
scaler3.unscale_(value_opt)
448-
else:
449-
value_loss_td["loss_value"].backward()
450-
torchrl_logger.debug("value_loss backward OK")
451-
critic_model_grad = clip_grad_norm_(value_model.parameters(), grad_clip)
452-
if autocast_dtype:
453-
scaler3.step(value_opt)
454-
scaler3.update()
455-
else:
456-
value_opt.step()
467+
with _prof_context("value"):
468+
with timeit("train/value-forward"), record_function("## value/forward ##"):
469+
# Mark step begin for CUDAGraph to prevent tensor overwrite issues
470+
torch.compiler.cudagraph_mark_step_begin()
471+
with torch.autocast(
472+
device_type=device.type, dtype=autocast_dtype
473+
) if autocast_dtype else contextlib.nullcontext():
474+
value_loss_td, sampled_tensordict = value_loss(sampled_tensordict)
475+
476+
with timeit("train/value-backward"), record_function(
477+
"## value/backward ##"
478+
):
479+
value_opt.zero_grad()
480+
if autocast_dtype:
481+
scaler3.scale(value_loss_td["loss_value"]).backward()
482+
scaler3.unscale_(value_opt)
483+
else:
484+
value_loss_td["loss_value"].backward()
485+
torchrl_logger.debug("value_loss backward OK")
486+
critic_model_grad = clip_grad_norm_(value_model.parameters(), grad_clip)
487+
if autocast_dtype:
488+
scaler3.step(value_opt)
489+
scaler3.update()
490+
else:
491+
value_opt.step()
457492

458493
# Step profiler (returns True if profiling complete)
459494
if profiler.step():
@@ -544,19 +579,20 @@ def main(cfg: DictConfig): # noqa: F821
544579
frames_at_log_start = collected_frames
545580

546581
# Update policy weights in collector (for async collection)
547-
with timeit("train/weight_update") as weight_update_timer:
548-
torchrl_logger.debug(
549-
f"optim_step={optim_step}: Starting weight update..."
550-
)
551-
policy[1].step(frames_collected_this_interval)
552-
collector.update_policy_weights_()
553-
# Increment policy version after weight update
554-
collector.increment_version()
555-
torchrl_logger.debug(
556-
f"optim_step={optim_step}: Weight update completed in "
557-
f"{weight_update_timer.elapsed():.3f}s, "
558-
f"policy_version={policy_version.version}"
559-
)
582+
with _prof_context("weight_update"):
583+
with timeit("train/weight_update") as weight_update_timer:
584+
torchrl_logger.debug(
585+
f"optim_step={optim_step}: Starting weight update..."
586+
)
587+
policy[1].step(frames_collected_this_interval)
588+
collector.update_policy_weights_()
589+
# Increment policy version after weight update
590+
collector.increment_version()
591+
torchrl_logger.debug(
592+
f"optim_step={optim_step}: Weight update completed in "
593+
f"{weight_update_timer.elapsed():.3f}s, "
594+
f"policy_version={policy_version.version}"
595+
)
560596

561597
# Evaluation (every eval_every optimization steps)
562598
if (optim_step + 1) % eval_every == 0:
@@ -598,6 +634,9 @@ def main(cfg: DictConfig): # noqa: F821
598634
if logger is not None:
599635
log_metrics(logger, eval_metrics, replay_buffer.write_count)
600636

637+
# Finish profiling and clean up resources
638+
profiler.finish()
639+
601640
if not test_env.is_closed:
602641
test_env.close()
603642
# Shutdown async collector (use async_shutdown since we used start())

0 commit comments

Comments
 (0)