@@ -44,8 +44,7 @@ def test_custom_backends_override_defaults(self, mock_register):
4444 mock_backend_instance = mock .Mock (spec = metrax_logging .LoggingBackend )
4545 mock_factory = mock .Mock (return_value = mock_backend_instance )
4646 options = metrics_logger .MetricsLoggerOptions (
47- log_dir = self .log_dir ,
48- backend_kwargs = {"custom_backend" : [mock_factory ]},
47+ log_dir = self .log_dir , backend_factories = [mock_factory ]
4948 )
5049
5150 logger = metrics_logger .MetricsLogger (metrics_logger_options = options )
@@ -116,8 +115,7 @@ def test_options_deepcopy_safety(self):
116115 self .mock_backends [0 ].assert_called_with (log_dir = new_log_dir )
117116 else :
118117 self .mock_backends [0 ].assert_called_with (
119- log_dir = new_log_dir ,
120- flush_every_n_steps = 100 ,
118+ log_dir = new_log_dir , flush_every_n_steps = 100
121119 )
122120
123121 logger1 .close ()
@@ -126,7 +124,7 @@ def test_options_deepcopy_safety(self):
126124 @mock .patch .object (jax .monitoring , "record_scalar" )
127125 def test_log_metrics (self , mock_record_scalar ):
128126 options = metrics_logger .MetricsLoggerOptions (
129- log_dir = self .log_dir , backend_kwargs = { "custom_backend" : []}
127+ log_dir = self .log_dir , backend_factories = []
130128 )
131129 logger = metrics_logger .MetricsLogger (metrics_logger_options = options )
132130 logger .log ("test_prefix" , "loss" , 0.1 , metrics_logger .Mode .TRAIN , 1 )
@@ -147,7 +145,7 @@ def test_log_metrics(self, mock_record_scalar):
147145 @mock .patch .object (jax .monitoring , "record_scalar" )
148146 def test_log_perplexity (self , mock_record_scalar ):
149147 options = metrics_logger .MetricsLoggerOptions (
150- log_dir = self .log_dir , backend_kwargs = { "custom_backend" : []}
148+ log_dir = self .log_dir , backend_factories = []
151149 )
152150 logger = metrics_logger .MetricsLogger (metrics_logger_options = options )
153151 logger .log ("test_prefix" , "perplexity" , 10.0 , metrics_logger .Mode .EVAL , 1 )
@@ -191,43 +189,6 @@ def test_raises_when_clu_backend_missing_in_internal_env(
191189 ):
192190 options .create_backends ()
193191
194- def test_backend_kwargs_are_passed_to_backends (self ):
195- """Tests that backend_kwargs are passed to the backends initialization."""
196- mock_wandb = mock .Mock ()
197- options = metrics_logger .MetricsLoggerOptions (
198- log_dir = self .log_dir ,
199- backend_kwargs = {
200- "wandb" : {
201- "resume" : "must" ,
202- "id" : "12345" ,
203- "settings" : mock_wandb .Settings (console = "off" ),
204- },
205- "tensorboard" : {"flush_interval_s" : 10.0 },
206- "clu" : {"some_arg" : "value" },
207- },
208- )
209-
210- # Trigger backend creation
211- _ = metrics_logger .MetricsLogger (metrics_logger_options = options )
212-
213- if env_utils .is_internal_env ():
214- self .mock_backends [0 ].assert_called_once_with (
215- log_dir = self .log_dir , ** options .backend_kwargs ["clu" ]
216- )
217- else :
218- self .mock_backends [0 ].assert_called_once_with (
219- log_dir = self .log_dir ,
220- flush_every_n_steps = 100 ,
221- ** options .backend_kwargs ["tensorboard" ],
222- )
223- self .mock_backends [1 ].assert_called_once_with (
224- project = "tunix" ,
225- name = "" ,
226- resume = "must" ,
227- id = "12345" ,
228- settings = mock_wandb .Settings .return_value ,
229- )
230-
231192
232193if __name__ == "__main__" :
233194 absltest .main ()
0 commit comments