-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
Currently, when using PyTorch Lightning with WandbLogger
for multi-node distributed training, only the system metrics (CPU/GPU utilization, memory usage, etc.) from node 0 and logs from rank 0 are recorded by wandb. This limits visibility into critical performance data from non-zero ranks, making it harder to debug issues like uneven resource utilization across nodes/GPUs or pinpoint hardware bottlenecks in distributed setups.
With wandb's recently added distributed experiment tracking support (also the end-to-end example report), it's now possible to collect system metrics and logs from all nodes. However, Lightning's WandbLogger
does not yet leverage this capability out-of-the-box. This feature request proposes updating Lightning's wandb integration to support full distributed system monitoring.
Pitch
Add native support in pytorch_lightning.loggers.WandbLogger
to:
- Log system metrics (hardware telemetry) from all nodes in distributed training, not just node 0
This could be implemented by:
- Adding a
log_all_ranks: bool
parameter toWandbLogger
to enable/disable this behavior wandb.init(..., settings=wandb.Settings(x_label="rank_0", mode="shared", x_primary=True))
for rank 0wandb.init(..., settings=wandb.Settings(x_label=f"rank_{rank}", mode="shared", x_primary=False))
for non-zero ranks
Alternatives
While users could manually override logging behavior by modifying wandb.init() parameters in non-zero ranks, this:
- Conflicts with Lightning's logger orchestration
- Risks creating multiple wandb runs unintentionally
- Requires error-prone custom code outside Lightning's abstractions
A cleaner native implementation would provide better safety and usability.
Additional context
wandb has recently updated the feature, here is the relevant issue:
When examining the WandbLogger
source code, the @rank_zero_experiment
decorator enforces wandb initialization and logging exclusively on rank 0. This creates architectural challenges for multi-rank logging because:
- Non-zero ranks never initialize a
wandb.Run
- Any attempt to modify the existing logger would require significant refactoring of the decorator logic
Instead of adding a log_all_ranks
parameter to WandbLogger
, a cleaner solution could be:
- Create a new
WandbDistributedLogger
class underlightning.pytorch.loggers.wandb
- Design this class to:
- Initialize wandb runs on all ranks using
wandb.init(..., settings=wandb.Settings(x_label=..., mode="shared", x_primary=...))
- Bypass the
@rank_zero_experiment
restriction - Preserve the original
WandbLogger
behavior (e.g.log_metrics
,log_hyperparams
, etc.)
- Initialize wandb runs on all ranks using
This approach:
- Maintains backward compatibility
- Avoids complicating the existing WandbLogger API
I would be happy to submit a PR implementing if needed.
- The new
WandbDistributedLogger
class - Integration tests verifying multi-node metric collection
cc @lantiga @Borda @morganmcg1 @borisdayma @scottire @parambharat