Skip to content

Commit 9f974c9

Browse files
committed
small adjustments
1 parent dbf75d9 commit 9f974c9

File tree

4 files changed

+33
-29
lines changed

4 files changed

+33
-29
lines changed

pix2tex/model/settings/config-vit.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ max_seq_len: 512
2929
max_width: 672
3030
min_height: 32
3131
min_width: 32
32-
micro_batchsize: 64
32+
micro_batchsize: -1
3333
model_path: checkpoints_add
3434
name: pix2tex-vit
3535
num_layers: 4

pix2tex/model/settings/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ backbone_layers:
66
betas:
77
- 0.9
88
- 0.999
9-
batchsize: 10
9+
batchsize: 64
1010
bos_token: 1
1111
channels: 1
1212
data: dataset/data/train.pkl

pix2tex/train.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,7 @@
1212
from pix2tex.eval import evaluate
1313
from pix2tex.models import get_model
1414
# from pix2tex.utils import *
15-
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler
16-
17-
18-
def gpu_memory_check(model, args):
19-
# check if largest batch can be handled by system
20-
try:
21-
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
22-
for _ in range(5):
23-
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
24-
seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long()
25-
loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq)
26-
loss.sum().backward()
27-
except RuntimeError:
28-
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
29-
model.zero_grad()
30-
with torch.cuda.device(args.device):torch.cuda.empty_cache()
31-
del im, seq
15+
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler, gpu_memory_check
3216

3317

3418
def train(args):
@@ -40,13 +24,15 @@ def train(args):
4024
valdataloader.update(**valargs)
4125
device = args.device
4226
model = get_model(args)
43-
gpu_memory_check(model, args)
44-
if args.load_chkpt is not None:
45-
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))
27+
if torch.cuda.is_available() and not args.no_cuda:
28+
gpu_memory_check(model, args)
4629
max_bleu, max_token_acc = 0, 0
4730
out_path = os.path.join(args.model_path, args.name)
4831
os.makedirs(out_path, exist_ok=True)
4932

33+
if args.load_chkpt is not None:
34+
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))
35+
5036
def save_models(e, step=0):
5137
torch.save(model.state_dict(), os.path.join(out_path, '%s_e%02d_step%02d.pth' % (args.name, e+1, step)))
5238
yaml.dump(dict(args), open(os.path.join(out_path, 'config.yaml'), 'w+'))
@@ -88,9 +74,9 @@ def save_models(e, step=0):
8874
wandb.log({'train/epoch': e+1})
8975
except KeyboardInterrupt:
9076
if e >= 2:
91-
save_models(e)
77+
save_models(e, step=i)
9278
raise KeyboardInterrupt
93-
save_models(e)
79+
save_models(e, step=len(dataloader))
9480

9581

9682
if __name__ == '__main__':

pix2tex/utils/utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,26 +52,44 @@ def seed_everything(seed: int):
5252
def parse_args(args, **kwargs) -> Munch:
5353
args = Munch({'epoch': 0}, **args)
5454
kwargs = Munch({'no_cuda': False, 'debug': False}, **kwargs)
55+
args.update(kwargs)
5556
args.wandb = not kwargs.debug and not args.debug
56-
args.device = get_device(args, kwargs)
57+
args.device = get_device(args, kwargs.no_cuda)
5758
args.max_dimensions = [args.max_width, args.max_height]
5859
args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)]
5960
if 'decoder_args' not in args or args.decoder_args is None:
6061
args.decoder_args = {}
6162
return args
6263

6364

64-
def get_device(args, kwargs):
65+
def get_device(args, no_cuda=False):
6566
device = 'cpu'
6667
available_gpus = torch.cuda.device_count()
67-
args.gpu_devices = args.gpu_devices if args.get('gpu_devices', False) else range(available_gpus)
68-
if available_gpus > 0 and not kwargs.no_cuda:
68+
args.gpu_devices = args.gpu_devices if args.get('gpu_devices', False) else list(range(available_gpus))
69+
if available_gpus > 0 and not no_cuda:
6970
device = 'cuda:%d' % args.gpu_devices[0] if args.gpu_devices else 0
7071
assert available_gpus >= len(args.gpu_devices), "Available %d gpu, but specified gpu %s." % (available_gpus, ','.join(map(str, args.gpu_devices)))
71-
assert max(args.gpu_devices) < available_gpus, "legal gpu_devices should in [%s], received [%s]" % (','.join(map(str, range(available_gpus))),','.join(map(str, args.gpu_devices)))
72+
assert max(args.gpu_devices) < available_gpus, "legal gpu_devices should in [%s], received [%s]" % (','.join(map(str, range(available_gpus))), ','.join(map(str, args.gpu_devices)))
7273
return device
7374

7475

76+
def gpu_memory_check(model, args):
77+
# check if largest batch can be handled by system
78+
try:
79+
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
80+
for _ in range(5):
81+
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
82+
seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long()
83+
loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq)
84+
loss.sum().backward()
85+
except RuntimeError:
86+
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
87+
model.zero_grad()
88+
with torch.cuda.device(args.device):
89+
torch.cuda.empty_cache()
90+
del im, seq
91+
92+
7593
def token2str(tokens, tokenizer) -> list:
7694
if len(tokens.shape) == 1:
7795
tokens = tokens[None, :]

0 commit comments

Comments
 (0)