Skip to content

Commit de31799

Browse files
move base class to interfaces file
1 parent 1c7483d commit de31799

File tree

2 files changed

+114
-114
lines changed

2 files changed

+114
-114
lines changed

src/forge/interfaces.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8-
from typing import Any
8+
from typing import Any, Mapping, Optional
99

1010
from monarch.actor import Actor, endpoint
1111

12-
from forge.types import Action, Message, Observation, State
12+
from forge.types import Action, Message, Observation, Scalar, State
1313

1414

1515
class Transform(ABC):
@@ -150,3 +150,110 @@ def tokenize_messages(
150150
tuple[list[int], list[bool]]: The list of token ids and the list of masks.
151151
"""
152152
pass
153+
154+
155+
class MetricLogger(ABC):
156+
"""Abstract metric logger.
157+
158+
Args:
159+
log_freq (Mapping[str, int]):
160+
calls to `log` and `log_dict` will be ignored if `step % log_freq[metric_name] != 0`
161+
"""
162+
163+
def __init__(self, log_freq: Mapping[str, int]):
164+
self._log_freq = log_freq
165+
self._step = None
166+
167+
def set_step(self, step: int) -> None:
168+
"""Subsequent log calls will use this step number by default if not provided to the log call."""
169+
self._step = step
170+
171+
def is_log_step(self, name: str, step: Optional[int] = None):
172+
"""Returns true if the current step is a logging step.
173+
174+
Args:
175+
name (str): metric name (for checking the log freq for this metric)
176+
step (int): current step. if not given, will use the one last provided via set_step()
177+
"""
178+
if step is None:
179+
assert (
180+
self._step is not None
181+
), "`step` arg required if `set_step` has not been called."
182+
step = self._step
183+
return step % self._log_freq[name] == 0
184+
185+
def log(
186+
self,
187+
name: str,
188+
data: Scalar,
189+
step: Optional[int] = None,
190+
) -> None:
191+
"""Log scalar data if this is a logging step.
192+
193+
Args:
194+
name (str): tag name used to group scalars
195+
data (Scalar): scalar data to log
196+
step (int): step value to record. if not given, will use the one last provided via set_step()
197+
"""
198+
if step is None:
199+
assert (
200+
self._step is not None
201+
), "`step` arg required if `set_step` has not been called."
202+
step = self._step
203+
if step % self._log_freq[name] == 0:
204+
self._log(name, data, step)
205+
206+
def log_dict(
207+
self, metrics: Mapping[str, Scalar], step: Optional[int] = None
208+
) -> None:
209+
"""Log multiple scalar values if this is a logging step.
210+
211+
Args:
212+
metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value
213+
step (int): step value to record. if not given, will use the one last provided via set_step()
214+
"""
215+
if step is None:
216+
assert (
217+
self._step is not None
218+
), "`step` arg required if `set_step` has not been called."
219+
step = self._step
220+
221+
log_step_metrics = {
222+
name: value
223+
for name, value in metrics.items()
224+
if step % self._log_freq[name] == 0
225+
}
226+
if log_step_metrics:
227+
self._log_dict(log_step_metrics, step)
228+
229+
@abstractmethod
230+
def _log(self, name: str, data: Scalar, step: int) -> None:
231+
"""Log scalar data.
232+
233+
Args:
234+
name (str): tag name used to group scalars
235+
data (Scalar): scalar data to log
236+
step (int): step value to record
237+
"""
238+
pass
239+
240+
@abstractmethod
241+
def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
242+
"""Log multiple scalar values.
243+
244+
Args:
245+
payload (Mapping[str, Scalar]): dictionary of tag name and scalar value
246+
step (int): step value to record
247+
"""
248+
pass
249+
250+
def __del__(self) -> None:
251+
self.close()
252+
253+
def close(self) -> None:
254+
"""
255+
Close log resource, flushing if necessary.
256+
This will automatically be called via __del__ when the instance goes out of scope.
257+
Logs should not be written after `close` is called.
258+
"""
259+
pass

src/forge/util/metric_logging.py

Lines changed: 5 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -7,126 +7,19 @@
77
import os
88
import sys
99
import time
10-
from abc import ABC, abstractmethod
1110
from pathlib import Path
1211
from typing import Mapping, Optional
1312

1413
import torch
1514

15+
from forge.interfaces import MetricLogger
1616
from forge.types import Scalar
1717

1818

1919
def get_metric_logger(logger: str = "stdout", **log_config):
2020
return METRIC_LOGGER_STR_TO_CLS[logger](**log_config)
2121

2222

23-
class MetricLogger(ABC):
24-
"""Abstract metric logger.
25-
26-
Args:
27-
log_freq (Mapping[str, int]):
28-
calls to `log` and `log_dict` will be ignored if `step % log_freq[metric_name] != 0`
29-
"""
30-
31-
def __init__(self, log_freq: Mapping[str, int]):
32-
self._log_freq = log_freq
33-
self._step = None
34-
35-
def set_step(self, step: int) -> None:
36-
"""Subsequent log calls will use this step number by default if not provided to the log call."""
37-
self._step = step
38-
39-
def is_log_step(self, name: str, step: Optional[int] = None):
40-
"""Returns true if the current step is a logging step.
41-
42-
Args:
43-
name (str): metric name (for checking the log freq for this metric)
44-
step (int): current step. if not given, will use the one last provided via set_step()
45-
"""
46-
if step is None:
47-
assert (
48-
self._step is not None
49-
), "`step` arg required if `set_step` has not been called."
50-
step = self._step
51-
return step % self._log_freq[name] == 0
52-
53-
def log(
54-
self,
55-
name: str,
56-
data: Scalar,
57-
step: Optional[int] = None,
58-
) -> None:
59-
"""Log scalar data if this is a logging step.
60-
61-
Args:
62-
name (str): tag name used to group scalars
63-
data (Scalar): scalar data to log
64-
step (int): step value to record. if not given, will use the one last provided via set_step()
65-
"""
66-
if step is None:
67-
assert (
68-
self._step is not None
69-
), "`step` arg required if `set_step` has not been called."
70-
step = self._step
71-
if step % self._log_freq[name] == 0:
72-
self._log(name, data, step)
73-
74-
def log_dict(
75-
self, metrics: Mapping[str, Scalar], step: Optional[int] = None
76-
) -> None:
77-
"""Log multiple scalar values if this is a logging step.
78-
79-
Args:
80-
metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value
81-
step (int): step value to record. if not given, will use the one last provided via set_step()
82-
"""
83-
if step is None:
84-
assert (
85-
self._step is not None
86-
), "`step` arg required if `set_step` has not been called."
87-
step = self._step
88-
89-
log_step_metrics = {
90-
name: value
91-
for name, value in metrics.items()
92-
if step % self._log_freq[name] == 0
93-
}
94-
if log_step_metrics:
95-
self._log_dict(log_step_metrics, step)
96-
97-
@abstractmethod
98-
def _log(self, name: str, data: Scalar, step: int) -> None:
99-
"""Log scalar data.
100-
101-
Args:
102-
name (str): tag name used to group scalars
103-
data (Scalar): scalar data to log
104-
step (int): step value to record
105-
"""
106-
pass
107-
108-
@abstractmethod
109-
def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
110-
"""Log multiple scalar values.
111-
112-
Args:
113-
payload (Mapping[str, Scalar]): dictionary of tag name and scalar value
114-
step (int): step value to record
115-
"""
116-
pass
117-
118-
def __del__(self) -> None:
119-
self.close()
120-
121-
def close(self) -> None:
122-
"""
123-
Close log resource, flushing if necessary.
124-
This will automatically be called via __del__ when the instance goes out of scope.
125-
Logs should not be written after `close` is called.
126-
"""
127-
pass
128-
129-
13023
class StdoutLogger(MetricLogger):
13124
"""Logger to standard output."""
13225

@@ -264,21 +157,21 @@ def __init__(
264157
from torch.utils.tensorboard import SummaryWriter
265158

266159
self._writer: Optional[SummaryWriter] = None
267-
_, self._rank = get_world_size_and_rank()
160+
rank = _get_rank()
268161

269162
# In case organize_logs is `True`, update log_dir to include a subdirectory for the
270163
# current run
271164
self.log_dir = (
272-
os.path.join(log_dir, f"run_{self._rank}_{time.time()}")
165+
os.path.join(log_dir, f"run_{rank}_{time.time()}")
273166
if organize_logs
274167
else log_dir
275168
)
276169

277170
# Initialize the log writer only if we're on rank 0.
278-
if self._rank == 0:
171+
if rank == 0:
279172
self._writer = SummaryWriter(log_dir=self.log_dir)
280173

281-
def log(self, name: str, data: Scalar, step: int) -> None:
174+
def _log(self, name: str, data: Scalar, step: int) -> None:
282175
if self._writer:
283176
self._writer.add_scalar(name, data, global_step=step, new_style=True)
284177

0 commit comments

Comments
 (0)