Skip to content

Commit 012de9b

Browse files
authored
Use a static wandb run_id if none specified (#1175)
1 parent b91190f commit 012de9b

File tree

14 files changed

+100
-99
lines changed

14 files changed

+100
-99
lines changed

requirements-devel.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ types-tqdm~=4.67.0
1313
types-editdistance~=0.6
1414
types-psutil~=5.9
1515
types-setuptools~=75.8
16+
wandb~=0.19

src/fairseq2/metrics/recorders/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from fairseq2.metrics.recorders._recorder import (
3939
NoopMetricRecorder as NoopMetricRecorder,
4040
)
41-
from fairseq2.metrics.recorders._recorder import record_metrics as record_metrics
4241
from fairseq2.metrics.recorders._tensorboard import (
4342
TENSORBOARD_RECORDER as TENSORBOARD_RECORDER,
4443
)

src/fairseq2/metrics/recorders/_composite.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,11 @@ def __init__(self, recorders: Sequence[MetricRecorder]) -> None:
2121
self._inner_recorders = recorders
2222

2323
@override
24-
def record_metrics(
25-
self,
26-
section: str,
27-
values: Mapping[str, object],
28-
step_nr: int | None = None,
29-
*,
30-
flush: bool = True,
24+
def record_metric_values(
25+
self, section: str, values: Mapping[str, object], step_nr: int | None = None
3126
) -> None:
3227
for recorder in self._inner_recorders:
33-
recorder.record_metrics(section, values, step_nr, flush=flush)
28+
recorder.record_metric_values(section, values, step_nr)
3429

3530
@override
3631
def close(self) -> None:

src/fairseq2/metrics/recorders/_jsonl.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,8 @@ def __init__(
6161
self._streams = {}
6262

6363
@override
64-
def record_metrics(
65-
self,
66-
section: str,
67-
values: Mapping[str, object],
68-
step_nr: int | None = None,
69-
*,
70-
flush: bool = True,
64+
def record_metric_values(
65+
self, section: str, values: Mapping[str, object], step_nr: int | None = None
7166
) -> None:
7267
section = section.strip()
7368

@@ -123,8 +118,7 @@ def sanitize(value: object, descriptor: MetricDescriptor) -> object:
123118

124119
stream.write("\n")
125120

126-
if flush:
127-
stream.flush()
121+
stream.flush()
128122
except OSError as ex:
129123
raise MetricRecordError(
130124
f"The metric values of the '{section}' cannot be saved to the JSON file. See the nested exception for details."

src/fairseq2/metrics/recorders/_log.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,8 @@ def __init__(
4242
self._display_names = {"valid": "Validation", "eval": "Evaluation"}
4343

4444
@override
45-
def record_metrics(
46-
self,
47-
section: str,
48-
values: Mapping[str, object],
49-
step_nr: int | None = None,
50-
*,
51-
flush: bool = True,
45+
def record_metric_values(
46+
self, section: str, values: Mapping[str, object], step_nr: int | None = None
5247
) -> None:
5348
if not self._log.is_enabled_for_info():
5449
return

src/fairseq2/metrics/recorders/_recorder.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
from abc import abstractmethod
10-
from collections.abc import Mapping, Sequence
10+
from collections.abc import Mapping
1111
from typing import final
1212

1313
from typing_extensions import override
@@ -19,20 +19,14 @@ class MetricRecorder(Closable):
1919
"""Records metric values."""
2020

2121
@abstractmethod
22-
def record_metrics(
23-
self,
24-
section: str,
25-
values: Mapping[str, object],
26-
step_nr: int | None = None,
27-
*,
28-
flush: bool = True,
22+
def record_metric_values(
23+
self, section: str, values: Mapping[str, object], step_nr: int | None = None
2924
) -> None:
3025
"""Record ``values``.
3126
3227
:param section: The run section (e.g. 'train', 'eval').
3328
:param values: The metric values.
3429
:param step_nr: The step number of the run.
35-
:param flush: If ``True``, flushes any buffers after recording.
3630
"""
3731

3832

@@ -43,36 +37,11 @@ class MetricRecordError(Exception):
4337
@final
4438
class NoopMetricRecorder(MetricRecorder):
4539
@override
46-
def record_metrics(
47-
self,
48-
section: str,
49-
values: Mapping[str, object],
50-
step_nr: int | None = None,
51-
*,
52-
flush: bool = True,
40+
def record_metric_values(
41+
self, section: str, values: Mapping[str, object], step_nr: int | None = None
5342
) -> None:
5443
pass
5544

5645
@override
5746
def close(self) -> None:
5847
pass
59-
60-
61-
def record_metrics(
62-
recorders: Sequence[MetricRecorder],
63-
section: str,
64-
values: Mapping[str, object],
65-
step_nr: int | None = None,
66-
*,
67-
flush: bool = True,
68-
) -> None:
69-
"""Record ``values`` to ``recorders``.
70-
71-
:param recorders: The recorders to record to.
72-
:param section: The run section (e.g. 'train', 'eval').
73-
:param values: The metric values.
74-
:param step_nr: The step number of the run.
75-
:param flush: If ``True``, flushes any buffers after recording.
76-
"""
77-
for recorder in recorders:
78-
recorder.record_metrics(section, values, step_nr, flush=flush)

src/fairseq2/metrics/recorders/_tensorboard.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,8 @@ def __init__(
6060
self._writers = {}
6161

6262
@override
63-
def record_metrics(
64-
self,
65-
section: str,
66-
values: Mapping[str, object],
67-
step_nr: int | None = None,
68-
*,
69-
flush: bool = True,
63+
def record_metric_values(
64+
self, section: str, values: Mapping[str, object], step_nr: int | None = None
7065
) -> None:
7166
writer = self._get_writer(section)
7267
if writer is None:
@@ -86,8 +81,7 @@ def record_metrics(
8681

8782
writer.add_scalar(display_name, value, step_nr)
8883

89-
if flush:
90-
writer.flush()
84+
writer.flush()
9185
except RuntimeError as ex:
9286
raise MetricRecordError(
9387
f"The metric values of the '{section}' section cannot be saved to TensorBoard. See the nested exception for details."

src/fairseq2/metrics/recorders/_wandb.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
else:
2121
_has_wandb = True
2222

23+
from fairseq2.file_system import FileMode, FileSystem
2324
from fairseq2.logging import log
2425
from fairseq2.metrics import MetricDescriptor
2526
from fairseq2.registry import Provider
@@ -35,8 +36,6 @@
3536
NoopMetricRecorder,
3637
)
3738

38-
WandbResume: TypeAlias = Literal["allow", "never", "auto"]
39-
4039

4140
@final
4241
class WandbRecorder(MetricRecorder):
@@ -57,13 +56,8 @@ def __init__(
5756
self._metric_descriptors = metric_descriptors
5857

5958
@override
60-
def record_metrics(
61-
self,
62-
section: str,
63-
values: Mapping[str, object],
64-
step_nr: int | None = None,
65-
*,
66-
flush: bool = True,
59+
def record_metric_values(
60+
self, section: str, values: Mapping[str, object], step_nr: int | None = None
6761
) -> None:
6862
for name, value in values.items():
6963
try:
@@ -91,28 +85,37 @@ def close(self) -> None:
9185
WANDB_RECORDER: Final = "wandb"
9286

9387

88+
WandbResumeMode: TypeAlias = Literal["allow", "never", "must", "auto"]
89+
90+
9491
@dataclass(kw_only=True)
9592
class WandbRecorderConfig:
9693
enabled: bool = False
9794

95+
entity: str | None = None
96+
9897
project: str | None = None
9998

100-
run_id: str | None = None
99+
run_id: str | None = "auto"
101100

102101
run_name: str | None = None
103102

104103
group: str | None = None
105104

106105
job_type: str | None = None
107106

108-
resume: WandbResume = "allow"
107+
resume_mode: WandbResumeMode = "allow"
109108

110109

111110
@final
112111
class WandbRecorderHandler(MetricRecorderHandler):
112+
_file_system: FileSystem
113113
_metric_descriptors: Provider[MetricDescriptor]
114114

115-
def __init__(self, metric_descriptors: Provider[MetricDescriptor]) -> None:
115+
def __init__(
116+
self, file_system: FileSystem, metric_descriptors: Provider[MetricDescriptor]
117+
) -> None:
118+
self._file_system = file_system
116119
self._metric_descriptors = metric_descriptors
117120

118121
@override
@@ -131,28 +134,32 @@ def create(
131134

132135
return NoopMetricRecorder()
133136

134-
try:
135-
hyper_params = unstructure(hyper_params)
136-
except StructureError as ex:
137-
raise ValueError(
138-
"`hyper_params` cannot be unstructured. See the nested exception for details."
139-
) from ex
137+
if hyper_params is not None:
138+
try:
139+
hyper_params = unstructure(hyper_params)
140+
except StructureError as ex:
141+
raise ValueError(
142+
"`hyper_params` cannot be unstructured. See the nested exception for details."
143+
) from ex
140144

141-
if not isinstance(hyper_params, dict):
142-
raise TypeError(
143-
f"The unstructured form of `hyper_params` must be of type `dict`, but is of type `{type(hyper_params)}` instead."
144-
)
145+
if not isinstance(hyper_params, dict):
146+
raise TypeError(
147+
f"The unstructured form of `hyper_params` must be of type `dict`, but is of type `{type(hyper_params)}` instead."
148+
)
149+
150+
run_id = self._get_run_id(output_dir, config)
145151

146152
try:
147153
run = wandb.init(
154+
entity=config.entity,
148155
project=config.project,
149156
dir=output_dir,
150-
id=config.run_id,
157+
id=run_id,
151158
name=config.run_name,
152159
config=hyper_params,
153160
group=config.group,
154161
job_type=config.job_type,
155-
resume=config.resume,
162+
resume=config.resume_mode,
156163
)
157164
except (RuntimeError, ValueError) as ex:
158165
raise MetricRecordError(
@@ -161,6 +168,43 @@ def create(
161168

162169
return WandbRecorder(run, self._metric_descriptors)
163170

171+
def _get_run_id(self, output_dir: Path, config: WandbRecorderConfig) -> str:
172+
run_id = config.run_id
173+
174+
if run_id is None:
175+
return wandb.util.generate_id()
176+
177+
if run_id != "auto":
178+
return run_id
179+
180+
wandb_file = output_dir.joinpath("wandb_run_id")
181+
182+
try:
183+
fp = self._file_system.open_text(wandb_file)
184+
185+
with fp:
186+
return fp.read()
187+
except FileNotFoundError:
188+
pass
189+
except OSError as ex:
190+
raise MetricRecordError(
191+
"The Weights & Biases run ID cannot be loaded. See the nested exception for details."
192+
) from ex
193+
194+
run_id = wandb.util.generate_id()
195+
196+
try:
197+
fp = self._file_system.open_text(wandb_file, mode=FileMode.WRITE)
198+
199+
with fp:
200+
fp.write(run_id)
201+
except OSError as ex:
202+
raise MetricRecordError(
203+
"The Weights & Biases run ID cannot be saved. See the nested exception for details."
204+
) from ex
205+
206+
return run_id
207+
164208
@property
165209
@override
166210
def name(self) -> str:

src/fairseq2/recipes/_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def _publish_metrics(self, unit: EvalUnit[BatchT]) -> None:
352352
section = f"{section}/{unit.name}"
353353

354354
try:
355-
self._metric_recorder.record_metrics(section, values)
355+
self._metric_recorder.record_metric_values(section, values)
356356
except MetricRecordError as ex:
357357
s = "evaluation"
358358

src/fairseq2/recipes/_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def _publish_metrics(self) -> None:
284284
values["wall_time"] = self._wall_watch.get_elapsed_time()
285285

286286
try:
287-
self._metric_recorder.record_metrics("generation", values)
287+
self._metric_recorder.record_metric_values("generation", values)
288288
except MetricRecordError as ex:
289289
raise RecipeError(
290290
"The generation metric values cannot recorded. See the nested exception for details."

0 commit comments

Comments
 (0)