Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 71 additions & 39 deletions grade_school_math/train.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,82 @@
import torch as th
from dataset import get_examples, GSMDataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import GPT2Config, AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from transformers import (
GPT2Tokenizer, GPT2LMHeadModel, GPT2Config,
AdamW, get_scheduler
)
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# === CLASS 1: Tokenizer Loader ===
class TokenizerLoader:
@staticmethod
def load(name="gpt2"):
return GPT2Tokenizer.from_pretrained(name)

# === CLASS 2: Dataset + Dataloader ===
class DataLoaderBuilder:
def __init__(self, tokenizer, batch_size=16):
examples = get_examples("train")
dataset = GSMDataset(tokenizer, examples)
self.loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def get_loader(self):
return self.loader

# === CLASS 3: Model + Optimizer ===
class ModelBuilder:
def __init__(self, model_name="gpt2", lr=1e-5, total_steps=1000):
config = GPT2Config.from_pretrained(model_name)
self.model = GPT2LMHeadModel.from_pretrained(model_name, config=config)
self.device = th.device("cuda" if th.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.train()
self.optimizer = AdamW(self.model.parameters(), lr=lr)
self.scheduler = get_scheduler(
"linear", optimizer=self.optimizer,
num_warmup_steps=0, num_training_steps=total_steps
)

# === CLASS 4: Trainer ===
class Trainer:
def __init__(self, model, optimizer, scheduler, device):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.device = device

def train(self, dataloader, num_epochs):
pbar = tqdm(range(num_epochs * len(dataloader)))
for epoch in range(num_epochs):
for batch in dataloader:
self.optimizer.zero_grad()
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(**batch, labels=batch["input_ids"])
loss = outputs[0]
loss.backward()
self.optimizer.step()
self.scheduler.step()
pbar.update(1)
pbar.set_description(f"train_loss: {loss.item():.5f}")

# === MAIN FUNCTION ===
def main():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
train_examples = get_examples("train")
train_dset = GSMDataset(tokenizer, train_examples)

device = th.device("cuda")
config = GPT2Config.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", config=config)
model.to(device)
model.train()

train_loader = DataLoader(train_dset, batch_size=16, shuffle=True)
optim = AdamW(model.parameters(), lr=1e-5)

num_epochs = 20
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
"linear",
optimizer=optim,
num_warmup_steps=0,
num_training_steps=num_training_steps,
)
tokenizer = TokenizerLoader.load("gpt2")
data_builder = DataLoaderBuilder(tokenizer)
train_loader = data_builder.get_loader()

pbar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
for batch in train_loader:
optim.zero_grad()
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch, labels=batch["input_ids"])
loss = outputs[0]
loss.backward()
optim.step()
lr_scheduler.step()
pbar.update(1)
pbar.set_description(f"train_loss: {loss.item():.5f}")

model.save_pretrained("model_ckpts/")
total_steps = len(train_loader) * 20
model_builder = ModelBuilder("gpt2", lr=1e-5, total_steps=total_steps)

trainer = Trainer(
model_builder.model,
model_builder.optimizer,
model_builder.scheduler,
model_builder.device
)
trainer.train(train_loader, num_epochs=20)

model_builder.model.save_pretrained("model_ckpts/")

if __name__ == "__main__":
main()