Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions tests/sft/metrics_logger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def test_custom_backends_override_defaults(self, mock_register):
mock_backend_instance = mock.Mock(spec=metrax_logging.LoggingBackend)
mock_factory = mock.Mock(return_value=mock_backend_instance)
options = metrics_logger.MetricsLoggerOptions(
log_dir=self.log_dir, backend_factories=[mock_factory]
log_dir=self.log_dir,
backend_kwargs={"custom_backend": [mock_factory]},
)

logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
Expand Down Expand Up @@ -115,7 +116,8 @@ def test_options_deepcopy_safety(self):
self.mock_backends[0].assert_called_with(log_dir=new_log_dir)
else:
self.mock_backends[0].assert_called_with(
log_dir=new_log_dir, flush_every_n_steps=100
log_dir=new_log_dir,
flush_every_n_steps=100,
)

logger1.close()
Expand All @@ -124,7 +126,7 @@ def test_options_deepcopy_safety(self):
@mock.patch.object(jax.monitoring, "record_scalar")
def test_log_metrics(self, mock_record_scalar):
options = metrics_logger.MetricsLoggerOptions(
log_dir=self.log_dir, backend_factories=[]
log_dir=self.log_dir, backend_kwargs={"custom_backend": []}
)
logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
logger.log("test_prefix", "loss", 0.1, metrics_logger.Mode.TRAIN, 1)
Expand All @@ -145,7 +147,7 @@ def test_log_metrics(self, mock_record_scalar):
@mock.patch.object(jax.monitoring, "record_scalar")
def test_log_perplexity(self, mock_record_scalar):
options = metrics_logger.MetricsLoggerOptions(
log_dir=self.log_dir, backend_factories=[]
log_dir=self.log_dir, backend_kwargs={"custom_backend": []}
)
logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
logger.log("test_prefix", "perplexity", 10.0, metrics_logger.Mode.EVAL, 1)
Expand Down Expand Up @@ -189,6 +191,43 @@ def test_raises_when_clu_backend_missing_in_internal_env(
):
options.create_backends()

def test_backend_kwargs_are_passed_to_backends(self):
"""Tests that backend_kwargs are passed to the backends initialization."""
mock_wandb = mock.Mock()
options = metrics_logger.MetricsLoggerOptions(
log_dir=self.log_dir,
backend_kwargs={
"wandb": {
"resume": "must",
"id": "12345",
"settings": mock_wandb.Settings(console="off"),
},
"tensorboard": {"flush_interval_s": 10.0},
"clu": {"some_arg": "value"},
},
)

# Trigger backend creation
_ = metrics_logger.MetricsLogger(metrics_logger_options=options)

if env_utils.is_internal_env():
self.mock_backends[0].assert_called_once_with(
log_dir=self.log_dir, **options.backend_kwargs["clu"]
)
else:
self.mock_backends[0].assert_called_once_with(
log_dir=self.log_dir,
flush_every_n_steps=100,
**options.backend_kwargs["tensorboard"],
)
self.mock_backends[1].assert_called_once_with(
project="tunix",
name="",
resume="must",
id="12345",
settings=mock_wandb.Settings.return_value,
)


if __name__ == "__main__":
absltest.main()
43 changes: 37 additions & 6 deletions tunix/sft/metrics_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import collections
import dataclasses
import enum
from typing import Callable
from typing import Any, Callable

from absl import logging
import jax
Expand All @@ -28,7 +28,26 @@ class MetricsLoggerOptions:
project_name: str = "tunix"
run_name: str = ""
flush_every_n_steps: int = 100
backend_factories: list[BackendFactory] | None = None
# Keyword arguments for backend initialization. The key is the backend name
# (e.g., 'wandb', 'clu', 'tensorboard' or 'custom_backend' which uses custom
# LoggingBackend factories) and the value is a dictionary of
# keyword arguments to be passed to the backend's constructor.
# For example:
# backend_kwargs={
# 'wandb': {
# 'resume': 'must',
# 'id': '12345',
# 'project': 'my-project',
# 'name': 'my-run',
# },
# 'tensorboard': {
# 'log_dir': '/path/to/log',
# 'flush_every_n_steps': 100,
# }
# }
backend_kwargs: dict[str, dict[str, Any] | list[BackendFactory]] = (
dataclasses.field(default_factory=dict)
)

def create_backends(self) -> list[LoggingBackend]:
"""Factory method to create a fresh set of live backends."""
Expand All @@ -37,28 +56,40 @@ def create_backends(self) -> list[LoggingBackend]:
return []

# Case 1: Override. Use user-provided factories.
if self.backend_factories is not None:
return [factory() for factory in self.backend_factories]
if (
"custom_backend" in self.backend_kwargs
and self.backend_kwargs["custom_backend"]
):
return [factory() for factory in self.backend_kwargs["custom_backend"]]

# Case 2: Defaults.
active_backends = []
kwargs_dict = self.backend_kwargs or {}

if env_utils.is_internal_env():
if CluBackend is None:
raise ImportError(
"Internal environment detected, but CluBackend not available."
)
active_backends.append(CluBackend(log_dir=self.log_dir))
clu_kwargs = kwargs_dict.get("clu", {})
active_backends.append(CluBackend(log_dir=self.log_dir, **clu_kwargs))
else:
tb_kwargs = kwargs_dict.get("tensorboard", {})
active_backends.append(
TensorboardBackend(
log_dir=self.log_dir,
flush_every_n_steps=self.flush_every_n_steps,
**tb_kwargs,
)
)
try:
wandb_kwargs = kwargs_dict.get("wandb", {})
active_backends.append(
WandbBackend(project=self.project_name, name=self.run_name)
WandbBackend(
project=self.project_name,
name=self.run_name,
**wandb_kwargs,
)
)
except ImportError:
logging.info("WandbBackend skipped: 'wandb' library not installed.")
Expand Down
Loading