Skip to content

Commit 932c72d

Browse files
authored
fix: log metrics that can be coerced to scalars (#1723)
Signed-off-by: Terry Kong <[email protected]>
1 parent 705d25f commit 932c72d

File tree

2 files changed

+117
-9
lines changed

2 files changed

+117
-9
lines changed

nemo_rl/utils/logger.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,23 @@ def __init__(self, cfg: TensorboardConfig, log_dir: Optional[str] = None):
121121
self.writer = SummaryWriter(log_dir=log_dir)
122122
print(f"Initialized TensorboardLogger at {log_dir}")
123123

124+
@staticmethod
125+
def _coerce_to_scalar(value: Any) -> int | float | bool | str | None:
126+
"""Coerce a value to a Python scalar for TensorBoard logging.
127+
128+
Returns the coerced value, or None if it can't be converted to a scalar.
129+
"""
130+
if isinstance(value, (int, float, bool, str)):
131+
return value
132+
if isinstance(value, (np.floating, np.integer, np.bool_)):
133+
return value.item()
134+
if isinstance(value, np.ndarray) and (value.ndim == 0 or value.size == 1):
135+
return value.item()
136+
if isinstance(value, torch.Tensor) and (value.ndim == 0 or value.numel() == 1):
137+
return value.item()
138+
# dict, list, multi-element arrays/tensors, or incompatible types
139+
return None
140+
124141
def log_metrics(
125142
self,
126143
metrics: dict[str, Any],
@@ -137,23 +154,19 @@ def log_metrics(
137154
step_metric: Optional step metric name (ignored in TensorBoard)
138155
"""
139156
for name, value in metrics.items():
140-
# NeMo-Gym will add additional metrics like wandb histograms. However, some people will log to Tensorboard instead which may not be compatible
141-
# This logic catches non-compatible objects being logged.
142-
if not isinstance(value, (int, float, bool, str)):
143-
continue
144-
145157
if prefix:
146158
name = f"{prefix}/{name}"
147159

148-
# Skip non-scalar values that TensorBoard can't handle
149-
if isinstance(value, (dict, list)):
160+
scalar = self._coerce_to_scalar(value)
161+
if scalar is None:
150162
print(
151-
f"Warning: Skipping non-scalar metric '{name}' for TensorBoard logging (type: {type(value).__name__})"
163+
f"Warning: Skipping metric '{name}' for TensorBoard logging "
164+
f"(unsupported type: {type(value).__name__})"
152165
)
153166
continue
154167

155168
try:
156-
self.writer.add_scalar(name, value, step)
169+
self.writer.add_scalar(name, scalar, step)
157170
except Exception as e:
158171
print(f"Warning: Failed to log metric '{name}' to TensorBoard: {e}")
159172
continue

tests/unit/utils/test_logger.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,101 @@ def test_log_hyperparams(self, mock_summary_writer, temp_dir):
128128
"model.hidden_size": 128,
129129
}
130130

131+
@patch("nemo_rl.utils.logger.SummaryWriter")
132+
def test_coerce_to_scalar_python_primitives(self, mock_summary_writer, temp_dir):
133+
"""Test that Python primitives pass through unchanged."""
134+
cfg = {"log_dir": temp_dir}
135+
logger = TensorboardLogger(cfg, log_dir=temp_dir)
136+
137+
assert logger._coerce_to_scalar(42) == 42
138+
assert logger._coerce_to_scalar(3.14) == 3.14
139+
assert logger._coerce_to_scalar(True) is True
140+
assert logger._coerce_to_scalar("hello") == "hello"
141+
142+
@patch("nemo_rl.utils.logger.SummaryWriter")
143+
def test_coerce_to_scalar_numpy_types(self, mock_summary_writer, temp_dir):
144+
"""Test that numpy scalar types are coerced to Python primitives."""
145+
import numpy as np
146+
147+
cfg = {"log_dir": temp_dir}
148+
logger = TensorboardLogger(cfg, log_dir=temp_dir)
149+
150+
# numpy scalar types
151+
assert logger._coerce_to_scalar(np.float32(1.5)) == 1.5
152+
assert logger._coerce_to_scalar(np.float64(2.5)) == 2.5
153+
assert logger._coerce_to_scalar(np.int32(10)) == 10
154+
assert logger._coerce_to_scalar(np.int64(20)) == 20
155+
assert logger._coerce_to_scalar(np.bool_(True)) is True
156+
157+
# 0-d numpy arrays
158+
assert logger._coerce_to_scalar(np.array(3.14)) == 3.14
159+
# 1-element numpy arrays
160+
assert logger._coerce_to_scalar(np.array([42])) == 42
161+
162+
# Multi-element arrays should return None
163+
assert logger._coerce_to_scalar(np.array([1, 2, 3])) is None
164+
165+
@patch("nemo_rl.utils.logger.SummaryWriter")
166+
def test_coerce_to_scalar_torch_tensors(self, mock_summary_writer, temp_dir):
167+
"""Test that torch scalar tensors are coerced to Python primitives."""
168+
cfg = {"log_dir": temp_dir}
169+
logger = TensorboardLogger(cfg, log_dir=temp_dir)
170+
171+
# 0-d tensors
172+
assert logger._coerce_to_scalar(torch.tensor(3.14)) == pytest.approx(3.14)
173+
assert logger._coerce_to_scalar(torch.tensor(42)) == 42
174+
175+
# 1-element tensors
176+
assert logger._coerce_to_scalar(torch.tensor([99])) == 99
177+
178+
# Multi-element tensors should return None
179+
assert logger._coerce_to_scalar(torch.tensor([1, 2, 3])) is None
180+
181+
@patch("nemo_rl.utils.logger.SummaryWriter")
182+
def test_coerce_to_scalar_incompatible_types(self, mock_summary_writer, temp_dir):
183+
"""Test that incompatible types return None."""
184+
cfg = {"log_dir": temp_dir}
185+
logger = TensorboardLogger(cfg, log_dir=temp_dir)
186+
187+
assert logger._coerce_to_scalar({"key": "value"}) is None
188+
assert logger._coerce_to_scalar([1, 2, 3]) is None
189+
assert logger._coerce_to_scalar(None) is None
190+
assert logger._coerce_to_scalar(object()) is None
191+
192+
@patch("nemo_rl.utils.logger.SummaryWriter")
193+
def test_log_metrics_coerces_numpy_and_torch(self, mock_summary_writer, temp_dir):
194+
"""Test that log_metrics correctly logs numpy/torch scalars."""
195+
import numpy as np
196+
197+
cfg = {"log_dir": temp_dir}
198+
logger = TensorboardLogger(cfg, log_dir=temp_dir)
199+
200+
metrics = {
201+
"python_float": 1.0,
202+
"numpy_float32": np.float32(2.0),
203+
"numpy_float64": np.float64(3.0),
204+
"torch_scalar": torch.tensor(4.0),
205+
"numpy_0d": np.array(5.0),
206+
"torch_1elem": torch.tensor([6.0]),
207+
"skip_list": [1, 2, 3],
208+
"skip_dict": {"a": 1},
209+
"skip_multi_tensor": torch.tensor([1.0, 2.0]),
210+
}
211+
logger.log_metrics(metrics, step=1)
212+
213+
mock_writer = mock_summary_writer.return_value
214+
# Should log 6 scalars, skip 3 incompatible
215+
assert mock_writer.add_scalar.call_count == 6
216+
217+
# Verify each scalar was logged with correct value
218+
calls = {c[0][0]: c[0][1] for c in mock_writer.add_scalar.call_args_list}
219+
assert calls["python_float"] == 1.0
220+
assert calls["numpy_float32"] == pytest.approx(2.0)
221+
assert calls["numpy_float64"] == pytest.approx(3.0)
222+
assert calls["torch_scalar"] == pytest.approx(4.0)
223+
assert calls["numpy_0d"] == pytest.approx(5.0)
224+
assert calls["torch_1elem"] == pytest.approx(6.0)
225+
131226

132227
class TestWandbLogger:
133228
"""Test the WandbLogger class."""

0 commit comments

Comments
 (0)