| 
7 | 7 | import os  | 
8 | 8 | import sys  | 
9 | 9 | import time  | 
10 |  | -from abc import ABC, abstractmethod  | 
11 | 10 | from pathlib import Path  | 
12 | 11 | from typing import Mapping, Optional  | 
13 | 12 | 
 
  | 
14 | 13 | import torch  | 
15 | 14 | 
 
  | 
 | 15 | +from forge.interfaces import MetricLogger  | 
16 | 16 | from forge.types import Scalar  | 
17 | 17 | 
 
  | 
18 | 18 | 
 
  | 
19 | 19 | def get_metric_logger(logger: str = "stdout", **log_config):  | 
20 | 20 |     return METRIC_LOGGER_STR_TO_CLS[logger](**log_config)  | 
21 | 21 | 
 
  | 
22 | 22 | 
 
  | 
23 |  | -class MetricLogger(ABC):  | 
24 |  | -    """Abstract metric logger.  | 
25 |  | -
  | 
26 |  | -    Args:  | 
27 |  | -        log_freq (Mapping[str, int]):  | 
28 |  | -            calls to `log` and `log_dict` will be ignored if `step % log_freq[metric_name] != 0`  | 
29 |  | -    """  | 
30 |  | - | 
31 |  | -    def __init__(self, log_freq: Mapping[str, int]):  | 
32 |  | -        self._log_freq = log_freq  | 
33 |  | -        self._step = None  | 
34 |  | - | 
35 |  | -    def set_step(self, step: int) -> None:  | 
36 |  | -        """Subsequent log calls will use this step number by default if not provided to the log call."""  | 
37 |  | -        self._step = step  | 
38 |  | - | 
39 |  | -    def is_log_step(self, name: str, step: Optional[int] = None):  | 
40 |  | -        """Returns true if the current step is a logging step.  | 
41 |  | -
  | 
42 |  | -        Args:  | 
43 |  | -            name (str): metric name (for checking the log freq for this metric)  | 
44 |  | -            step (int): current step. if not given, will use the one last provided via set_step()  | 
45 |  | -        """  | 
46 |  | -        if step is None:  | 
47 |  | -            assert (  | 
48 |  | -                self._step is not None  | 
49 |  | -            ), "`step` arg required if `set_step` has not been called."  | 
50 |  | -            step = self._step  | 
51 |  | -        return step % self._log_freq[name] == 0  | 
52 |  | - | 
53 |  | -    def log(  | 
54 |  | -        self,  | 
55 |  | -        name: str,  | 
56 |  | -        data: Scalar,  | 
57 |  | -        step: Optional[int] = None,  | 
58 |  | -    ) -> None:  | 
59 |  | -        """Log scalar data if this is a logging step.  | 
60 |  | -
  | 
61 |  | -        Args:  | 
62 |  | -            name (str): tag name used to group scalars  | 
63 |  | -            data (Scalar): scalar data to log  | 
64 |  | -            step (int): step value to record. if not given, will use the one last provided via set_step()  | 
65 |  | -        """  | 
66 |  | -        if step is None:  | 
67 |  | -            assert (  | 
68 |  | -                self._step is not None  | 
69 |  | -            ), "`step` arg required if `set_step` has not been called."  | 
70 |  | -            step = self._step  | 
71 |  | -        if step % self._log_freq[name] == 0:  | 
72 |  | -            self._log(name, data, step)  | 
73 |  | - | 
74 |  | -    def log_dict(  | 
75 |  | -        self, metrics: Mapping[str, Scalar], step: Optional[int] = None  | 
76 |  | -    ) -> None:  | 
77 |  | -        """Log multiple scalar values if this is a logging step.  | 
78 |  | -
  | 
79 |  | -        Args:  | 
80 |  | -            metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value  | 
81 |  | -            step (int): step value to record. if not given, will use the one last provided via set_step()  | 
82 |  | -        """  | 
83 |  | -        if step is None:  | 
84 |  | -            assert (  | 
85 |  | -                self._step is not None  | 
86 |  | -            ), "`step` arg required if `set_step` has not been called."  | 
87 |  | -            step = self._step  | 
88 |  | - | 
89 |  | -        log_step_metrics = {  | 
90 |  | -            name: value  | 
91 |  | -            for name, value in metrics.items()  | 
92 |  | -            if step % self._log_freq[name] == 0  | 
93 |  | -        }  | 
94 |  | -        if log_step_metrics:  | 
95 |  | -            self._log_dict(log_step_metrics, step)  | 
96 |  | - | 
97 |  | -    @abstractmethod  | 
98 |  | -    def _log(self, name: str, data: Scalar, step: int) -> None:  | 
99 |  | -        """Log scalar data.  | 
100 |  | -
  | 
101 |  | -        Args:  | 
102 |  | -            name (str): tag name used to group scalars  | 
103 |  | -            data (Scalar): scalar data to log  | 
104 |  | -            step (int): step value to record  | 
105 |  | -        """  | 
106 |  | -        pass  | 
107 |  | - | 
108 |  | -    @abstractmethod  | 
109 |  | -    def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:  | 
110 |  | -        """Log multiple scalar values.  | 
111 |  | -
  | 
112 |  | -        Args:  | 
113 |  | -            payload (Mapping[str, Scalar]): dictionary of tag name and scalar value  | 
114 |  | -            step (int): step value to record  | 
115 |  | -        """  | 
116 |  | -        pass  | 
117 |  | - | 
118 |  | -    def __del__(self) -> None:  | 
119 |  | -        self.close()  | 
120 |  | - | 
121 |  | -    def close(self) -> None:  | 
122 |  | -        """  | 
123 |  | -        Close log resource, flushing if necessary.  | 
124 |  | -        This will automatically be called via __del__ when the instance goes out of scope.  | 
125 |  | -        Logs should not be written after `close` is called.  | 
126 |  | -        """  | 
127 |  | -        pass  | 
128 |  | - | 
129 |  | - | 
130 | 23 | class StdoutLogger(MetricLogger):  | 
131 | 24 |     """Logger to standard output."""  | 
132 | 25 | 
 
  | 
@@ -264,21 +157,21 @@ def __init__(  | 
264 | 157 |         from torch.utils.tensorboard import SummaryWriter  | 
265 | 158 | 
 
  | 
266 | 159 |         self._writer: Optional[SummaryWriter] = None  | 
267 |  | -        _, self._rank = get_world_size_and_rank()  | 
 | 160 | +        rank = _get_rank()  | 
268 | 161 | 
 
  | 
269 | 162 |         # In case organize_logs is `True`, update log_dir to include a subdirectory for the  | 
270 | 163 |         # current run  | 
271 | 164 |         self.log_dir = (  | 
272 |  | -            os.path.join(log_dir, f"run_{self._rank}_{time.time()}")  | 
 | 165 | +            os.path.join(log_dir, f"run_{rank}_{time.time()}")  | 
273 | 166 |             if organize_logs  | 
274 | 167 |             else log_dir  | 
275 | 168 |         )  | 
276 | 169 | 
 
  | 
277 | 170 |         # Initialize the log writer only if we're on rank 0.  | 
278 |  | -        if self._rank == 0:  | 
 | 171 | +        if rank == 0:  | 
279 | 172 |             self._writer = SummaryWriter(log_dir=self.log_dir)  | 
280 | 173 | 
 
  | 
281 |  | -    def log(self, name: str, data: Scalar, step: int) -> None:  | 
 | 174 | +    def _log(self, name: str, data: Scalar, step: int) -> None:  | 
282 | 175 |         if self._writer:  | 
283 | 176 |             self._writer.add_scalar(name, data, global_step=step, new_style=True)  | 
284 | 177 | 
 
  | 
 | 
0 commit comments