-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
78 lines (63 loc) · 2.44 KB
/
train.py
File metadata and controls
78 lines (63 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from pytorch_lightning import Trainer, seed_everything
from lightning.pytorch.loggers.neptune import NeptuneLogger
import os
import argparse
from dotenv import load_dotenv
from GLUEDataModule import GLUEDataModule
from GLUETransformer import GLUETransformer
load_dotenv()
# Setup Neptune
my_api_token = os.getenv('NEPTUNE_API_TOKEN')
my_project = os.getenv('NEPTUNE_PROJECT')
# get run configuration
parser = argparse.ArgumentParser(description="Process inputs.")
parser.add_argument('--checkpoint_dir', type=str, required=False, help='directory for model checkpoints')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--adam_epsilon', type=float, default=1e-8, help='epsilon for Adam optimizer')
parser.add_argument('--warmup_steps', type=int, default=248, help='number of warmup steps for learning rate schedule')
parser.add_argument('--weight_decay', type=float, default=0.00934170221511866, help='weight decay for optimization')
parser.add_argument('--train_batch_size', type=int, default=1024, help='batch size for training')
parser.add_argument('--eval_batch_size', type=int, default=32, help='batch size for evaluation')
args = parser.parse_args()
print(f"Checkpoint directory: {args.checkpoint_dir}")
print(f"Learning rate: {args.lr}")
print(f"Adam Epsilon: {args.adam_epsilon}")
print(f"Warmup Steps: {args.warmup_steps}")
print(f"Weight Decay: {args.weight_decay}")
print(f"Train Batch Size: {args.train_batch_size}")
print(f"Eval Batch Size: {args.eval_batch_size}")
seed_everything(42)
dm = GLUEDataModule(
model_name_or_path="distilbert-base-uncased",
task_name="mrpc",
)
dm.setup("fit")
model = GLUETransformer(
model_name_or_path="distilbert-base-uncased",
num_labels=dm.num_labels,
eval_splits=dm.eval_splits,
task_name=dm.task_name,
learning_rate=args.lr,
adam_epsilon=args.adam_epsilon,
weight_decay=args.weight_decay,
train_batch_size=args.train_batch_size,
eval_batch_size=args.eval_batch_size
)
# Create NeptuneLogger instance
neptune_logger = NeptuneLogger(
project=my_project,
api_token=my_api_token,
)
# add neptune to the logger
trainer = Trainer(
max_epochs=3,
accelerator="auto",
devices=1 if torch.cuda.is_available() else None,
logger=neptune_logger
)
trainer.fit(model, datamodule=dm)
# save model
if(args.checkpoint_dir != None):
model_save_path = args.checkpoint_dir + 'model.pt'
torch.save(model, model_save_path)