Skip to content

Commit 903204a

Browse files
fixing docstrings
1 parent b1dd753 commit 903204a

File tree

2 files changed

+44
-6
lines changed

2 files changed

+44
-6
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
# enable_profiling: false
44

55
metrics:
6-
logger: wandb
7-
project: dsafjkdsafjlskdjf
6+
logger: tensorboard
87
freq:
98
loss: 10
109

src/forge/util/metric_logging.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,24 @@ def is_log_step(self, name: str, step: int) -> bool:
3838
return step % self._freq[name] == 0
3939

4040
def log(self, name: str, data: Scalar, step: int) -> None:
41+
"""Log the metric if it is a logging step.
42+
43+
Args:
44+
name (str): metric name
45+
data (Scalar): metric value
46+
step (int): current step
47+
"""
4148
if not self.is_log_step(name, step):
4249
return
4350
print(f"Step {step} | {name}:{data}")
4451

4552
def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None:
53+
"""Log the metrics for which this is currently a logging step.
54+
55+
Args:
56+
metrics (Mapping[str, Scalar]): dict of metric names and values
57+
step (int): current step
58+
"""
4659
log_step_metrics = {
4760
name: value
4861
for name, value in metrics.items()
@@ -76,8 +89,8 @@ class TensorBoardLogger(MetricLogger):
7689
**kwargs: additional arguments
7790
7891
Example:
79-
>>> from torchtune.training.metric_logging import TensorBoardLogger
80-
>>> logger = TensorBoardLogger(log_dir="my_log_dir")
92+
>>> from forge.util.metric_logging import TensorBoardLogger
93+
>>> logger = TensorBoardLogger(freq={"loss": 10}, log_dir="my_log_dir")
8194
>>> logger.log("my_metric", 1.0, 1)
8295
>>> logger.log_dict({"my_metric": 1.0}, 1)
8396
>>> logger.close()
@@ -123,10 +136,23 @@ def is_log_step(self, name: str, step: int) -> bool:
123136
return step % self._freq[name] == 0
124137

125138
def log(self, name: str, data: Scalar, step: int) -> None:
139+
"""Log the metric if it is a logging step.
140+
141+
Args:
142+
name (str): metric name
143+
data (Scalar): metric value
144+
step (int): current step
145+
"""
126146
if self._writer:
127147
self._writer.add_scalar(name, data, global_step=step, new_style=True)
128148

129149
def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None:
150+
"""Log the metrics for which this is currently a logging step.
151+
152+
Args:
153+
metrics (Mapping[str, Scalar]): dict of metric names and values
154+
step (int): current step
155+
"""
130156
for name, data in metrics.items():
131157
if self.is_log_step(name, step):
132158
self.log(name, data, step)
@@ -153,8 +179,8 @@ class WandBLogger(MetricLogger):
153179
**kwargs: additional arguments to pass to wandb.init
154180
155181
Example:
156-
>>> from torchtune.training.metric_logging import WandBLogger
157-
>>> logger = WandBLogger(log_dir="wandb", project="my_project", entity="my_entity", group="my_group")
182+
>>> from forge.util.metric_logging import WandBLogger
183+
>>> logger = WandBLogger(freq={"loss": 10}, log_dir="wandb", project="my_project")
158184
>>> logger.log("my_metric", 1.0, 1)
159185
>>> logger.log_dict({"my_metric": 1.0}, 1)
160186
>>> logger.close()
@@ -218,10 +244,23 @@ def is_log_step(self, name: str, step: int) -> bool:
218244
return step % self._freq[name] == 0
219245

220246
def log(self, name: str, data: Scalar, step: int) -> None:
247+
"""Log the metric if it is a logging step.
248+
249+
Args:
250+
name (str): metric name
251+
data (Scalar): metric value
252+
step (int): current step
253+
"""
221254
if self._wandb.run and self.is_log_step(name, step):
222255
self._wandb.log({name: data, "step": step})
223256

224257
def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None:
258+
"""Log the metrics for which this is currently a logging step.
259+
260+
Args:
261+
metrics (Mapping[str, Scalar]): dict of metric names and values
262+
step (int): current step
263+
"""
225264
log_step_metrics = {
226265
name: value
227266
for name, value in metrics.items()

0 commit comments

Comments
 (0)