1717from typing import Any , Dict , List
1818
1919import pytz
20- from monarch .actor import current_rank
2120
2221from forge .observability .utils import get_proc_name_with_rank
2322
2423from forge .util .logging import get_logger , log_once
24+ from monarch .actor import current_rank
2525
2626logger = get_logger ("INFO" )
2727
@@ -606,7 +606,7 @@ async def init_backends(
606606
607607 # instantiate local backend
608608 backend : LoggerBackend = get_logger_backend_class (backend_name )(
609- backend_config
609+ ** backend_config
610610 )
611611 await backend .init (
612612 role = BackendRole .LOCAL ,
@@ -760,10 +760,21 @@ async def shutdown(self):
760760
761761
762762class LoggerBackend (ABC ):
763- """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc."""
763+ """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.
764+
765+ Args:
766+ logging_mode: Logging behavior mode.
767+ per_rank_share_run: Whether ranks share run. Default False.
768+ **kwargs: Backend-specific arguments (e.g., project, name, tags for WandB).
769+ """
770+
771+ def __init__ (
772+ self , * , logging_mode : LoggingMode , per_rank_share_run : bool = False , ** kwargs
773+ ) -> None :
764774
765- def __init__ (self , logger_backend_config : dict [str , Any ]) -> None :
766- self .logger_backend_config = logger_backend_config
775+ self .logging_mode = logging_mode
776+ self .per_rank_share_run = per_rank_share_run
777+ self .backend_kwargs = kwargs
767778
768779 @abstractmethod
769780 async def init (
@@ -823,8 +834,13 @@ def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None:
823834class ConsoleBackend (LoggerBackend ):
824835 """Simple console logging of metrics."""
825836
826- def __init__ (self , logger_backend_config : dict [str , Any ]) -> None :
827- super ().__init__ (logger_backend_config )
837+ def __init__ (
838+ self , * , logging_mode : LoggingMode , per_rank_share_run : bool = False , ** kwargs
839+ ) -> None :
840+ super ().__init__ (
841+ logging_mode = logging_mode , per_rank_share_run = per_rank_share_run , ** kwargs
842+ )
843+ self .process_name = None
828844
829845 async def init (
830846 self ,
@@ -868,85 +884,101 @@ class WandbBackend(LoggerBackend):
868884
869885 For logging mode details, see `forge.observability.metrics.LoggingMode` documentation.
870886
871- More details on wandb distributed logging here : https://docs.wandb.ai/guides/track/log/distributed-training/
887+ More details on wandb distributed logging: https://docs.wandb.ai/guides/track/log/distributed-training/
872888
873889 Configuration:
874- logging_mode (LoggingMode): Determines logging behavior
890+ logging_mode (LoggingMode): Determines logging behavior.
875891 per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks.
876- If true, then a single wandb is created and all ranks log to it. Its particularly useful if
877- logging with no_reduce to capture a time based stream of information. Not recommended if reducing values.
878- project (str): WandB project name
879- group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group"
892+ If true, a single wandb run is created and all ranks log to it. Particularly useful for
893+ logging with no_reduce to capture time-based streams. Not recommended if reducing values.
894+ **kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, etc.)
895+
896+ Example:
897+ WandbBackend(
898+ logging_mode=LoggingMode.PER_RANK_REDUCE,
899+ per_rank_share_run=False,
900+ project="my_project",
901+ group="exp_group",
902+ name="my_experiment",
903+ tags=["rl", "v2"],
904+ notes="Testing new reward"
905+ )
880906 """
881907
882- def __init__ (self , logger_backend_config : dict [str , Any ]) -> None :
883- super ().__init__ (logger_backend_config )
884- self .project = logger_backend_config ["project" ]
885- self .group = logger_backend_config .get ("group" , "experiment_group" )
886- self .process_name = None
908+ def __init__ (
909+ self , * , logging_mode : LoggingMode , per_rank_share_run : bool = False , ** kwargs
910+ ) -> None :
911+ super ().__init__ (
912+ logging_mode = logging_mode , per_rank_share_run = per_rank_share_run , ** kwargs
913+ )
887914 self .run = None
888915 self .logging_mode = LoggingMode (logger_backend_config ["logging_mode" ])
889916 self .per_rank_share_run = logger_backend_config .get ("per_rank_share_run" , False )
890917 self ._tables : dict [str , "wandb.Table" ] = {}
918+ self .process_name = None
891919
892920 async def init (
893921 self ,
894922 role : BackendRole ,
895923 controller_logger_metadata : dict [str , Any ] | None = None ,
896924 process_name : str | None = None ,
897925 ) -> None :
898-
899926 if controller_logger_metadata is None :
900927 controller_logger_metadata = {}
901928
929+ # Pop name, if any, to concat to process_name.
930+ run_name = self .backend_kwargs .pop ("name" , None )
902931 self .process_name = process_name
903932
904- # GLOBAL_REDUCE mode: only inits on controller
933+ # Format run name based on mode and role
905934 if self .logging_mode == LoggingMode .GLOBAL_REDUCE :
906935 if role != BackendRole .GLOBAL :
907936 logger .warning (f"Skipped init for GLOBAL_REDUCE mode and { role } role." )
908937 return
909- await self ._init_global ()
938+ # use name as-is, no need to append controller process_name
939+ await self ._init_global (run_name )
910940
911- # Per-rank modes based on per_rank_share_run bool
912941 elif role == BackendRole .GLOBAL and self .per_rank_share_run :
913- await self ._init_shared_global ()
942+ # use name as-is, no need to append controller process_name
943+ await self ._init_shared_global (run_name )
914944
915945 elif role == BackendRole .LOCAL :
946+ # Per-rank: append process_name
947+ run_name = f"{ run_name } _{ process_name } " if run_name else process_name
948+
916949 if self .per_rank_share_run :
917- await self ._init_shared_local (controller_logger_metadata )
950+ shared_id = controller_logger_metadata .get ("shared_run_id" )
951+ if shared_id is None :
952+ raise ValueError (
953+ f"Shared ID required but not provided for { process_name } backend init"
954+ )
955+ await self ._init_shared_local (run_name , shared_id , process_name )
918956 else :
919- await self ._init_per_rank ()
957+ await self ._init_per_rank (run_name )
920958
921- async def _init_global (self ):
959+ async def _init_global (self , run_name : str | None ):
922960 import wandb
923961
924- self .run = wandb .init (project = self . project , group = self .group )
962+ self .run = wandb .init (name = run_name , ** self .backend_kwargs )
925963
926- async def _init_per_rank (self ):
964+ async def _init_per_rank (self , run_name : str ):
927965 import wandb
928966
929- self .run = wandb .init (
930- project = self .project , group = self .group , name = self .process_name
931- )
967+ self .run = wandb .init (name = run_name , ** self .backend_kwargs )
932968
933- async def _init_shared_global (self ):
969+ async def _init_shared_global (self , run_name : str | None ):
934970 import wandb
935971
936972 settings = wandb .Settings (
937973 mode = "shared" , x_primary = True , x_label = "controller_primary"
938974 )
939- self .run = wandb .init (project = self . project , group = self . group , settings = settings )
975+ self .run = wandb .init (name = run_name , settings = settings , ** self . backend_kwargs )
940976
941- async def _init_shared_local (self , controller_metadata : dict [str , Any ]):
977+ async def _init_shared_local (
978+ self , run_name : str , shared_id : str , process_name : str
979+ ):
942980 import wandb
943981
944- shared_id = controller_metadata .get ("shared_run_id" )
945- if shared_id is None :
946- raise ValueError (
947- f"Shared ID required but not provided for { self .process_name } backend init"
948- )
949-
950982 # Clear any stale service tokens that might be pointing to dead processes
951983 # In multiprocessing environments, WandB service tokens can become stale and point
952984 # to dead service processes. This causes wandb.init() to hang indefinitely trying
@@ -955,14 +987,9 @@ async def _init_shared_local(self, controller_metadata: dict[str, Any]):
955987
956988 service_token .clear_service_in_env ()
957989
958- settings = wandb .Settings (
959- mode = "shared" , x_primary = False , x_label = self .process_name
960- )
990+ settings = wandb .Settings (mode = "shared" , x_primary = False , x_label = process_name )
961991 self .run = wandb .init (
962- id = shared_id ,
963- project = self .project ,
964- group = self .group ,
965- settings = settings ,
992+ name = run_name , id = shared_id , settings = settings , ** self .backend_kwargs
966993 )
967994
968995 async def log_batch (
@@ -990,7 +1017,7 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
9901017 return
9911018
9921019 # Log with custom timestamp for precision
993- # Users can choose x-axis as timestamp in WandB UI and display as dateimte
1020+ # Users can choose x-axis as timestamp in WandB UI and display as datetime
9941021 log_data = {
9951022 metric .key : metric .value ,
9961023 "timestamp" : metric .timestamp ,
0 commit comments