@@ -550,9 +550,10 @@ def __init__(self) -> None:
550550 async def init_backends (
551551 self ,
552552 metadata_per_controller_backend : dict [str , dict [str , Any ]] | None ,
553- config : dict [str , Any ],
553+ backend_config : dict [str , Any ],
554554 global_step : int = 0 ,
555555 process_name : str | None = None ,
556+ run_config : dict [str , Any ] | None = None ,
556557 ) -> None :
557558 """Initialize per-rank logger backends and MetricCollector state.
558559
@@ -563,12 +564,15 @@ async def init_backends(
563564 metadata_per_controller_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from controller
564565 for backends that require shared state across processes, e.g.,
565566 {"wandb": {"shared_run_id": "abc123"}}.
566- config (Dict[str, Any]): Backend configurations where each key is a backend name
567+ backend_config (Dict[str, Any]): Backend configurations where each key is a backend name
567568 and value contains logging_mode and backend-specific settings.
568569 e.g., {"wandb": {"logging_mode": "per_rank_no_reduce", "project": "my_proj"}}
569570 global_step (int, default 0): Initial step for logging. Can be used when
570571 resuming from a checkpoint.
571572 process_name (str | None): The meaningful process name for logging.
573+ run_config (dict[str, Any] | None): Your application's configuration
574+ (hyperparameters, dataset, model settings) to log to backends for
575+ experiment tracking.
572576 """
573577 if self ._is_initialized :
574578 logger .debug (
@@ -583,8 +587,8 @@ async def init_backends(
583587 self .per_rank_no_reduce_backends : list [LoggerBackend ] = []
584588
585589 # Initialize backends based on logging mode
586- for backend_name , backend_config in config .items ():
587- mode = backend_config ["logging_mode" ]
590+ for backend_name , cfg in backend_config .items ():
591+ mode = cfg ["logging_mode" ]
588592
589593 # sanity check
590594 if not isinstance (mode , LoggingMode ):
@@ -605,13 +609,12 @@ async def init_backends(
605609 )
606610
607611 # instantiate local backend
608- backend : LoggerBackend = get_logger_backend_class (backend_name )(
609- ** backend_config
610- )
612+ backend : LoggerBackend = get_logger_backend_class (backend_name )(** cfg )
611613 await backend .init (
612614 role = BackendRole .LOCAL ,
613615 controller_logger_metadata = controller_metadata ,
614616 process_name = self .proc_name_with_rank ,
617+ run_config = run_config ,
615618 )
616619
617620 # Categorize by logging mode
@@ -781,6 +784,7 @@ async def init(
781784 role : BackendRole ,
782785 controller_logger_metadata : dict [str , Any ] | None = None ,
783786 process_name : str | None = None ,
787+ run_config : dict [str , Any ] | None = None ,
784788 ) -> None :
785789 """
786790 Initializes backend, e.g. wandb.run.init().
@@ -791,6 +795,9 @@ async def init(
791795 controller_logger_metadata (dict[str, Any] | None): From global backend for
792796 backend that required shared info, e.g. {"shared_run_id": "abc123"}.
793797 process_name (str | None): Process name for logging.
798+ run_config (dict[str, Any] | None): Your application's configuration
799+ (hyperparameters, dataset, model settings) to log to backend for
800+ experiment tracking.
794801
795802 Raises: ValueError if missing metadata for shared local init.
796803 """
@@ -856,6 +863,7 @@ async def init(
856863 role : BackendRole ,
857864 controller_logger_metadata : dict [str , Any ] | None = None ,
858865 process_name : str | None = None ,
866+ run_config : dict [str , Any ] | None = None ,
859867 ) -> None :
860868 self .process_name = process_name
861869
@@ -927,13 +935,15 @@ async def init(
927935 role : BackendRole ,
928936 controller_logger_metadata : dict [str , Any ] | None = None ,
929937 process_name : str | None = None ,
938+ run_config : dict [str , Any ] | None = None ,
930939 ) -> None :
931940 if controller_logger_metadata is None :
932941 controller_logger_metadata = {}
933942
934943 # Pop name, if any, to concat to process_name.
935944 run_name = self .backend_kwargs .pop ("name" , None )
936945 self .process_name = process_name
946+ self .run_config = run_config
937947
938948 # Format run name based on mode and role
939949 if self .logging_mode == LoggingMode .GLOBAL_REDUCE :
@@ -964,20 +974,29 @@ async def init(
964974 async def _init_global (self , run_name : str | None ):
965975 import wandb
966976
967- self .run = wandb .init (name = run_name , ** self .backend_kwargs )
977+ self .run = wandb .init (
978+ name = run_name , config = self .run_config , ** self .backend_kwargs
979+ )
968980
969981 async def _init_per_rank (self , run_name : str ):
970982 import wandb
971983
972- self .run = wandb .init (name = run_name , ** self .backend_kwargs )
984+ self .run = wandb .init (
985+ name = run_name , config = self .run_config , ** self .backend_kwargs
986+ )
973987
974988 async def _init_shared_global (self , run_name : str | None ):
975989 import wandb
976990
977991 settings = wandb .Settings (
978992 mode = "shared" , x_primary = True , x_label = "controller_primary"
979993 )
980- self .run = wandb .init (name = run_name , settings = settings , ** self .backend_kwargs )
994+ self .run = wandb .init (
995+ name = run_name ,
996+ config = self .run_config ,
997+ settings = settings ,
998+ ** self .backend_kwargs ,
999+ )
9811000
9821001 async def _init_shared_local (
9831002 self , run_name : str , shared_id : str , process_name : str
@@ -994,7 +1013,11 @@ async def _init_shared_local(
9941013
9951014 settings = wandb .Settings (mode = "shared" , x_primary = False , x_label = process_name )
9961015 self .run = wandb .init (
997- name = run_name , id = shared_id , settings = settings , ** self .backend_kwargs
1016+ name = run_name ,
1017+ id = shared_id ,
1018+ config = self .run_config ,
1019+ settings = settings ,
1020+ ** self .backend_kwargs ,
9981021 )
9991022
10001023 async def log_batch (
0 commit comments