@@ -44,7 +44,8 @@ def test_custom_backends_override_defaults(self, mock_register):
4444 mock_backend_instance = mock .Mock (spec = metrax_logging .LoggingBackend )
4545 mock_factory = mock .Mock (return_value = mock_backend_instance )
4646 options = metrics_logger .MetricsLoggerOptions (
47- log_dir = self .log_dir , backend_factories = [mock_factory ]
47+ log_dir = self .log_dir ,
48+ backend_kwargs = {"custom_backend" : [mock_factory ]},
4849 )
4950
5051 logger = metrics_logger .MetricsLogger (metrics_logger_options = options )
@@ -115,7 +116,8 @@ def test_options_deepcopy_safety(self):
115116 self .mock_backends [0 ].assert_called_with (log_dir = new_log_dir )
116117 else :
117118 self .mock_backends [0 ].assert_called_with (
118- log_dir = new_log_dir , flush_every_n_steps = 100
119+ log_dir = new_log_dir ,
120+ flush_every_n_steps = 100 ,
119121 )
120122
121123 logger1 .close ()
@@ -124,7 +126,7 @@ def test_options_deepcopy_safety(self):
124126 @mock .patch .object (jax .monitoring , "record_scalar" )
125127 def test_log_metrics (self , mock_record_scalar ):
126128 options = metrics_logger .MetricsLoggerOptions (
127- log_dir = self .log_dir , backend_factories = []
129+ log_dir = self .log_dir , backend_kwargs = { "custom_backend" : []}
128130 )
129131 logger = metrics_logger .MetricsLogger (metrics_logger_options = options )
130132 logger .log ("test_prefix" , "loss" , 0.1 , metrics_logger .Mode .TRAIN , 1 )
@@ -145,7 +147,7 @@ def test_log_metrics(self, mock_record_scalar):
145147 @mock .patch .object (jax .monitoring , "record_scalar" )
146148 def test_log_perplexity (self , mock_record_scalar ):
147149 options = metrics_logger .MetricsLoggerOptions (
148- log_dir = self .log_dir , backend_factories = []
150+ log_dir = self .log_dir , backend_kwargs = { "custom_backend" : []}
149151 )
150152 logger = metrics_logger .MetricsLogger (metrics_logger_options = options )
151153 logger .log ("test_prefix" , "perplexity" , 10.0 , metrics_logger .Mode .EVAL , 1 )
@@ -189,6 +191,43 @@ def test_raises_when_clu_backend_missing_in_internal_env(
189191 ):
190192 options .create_backends ()
191193
194+ def test_backend_kwargs_are_passed_to_backends (self ):
195+ """Tests that backend_kwargs are passed to the backends initialization."""
196+ mock_wandb = mock .Mock ()
197+ options = metrics_logger .MetricsLoggerOptions (
198+ log_dir = self .log_dir ,
199+ backend_kwargs = {
200+ "wandb" : {
201+ "resume" : "must" ,
202+ "id" : "12345" ,
203+ "settings" : mock_wandb .Settings (console = "off" ),
204+ },
205+ "tensorboard" : {"flush_interval_s" : 10.0 },
206+ "clu" : {"some_arg" : "value" },
207+ },
208+ )
209+
210+ # Trigger backend creation
211+ _ = metrics_logger .MetricsLogger (metrics_logger_options = options )
212+
213+ if env_utils .is_internal_env ():
214+ self .mock_backends [0 ].assert_called_once_with (
215+ log_dir = self .log_dir , ** options .backend_kwargs ["clu" ]
216+ )
217+ else :
218+ self .mock_backends [0 ].assert_called_once_with (
219+ log_dir = self .log_dir ,
220+ flush_every_n_steps = 100 ,
221+ ** options .backend_kwargs ["tensorboard" ],
222+ )
223+ self .mock_backends [1 ].assert_called_once_with (
224+ project = "tunix" ,
225+ name = "" ,
226+ resume = "must" ,
227+ id = "12345" ,
228+ settings = mock_wandb .Settings .return_value ,
229+ )
230+
192231
193232if __name__ == "__main__" :
194233 absltest .main ()
0 commit comments