diff --git a/pix2tex/dataset/dataset.py b/pix2tex/dataset/dataset.py index aa6dec5..7e7fc37 100644 --- a/pix2tex/dataset/dataset.py +++ b/pix2tex/dataset/dataset.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import IterableDataset, DataLoader import numpy as np import imagesize import logging @@ -15,10 +16,10 @@ from pix2tex.utils.utils import in_model_path from pix2tex.dataset.transforms import train_transform, test_transform +import math - -class Im2LatexDataset: +class Im2LatexDataset(IterableDataset): keep_smaller_batches = False shuffle = True batchsize = 16 @@ -33,6 +34,7 @@ class Im2LatexDataset: eos_token_id = 2 transform = train_transform data = defaultdict(lambda: []) + permutation = None def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, batchsize=16, max_seq_len=1024, max_dimensions=(1024, 512), min_dimensions=(32, 32), pad=False, keep_smaller_batches=False, test=False): @@ -42,7 +44,7 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba equations (str, optional): Path to equations. Defaults to None. images (str, optional): Directory where images are saved. Defaults to None. tokenizer (str, optional): Path to saved tokenizer. Defaults to None. - shuffle (bool, opitonal): Defaults to True. + shuffle (bool, opitonal): Defaults to True. batchsize (int, optional): Defaults to 16. max_seq_len (int, optional): Defaults to 1024. max_dimensions (tuple(int, int), optional): Maximal dimensions the model can handle @@ -75,13 +77,14 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba self.data[(width, height)].append((eqs[self.indices[i]], im)) except KeyboardInterrupt: pass + # formula&image pairs grouped by image size self.data = dict(self.data) self._get_size() - + self._shuffle() iter(self) def __len__(self): - return self.size + return self.size # total number of batches given the batchsize def __iter__(self): self.i = 0 @@ -89,18 +92,24 @@ def __iter__(self): self.pairs = [] for k in self.data: info = np.array(self.data[k], dtype=object) - p = torch.randperm(len(info)) if self.shuffle else torch.arange(len(info)) for i in range(0, len(info), self.batchsize): - batch = info[p[i:i+self.batchsize]] + batch = info[i:i+self.batchsize] if len(batch.shape) == 1: batch = batch[None, :] if len(batch) < self.batchsize and not self.keep_smaller_batches: continue self.pairs.append(batch) - if self.shuffle: - self.pairs = np.random.permutation(np.array(self.pairs, dtype=object)) + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + # configure the dataset to only process the split workload + per_worker = int(math.ceil(self.size/float(worker_info.num_workers))) + worker_id = worker_info.id + self.start = worker_id * per_worker + self.end = min(self.start + per_worker, self.size) else: - self.pairs = np.array(self.pairs, dtype=object) + self.start, self.end = 0, self.size + + self.pairs = np.array(self.pairs, dtype=object)[self.permutation[self.start:self.end]] self.size = len(self.pairs) return self @@ -121,6 +130,8 @@ def prepare_data(self, batch): """ eqs, ims = batch.T + # for im in ims: + # print(im) tok = self.tokenizer(list(eqs), return_token_type_ids=False) # pad with bos and eos token for k, p in zip(tok, [[self.bos_token_id, self.eos_token_id], [1, 1]]): @@ -155,6 +166,15 @@ def _get_size(self): for k in self.data: div, mod = divmod(len(self.data[k]), self.batchsize) self.size += div # + (1 if mod > 0 else 0) + if self.permutation is None or len(self.permutation) != self.size: + self._shuffle() + + def _shuffle(self): + if self.shuffle: + self.permutation = np.random.permutation(self.size) + else: + self.permutation = np.arange(self.size) + return self def load(self, filename, args=[]): """returns a pickled version of a dataset @@ -169,6 +189,7 @@ def load(self, filename, args=[]): filename = os.path.realpath(tempf) with open(filename, 'rb') as file: x = pickle.load(file) + x._get_size() return x def combine(self, x): @@ -216,7 +237,19 @@ def update(self, **kwargs): tokenizer_file = os.path.realpath(tokenizer_file) self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) self._get_size() - iter(self) + return iter(self) + + +class Dataloader(DataLoader): + def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, drop_last=True, num_workers=0, pin_memory=False): + self.dataset = dataset + self.tokenizer = dataset.tokenizer + self.dataset.update(batchsize=batch_size, shuffle=shuffle, keep_smaller_batches=not drop_last) + super().__init__(self.dataset, num_workers=num_workers, shuffle=False, batch_size=None, pin_memory=pin_memory) + + def __iter__(self): + self.dataset._shuffle() + return super().__iter__() def generate_tokenizer(equations, output, vocab_size): diff --git a/pix2tex/eval.py b/pix2tex/eval.py index 8742988..b47f486 100644 --- a/pix2tex/eval.py +++ b/pix2tex/eval.py @@ -1,4 +1,4 @@ -from pix2tex.dataset.dataset import Im2LatexDataset +from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader import argparse import logging import yaml @@ -28,12 +28,12 @@ def detokenize(tokens, tokenizer): @torch.no_grad() -def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'): +def evaluate(model: Model, dataset: Dataloader, args: Munch, num_batches: int = None, name: str = 'test'): """evaluates the model. Returns bleu score on the dataset Args: model (torch.nn.Module): the model - dataset (Im2LatexDataset): test dataset + dataset (Dataloader): test dataset args (Munch): arguments num_batches (int): How many batches to evaluate on. Defaults to None (all batches). name (str, optional): name of the test e.g. val or test for wandb. Defaults to 'test'. @@ -46,7 +46,7 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i log = {} bleus, edit_dists, token_acc = [], [], [] bleu_score, edit_distance, token_accuracy = 0, 1, 0 - pbar = tqdm(enumerate(iter(dataset)), total=len(dataset)) + pbar = tqdm(enumerate(dataset), total=len(dataset)) for i, (seq, im) in pbar: if seq is None or im is None: continue diff --git a/pix2tex/model/settings/config-vit.yaml b/pix2tex/model/settings/config-vit.yaml index 3d94e84..162880a 100644 --- a/pix2tex/model/settings/config-vit.yaml +++ b/pix2tex/model/settings/config-vit.yaml @@ -1,4 +1,5 @@ gpu_devices: null #[0,1,2,3,4,5,6,7] +num_workers: 0 betas: - 0.9 - 0.999 diff --git a/pix2tex/model/settings/config.yaml b/pix2tex/model/settings/config.yaml index a579f9e..c19dca5 100644 --- a/pix2tex/model/settings/config.yaml +++ b/pix2tex/model/settings/config.yaml @@ -1,4 +1,5 @@ gpu_devices: null #[0,1,2,3,4,5,6,7] +num_workers: 0 backbone_layers: - 2 - 3 diff --git a/pix2tex/model/settings/debug.yaml b/pix2tex/model/settings/debug.yaml index 94e3b77..7026fa2 100644 --- a/pix2tex/model/settings/debug.yaml +++ b/pix2tex/model/settings/debug.yaml @@ -65,3 +65,7 @@ pad: False pad_token: 0 bos_token: 1 eos_token: 2 + +#devices(GPU&CPU) +num_workers: 0 +gpu_devices: null #[0,1,2,3,4,5,6,7] \ No newline at end of file diff --git a/pix2tex/train.py b/pix2tex/train.py index bd2f599..309cd18 100644 --- a/pix2tex/train.py +++ b/pix2tex/train.py @@ -1,4 +1,4 @@ -from pix2tex.dataset.dataset import Im2LatexDataset +from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader import os import argparse import logging @@ -16,12 +16,10 @@ def train(args): - dataloader = Im2LatexDataset().load(args.data) - dataloader.update(**args, test=False) - valdataloader = Im2LatexDataset().load(args.valdata) - valargs = args.copy() - valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True) - valdataloader.update(**valargs) + train_dataset = Im2LatexDataset().load(args.data).update(**args, test=False) + train_dataloader = Dataloader(train_dataset, batch_size=args.batchsize, num_workers=args.num_workers, pin_memory=args.pin_memory) + val_dataset = Im2LatexDataset().load(args.valdata).update(**args, test=True) + val_dataloader = Dataloader(val_dataset, batch_size=args.testbatchsize, num_workers=args.num_workers, drop_last=False, pin_memory=args.pin_memory) device = args.device model = get_model(args) if torch.cuda.is_available() and not args.no_cuda: @@ -47,7 +45,7 @@ def save_models(e, step=0): try: for e in range(args.epoch, args.epochs): args.epoch = e - dset = tqdm(iter(dataloader)) + dset = tqdm(train_dataloader) for i, (seq, im) in enumerate(dset): if seq is not None and im is not None: opt.zero_grad() @@ -63,20 +61,20 @@ def save_models(e, step=0): dset.set_description('Loss: %.4f' % total_loss) if args.wandb: wandb.log({'train/loss': total_loss}) - if (i+1+len(dataloader)*e) % args.sample_freq == 0: - bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val') + if (i+1+len(train_dataloader)*e) % args.sample_freq == 0: + bleu_score, edit_distance, token_accuracy = evaluate(model, val_dataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val') if bleu_score > max_bleu and token_accuracy > max_token_acc: max_bleu, max_token_acc = bleu_score, token_accuracy save_models(e, step=i) if (e+1) % args.save_freq == 0: - save_models(e, step=len(dataloader)) + save_models(e, step=len(train_dataloader)) if args.wandb: wandb.log({'train/epoch': e+1}) except KeyboardInterrupt: if e >= 2: save_models(e, step=i) raise KeyboardInterrupt - save_models(e, step=len(dataloader)) + save_models(e, step=len(train_dataloader)) if __name__ == '__main__': diff --git a/pix2tex/utils/utils.py b/pix2tex/utils/utils.py index 2b5f920..e07ac4c 100644 --- a/pix2tex/utils/utils.py +++ b/pix2tex/utils/utils.py @@ -55,6 +55,8 @@ def parse_args(args, **kwargs) -> Munch: args.update(kwargs) args.wandb = not kwargs.debug and not args.debug args.device = get_device(args, kwargs.no_cuda) + args.num_workers = args.get('num_workers', 0) + args.pin_memory = args.get('pin_memory', False) args.max_dimensions = [args.max_width, args.max_height] args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)] if 'decoder_args' not in args or args.decoder_args is None: