Skip to content

Commit ad9eb87

Browse files
lc5211The tunix Authors
authored andcommitted
[Tunix] Add backend_kwargs to MetricsLoggerOptions.
PiperOrigin-RevId: 890163481
1 parent 5872542 commit ad9eb87

File tree

2 files changed

+80
-10
lines changed

2 files changed

+80
-10
lines changed

tests/sft/metrics_logger_test.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def test_custom_backends_override_defaults(self, mock_register):
4444
mock_backend_instance = mock.Mock(spec=metrax_logging.LoggingBackend)
4545
mock_factory = mock.Mock(return_value=mock_backend_instance)
4646
options = metrics_logger.MetricsLoggerOptions(
47-
log_dir=self.log_dir, backend_factories=[mock_factory]
47+
log_dir=self.log_dir,
48+
backend_kwargs={"custom_backend": [mock_factory]},
4849
)
4950

5051
logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
@@ -115,7 +116,8 @@ def test_options_deepcopy_safety(self):
115116
self.mock_backends[0].assert_called_with(log_dir=new_log_dir)
116117
else:
117118
self.mock_backends[0].assert_called_with(
118-
log_dir=new_log_dir, flush_every_n_steps=100
119+
log_dir=new_log_dir,
120+
flush_every_n_steps=100,
119121
)
120122

121123
logger1.close()
@@ -124,7 +126,7 @@ def test_options_deepcopy_safety(self):
124126
@mock.patch.object(jax.monitoring, "record_scalar")
125127
def test_log_metrics(self, mock_record_scalar):
126128
options = metrics_logger.MetricsLoggerOptions(
127-
log_dir=self.log_dir, backend_factories=[]
129+
log_dir=self.log_dir, backend_kwargs={"custom_backend": []}
128130
)
129131
logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
130132
logger.log("test_prefix", "loss", 0.1, metrics_logger.Mode.TRAIN, 1)
@@ -145,7 +147,7 @@ def test_log_metrics(self, mock_record_scalar):
145147
@mock.patch.object(jax.monitoring, "record_scalar")
146148
def test_log_perplexity(self, mock_record_scalar):
147149
options = metrics_logger.MetricsLoggerOptions(
148-
log_dir=self.log_dir, backend_factories=[]
150+
log_dir=self.log_dir, backend_kwargs={"custom_backend": []}
149151
)
150152
logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
151153
logger.log("test_prefix", "perplexity", 10.0, metrics_logger.Mode.EVAL, 1)
@@ -189,6 +191,43 @@ def test_raises_when_clu_backend_missing_in_internal_env(
189191
):
190192
options.create_backends()
191193

194+
def test_backend_kwargs_are_passed_to_backends(self):
195+
"""Tests that backend_kwargs are passed to the backends initialization."""
196+
mock_wandb = mock.Mock()
197+
options = metrics_logger.MetricsLoggerOptions(
198+
log_dir=self.log_dir,
199+
backend_kwargs={
200+
"wandb": {
201+
"resume": "must",
202+
"id": "12345",
203+
"settings": mock_wandb.Settings(console="off"),
204+
},
205+
"tensorboard": {"flush_interval_s": 10.0},
206+
"clu": {"some_arg": "value"},
207+
},
208+
)
209+
210+
# Trigger backend creation
211+
_ = metrics_logger.MetricsLogger(metrics_logger_options=options)
212+
213+
if env_utils.is_internal_env():
214+
self.mock_backends[0].assert_called_once_with(
215+
log_dir=self.log_dir, **options.backend_kwargs["clu"]
216+
)
217+
else:
218+
self.mock_backends[0].assert_called_once_with(
219+
log_dir=self.log_dir,
220+
flush_every_n_steps=100,
221+
**options.backend_kwargs["tensorboard"],
222+
)
223+
self.mock_backends[1].assert_called_once_with(
224+
project="tunix",
225+
name="",
226+
resume="must",
227+
id="12345",
228+
settings=mock_wandb.Settings.return_value,
229+
)
230+
192231

193232
if __name__ == "__main__":
194233
absltest.main()

tunix/sft/metrics_logger.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import collections
44
import dataclasses
55
import enum
6-
from typing import Callable
6+
from typing import Any, Callable
77

88
from absl import logging
99
import jax
@@ -28,7 +28,26 @@ class MetricsLoggerOptions:
2828
project_name: str = "tunix"
2929
run_name: str = ""
3030
flush_every_n_steps: int = 100
31-
backend_factories: list[BackendFactory] | None = None
31+
# Keyword arguments for backend initialization. The key is the backend name
32+
# (e.g., 'wandb', 'clu', 'tensorboard' or 'custom_backend' which uses custom
33+
# LoggingBackend factories) and the value is a dictionary of
34+
# keyword arguments to be passed to the backend's constructor.
35+
# For example:
36+
# backend_kwargs={
37+
# 'wandb': {
38+
# 'resume': 'must',
39+
# 'id': '12345',
40+
# 'project': 'my-project',
41+
# 'name': 'my-run',
42+
# },
43+
# 'tensorboard': {
44+
# 'log_dir': '/path/to/log',
45+
# 'flush_every_n_steps': 100,
46+
# }
47+
# }
48+
backend_kwargs: dict[str, dict[str, Any] | list[BackendFactory]] = (
49+
dataclasses.field(default_factory=dict)
50+
)
3251

3352
def create_backends(self) -> list[LoggingBackend]:
3453
"""Factory method to create a fresh set of live backends."""
@@ -37,28 +56,40 @@ def create_backends(self) -> list[LoggingBackend]:
3756
return []
3857

3958
# Case 1: Override. Use user-provided factories.
40-
if self.backend_factories is not None:
41-
return [factory() for factory in self.backend_factories]
59+
if (
60+
"custom_backend" in self.backend_kwargs
61+
and self.backend_kwargs["custom_backend"]
62+
):
63+
return [factory() for factory in self.backend_kwargs["custom_backend"]]
4264

4365
# Case 2: Defaults.
4466
active_backends = []
67+
kwargs_dict = self.backend_kwargs or {}
4568

4669
if env_utils.is_internal_env():
4770
if CluBackend is None:
4871
raise ImportError(
4972
"Internal environment detected, but CluBackend not available."
5073
)
51-
active_backends.append(CluBackend(log_dir=self.log_dir))
74+
clu_kwargs = kwargs_dict.get("clu", {})
75+
active_backends.append(CluBackend(log_dir=self.log_dir, **clu_kwargs))
5276
else:
77+
tb_kwargs = kwargs_dict.get("tensorboard", {})
5378
active_backends.append(
5479
TensorboardBackend(
5580
log_dir=self.log_dir,
5681
flush_every_n_steps=self.flush_every_n_steps,
82+
**tb_kwargs,
5783
)
5884
)
5985
try:
86+
wandb_kwargs = kwargs_dict.get("wandb", {})
6087
active_backends.append(
61-
WandbBackend(project=self.project_name, name=self.run_name)
88+
WandbBackend(
89+
project=self.project_name,
90+
name=self.run_name,
91+
**wandb_kwargs,
92+
)
6293
)
6394
except ImportError:
6495
logging.info("WandbBackend skipped: 'wandb' library not installed.")

0 commit comments

Comments
 (0)