You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq)
26
-
loss.sum().backward()
27
-
except RuntimeError:
28
-
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
29
-
model.zero_grad()
30
-
with torch.cuda.device(args.device):torch.cuda.empty_cache()
31
-
del im, seq
15
+
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler, gpu_memory_check
assert max(args.gpu_devices) < available_gpus, "legal gpu_devices should in [%s], received [%s]" % (','.join(map(str, range(available_gpus))),','.join(map(str, args.gpu_devices)))
72
+
assert max(args.gpu_devices) < available_gpus, "legal gpu_devices should in [%s], received [%s]" % (','.join(map(str, range(available_gpus))),','.join(map(str, args.gpu_devices)))
72
73
return device
73
74
74
75
76
+
def gpu_memory_check(model, args):
77
+
# check if largest batch can be handled by system
78
+
try:
79
+
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
80
+
for _ in range(5):
81
+
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq)
84
+
loss.sum().backward()
85
+
except RuntimeError:
86
+
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
0 commit comments