|
16 | 16 |
|
17 | 17 | from dreamer_utils import ( |
18 | 18 | _default_device, |
| 19 | + create_prof_handle, |
19 | 20 | DreamerProfiler, |
20 | 21 | dump_video, |
21 | 22 | log_metrics, |
|
32 | 33 | from torch.autograd.profiler import record_function |
33 | 34 | from torch.nn.utils import clip_grad_norm_ |
34 | 35 | from torchrl._utils import compile_with_warmup, logger as torchrl_logger, timeit |
| 36 | + |
| 37 | +# Distributed profiling with prof (only imported when enabled) |
| 38 | +try: |
| 39 | + import prof |
| 40 | + |
| 41 | + _PROF_AVAILABLE = True |
| 42 | +except ImportError: |
| 43 | + _PROF_AVAILABLE = False |
| 44 | + |
| 45 | + |
| 46 | +def _prof_context(name: str): |
| 47 | + """Return a prof.profile context manager if prof is available, else nullcontext.""" |
| 48 | + if _PROF_AVAILABLE: |
| 49 | + return prof.profile(name) |
| 50 | + return contextlib.nullcontext() |
| 51 | + |
| 52 | + |
35 | 53 | from torchrl.envs.llm.transforms import PolicyVersion |
36 | 54 | from torchrl.envs.utils import ExplorationType, set_exploration_type |
37 | 55 | from torchrl.objectives.dreamer import ( |
@@ -80,6 +98,10 @@ def main(cfg: DictConfig): # noqa: F821 |
80 | 98 | if hasattr(logger, "log_hparams"): |
81 | 99 | logger.log_hparams(cfg) |
82 | 100 |
|
| 101 | + # Create prof handle early for distributed profiling (before collector creation) |
| 102 | + # This allows the shm_name to be passed to collector workers |
| 103 | + prof_handle = create_prof_handle(cfg) |
| 104 | + |
83 | 105 | # make_environments returns (train_env_factory, test_env) for async collection |
84 | 106 | train_env_factory, test_env = make_environments( |
85 | 107 | cfg=cfg, |
@@ -180,6 +202,7 @@ def main(cfg: DictConfig): # noqa: F821 |
180 | 202 | replay_buffer=replay_buffer, |
181 | 203 | storage_transform=storage_transform, |
182 | 204 | track_policy_version=policy_version, |
| 205 | + prof_shm_name=prof_handle.shm_name if prof_handle is not None else None, |
183 | 206 | ) |
184 | 207 |
|
185 | 208 | # Enable collector worker profiling if configured |
@@ -304,7 +327,9 @@ def main(cfg: DictConfig): # noqa: F821 |
304 | 327 | t_log_start = time.time() |
305 | 328 |
|
306 | 329 | # Profiling setup (encapsulated in helper class) |
307 | | - profiler = DreamerProfiler(cfg, device, pbar, compile_warmup=compile_warmup) |
| 330 | + profiler = DreamerProfiler( |
| 331 | + cfg, device, pbar, compile_warmup=compile_warmup, prof_handle=prof_handle |
| 332 | + ) |
308 | 333 |
|
309 | 334 | # Start async collection - collector fills the buffer in background |
310 | 335 | torchrl_logger.info("Starting async collection...") |
@@ -363,97 +388,107 @@ def main(cfg: DictConfig): # noqa: F821 |
363 | 388 | ) |
364 | 389 |
|
365 | 390 | # sample from replay buffer |
366 | | - with timeit("train/sample"), record_function("## train/sample ##"): |
367 | | - sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length) |
368 | | - if profiling_enabled: |
369 | | - torch.cuda.synchronize() |
| 391 | + with _prof_context("sample"): |
| 392 | + with timeit("train/sample"), record_function("## train/sample ##"): |
| 393 | + sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length) |
| 394 | + if profiling_enabled: |
| 395 | + torch.cuda.synchronize() |
370 | 396 |
|
371 | 397 | # update world model |
372 | | - with timeit("train/world_model-forward"), record_function( |
373 | | - "## world_model/forward ##" |
374 | | - ): |
375 | | - # Mark step begin for CUDAGraph to prevent tensor overwrite issues |
376 | | - torch.compiler.cudagraph_mark_step_begin() |
377 | | - with torch.autocast( |
378 | | - device_type=device.type, |
379 | | - dtype=autocast_dtype, |
380 | | - ) if autocast_dtype else contextlib.nullcontext(): |
381 | | - assert ( |
382 | | - sampled_tensordict.device.type == "cuda" |
383 | | - ), "sampled_tensordict should be on CUDA" |
384 | | - model_loss_td, sampled_tensordict = world_model_loss(sampled_tensordict) |
385 | | - loss_world_model = ( |
386 | | - model_loss_td["loss_model_kl"] |
387 | | - + model_loss_td["loss_model_reco"] |
388 | | - + model_loss_td["loss_model_reward"] |
389 | | - ) |
| 398 | + with _prof_context("world_model"): |
| 399 | + with timeit("train/world_model-forward"), record_function( |
| 400 | + "## world_model/forward ##" |
| 401 | + ): |
| 402 | + # Mark step begin for CUDAGraph to prevent tensor overwrite issues |
| 403 | + torch.compiler.cudagraph_mark_step_begin() |
| 404 | + with torch.autocast( |
| 405 | + device_type=device.type, |
| 406 | + dtype=autocast_dtype, |
| 407 | + ) if autocast_dtype else contextlib.nullcontext(): |
| 408 | + assert ( |
| 409 | + sampled_tensordict.device.type == "cuda" |
| 410 | + ), "sampled_tensordict should be on CUDA" |
| 411 | + model_loss_td, sampled_tensordict = world_model_loss( |
| 412 | + sampled_tensordict |
| 413 | + ) |
| 414 | + loss_world_model = ( |
| 415 | + model_loss_td["loss_model_kl"] |
| 416 | + + model_loss_td["loss_model_reco"] |
| 417 | + + model_loss_td["loss_model_reward"] |
| 418 | + ) |
390 | 419 |
|
391 | | - with timeit("train/world_model-backward"), record_function( |
392 | | - "## world_model/backward ##" |
393 | | - ): |
394 | | - world_model_opt.zero_grad() |
395 | | - if autocast_dtype: |
396 | | - scaler1.scale(loss_world_model).backward() |
397 | | - scaler1.unscale_(world_model_opt) |
398 | | - else: |
399 | | - loss_world_model.backward() |
400 | | - torchrl_logger.debug("world_model_loss backward OK") |
401 | | - world_model_grad = clip_grad_norm_(world_model.parameters(), grad_clip) |
402 | | - if autocast_dtype: |
403 | | - scaler1.step(world_model_opt) |
404 | | - scaler1.update() |
405 | | - else: |
406 | | - world_model_opt.step() |
| 420 | + with timeit("train/world_model-backward"), record_function( |
| 421 | + "## world_model/backward ##" |
| 422 | + ): |
| 423 | + world_model_opt.zero_grad() |
| 424 | + if autocast_dtype: |
| 425 | + scaler1.scale(loss_world_model).backward() |
| 426 | + scaler1.unscale_(world_model_opt) |
| 427 | + else: |
| 428 | + loss_world_model.backward() |
| 429 | + torchrl_logger.debug("world_model_loss backward OK") |
| 430 | + world_model_grad = clip_grad_norm_(world_model.parameters(), grad_clip) |
| 431 | + if autocast_dtype: |
| 432 | + scaler1.step(world_model_opt) |
| 433 | + scaler1.update() |
| 434 | + else: |
| 435 | + world_model_opt.step() |
407 | 436 |
|
408 | 437 | # update actor network |
409 | | - with timeit("train/actor-forward"), record_function("## actor/forward ##"): |
410 | | - # Mark step begin for CUDAGraph to prevent tensor overwrite issues |
411 | | - torch.compiler.cudagraph_mark_step_begin() |
412 | | - with torch.autocast( |
413 | | - device_type=device.type, dtype=autocast_dtype |
414 | | - ) if autocast_dtype else contextlib.nullcontext(): |
415 | | - actor_loss_td, sampled_tensordict = actor_loss( |
416 | | - sampled_tensordict.reshape(-1) |
417 | | - ) |
| 438 | + with _prof_context("actor"): |
| 439 | + with timeit("train/actor-forward"), record_function("## actor/forward ##"): |
| 440 | + # Mark step begin for CUDAGraph to prevent tensor overwrite issues |
| 441 | + torch.compiler.cudagraph_mark_step_begin() |
| 442 | + with torch.autocast( |
| 443 | + device_type=device.type, dtype=autocast_dtype |
| 444 | + ) if autocast_dtype else contextlib.nullcontext(): |
| 445 | + actor_loss_td, sampled_tensordict = actor_loss( |
| 446 | + sampled_tensordict.reshape(-1) |
| 447 | + ) |
418 | 448 |
|
419 | | - with timeit("train/actor-backward"), record_function("## actor/backward ##"): |
420 | | - actor_opt.zero_grad() |
421 | | - if autocast_dtype: |
422 | | - scaler2.scale(actor_loss_td["loss_actor"]).backward() |
423 | | - scaler2.unscale_(actor_opt) |
424 | | - else: |
425 | | - actor_loss_td["loss_actor"].backward() |
426 | | - torchrl_logger.debug("actor_loss backward OK") |
427 | | - actor_model_grad = clip_grad_norm_(actor_model.parameters(), grad_clip) |
428 | | - if autocast_dtype: |
429 | | - scaler2.step(actor_opt) |
430 | | - scaler2.update() |
431 | | - else: |
432 | | - actor_opt.step() |
| 449 | + with timeit("train/actor-backward"), record_function( |
| 450 | + "## actor/backward ##" |
| 451 | + ): |
| 452 | + actor_opt.zero_grad() |
| 453 | + if autocast_dtype: |
| 454 | + scaler2.scale(actor_loss_td["loss_actor"]).backward() |
| 455 | + scaler2.unscale_(actor_opt) |
| 456 | + else: |
| 457 | + actor_loss_td["loss_actor"].backward() |
| 458 | + torchrl_logger.debug("actor_loss backward OK") |
| 459 | + actor_model_grad = clip_grad_norm_(actor_model.parameters(), grad_clip) |
| 460 | + if autocast_dtype: |
| 461 | + scaler2.step(actor_opt) |
| 462 | + scaler2.update() |
| 463 | + else: |
| 464 | + actor_opt.step() |
433 | 465 |
|
434 | 466 | # update value network |
435 | | - with timeit("train/value-forward"), record_function("## value/forward ##"): |
436 | | - # Mark step begin for CUDAGraph to prevent tensor overwrite issues |
437 | | - torch.compiler.cudagraph_mark_step_begin() |
438 | | - with torch.autocast( |
439 | | - device_type=device.type, dtype=autocast_dtype |
440 | | - ) if autocast_dtype else contextlib.nullcontext(): |
441 | | - value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) |
442 | | - |
443 | | - with timeit("train/value-backward"), record_function("## value/backward ##"): |
444 | | - value_opt.zero_grad() |
445 | | - if autocast_dtype: |
446 | | - scaler3.scale(value_loss_td["loss_value"]).backward() |
447 | | - scaler3.unscale_(value_opt) |
448 | | - else: |
449 | | - value_loss_td["loss_value"].backward() |
450 | | - torchrl_logger.debug("value_loss backward OK") |
451 | | - critic_model_grad = clip_grad_norm_(value_model.parameters(), grad_clip) |
452 | | - if autocast_dtype: |
453 | | - scaler3.step(value_opt) |
454 | | - scaler3.update() |
455 | | - else: |
456 | | - value_opt.step() |
| 467 | + with _prof_context("value"): |
| 468 | + with timeit("train/value-forward"), record_function("## value/forward ##"): |
| 469 | + # Mark step begin for CUDAGraph to prevent tensor overwrite issues |
| 470 | + torch.compiler.cudagraph_mark_step_begin() |
| 471 | + with torch.autocast( |
| 472 | + device_type=device.type, dtype=autocast_dtype |
| 473 | + ) if autocast_dtype else contextlib.nullcontext(): |
| 474 | + value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) |
| 475 | + |
| 476 | + with timeit("train/value-backward"), record_function( |
| 477 | + "## value/backward ##" |
| 478 | + ): |
| 479 | + value_opt.zero_grad() |
| 480 | + if autocast_dtype: |
| 481 | + scaler3.scale(value_loss_td["loss_value"]).backward() |
| 482 | + scaler3.unscale_(value_opt) |
| 483 | + else: |
| 484 | + value_loss_td["loss_value"].backward() |
| 485 | + torchrl_logger.debug("value_loss backward OK") |
| 486 | + critic_model_grad = clip_grad_norm_(value_model.parameters(), grad_clip) |
| 487 | + if autocast_dtype: |
| 488 | + scaler3.step(value_opt) |
| 489 | + scaler3.update() |
| 490 | + else: |
| 491 | + value_opt.step() |
457 | 492 |
|
458 | 493 | # Step profiler (returns True if profiling complete) |
459 | 494 | if profiler.step(): |
@@ -544,19 +579,20 @@ def main(cfg: DictConfig): # noqa: F821 |
544 | 579 | frames_at_log_start = collected_frames |
545 | 580 |
|
546 | 581 | # Update policy weights in collector (for async collection) |
547 | | - with timeit("train/weight_update") as weight_update_timer: |
548 | | - torchrl_logger.debug( |
549 | | - f"optim_step={optim_step}: Starting weight update..." |
550 | | - ) |
551 | | - policy[1].step(frames_collected_this_interval) |
552 | | - collector.update_policy_weights_() |
553 | | - # Increment policy version after weight update |
554 | | - collector.increment_version() |
555 | | - torchrl_logger.debug( |
556 | | - f"optim_step={optim_step}: Weight update completed in " |
557 | | - f"{weight_update_timer.elapsed():.3f}s, " |
558 | | - f"policy_version={policy_version.version}" |
559 | | - ) |
| 582 | + with _prof_context("weight_update"): |
| 583 | + with timeit("train/weight_update") as weight_update_timer: |
| 584 | + torchrl_logger.debug( |
| 585 | + f"optim_step={optim_step}: Starting weight update..." |
| 586 | + ) |
| 587 | + policy[1].step(frames_collected_this_interval) |
| 588 | + collector.update_policy_weights_() |
| 589 | + # Increment policy version after weight update |
| 590 | + collector.increment_version() |
| 591 | + torchrl_logger.debug( |
| 592 | + f"optim_step={optim_step}: Weight update completed in " |
| 593 | + f"{weight_update_timer.elapsed():.3f}s, " |
| 594 | + f"policy_version={policy_version.version}" |
| 595 | + ) |
560 | 596 |
|
561 | 597 | # Evaluation (every eval_every optimization steps) |
562 | 598 | if (optim_step + 1) % eval_every == 0: |
@@ -598,6 +634,9 @@ def main(cfg: DictConfig): # noqa: F821 |
598 | 634 | if logger is not None: |
599 | 635 | log_metrics(logger, eval_metrics, replay_buffer.write_count) |
600 | 636 |
|
| 637 | + # Finish profiling and clean up resources |
| 638 | + profiler.finish() |
| 639 | + |
601 | 640 | if not test_env.is_closed: |
602 | 641 | test_env.close() |
603 | 642 | # Shutdown async collector (use async_shutdown since we used start()) |
|
0 commit comments