diff --git a/.gitignore b/.gitignore index e80e5f6..4290488 100644 --- a/.gitignore +++ b/.gitignore @@ -139,4 +139,4 @@ pix2tex/model/checkpoints/** .vscode .DS_Store test/* - +*.prf diff --git a/pix2tex/eval.py b/pix2tex/eval.py index 8742988..a3d3695 100644 --- a/pix2tex/eval.py +++ b/pix2tex/eval.py @@ -27,6 +27,118 @@ def detokenize(tokens, tokenizer): return toks +def evaluate_step(model: Model, dataset_tokenizer, data_batch, args: Munch, name: str = 'test'): + """One step to evaluate the model. Returns bleu score on the data batch + + Args: + model (torch.nn.Module): the model + data_batch : one test data batch (seq, im) + args (Munch): arguments + + Returns: + Tuple[float, float, float]: BLEU score of batch, normed edit distance, token accuracy + """ + (seq, im) = data_batch + edit_dists = [] + log = {} + bleu_score, edit_distance, token_accuracy = 0, 1, 0 + + dec = model.generate(im, temperature=args.get('temperature', .2)) + pred = detokenize(dec, dataset_tokenizer) + truth = detokenize(seq['input_ids'], dataset_tokenizer) + + # blue score + bleu_score = metrics.bleu_score(pred, [alternatives(x) for x in truth]) + + # edit distance + for predi, truthi in zip(token2str(dec, dataset_tokenizer), token2str(seq['input_ids'], dataset_tokenizer)): + ts = post_process(truthi) + if len(ts) > 0: + edit_dists.append(distance(post_process(predi), ts)/len(ts)) + edit_distance = np.mean(edit_dists) if len(edit_dists) > 0 else 1 + + # token accuracy + tgt_seq = seq['input_ids'][:, 1:] + shape_diff = dec.shape[1]-tgt_seq.shape[1] + if shape_diff < 0: + dec = torch.nn.functional.pad(dec, (0, -shape_diff), "constant", args.pad_token) + elif shape_diff > 0: + tgt_seq = torch.nn.functional.pad(tgt_seq, (0, shape_diff), "constant", args.pad_token) + mask = torch.logical_or(tgt_seq != args.pad_token, dec != args.pad_token) + tok_acc = (dec == tgt_seq)[mask].float().mean().item() + token_accuracy = np.mean(tok_acc) + + log[name+'/bleu'] = bleu_score + log[name+'/token_acc'] = token_accuracy + log[name+'/edit_distance'] = edit_distance + + if args.wandb: + pred = token2str(dec, dataset_tokenizer) + truth = token2str(seq['input_ids'], dataset_tokenizer) + table = wandb.Table(columns=["Truth", "Prediction"]) + for k in range(min([len(pred), args.test_samples])): + table.add_data(post_process(truth[k]), post_process(pred[k])) + log[name+'/examples'] = table + wandb.log(log) + return bleu_score, edit_distance, token_accuracy + + +def evaluate_step__(model: Model, dataset_tokenizer, data_batch, args: Munch, name: str = 'test'): + """One step to evaluate the model. Returns bleu score on the data batch + + Args: + model (torch.nn.Module): the model + data_batch : test data batch + args (Munch): arguments + + Returns: + Tuple[float, float, float]: BLEU score of batch, normed edit distance, token accuracy + """ + (seq, im) = data_batch + bleus, edit_dists, token_acc = [], [], [] + bleu_score, edit_distance, token_accuracy = 0, 1, 0 + log = {} + + # loss = decoder(tgt_seq, mask=tgt_mask, context=encoded) + dec = model.generate(im, temperature=args.get('temperature', .2)) + pred = detokenize(dec, dataset_tokenizer) + truth = detokenize(seq['input_ids'], dataset_tokenizer) + bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth])) + for predi, truthi in zip(token2str(dec, dataset_tokenizer), token2str(seq['input_ids'], dataset_tokenizer)): + ts = post_process(truthi) + if len(ts) > 0: + edit_dists.append(distance(post_process(predi), ts)/len(ts)) + # dec = dec.cpu() + tgt_seq = seq['input_ids'][:, 1:] + shape_diff = dec.shape[1]-tgt_seq.shape[1] + if shape_diff < 0: + dec = torch.nn.functional.pad(dec, (0, -shape_diff), "constant", args.pad_token) + elif shape_diff > 0: + tgt_seq = torch.nn.functional.pad(tgt_seq, (0, shape_diff), "constant", args.pad_token) + mask = torch.logical_or(tgt_seq != args.pad_token, dec != args.pad_token) + tok_acc = (dec == tgt_seq)[mask].float().mean().item() + token_acc.append(tok_acc) + + if len(bleus) > 0: + bleu_score = np.mean(bleus) + log[name+'/bleu'] = bleu_score + if len(edit_dists) > 0: + edit_distance = np.mean(edit_dists) + log[name+'/edit_distance'] = edit_distance + if len(token_acc) > 0: + token_accuracy = np.mean(token_acc) + log[name+'/token_acc'] = token_accuracy + if args.wandb: + pred = token2str(dec, dataset_tokenizer) + truth = token2str(seq['input_ids'], dataset_tokenizer) + table = wandb.Table(columns=["Truth", "Prediction"]) + for k in range(min([len(pred), args.test_samples])): + table.add_data(post_process(truth[k]), post_process(pred[k])) + log[name+'/examples'] = table + wandb.log(log) + return bleu_score, edit_distance, token_accuracy + + @torch.no_grad() def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'): """evaluates the model. Returns bleu score on the dataset diff --git a/pix2tex/train.py b/pix2tex/train.py index bd2f599..4ff1e6e 100644 --- a/pix2tex/train.py +++ b/pix2tex/train.py @@ -1,5 +1,6 @@ from pix2tex.dataset.dataset import Im2LatexDataset import os +import sys import argparse import logging import yaml @@ -9,10 +10,14 @@ from tqdm.auto import tqdm import wandb import torch.nn as nn -from pix2tex.eval import evaluate +from pix2tex.eval import evaluate, evaluate_step from pix2tex.models import get_model # from pix2tex.utils import * from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler, gpu_memory_check +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar, OnExceptionCheckpoint +from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme +from pytorch_lightning.loggers import CSVLogger def train(args): @@ -79,6 +84,126 @@ def save_models(e, step=0): save_models(e, step=len(dataloader)) +class DataModule(pl.LightningDataModule): + def __init__(self, args, **kwargs): + super().__init__() + self.args = args + + train_dataloader = Im2LatexDataset().load(args.data) + train_dataloader.update(**args, test=False) + val_dataloader = Im2LatexDataset().load(args.valdata) + val_args = args.copy() + val_args.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True) + val_dataloader.update(**val_args) + dataset_tokenizer = val_dataloader.tokenizer + + self.dataset_tokenizer = dataset_tokenizer + self.train_data = train_dataloader + self.valid_data = val_dataloader + + def train_dataloader(self): + return self.train_data + + def val_dataloader(self): + return self.valid_data + + +class OCR_Model(pl.LightningModule): + def __init__(self, args, dataset_tokenizer, **kwargs): + super().__init__() + self.args = args + self.dataset_tokenizer = dataset_tokenizer + + model = get_model(args) + if args.load_chkpt is not None: + model.load_state_dict(torch.load(args.load_chkpt)) + self.model = model + if torch.cuda.is_available() and not args.no_cuda: + gpu_memory_check(model, args) + + microbatch = args.get('micro_batchsize', -1) + if microbatch == -1: + microbatch = args.batchsize + self.microbatch = microbatch + + def forward(self, x): + return self.model(x) + + def configure_optimizers(self): + args = self.args + opt = get_optimizer(args.optimizer)(self.model.parameters(), args.lr, betas=args.betas) + scheduler = get_scheduler(args.scheduler)(opt, step_size=args.lr_step, gamma=args.gamma) + return { + "optimizer": opt, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "train_loss", + } + } + + def training_step(self, train_batch, batch_idx): + args = self.args + (seq, im) = train_batch + if seq is not None and im is not None: + total_loss = 0 + for j in range(0, len(im), self.microbatch): + tgt_seq, tgt_mask = seq['input_ids'][j:j+self.microbatch], seq['attention_mask'][j:j+self.microbatch].bool() + loss = self.model.data_parallel(im[j:j+self.microbatch], device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*self.microbatch/args.batchsize + total_loss += loss + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) + if args.wandb: + wandb.log({'train/loss': total_loss}) + + self.log('train_loss', total_loss, on_epoch=True, on_step=False, prog_bar=True) + return total_loss + + def validation_step(self, val_batch, batch_idx): + bleu_score, edit_distance, token_accuracy = evaluate_step(self.model, self.dataset_tokenizer, val_batch, self.args, name='val') + metric_dict = {'bleu_score': bleu_score, 'edit_distance': edit_distance, 'token_accuracy': token_accuracy} + self.log_dict(metric_dict, on_epoch=True, on_step=False, prog_bar=True) + return metric_dict + + def on_train_epoch_end(self): + if self.args.wandb: + wandb.log({'train/epoch': self.current_epoch+1}) + + +class OCR(): + def __init__(self, args): + self.args = args + self.logger = CSVLogger(save_dir='pl_logs', name='') + self.out_path = os.path.join(args.model_path, args.name) + os.makedirs(self.out_path, exist_ok=True) + self.data_model_setup() + self.callbacks_setup() + + def data_model_setup(self): + self.Data = DataModule(self.args) + dataset_tokenizer = self.Data.dataset_tokenizer + self.Model = OCR_Model(self.args, dataset_tokenizer) + + def callbacks_setup(self): + save_name = f'pl_{args.name}' + '_{epoch}_{step}' + + # NOTE: currently lightning doesn't support multiple monitor metrics + save_ckpt = ModelCheckpoint(monitor='bleu_score', mode='max', filename=save_name, dirpath=self.out_path, + every_n_epochs=self.args.save_freq, save_top_k=10, save_last=True) + + # BUG: exp_save_name was alaways like pl_pix2tex_0_0.ckpt. possibly a bug in lightning + exp_save_name = f'pl_pix2tex_{self.Model.current_epoch}_{self.Model.global_step}' + excpt = OnExceptionCheckpoint(dirpath=self.out_path, filename=exp_save_name) + bar = RichProgressBar(leave=True, theme=RichProgressBarTheme( + description='green_yellow', progress_bar='green1', progress_bar_finished='green1')) + self.callbacks = [save_ckpt, excpt, bar] + + def fit(self): + args = self.args + accelerator = 'gpu' if torch.cuda.is_available() and not args.no_cuda else 'cpu' + trainer = pl.Trainer(accelerator=accelerator, callbacks=self.callbacks, logger=self.logger, + max_epochs=args.epochs, val_check_interval=args.sample_freq) + trainer.fit(self.Model, self.Data) + + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Train model') parser.add_argument('--config', default=None, help='path to yaml config file', type=str) @@ -99,4 +224,7 @@ def save_models(e, step=0): args.id = wandb.util.generate_id() wandb.init(config=dict(args), resume='allow', name=args.name, id=args.id) args = Munch(wandb.config) - train(args) + # train(args) + + ocr = OCR(args) + ocr.fit() diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..26ef95a --- /dev/null +++ b/run.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +if [ $1 == "setup" ]; then + echo "Setting up python virtual environment" + echo "Entering virtual environment" + source ./venv/bin/activate + + pip3 install 'pix2tex[train]' + pip3 install pytorch-lightning rich + + # install and login wandb + pip3 install wandb + wandb login + +elif [ $1 == "generate" ]; then + echo "Generate images dataset" + # eg. python3 -m pix2tex.dataset.dataset --equations path_to_textfile --images path_to_images --out dataset.pkl + python3 -m pix2tex.dataset.dataset --equations pix2tex/dataset/data/math.txt --images pix2tex/dataset/data/train --out pix2tex/dataset/data/train.pkl + python3 -m pix2tex.dataset.dataset --equations pix2tex/dataset/data/math.txt --images pix2tex/dataset/data/val --out pix2tex/dataset/data/val.pkl + +elif [ $1 == "train" ]; then + echo "Training model" + python3 -m pix2tex.train --config pix2tex/model/settings/config.yaml + +else + echo "Invalid argument" + echo "Usage: ./run.sh [setup|generate|train|test]" +fi