@@ -503,7 +503,7 @@ async def init_backends(
503503
504504 # instantiate local backend
505505 backend : LoggerBackend = get_logger_backend_class (backend_name )(
506- backend_config
506+ ** backend_config
507507 )
508508 await backend .init (
509509 role = BackendRole .LOCAL ,
@@ -643,10 +643,21 @@ async def shutdown(self):
643643
644644
645645class LoggerBackend (ABC ):
646- """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc."""
646+ """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.
647647
648- def __init__ (self , logger_backend_config : dict [str , Any ]) -> None :
649- self .logger_backend_config = logger_backend_config
648+ Args:
649+ logging_mode: Logging behavior mode.
650+ per_rank_share_run: Whether ranks share run. Default False.
651+ **kwargs: Backend-specific arguments (e.g., project, name, tags for WandB).
652+ """
653+
654+ def __init__ (
655+ self , * , logging_mode : LoggingMode , per_rank_share_run : bool = False , ** kwargs
656+ ) -> None :
657+
658+ self .logging_mode = logging_mode
659+ self .per_rank_share_run = per_rank_share_run
660+ self .backend_kwargs = kwargs
650661
651662 @abstractmethod
652663 async def init (
@@ -706,8 +717,13 @@ def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None:
706717class ConsoleBackend (LoggerBackend ):
707718 """Simple console logging of metrics."""
708719
709- def __init__ (self , logger_backend_config : dict [str , Any ]) -> None :
710- super ().__init__ (logger_backend_config )
720+ def __init__ (
721+ self , * , logging_mode : LoggingMode , per_rank_share_run : bool = False , ** kwargs
722+ ) -> None :
723+ super ().__init__ (
724+ logging_mode = logging_mode , per_rank_share_run = per_rank_share_run , ** kwargs
725+ )
726+ self .process_name = None
711727
712728 async def init (
713729 self ,
@@ -741,84 +757,98 @@ class WandbBackend(LoggerBackend):
741757
742758 For logging mode details, see `forge.observability.metrics.LoggingMode` documentation.
743759
744- More details on wandb distributed logging here : https://docs.wandb.ai/guides/track/log/distributed-training/
760+ More details on wandb distributed logging: https://docs.wandb.ai/guides/track/log/distributed-training/
745761
746762 Configuration:
747- logging_mode (LoggingMode): Determines logging behavior
763+ logging_mode (LoggingMode): Determines logging behavior.
748764 per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks.
749- If true, then a single wandb is created and all ranks log to it. Its particularly useful if
750- logging with no_reduce to capture a time based stream of information. Not recommended if reducing values.
751- project (str): WandB project name
752- group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group"
765+ If true, a single wandb run is created and all ranks log to it. Particularly useful for
766+ logging with no_reduce to capture time-based streams. Not recommended if reducing values.
767+ **kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, etc.)
768+
769+ Example:
770+ WandbBackend(
771+ logging_mode=LoggingMode.PER_RANK_REDUCE,
772+ per_rank_share_run=False,
773+ project="my_project",
774+ group="exp_group",
775+ name="my_experiment",
776+ tags=["rl", "v2"],
777+ notes="Testing new reward"
778+ )
753779 """
754780
755- def __init__ (self , logger_backend_config : dict [str , Any ]) -> None :
756- super ().__init__ (logger_backend_config )
757- self .project = logger_backend_config ["project" ]
758- self .group = logger_backend_config .get ("group" , "experiment_group" )
759- self .process_name = None
781+ def __init__ (
782+ self , * , logging_mode : LoggingMode , per_rank_share_run : bool = False , ** kwargs
783+ ) -> None :
784+ super ().__init__ (
785+ logging_mode = logging_mode , per_rank_share_run = per_rank_share_run , ** kwargs
786+ )
760787 self .run = None
761- self .logging_mode = LoggingMode (logger_backend_config ["logging_mode" ])
762- self .per_rank_share_run = logger_backend_config .get ("per_rank_share_run" , False )
788+ self .process_name = None
763789
764790 async def init (
765791 self ,
766792 role : BackendRole ,
767793 controller_logger_metadata : dict [str , Any ] | None = None ,
768794 process_name : str | None = None ,
769795 ) -> None :
770-
771796 if controller_logger_metadata is None :
772797 controller_logger_metadata = {}
773798
799+ # Pop name, if any, to concat to process_name.
800+ run_name = self .backend_kwargs .pop ("name" , None )
774801 self .process_name = process_name
775802
776- # GLOBAL_REDUCE mode: only inits on controller
803+ # Format run name based on mode and role
777804 if self .logging_mode == LoggingMode .GLOBAL_REDUCE :
778805 if role != BackendRole .GLOBAL :
779806 logger .warning (f"Skipped init for GLOBAL_REDUCE mode and { role } role." )
780807 return
781- await self ._init_global ()
808+ # use name as-is, no need to append controller process_name
809+ await self ._init_global (run_name )
782810
783- # Per-rank modes based on per_rank_share_run bool
784811 elif role == BackendRole .GLOBAL and self .per_rank_share_run :
785- await self ._init_shared_global ()
812+ # use name as-is, no need to append controller process_name
813+ await self ._init_shared_global (run_name )
786814
787815 elif role == BackendRole .LOCAL :
816+ # Per-rank: append process_name
817+ run_name = f"{ run_name } _{ process_name } " if run_name else process_name
818+
788819 if self .per_rank_share_run :
789- await self ._init_shared_local (controller_logger_metadata )
820+ shared_id = controller_logger_metadata .get ("shared_run_id" )
821+ if shared_id is None :
822+ raise ValueError (
823+ f"Shared ID required but not provided for { process_name } backend init"
824+ )
825+ await self ._init_shared_local (run_name , shared_id , process_name )
790826 else :
791- await self ._init_per_rank ()
827+ await self ._init_per_rank (run_name )
792828
793- async def _init_global (self ):
829+ async def _init_global (self , run_name : str | None ):
794830 import wandb
795831
796- self .run = wandb .init (project = self . project , group = self .group )
832+ self .run = wandb .init (name = run_name , ** self .backend_kwargs )
797833
798- async def _init_per_rank (self ):
834+ async def _init_per_rank (self , run_name : str ):
799835 import wandb
800836
801- self .run = wandb .init (
802- project = self .project , group = self .group , name = self .process_name
803- )
837+ self .run = wandb .init (name = run_name , ** self .backend_kwargs )
804838
805- async def _init_shared_global (self ):
839+ async def _init_shared_global (self , run_name : str | None ):
806840 import wandb
807841
808842 settings = wandb .Settings (
809843 mode = "shared" , x_primary = True , x_label = "controller_primary"
810844 )
811- self .run = wandb .init (project = self . project , group = self . group , settings = settings )
845+ self .run = wandb .init (name = run_name , settings = settings , ** self . backend_kwargs )
812846
813- async def _init_shared_local (self , controller_metadata : dict [str , Any ]):
847+ async def _init_shared_local (
848+ self , run_name : str , shared_id : str , process_name : str
849+ ):
814850 import wandb
815851
816- shared_id = controller_metadata .get ("shared_run_id" )
817- if shared_id is None :
818- raise ValueError (
819- f"Shared ID required but not provided for { self .process_name } backend init"
820- )
821-
822852 # Clear any stale service tokens that might be pointing to dead processes
823853 # In multiprocessing environments, WandB service tokens can become stale and point
824854 # to dead service processes. This causes wandb.init() to hang indefinitely trying
@@ -827,14 +857,9 @@ async def _init_shared_local(self, controller_metadata: dict[str, Any]):
827857
828858 service_token .clear_service_in_env ()
829859
830- settings = wandb .Settings (
831- mode = "shared" , x_primary = False , x_label = self .process_name
832- )
860+ settings = wandb .Settings (mode = "shared" , x_primary = False , x_label = process_name )
833861 self .run = wandb .init (
834- id = shared_id ,
835- project = self .project ,
836- group = self .group ,
837- settings = settings ,
862+ name = run_name , id = shared_id , settings = settings , ** self .backend_kwargs
838863 )
839864
840865 async def log_batch (
@@ -862,7 +887,7 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
862887 return
863888
864889 # Log with custom timestamp for precision
865- # Users can choose x-axis as timestamp in WandB UI and display as dateimte
890+ # Users can choose x-axis as timestamp in WandB UI and display as datetime
866891 log_data = {
867892 metric .key : metric .value ,
868893 "timestamp" : metric .timestamp ,
0 commit comments