2020else :
2121 _has_wandb = True
2222
23+ from fairseq2 .file_system import FileMode , FileSystem
2324from fairseq2 .logging import log
2425from fairseq2 .metrics import MetricDescriptor
2526from fairseq2 .registry import Provider
3536 NoopMetricRecorder ,
3637)
3738
38- WandbResume : TypeAlias = Literal ["allow" , "never" , "auto" ]
39-
4039
4140@final
4241class WandbRecorder (MetricRecorder ):
@@ -57,13 +56,8 @@ def __init__(
5756 self ._metric_descriptors = metric_descriptors
5857
5958 @override
60- def record_metrics (
61- self ,
62- section : str ,
63- values : Mapping [str , object ],
64- step_nr : int | None = None ,
65- * ,
66- flush : bool = True ,
59+ def record_metric_values (
60+ self , section : str , values : Mapping [str , object ], step_nr : int | None = None
6761 ) -> None :
6862 for name , value in values .items ():
6963 try :
@@ -91,28 +85,37 @@ def close(self) -> None:
9185WANDB_RECORDER : Final = "wandb"
9286
9387
88+ WandbResumeMode : TypeAlias = Literal ["allow" , "never" , "must" , "auto" ]
89+
90+
9491@dataclass (kw_only = True )
9592class WandbRecorderConfig :
9693 enabled : bool = False
9794
95+ entity : str | None = None
96+
9897 project : str | None = None
9998
100- run_id : str | None = None
99+ run_id : str | None = "auto"
101100
102101 run_name : str | None = None
103102
104103 group : str | None = None
105104
106105 job_type : str | None = None
107106
108- resume : WandbResume = "allow"
107+ resume_mode : WandbResumeMode = "allow"
109108
110109
111110@final
112111class WandbRecorderHandler (MetricRecorderHandler ):
112+ _file_system : FileSystem
113113 _metric_descriptors : Provider [MetricDescriptor ]
114114
115- def __init__ (self , metric_descriptors : Provider [MetricDescriptor ]) -> None :
115+ def __init__ (
116+ self , file_system : FileSystem , metric_descriptors : Provider [MetricDescriptor ]
117+ ) -> None :
118+ self ._file_system = file_system
116119 self ._metric_descriptors = metric_descriptors
117120
118121 @override
@@ -131,28 +134,32 @@ def create(
131134
132135 return NoopMetricRecorder ()
133136
134- try :
135- hyper_params = unstructure (hyper_params )
136- except StructureError as ex :
137- raise ValueError (
138- "`hyper_params` cannot be unstructured. See the nested exception for details."
139- ) from ex
137+ if hyper_params is not None :
138+ try :
139+ hyper_params = unstructure (hyper_params )
140+ except StructureError as ex :
141+ raise ValueError (
142+ "`hyper_params` cannot be unstructured. See the nested exception for details."
143+ ) from ex
140144
141- if not isinstance (hyper_params , dict ):
142- raise TypeError (
143- f"The unstructured form of `hyper_params` must be of type `dict`, but is of type `{ type (hyper_params )} ` instead."
144- )
145+ if not isinstance (hyper_params , dict ):
146+ raise TypeError (
147+ f"The unstructured form of `hyper_params` must be of type `dict`, but is of type `{ type (hyper_params )} ` instead."
148+ )
149+
150+ run_id = self ._get_run_id (output_dir , config )
145151
146152 try :
147153 run = wandb .init (
154+ entity = config .entity ,
148155 project = config .project ,
149156 dir = output_dir ,
150- id = config . run_id ,
157+ id = run_id ,
151158 name = config .run_name ,
152159 config = hyper_params ,
153160 group = config .group ,
154161 job_type = config .job_type ,
155- resume = config .resume ,
162+ resume = config .resume_mode ,
156163 )
157164 except (RuntimeError , ValueError ) as ex :
158165 raise MetricRecordError (
@@ -161,6 +168,43 @@ def create(
161168
162169 return WandbRecorder (run , self ._metric_descriptors )
163170
171+ def _get_run_id (self , output_dir : Path , config : WandbRecorderConfig ) -> str :
172+ run_id = config .run_id
173+
174+ if run_id is None :
175+ return wandb .util .generate_id ()
176+
177+ if run_id != "auto" :
178+ return run_id
179+
180+ wandb_file = output_dir .joinpath ("wandb_run_id" )
181+
182+ try :
183+ fp = self ._file_system .open_text (wandb_file )
184+
185+ with fp :
186+ return fp .read ()
187+ except FileNotFoundError :
188+ pass
189+ except OSError as ex :
190+ raise MetricRecordError (
191+ "The Weights & Biases run ID cannot be loaded. See the nested exception for details."
192+ ) from ex
193+
194+ run_id = wandb .util .generate_id ()
195+
196+ try :
197+ fp = self ._file_system .open_text (wandb_file , mode = FileMode .WRITE )
198+
199+ with fp :
200+ fp .write (run_id )
201+ except OSError as ex :
202+ raise MetricRecordError (
203+ "The Weights & Biases run ID cannot be saved. See the nested exception for details."
204+ ) from ex
205+
206+ return run_id
207+
164208 @property
165209 @override
166210 def name (self ) -> str :
0 commit comments