Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pix2tex/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def update(self, **kwargs):
class Dataloader(DataLoader):
def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, *args, **kwargs):
self.dataset = dataset
self.dataset.update(batchsize=batch_size, shuffle=shuffle)
self.dataset.update(batchsize=batch_size, shuffle=shuffle, *args, **kwargs)
super().__init__(self.dataset, *args, shuffle=False, batch_size=None, **kwargs)

def __iter__(self):
Expand Down
6 changes: 3 additions & 3 deletions pix2tex/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pix2tex.dataset.dataset import Im2LatexDataset
from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader
import argparse
import logging
import yaml
Expand Down Expand Up @@ -28,12 +28,12 @@ def detokenize(tokens, tokenizer):


@torch.no_grad()
def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'):
def evaluate(model: Model, dataset: Dataloader, args: Munch, num_batches: int = None, name: str = 'test'):
"""evaluates the model. Returns bleu score on the dataset

Args:
model (torch.nn.Module): the model
dataset (Im2LatexDataset): test dataset
dataset (Dataloader): test dataset
args (Munch): arguments
num_batches (int): How many batches to evaluate on. Defaults to None (all batches).
name (str, optional): name of the test e.g. val or test for wandb. Defaults to 'test'.
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/config-vit.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
num_workers: 0
betas:
- 0.9
- 0.999
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
num_workers: 0
backbone_layers:
- 2
- 3
Expand Down
4 changes: 4 additions & 0 deletions pix2tex/model/settings/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ pad: False
pad_token: 0
bos_token: 1
eos_token: 2

#devices(GPU&CPU)
num_workers: 0
gpu_devices: null #[0,1,2,3,4,5,6,7]
20 changes: 10 additions & 10 deletions pix2tex/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pix2tex.dataset.dataset import Im2LatexDataset
from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader
import os
import argparse
import logging
Expand All @@ -16,12 +16,12 @@


def train(args):
dataloader = Im2LatexDataset().load(args.data)
dataloader.update(**args, test=False)
valdataloader = Im2LatexDataset().load(args.valdata)
train_dataset = Im2LatexDataset().load(args.data)
train_dataloader = Dataloader(train_dataset, **args, test=False)
val_dataset = Im2LatexDataset().load(args.valdata)
valargs = args.copy()
valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)
valdataloader.update(**valargs)
val_dataloader = Dataloader(val_dataset, **valargs)
device = args.device
model = get_model(args)
if torch.cuda.is_available() and not args.no_cuda:
Expand All @@ -47,7 +47,7 @@ def save_models(e, step=0):
try:
for e in range(args.epoch, args.epochs):
args.epoch = e
dset = tqdm(iter(dataloader))
dset = tqdm(iter(train_dataloader))
for i, (seq, im) in enumerate(dset):
if seq is not None and im is not None:
opt.zero_grad()
Expand All @@ -63,20 +63,20 @@ def save_models(e, step=0):
dset.set_description('Loss: %.4f' % total_loss)
if args.wandb:
wandb.log({'train/loss': total_loss})
if (i+1+len(dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if (i+1+len(train_dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, val_dataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if bleu_score > max_bleu and token_accuracy > max_token_acc:
max_bleu, max_token_acc = bleu_score, token_accuracy
save_models(e, step=i)
if (e+1) % args.save_freq == 0:
save_models(e, step=len(dataloader))
save_models(e, step=len(train_dataloader))
if args.wandb:
wandb.log({'train/epoch': e+1})
except KeyboardInterrupt:
if e >= 2:
save_models(e, step=i)
raise KeyboardInterrupt
save_models(e, step=len(dataloader))
save_models(e, step=len(train_dataloader))


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions pix2tex/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def parse_args(args, **kwargs) -> Munch:
args.update(kwargs)
args.wandb = not kwargs.debug and not args.debug
args.device = get_device(args, kwargs.no_cuda)
args.num_workers = args.get('num_workers', 0)
args.max_dimensions = [args.max_width, args.max_height]
args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)]
if 'decoder_args' not in args or args.decoder_args is None:
Expand Down