@@ -128,6 +128,101 @@ def test_log_hyperparams(self, mock_summary_writer, temp_dir):
128128 "model.hidden_size" : 128 ,
129129 }
130130
131+ @patch ("nemo_rl.utils.logger.SummaryWriter" )
132+ def test_coerce_to_scalar_python_primitives (self , mock_summary_writer , temp_dir ):
133+ """Test that Python primitives pass through unchanged."""
134+ cfg = {"log_dir" : temp_dir }
135+ logger = TensorboardLogger (cfg , log_dir = temp_dir )
136+
137+ assert logger ._coerce_to_scalar (42 ) == 42
138+ assert logger ._coerce_to_scalar (3.14 ) == 3.14
139+ assert logger ._coerce_to_scalar (True ) is True
140+ assert logger ._coerce_to_scalar ("hello" ) == "hello"
141+
142+ @patch ("nemo_rl.utils.logger.SummaryWriter" )
143+ def test_coerce_to_scalar_numpy_types (self , mock_summary_writer , temp_dir ):
144+ """Test that numpy scalar types are coerced to Python primitives."""
145+ import numpy as np
146+
147+ cfg = {"log_dir" : temp_dir }
148+ logger = TensorboardLogger (cfg , log_dir = temp_dir )
149+
150+ # numpy scalar types
151+ assert logger ._coerce_to_scalar (np .float32 (1.5 )) == 1.5
152+ assert logger ._coerce_to_scalar (np .float64 (2.5 )) == 2.5
153+ assert logger ._coerce_to_scalar (np .int32 (10 )) == 10
154+ assert logger ._coerce_to_scalar (np .int64 (20 )) == 20
155+ assert logger ._coerce_to_scalar (np .bool_ (True )) is True
156+
157+ # 0-d numpy arrays
158+ assert logger ._coerce_to_scalar (np .array (3.14 )) == 3.14
159+ # 1-element numpy arrays
160+ assert logger ._coerce_to_scalar (np .array ([42 ])) == 42
161+
162+ # Multi-element arrays should return None
163+ assert logger ._coerce_to_scalar (np .array ([1 , 2 , 3 ])) is None
164+
165+ @patch ("nemo_rl.utils.logger.SummaryWriter" )
166+ def test_coerce_to_scalar_torch_tensors (self , mock_summary_writer , temp_dir ):
167+ """Test that torch scalar tensors are coerced to Python primitives."""
168+ cfg = {"log_dir" : temp_dir }
169+ logger = TensorboardLogger (cfg , log_dir = temp_dir )
170+
171+ # 0-d tensors
172+ assert logger ._coerce_to_scalar (torch .tensor (3.14 )) == pytest .approx (3.14 )
173+ assert logger ._coerce_to_scalar (torch .tensor (42 )) == 42
174+
175+ # 1-element tensors
176+ assert logger ._coerce_to_scalar (torch .tensor ([99 ])) == 99
177+
178+ # Multi-element tensors should return None
179+ assert logger ._coerce_to_scalar (torch .tensor ([1 , 2 , 3 ])) is None
180+
181+ @patch ("nemo_rl.utils.logger.SummaryWriter" )
182+ def test_coerce_to_scalar_incompatible_types (self , mock_summary_writer , temp_dir ):
183+ """Test that incompatible types return None."""
184+ cfg = {"log_dir" : temp_dir }
185+ logger = TensorboardLogger (cfg , log_dir = temp_dir )
186+
187+ assert logger ._coerce_to_scalar ({"key" : "value" }) is None
188+ assert logger ._coerce_to_scalar ([1 , 2 , 3 ]) is None
189+ assert logger ._coerce_to_scalar (None ) is None
190+ assert logger ._coerce_to_scalar (object ()) is None
191+
192+ @patch ("nemo_rl.utils.logger.SummaryWriter" )
193+ def test_log_metrics_coerces_numpy_and_torch (self , mock_summary_writer , temp_dir ):
194+ """Test that log_metrics correctly logs numpy/torch scalars."""
195+ import numpy as np
196+
197+ cfg = {"log_dir" : temp_dir }
198+ logger = TensorboardLogger (cfg , log_dir = temp_dir )
199+
200+ metrics = {
201+ "python_float" : 1.0 ,
202+ "numpy_float32" : np .float32 (2.0 ),
203+ "numpy_float64" : np .float64 (3.0 ),
204+ "torch_scalar" : torch .tensor (4.0 ),
205+ "numpy_0d" : np .array (5.0 ),
206+ "torch_1elem" : torch .tensor ([6.0 ]),
207+ "skip_list" : [1 , 2 , 3 ],
208+ "skip_dict" : {"a" : 1 },
209+ "skip_multi_tensor" : torch .tensor ([1.0 , 2.0 ]),
210+ }
211+ logger .log_metrics (metrics , step = 1 )
212+
213+ mock_writer = mock_summary_writer .return_value
214+ # Should log 6 scalars, skip 3 incompatible
215+ assert mock_writer .add_scalar .call_count == 6
216+
217+ # Verify each scalar was logged with correct value
218+ calls = {c [0 ][0 ]: c [0 ][1 ] for c in mock_writer .add_scalar .call_args_list }
219+ assert calls ["python_float" ] == 1.0
220+ assert calls ["numpy_float32" ] == pytest .approx (2.0 )
221+ assert calls ["numpy_float64" ] == pytest .approx (3.0 )
222+ assert calls ["torch_scalar" ] == pytest .approx (4.0 )
223+ assert calls ["numpy_0d" ] == pytest .approx (5.0 )
224+ assert calls ["torch_1elem" ] == pytest .approx (6.0 )
225+
131226
132227class TestWandbLogger :
133228 """Test the WandbLogger class."""
0 commit comments