Skip to content

Commit 284ad69

Browse files
author
The tunix Authors
committed
Code update
PiperOrigin-RevId: 890187291
1 parent b6e98ae commit 284ad69

File tree

2 files changed

+10
-80
lines changed

2 files changed

+10
-80
lines changed

tests/sft/metrics_logger_test.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def test_custom_backends_override_defaults(self, mock_register):
4444
mock_backend_instance = mock.Mock(spec=metrax_logging.LoggingBackend)
4545
mock_factory = mock.Mock(return_value=mock_backend_instance)
4646
options = metrics_logger.MetricsLoggerOptions(
47-
log_dir=self.log_dir,
48-
backend_kwargs={"custom_backend": [mock_factory]},
47+
log_dir=self.log_dir, backend_factories=[mock_factory]
4948
)
5049

5150
logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
@@ -116,8 +115,7 @@ def test_options_deepcopy_safety(self):
116115
self.mock_backends[0].assert_called_with(log_dir=new_log_dir)
117116
else:
118117
self.mock_backends[0].assert_called_with(
119-
log_dir=new_log_dir,
120-
flush_every_n_steps=100,
118+
log_dir=new_log_dir, flush_every_n_steps=100
121119
)
122120

123121
logger1.close()
@@ -126,7 +124,7 @@ def test_options_deepcopy_safety(self):
126124
@mock.patch.object(jax.monitoring, "record_scalar")
127125
def test_log_metrics(self, mock_record_scalar):
128126
options = metrics_logger.MetricsLoggerOptions(
129-
log_dir=self.log_dir, backend_kwargs={"custom_backend": []}
127+
log_dir=self.log_dir, backend_factories=[]
130128
)
131129
logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
132130
logger.log("test_prefix", "loss", 0.1, metrics_logger.Mode.TRAIN, 1)
@@ -147,7 +145,7 @@ def test_log_metrics(self, mock_record_scalar):
147145
@mock.patch.object(jax.monitoring, "record_scalar")
148146
def test_log_perplexity(self, mock_record_scalar):
149147
options = metrics_logger.MetricsLoggerOptions(
150-
log_dir=self.log_dir, backend_kwargs={"custom_backend": []}
148+
log_dir=self.log_dir, backend_factories=[]
151149
)
152150
logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
153151
logger.log("test_prefix", "perplexity", 10.0, metrics_logger.Mode.EVAL, 1)
@@ -191,43 +189,6 @@ def test_raises_when_clu_backend_missing_in_internal_env(
191189
):
192190
options.create_backends()
193191

194-
def test_backend_kwargs_are_passed_to_backends(self):
195-
"""Tests that backend_kwargs are passed to the backends initialization."""
196-
mock_wandb = mock.Mock()
197-
options = metrics_logger.MetricsLoggerOptions(
198-
log_dir=self.log_dir,
199-
backend_kwargs={
200-
"wandb": {
201-
"resume": "must",
202-
"id": "12345",
203-
"settings": mock_wandb.Settings(console="off"),
204-
},
205-
"tensorboard": {"flush_interval_s": 10.0},
206-
"clu": {"some_arg": "value"},
207-
},
208-
)
209-
210-
# Trigger backend creation
211-
_ = metrics_logger.MetricsLogger(metrics_logger_options=options)
212-
213-
if env_utils.is_internal_env():
214-
self.mock_backends[0].assert_called_once_with(
215-
log_dir=self.log_dir, **options.backend_kwargs["clu"]
216-
)
217-
else:
218-
self.mock_backends[0].assert_called_once_with(
219-
log_dir=self.log_dir,
220-
flush_every_n_steps=100,
221-
**options.backend_kwargs["tensorboard"],
222-
)
223-
self.mock_backends[1].assert_called_once_with(
224-
project="tunix",
225-
name="",
226-
resume="must",
227-
id="12345",
228-
settings=mock_wandb.Settings.return_value,
229-
)
230-
231192

232193
if __name__ == "__main__":
233194
absltest.main()

tunix/sft/metrics_logger.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import collections
44
import dataclasses
55
import enum
6-
from typing import Any, Callable
6+
from typing import Callable
77

88
from absl import logging
99
import jax
@@ -28,26 +28,7 @@ class MetricsLoggerOptions:
2828
project_name: str = "tunix"
2929
run_name: str = ""
3030
flush_every_n_steps: int = 100
31-
# Keyword arguments for backend initialization. The key is the backend name
32-
# (e.g., 'wandb', 'clu', 'tensorboard' or 'custom_backend' which uses custom
33-
# LoggingBackend factories) and the value is a dictionary of
34-
# keyword arguments to be passed to the backend's constructor.
35-
# For example:
36-
# backend_kwargs={
37-
# 'wandb': {
38-
# 'resume': 'must',
39-
# 'id': '12345',
40-
# 'project': 'my-project',
41-
# 'name': 'my-run',
42-
# },
43-
# 'tensorboard': {
44-
# 'log_dir': '/path/to/log',
45-
# 'flush_every_n_steps': 100,
46-
# }
47-
# }
48-
backend_kwargs: dict[str, dict[str, Any] | list[BackendFactory]] = (
49-
dataclasses.field(default_factory=dict)
50-
)
31+
backend_factories: list[BackendFactory] | None = None
5132

5233
def create_backends(self) -> list[LoggingBackend]:
5334
"""Factory method to create a fresh set of live backends."""
@@ -56,40 +37,28 @@ def create_backends(self) -> list[LoggingBackend]:
5637
return []
5738

5839
# Case 1: Override. Use user-provided factories.
59-
if (
60-
"custom_backend" in self.backend_kwargs
61-
and self.backend_kwargs["custom_backend"]
62-
):
63-
return [factory() for factory in self.backend_kwargs["custom_backend"]]
40+
if self.backend_factories is not None:
41+
return [factory() for factory in self.backend_factories]
6442

6543
# Case 2: Defaults.
6644
active_backends = []
67-
kwargs_dict = self.backend_kwargs or {}
6845

6946
if env_utils.is_internal_env():
7047
if CluBackend is None:
7148
raise ImportError(
7249
"Internal environment detected, but CluBackend not available."
7350
)
74-
clu_kwargs = kwargs_dict.get("clu", {})
75-
active_backends.append(CluBackend(log_dir=self.log_dir, **clu_kwargs))
51+
active_backends.append(CluBackend(log_dir=self.log_dir))
7652
else:
77-
tb_kwargs = kwargs_dict.get("tensorboard", {})
7853
active_backends.append(
7954
TensorboardBackend(
8055
log_dir=self.log_dir,
8156
flush_every_n_steps=self.flush_every_n_steps,
82-
**tb_kwargs,
8357
)
8458
)
8559
try:
86-
wandb_kwargs = kwargs_dict.get("wandb", {})
8760
active_backends.append(
88-
WandbBackend(
89-
project=self.project_name,
90-
name=self.run_name,
91-
**wandb_kwargs,
92-
)
61+
WandbBackend(project=self.project_name, name=self.run_name)
9362
)
9463
except ImportError:
9564
logging.info("WandbBackend skipped: 'wandb' library not installed.")

0 commit comments

Comments
 (0)