Skip to content

Commit 6772c76

Browse files
committed
Added wandb options including whther to write grads to logs.
1 parent 4785882 commit 6772c76

File tree

5 files changed

+86
-2
lines changed

5 files changed

+86
-2
lines changed

launcher.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ NSTEPS=100_000
1212
torchrun --nproc_per_node $NGPUS \
1313
trainval.py \
1414
-e mlm \
15+
--wandb-entity-name jbmf \
16+
--wandb-project-name tf-encoder \
17+
--wandb-run-name mlm \
18+
--wandb-log-gradients false \
1519
--steps $NSTEPS \
1620
-sb $PATH_TO_LOG \
1721
--train_data_name $TRAIN_DATA_NAME \

src/hf_trainer.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import wandb
23
import torch
34
from torch.utils.data.dataset import Dataset
45
import transformers
@@ -16,6 +17,9 @@
1617
retrieval_eval,
1718
)
1819
from src.datasets_loader import Collator
20+
from src.logging_callback import LoggingCallback
21+
22+
from functools import partial
1923

2024

2125
def compute_metrics(eval_pred: PredictionOutput) -> Dict[str, float]:
@@ -221,6 +225,10 @@ def get_trainer(
221225
log_every: int = 100,
222226
local_rank: int = 0,
223227
deepspeed_cfg_path: str = None,
228+
wandb_entity_name: str = None,
229+
wandb_project_name: str = None,
230+
wandb_run_name: str = None,
231+
wandb_log_grads: bool = False,
224232
) -> CustomTrainer:
225233
"""Intanstiates Trainer object.
226234
@@ -234,6 +242,10 @@ def get_trainer(
234242
log_every (int): Logging interval.
235243
local_rank (int): Device id for distributed training.
236244
deepspeed_cfg_path (str, Optional): Optional path to deepspeed config.
245+
wandb_entity_name (str, optional): Wandb entity. Defaults to None.
246+
wandb_project_name (str, optional): Project name for wandb. Defaults to None.
247+
wandb-run-name (str, optional): Run id name for wandb. Defaults to None.
248+
wandb_log_grads (bool, optional): Whether to write grads on wandb logs. Defaults to False.
237249
238250
Returns:
239251
CustomTrainer: Trainer object.
@@ -265,18 +277,27 @@ def get_trainer(
265277
save_strategy="steps",
266278
save_steps=log_every,
267279
evaluation_strategy="steps",
268-
report_to="wandb",
280+
# report_to="wandb",
269281
)
270282

271283
encoder = get_encoder(exp_dict=exp_dict)
272284

285+
wandb.init(
286+
name=wandb_run_name,
287+
entity=wandb_entity_name,
288+
project=wandb_project_name,
289+
)
290+
273291
trainer = CustomTrainer(
274292
model=encoder,
275293
args=training_args,
276294
train_dataset=train_dataset,
277295
eval_dataset=valid_dataset,
278296
compute_metrics=compute_metrics,
279297
data_collator=collate_fn,
298+
callbacks=[
299+
LoggingCallback(log_grads=wandb_log_grads),
300+
],
280301
)
281302

282303
return trainer

src/logging_callback.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from transformers.integrations import WandbCallback
2+
3+
4+
class LoggingCallback(WandbCallback):
5+
"""
6+
Overrigding WandbCallback to optionally turn off gradient logging.
7+
"""
8+
9+
def __init__(self, log_grads: bool):
10+
11+
super().__init__()
12+
13+
self.log_grads = log_grads
14+
15+
def setup(self, args, state, model, **kwargs):
16+
17+
super().setup(args, state, model, **kwargs)
18+
_watch_model = "all" if self.log_grads else "parameters"
19+
self._wandb.watch(
20+
model, log=_watch_model, log_freq=max(100, args.logging_steps)
21+
)

src/training_args.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,26 @@
11
import argparse
22

33

4+
def parse_bool_flag(s: str) -> bool:
5+
"""Parse boolean arguments from the command line.
6+
7+
Args:
8+
s (str): Input arg string.
9+
10+
Returns:
11+
bool: _description_
12+
"""
13+
_FALSY_STRINGS = {"off", "false", "0"}
14+
_TRUTHY_STRINGS = {"on", "true", "1"}
15+
if s.lower() in _FALSY_STRINGS:
16+
return False
17+
elif s.lower() in _TRUTHY_STRINGS:
18+
return True
19+
else:
20+
raise argparse.ArgumentTypeError("Invalid value for a boolean flag")
21+
22+
423
def parse_args():
5-
# Specify arguments regarding save directory and job scheduler
624
parser = argparse.ArgumentParser()
725
parser.add_argument(
826
"-e",
@@ -45,6 +63,22 @@ def parse_args():
4563
type=int,
4664
help="Number of iterations to wait before logging training scores.",
4765
)
66+
parser.add_argument(
67+
"--wandb-entity-name",
68+
type=str,
69+
default="bigcode",
70+
help="Name of wandb entity for reporting.",
71+
)
72+
parser.add_argument(
73+
"--wandb-project-name", type=str, default=None, help="Name of wandb project."
74+
)
75+
parser.add_argument("--wandb-run-name", type=str, default=None, help="Name of run.")
76+
parser.add_argument(
77+
"--wandb-log-gradients",
78+
type=parse_bool_flag,
79+
default="false",
80+
help="Whether to write gradients to wandb logs.",
81+
)
4882
parser.add_argument(
4983
"--dist_url",
5084
default="env://",

trainval_toolkit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def train(exp_dict, savedir, args):
5757
valid_dataset=gfg_test_data,
5858
collate_fn=collate_fn,
5959
log_every=args.log_every,
60+
wandb_entity_name=args.wandb_entity_name,
61+
wandb_project_name=args.wandb_project_name,
62+
wandb_run_name=args.wandb_run_name,
63+
wandb_log_grads=args.wandb_log_gradients,
6064
local_rank=args.local_rank,
6165
deepspeed_cfg_path=args.deepspeed,
6266
)

0 commit comments

Comments
 (0)