Skip to content

Commit 0d6cece

Browse files
committed
Improve wandb support
1 parent ed104db commit 0d6cece

File tree

25 files changed

+144
-108
lines changed

25 files changed

+144
-108
lines changed

src/fairseq2/metrics/recorders/_composite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ def __init__(self, recorders: Sequence[MetricRecorder]) -> None:
2323
@override
2424
def record_metrics(
2525
self,
26-
run: str,
26+
section: str,
2727
values: Mapping[str, object],
2828
step_nr: int | None = None,
2929
*,
3030
flush: bool = True,
3131
) -> None:
3232
for recorder in self._inner_recorders:
33-
recorder.record_metrics(run, values, step_nr, flush=flush)
33+
recorder.record_metrics(section, values, step_nr, flush=flush)
3434

3535
@override
3636
def close(self) -> None:

src/fairseq2/metrics/recorders/_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
class MetricRecorderHandler(ABC):
1616
@abstractmethod
17-
def create(self, output_dir: Path, config: object) -> MetricRecorder: ...
17+
def create(
18+
self, output_dir: Path, config: object, hyper_params: object
19+
) -> MetricRecorder: ...
1820

1921
@property
2022
@abstractmethod

src/fairseq2/metrics/recorders/_jsonl.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
class JsonlMetricRecorder(MetricRecorder):
3838
"""Records metric values to JSONL files."""
3939

40-
_RUN_PART_REGEX: Final = re.compile("^[-_a-zA-Z0-9]+$")
40+
_SECTION_PART_REGEX: Final = re.compile("^[-_a-zA-Z0-9]+$")
4141

4242
_output_dir: Path
4343
_file_system: FileSystem
@@ -63,21 +63,21 @@ def __init__(
6363
@override
6464
def record_metrics(
6565
self,
66-
run: str,
66+
section: str,
6767
values: Mapping[str, object],
6868
step_nr: int | None = None,
6969
*,
7070
flush: bool = True,
7171
) -> None:
72-
run = run.strip()
72+
section = section.strip()
7373

74-
for part in run.split("/"):
75-
if re.match(self._RUN_PART_REGEX, part) is None:
74+
for part in section.split("/"):
75+
if re.match(self._SECTION_PART_REGEX, part) is None:
7676
raise ValueError(
77-
f"`run` must contain only alphanumeric characters, dash, underscore, and forward slash, but is '{run}' instead."
77+
f"`section` must contain only alphanumeric characters, dash, underscore, and forward slash, but is '{section}' instead."
7878
)
7979

80-
stream = self._get_stream(run)
80+
stream = self._get_stream(section)
8181

8282
values_and_descriptors = []
8383

@@ -127,16 +127,16 @@ def sanitize(value: object, descriptor: MetricDescriptor) -> object:
127127
stream.flush()
128128
except OSError as ex:
129129
raise MetricRecordError(
130-
f"The metric values of the '{run}' cannot be saved to the JSON file. See the nested exception for details."
130+
f"The metric values of the '{section}' cannot be saved to the JSON file. See the nested exception for details."
131131
) from ex
132132

133-
def _get_stream(self, run: str) -> TextIO:
133+
def _get_stream(self, section: str) -> TextIO:
134134
try:
135-
return self._streams[run]
135+
return self._streams[section]
136136
except KeyError:
137137
pass
138138

139-
file = self._output_dir.joinpath(run).with_suffix(".jsonl")
139+
file = self._output_dir.joinpath(section).with_suffix(".jsonl")
140140

141141
try:
142142
self._file_system.make_directory(file.parent)
@@ -149,10 +149,10 @@ def _get_stream(self, run: str) -> TextIO:
149149
fp = self._file_system.open_text(file, mode=FileMode.APPEND)
150150
except OSError as ex:
151151
raise MetricRecordError(
152-
f"The '{file}' metric file for the '{run} run cannot be created. See the nested exception for details."
152+
f"The '{file}' metric file for the '{section} section cannot be created. See the nested exception for details."
153153
) from ex
154154

155-
self._streams[run] = fp
155+
self._streams[section] = fp
156156

157157
return fp
158158

@@ -184,7 +184,9 @@ def __init__(
184184
self._metric_descriptors = metric_descriptors
185185

186186
@override
187-
def create(self, output_dir: Path, config: object) -> MetricRecorder:
187+
def create(
188+
self, output_dir: Path, config: object, hyper_params: object
189+
) -> MetricRecorder:
188190
config = structure(config, JsonlMetricRecorderConfig)
189191

190192
validate(config)

src/fairseq2/metrics/recorders/_log.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
@override
4545
def record_metrics(
4646
self,
47-
run: str,
47+
section: str,
4848
values: Mapping[str, object],
4949
step_nr: int | None = None,
5050
*,
@@ -85,19 +85,19 @@ def record_metrics(
8585
if not s:
8686
s = "N/A"
8787

88-
run_parts = run.split("/")
88+
section_parts = section.split("/")
8989

90-
phase = self._display_names.get(run_parts[0])
91-
if phase is None:
92-
phase = run_parts[0].capitalize()
90+
title = self._display_names.get(section_parts[0])
91+
if title is None:
92+
title = section_parts[0].capitalize()
9393

9494
if step_nr is None:
95-
m = f"{phase} Metrics"
95+
m = f"{title} Metrics"
9696
else:
97-
m = f"{phase} Metrics (step {step_nr})"
97+
m = f"{title} Metrics (step {step_nr})"
9898

99-
if len(run_parts) > 1:
100-
m = f"{m} - {'/'.join(run_parts[1:])}"
99+
if len(section_parts) > 1:
100+
m = f"{m} - {'/'.join(section_parts[1:])}"
101101

102102
self._log.info("{} - {}", m, s)
103103

@@ -126,7 +126,9 @@ def __init__(
126126
self._metric_descriptors = metric_descriptors
127127

128128
@override
129-
def create(self, output_dir: Path, config: object) -> MetricRecorder:
129+
def create(
130+
self, output_dir: Path, config: object, hyper_params: object
131+
) -> MetricRecorder:
130132
config = structure(config, LogMetricRecorderConfig)
131133

132134
validate(config)

src/fairseq2/metrics/recorders/_recorder.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,18 @@ class MetricRecorder(ABC):
1919
@abstractmethod
2020
def record_metrics(
2121
self,
22-
run: str,
22+
section: str,
2323
values: Mapping[str, object],
2424
step_nr: int | None = None,
2525
*,
2626
flush: bool = True,
2727
) -> None:
2828
"""Record ``values``.
2929
30-
:param run:
31-
The name of the run (e.g. 'train', 'eval').
32-
:param values:
33-
The metric values.
34-
:param step_nr:
35-
The step number of the run.
36-
:param flush:
37-
If ``True``, flushes any buffers after recording.
30+
:param section: The run section (e.g. 'train', 'eval').
31+
:param values: The metric values.
32+
:param step_nr: The step number of the run.
33+
:param flush: If ``True``, flushes any buffers after recording.
3834
"""
3935

4036
@abstractmethod
@@ -51,7 +47,7 @@ class NoopMetricRecorder(MetricRecorder):
5147
@override
5248
def record_metrics(
5349
self,
54-
run: str,
50+
section: str,
5551
values: Mapping[str, object],
5652
step_nr: int | None = None,
5753
*,
@@ -66,7 +62,7 @@ def close(self) -> None:
6662

6763
def record_metrics(
6864
recorders: Sequence[MetricRecorder],
69-
run: str,
65+
section: str,
7066
values: Mapping[str, object],
7167
step_nr: int | None = None,
7268
*,
@@ -75,10 +71,10 @@ def record_metrics(
7571
"""Record ``values`` to ``recorders``.
7672
7773
:param recorders: The recorders to record to.
78-
:param run: The name of the run (e.g. 'train', 'eval').
74+
:param section: The run section (e.g. 'train', 'eval').
7975
:param values: The metric values.
8076
:param step_nr: The step number of the run.
8177
:param flush: If ``True``, flushes any buffers after recording.
8278
"""
8379
for recorder in recorders:
84-
recorder.record_metrics(run, values, step_nr, flush=flush)
80+
recorder.record_metrics(section, values, step_nr, flush=flush)

src/fairseq2/metrics/recorders/_tensorboard.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ def __init__(
6262
@override
6363
def record_metrics(
6464
self,
65-
run: str,
65+
section: str,
6666
values: Mapping[str, object],
6767
step_nr: int | None = None,
6868
*,
6969
flush: bool = True,
7070
) -> None:
71-
writer = self._get_writer(run)
71+
writer = self._get_writer(section)
7272
if writer is None:
7373
return
7474

@@ -90,18 +90,18 @@ def record_metrics(
9090
writer.flush()
9191
except RuntimeError as ex:
9292
raise MetricRecordError(
93-
f"The metric values of the '{run}' cannot be saved to TensorBoard. See the nested exception for details."
93+
f"The metric values of the '{section}' section cannot be saved to TensorBoard. See the nested exception for details."
9494
) from ex
9595

96-
def _get_writer(self, run: str) -> SummaryWriter | None:
96+
def _get_writer(self, section: str) -> SummaryWriter | None:
9797
if not _has_tensorboard:
9898
return None
9999

100-
writer = self._writers.get(run)
100+
writer = self._writers.get(section)
101101
if writer is None:
102-
writer = SummaryWriter(self._output_dir.joinpath(run))
102+
writer = SummaryWriter(self._output_dir.joinpath(section))
103103

104-
self._writers[run] = writer
104+
self._writers[section] = writer
105105

106106
return writer
107107

@@ -129,7 +129,9 @@ def __init__(self, metric_descriptors: Provider[MetricDescriptor]) -> None:
129129
self._metric_descriptors = metric_descriptors
130130

131131
@override
132-
def create(self, output_dir: Path, config: object) -> MetricRecorder:
132+
def create(
133+
self, output_dir: Path, config: object, hyper_params: object
134+
) -> MetricRecorder:
133135
config = structure(config, TensorBoardRecorderConfig)
134136

135137
validate(config)

0 commit comments

Comments
 (0)