Skip to content

Commit 928f03b

Browse files
author
Felipe Mello
committed
add kwargs
1 parent ff3290e commit 928f03b

File tree

2 files changed

+85
-52
lines changed

2 files changed

+85
-52
lines changed

src/forge/observability/metric_actors.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,16 @@ def _validate_backend_config(
224224
f"{', '.join([mode.value for mode in LoggingMode])}."
225225
)
226226

227-
mode_str = config["logging_mode"]
228-
mode = LoggingMode(mode_str)
227+
# Convert string to LoggingMode enum
228+
mode_value = config["logging_mode"]
229+
if isinstance(mode_value, str):
230+
mode = LoggingMode(mode_value)
231+
elif isinstance(mode_value, LoggingMode):
232+
mode = mode_value
233+
else:
234+
raise TypeError(
235+
f"logging_mode must be str or LoggingMode enum, got {type(mode_value)}"
236+
)
229237

230238
# Validate per_rank_share_run configuration
231239
share_run = config.get("per_rank_share_run", False)
@@ -302,7 +310,7 @@ async def init_backends(self, config: dict[str, Any]) -> None:
302310
mode = backend_config["logging_mode"]
303311

304312
backend: LoggerBackend = get_logger_backend_class(backend_name)(
305-
backend_config
313+
**backend_config
306314
)
307315
await backend.init(role=BackendRole.GLOBAL, process_name="global_reduce")
308316

src/forge/observability/metrics.py

Lines changed: 74 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ async def init_backends(
503503

504504
# instantiate local backend
505505
backend: LoggerBackend = get_logger_backend_class(backend_name)(
506-
backend_config
506+
**backend_config
507507
)
508508
await backend.init(
509509
role=BackendRole.LOCAL,
@@ -643,10 +643,21 @@ async def shutdown(self):
643643

644644

645645
class LoggerBackend(ABC):
646-
"""Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc."""
646+
"""Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.
647647
648-
def __init__(self, logger_backend_config: dict[str, Any]) -> None:
649-
self.logger_backend_config = logger_backend_config
648+
Args:
649+
logging_mode: Logging behavior mode.
650+
per_rank_share_run: Whether ranks share run. Default False.
651+
**kwargs: Backend-specific arguments (e.g., project, name, tags for WandB).
652+
"""
653+
654+
def __init__(
655+
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
656+
) -> None:
657+
658+
self.logging_mode = logging_mode
659+
self.per_rank_share_run = per_rank_share_run
660+
self.backend_kwargs = kwargs
650661

651662
@abstractmethod
652663
async def init(
@@ -706,8 +717,13 @@ def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None:
706717
class ConsoleBackend(LoggerBackend):
707718
"""Simple console logging of metrics."""
708719

709-
def __init__(self, logger_backend_config: dict[str, Any]) -> None:
710-
super().__init__(logger_backend_config)
720+
def __init__(
721+
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
722+
) -> None:
723+
super().__init__(
724+
logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs
725+
)
726+
self.process_name = None
711727

712728
async def init(
713729
self,
@@ -741,84 +757,98 @@ class WandbBackend(LoggerBackend):
741757
742758
For logging mode details, see `forge.observability.metrics.LoggingMode` documentation.
743759
744-
More details on wandb distributed logging here: https://docs.wandb.ai/guides/track/log/distributed-training/
760+
More details on wandb distributed logging: https://docs.wandb.ai/guides/track/log/distributed-training/
745761
746762
Configuration:
747-
logging_mode (LoggingMode): Determines logging behavior
763+
logging_mode (LoggingMode): Determines logging behavior.
748764
per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks.
749-
If true, then a single wandb is created and all ranks log to it. Its particularly useful if
750-
logging with no_reduce to capture a time based stream of information. Not recommended if reducing values.
751-
project (str): WandB project name
752-
group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group"
765+
If true, a single wandb run is created and all ranks log to it. Particularly useful for
766+
logging with no_reduce to capture time-based streams. Not recommended if reducing values.
767+
**kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, etc.)
768+
769+
Example:
770+
WandbBackend(
771+
logging_mode=LoggingMode.PER_RANK_REDUCE,
772+
per_rank_share_run=False,
773+
project="my_project",
774+
group="exp_group",
775+
name="my_experiment",
776+
tags=["rl", "v2"],
777+
notes="Testing new reward"
778+
)
753779
"""
754780

755-
def __init__(self, logger_backend_config: dict[str, Any]) -> None:
756-
super().__init__(logger_backend_config)
757-
self.project = logger_backend_config["project"]
758-
self.group = logger_backend_config.get("group", "experiment_group")
759-
self.process_name = None
781+
def __init__(
782+
self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
783+
) -> None:
784+
super().__init__(
785+
logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs
786+
)
760787
self.run = None
761-
self.logging_mode = LoggingMode(logger_backend_config["logging_mode"])
762-
self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False)
788+
self.process_name = None
763789

764790
async def init(
765791
self,
766792
role: BackendRole,
767793
controller_logger_metadata: dict[str, Any] | None = None,
768794
process_name: str | None = None,
769795
) -> None:
770-
771796
if controller_logger_metadata is None:
772797
controller_logger_metadata = {}
773798

799+
# Pop name, if any, to concat to process_name.
800+
run_name = self.backend_kwargs.pop("name", None)
774801
self.process_name = process_name
775802

776-
# GLOBAL_REDUCE mode: only inits on controller
803+
# Format run name based on mode and role
777804
if self.logging_mode == LoggingMode.GLOBAL_REDUCE:
778805
if role != BackendRole.GLOBAL:
779806
logger.warning(f"Skipped init for GLOBAL_REDUCE mode and {role} role.")
780807
return
781-
await self._init_global()
808+
# use name as-is, no need to append controller process_name
809+
await self._init_global(run_name)
782810

783-
# Per-rank modes based on per_rank_share_run bool
784811
elif role == BackendRole.GLOBAL and self.per_rank_share_run:
785-
await self._init_shared_global()
812+
# use name as-is, no need to append controller process_name
813+
await self._init_shared_global(run_name)
786814

787815
elif role == BackendRole.LOCAL:
816+
# Per-rank: append process_name
817+
run_name = f"{run_name}_{process_name}" if run_name else process_name
818+
788819
if self.per_rank_share_run:
789-
await self._init_shared_local(controller_logger_metadata)
820+
shared_id = controller_logger_metadata.get("shared_run_id")
821+
if shared_id is None:
822+
raise ValueError(
823+
f"Shared ID required but not provided for {process_name} backend init"
824+
)
825+
await self._init_shared_local(run_name, shared_id, process_name)
790826
else:
791-
await self._init_per_rank()
827+
await self._init_per_rank(run_name)
792828

793-
async def _init_global(self):
829+
async def _init_global(self, run_name: str | None):
794830
import wandb
795831

796-
self.run = wandb.init(project=self.project, group=self.group)
832+
self.run = wandb.init(name=run_name, **self.backend_kwargs)
797833

798-
async def _init_per_rank(self):
834+
async def _init_per_rank(self, run_name: str):
799835
import wandb
800836

801-
self.run = wandb.init(
802-
project=self.project, group=self.group, name=self.process_name
803-
)
837+
self.run = wandb.init(name=run_name, **self.backend_kwargs)
804838

805-
async def _init_shared_global(self):
839+
async def _init_shared_global(self, run_name: str | None):
806840
import wandb
807841

808842
settings = wandb.Settings(
809843
mode="shared", x_primary=True, x_label="controller_primary"
810844
)
811-
self.run = wandb.init(project=self.project, group=self.group, settings=settings)
845+
self.run = wandb.init(name=run_name, settings=settings, **self.backend_kwargs)
812846

813-
async def _init_shared_local(self, controller_metadata: dict[str, Any]):
847+
async def _init_shared_local(
848+
self, run_name: str, shared_id: str, process_name: str
849+
):
814850
import wandb
815851

816-
shared_id = controller_metadata.get("shared_run_id")
817-
if shared_id is None:
818-
raise ValueError(
819-
f"Shared ID required but not provided for {self.process_name} backend init"
820-
)
821-
822852
# Clear any stale service tokens that might be pointing to dead processes
823853
# In multiprocessing environments, WandB service tokens can become stale and point
824854
# to dead service processes. This causes wandb.init() to hang indefinitely trying
@@ -827,14 +857,9 @@ async def _init_shared_local(self, controller_metadata: dict[str, Any]):
827857

828858
service_token.clear_service_in_env()
829859

830-
settings = wandb.Settings(
831-
mode="shared", x_primary=False, x_label=self.process_name
832-
)
860+
settings = wandb.Settings(mode="shared", x_primary=False, x_label=process_name)
833861
self.run = wandb.init(
834-
id=shared_id,
835-
project=self.project,
836-
group=self.group,
837-
settings=settings,
862+
name=run_name, id=shared_id, settings=settings, **self.backend_kwargs
838863
)
839864

840865
async def log_batch(
@@ -862,7 +887,7 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
862887
return
863888

864889
# Log with custom timestamp for precision
865-
# Users can choose x-axis as timestamp in WandB UI and display as dateimte
890+
# Users can choose x-axis as timestamp in WandB UI and display as datetime
866891
log_data = {
867892
metric.key: metric.value,
868893
"timestamp": metric.timestamp,

0 commit comments

Comments
 (0)