Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/fairseq2/metrics/recorders/_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def __init__(self, recorders: Sequence[MetricRecorder]) -> None:
@override
def record_metrics(
self,
run: str,
section: str,
values: Mapping[str, object],
step_nr: int | None = None,
*,
flush: bool = True,
) -> None:
for recorder in self._inner_recorders:
recorder.record_metrics(run, values, step_nr, flush=flush)
recorder.record_metrics(section, values, step_nr, flush=flush)

@override
def close(self) -> None:
Expand Down
4 changes: 3 additions & 1 deletion src/fairseq2/metrics/recorders/_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

class MetricRecorderHandler(ABC):
@abstractmethod
def create(self, output_dir: Path, config: object) -> MetricRecorder: ...
def create(
self, output_dir: Path, config: object, hyper_params: object
) -> MetricRecorder: ...

@property
@abstractmethod
Expand Down
30 changes: 16 additions & 14 deletions src/fairseq2/metrics/recorders/_jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
class JsonlMetricRecorder(MetricRecorder):
"""Records metric values to JSONL files."""

_RUN_PART_REGEX: Final = re.compile("^[-_a-zA-Z0-9]+$")
_SECTION_PART_REGEX: Final = re.compile("^[-_a-zA-Z0-9]+$")

_output_dir: Path
_file_system: FileSystem
Expand All @@ -63,21 +63,21 @@ def __init__(
@override
def record_metrics(
self,
run: str,
section: str,
values: Mapping[str, object],
step_nr: int | None = None,
*,
flush: bool = True,
) -> None:
run = run.strip()
section = section.strip()

for part in run.split("/"):
if re.match(self._RUN_PART_REGEX, part) is None:
for part in section.split("/"):
if re.match(self._SECTION_PART_REGEX, part) is None:
raise ValueError(
f"`run` must contain only alphanumeric characters, dash, underscore, and forward slash, but is '{run}' instead."
f"`section` must contain only alphanumeric characters, dash, underscore, and forward slash, but is '{section}' instead."
)

stream = self._get_stream(run)
stream = self._get_stream(section)

values_and_descriptors = []

Expand Down Expand Up @@ -127,16 +127,16 @@ def sanitize(value: object, descriptor: MetricDescriptor) -> object:
stream.flush()
except OSError as ex:
raise MetricRecordError(
f"The metric values of the '{run}' cannot be saved to the JSON file. See the nested exception for details."
f"The metric values of the '{section}' cannot be saved to the JSON file. See the nested exception for details."
) from ex

def _get_stream(self, run: str) -> TextIO:
def _get_stream(self, section: str) -> TextIO:
try:
return self._streams[run]
return self._streams[section]
except KeyError:
pass

file = self._output_dir.joinpath(run).with_suffix(".jsonl")
file = self._output_dir.joinpath(section).with_suffix(".jsonl")

try:
self._file_system.make_directory(file.parent)
Expand All @@ -149,10 +149,10 @@ def _get_stream(self, run: str) -> TextIO:
fp = self._file_system.open_text(file, mode=FileMode.APPEND)
except OSError as ex:
raise MetricRecordError(
f"The '{file}' metric file for the '{run} run cannot be created. See the nested exception for details."
f"The '{file}' metric file for the '{section} section cannot be created. See the nested exception for details."
) from ex

self._streams[run] = fp
self._streams[section] = fp

return fp

Expand Down Expand Up @@ -184,7 +184,9 @@ def __init__(
self._metric_descriptors = metric_descriptors

@override
def create(self, output_dir: Path, config: object) -> MetricRecorder:
def create(
self, output_dir: Path, config: object, hyper_params: object
) -> MetricRecorder:
config = structure(config, JsonlMetricRecorderConfig)

validate(config)
Expand Down
22 changes: 12 additions & 10 deletions src/fairseq2/metrics/recorders/_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
@override
def record_metrics(
self,
run: str,
section: str,
values: Mapping[str, object],
step_nr: int | None = None,
*,
Expand Down Expand Up @@ -85,19 +85,19 @@ def record_metrics(
if not s:
s = "N/A"

run_parts = run.split("/")
section_parts = section.split("/")

phase = self._display_names.get(run_parts[0])
if phase is None:
phase = run_parts[0].capitalize()
title = self._display_names.get(section_parts[0])
if title is None:
title = section_parts[0].capitalize()

if step_nr is None:
m = f"{phase} Metrics"
m = f"{title} Metrics"
else:
m = f"{phase} Metrics (step {step_nr})"
m = f"{title} Metrics (step {step_nr})"

if len(run_parts) > 1:
m = f"{m} - {'/'.join(run_parts[1:])}"
if len(section_parts) > 1:
m = f"{m} - {'/'.join(section_parts[1:])}"

self._log.info("{} - {}", m, s)

Expand Down Expand Up @@ -126,7 +126,9 @@ def __init__(
self._metric_descriptors = metric_descriptors

@override
def create(self, output_dir: Path, config: object) -> MetricRecorder:
def create(
self, output_dir: Path, config: object, hyper_params: object
) -> MetricRecorder:
config = structure(config, LogMetricRecorderConfig)

validate(config)
Expand Down
22 changes: 9 additions & 13 deletions src/fairseq2/metrics/recorders/_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,18 @@ class MetricRecorder(ABC):
@abstractmethod
def record_metrics(
self,
run: str,
section: str,
values: Mapping[str, object],
step_nr: int | None = None,
*,
flush: bool = True,
) -> None:
"""Record ``values``.

:param run:
The name of the run (e.g. 'train', 'eval').
:param values:
The metric values.
:param step_nr:
The step number of the run.
:param flush:
If ``True``, flushes any buffers after recording.
:param section: The run section (e.g. 'train', 'eval').
:param values: The metric values.
:param step_nr: The step number of the run.
:param flush: If ``True``, flushes any buffers after recording.
"""

@abstractmethod
Expand All @@ -51,7 +47,7 @@ class NoopMetricRecorder(MetricRecorder):
@override
def record_metrics(
self,
run: str,
section: str,
values: Mapping[str, object],
step_nr: int | None = None,
*,
Expand All @@ -66,7 +62,7 @@ def close(self) -> None:

def record_metrics(
recorders: Sequence[MetricRecorder],
run: str,
section: str,
values: Mapping[str, object],
step_nr: int | None = None,
*,
Expand All @@ -75,10 +71,10 @@ def record_metrics(
"""Record ``values`` to ``recorders``.

:param recorders: The recorders to record to.
:param run: The name of the run (e.g. 'train', 'eval').
:param section: The run section (e.g. 'train', 'eval').
:param values: The metric values.
:param step_nr: The step number of the run.
:param flush: If ``True``, flushes any buffers after recording.
"""
for recorder in recorders:
recorder.record_metrics(run, values, step_nr, flush=flush)
recorder.record_metrics(section, values, step_nr, flush=flush)
18 changes: 10 additions & 8 deletions src/fairseq2/metrics/recorders/_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def __init__(
@override
def record_metrics(
self,
run: str,
section: str,
values: Mapping[str, object],
step_nr: int | None = None,
*,
flush: bool = True,
) -> None:
writer = self._get_writer(run)
writer = self._get_writer(section)
if writer is None:
return

Expand All @@ -90,18 +90,18 @@ def record_metrics(
writer.flush()
except RuntimeError as ex:
raise MetricRecordError(
f"The metric values of the '{run}' cannot be saved to TensorBoard. See the nested exception for details."
f"The metric values of the '{section}' section cannot be saved to TensorBoard. See the nested exception for details."
) from ex

def _get_writer(self, run: str) -> SummaryWriter | None:
def _get_writer(self, section: str) -> SummaryWriter | None:
if not _has_tensorboard:
return None

writer = self._writers.get(run)
writer = self._writers.get(section)
if writer is None:
writer = SummaryWriter(self._output_dir.joinpath(run))
writer = SummaryWriter(self._output_dir.joinpath(section))

self._writers[run] = writer
self._writers[section] = writer

return writer

Expand Down Expand Up @@ -129,7 +129,9 @@ def __init__(self, metric_descriptors: Provider[MetricDescriptor]) -> None:
self._metric_descriptors = metric_descriptors

@override
def create(self, output_dir: Path, config: object) -> MetricRecorder:
def create(
self, output_dir: Path, config: object, hyper_params: object
) -> MetricRecorder:
config = structure(config, TensorBoardRecorderConfig)

validate(config)
Expand Down
Loading
Loading