Skip to content

Commit 4d62783

Browse files
feat: add gpu mem and util logging to wandb/tensorboard (#37)
Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Parth Chadha <pchadha@nvidia.com>
1 parent 43ace69 commit 4d62783

File tree

16 files changed

+790
-33
lines changed

16 files changed

+790
-33
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dist/
2121
# Cache
2222
uv_cache/
2323
hf_home/
24+
hf_datasets_cache/
2425
*logs/
2526
datasets/
2627
docker/

docs/design_docs/gpu_logger.md

Whitespace-only changes.

docs/design_docs/logger.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,31 @@ When enabled, the pretty logging will generate formatted text similar to:
7878

7979
![Validation Pretty Logging Example](../assets/val-log.png)
8080

81+
## GPU Metric Logging
82+
83+
Reinforcer monitors GPU memory and utilization through [system metrics](https://docs.ray.io/en/latest/ray-observability/reference/system-metrics.html#system-metrics) exposed by Ray nodes. While Ray makes these metrics available for tools like Prometheus, Reinforcer directly polls GPU memory and utilization data and logs them to TensorBoard and/or Weights & Biases.
84+
85+
This approach allows us to offer the same GPU metric tracking on all loggers (not just wandb) and simplifies the implementation greatly.
86+
87+
This feature is enabled with the `monitor_gpus` configuration parameter and the frequency of collection and flushing to the loggers is controlled by `gpu_collection_interval` and `gpu_flush_interval` (both in seconds), respectively:
88+
89+
```python
90+
logger:
91+
wandb_enabled: false
92+
tensorboard_enabled: false
93+
monitor_gpus: true
94+
gpu_monitoring:
95+
collection_interval: 10
96+
flush_interval: 10
97+
```
98+
99+
:::{note}
100+
While monitoring through the remote workers is possible, it requires some delicate implementation details to make sure:
101+
* sending logs back to driver does not incur a large overhead
102+
* metrics are easily interpretable since we may be double counting due to colocated workers
103+
* workers gracefully flush their logs in the event of failure
104+
* the logging is the same for tensorboard and wandb
105+
* some workers which spawn other workers correctly report the total usage of the grandchild worker
106+
107+
These reasons lead us to the simple implementation of collecting on the driver
108+
:::

examples/configs/grpo_math_1B.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,14 @@ logger:
7777
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
7878
wandb_enabled: false
7979
tensorboard_enabled: false
80+
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
8081
wandb:
8182
project: "grpo-dev"
8283
name: "grpo-dev-logger"
8384
tensorboard: {}
85+
gpu_monitoring:
86+
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
87+
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)
8488

8589
cluster:
8690
gpus_per_node: 1

examples/configs/sft.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@ data:
4444

4545
logger:
4646
log_dir: "logs" # Base directory for all logs
47-
wandb_enabled: true
47+
wandb_enabled: false
4848
tensorboard_enabled: false
49+
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
4950
wandb:
5051
project: "sft-dev"
5152
name: "sft-dev-logger"
5253
tensorboard:
5354
log_dir: "tb_logs"
55+
gpu_monitoring:
56+
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
57+
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)
5458

5559
cluster:
5660
gpus_per_node: 8

examples/run_grpo_math.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,10 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs
195195
task_data_processors["math"] = (math_task_spec, openinstructmath2_data_processor)
196196

197197
math_env = MathEnvironment.options(
198-
runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE}
198+
runtime_env={
199+
"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE,
200+
"env_vars": dict(os.environ), # Pass thru all user environment variables
201+
}
199202
).remote(env_configs["math"])
200203
dataset = AllTaskProcessedDataset(
201204
data.formatted_ds["train"],

nemo_reinforcer/algorithms/grpo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ def setup(
137137
logger_config = master_config["logger"]
138138
cluster_config = master_config["cluster"]
139139

140+
# ==========================
141+
# Logger
142+
# ==========================
143+
logger = Logger(logger_config)
144+
logger.log_hyperparams(master_config)
145+
140146
# ==========================
141147
# Checkpointing
142148
# ==========================
@@ -238,8 +244,6 @@ def setup(
238244
)
239245

240246
loss_fn = ClippedPGLossFn(loss_config)
241-
logger = Logger(logger_config)
242-
logger.log_hyperparams(master_config)
243247

244248
print("\n" + "=" * 60)
245249
print(" " * 18 + "SETUP COMPLETE")

nemo_reinforcer/algorithms/loss_functions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,8 @@ def __call__(
166166
num_unmasked_tokens = torch.tensor(1)
167167
loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens
168168

169-
return loss, {"loss": loss.item(), "num_unmasked_tokens": num_unmasked_tokens.item(), "total_tokens": mask.numel()}
169+
return loss, {
170+
"loss": loss.item(),
171+
"num_unmasked_tokens": num_unmasked_tokens.item(),
172+
"total_tokens": mask.numel(),
173+
}

nemo_reinforcer/algorithms/sft.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class SFTConfig(TypedDict):
6161
val_at_start: bool
6262
seed: int
6363

64+
6465
class MasterConfig(TypedDict):
6566
policy: PolicyConfig
6667
data: DataConfig
@@ -102,6 +103,12 @@ def setup(
102103
cluster_config = master_config["cluster"]
103104
sft_config = master_config["sft"]
104105

106+
# ==========================
107+
# Logger
108+
# ==========================
109+
logger = Logger(logger_config)
110+
logger.log_hyperparams(master_config)
111+
105112
# ==========================
106113
# Checkpointing
107114
# ==========================
@@ -179,9 +186,6 @@ def setup(
179186
loss_fn = NLLLoss()
180187
print(f" ✓ Model initialized")
181188

182-
logger = Logger(logger_config)
183-
logger.log_hyperparams(master_config)
184-
185189
print("\n" + "=" * 60)
186190
print(" " * 18 + "SETUP COMPLETE")
187191
print("=" * 60 + "\n")

nemo_reinforcer/algorithms/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def masked_mean(values, mask, dim=None):
123123
return values[mask.bool()].mean()
124124
return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan)
125125

126+
126127
def set_seed(seed: int):
127128
"""Sets the seed for python, numpy, and pytorch."""
128129
random.seed(seed)

0 commit comments

Comments
 (0)