Skip to content

Commit dbf75d9

Browse files
authored
initial device context at args.device
if User A use gpu6,7 and User B use gpu0. Then UserB kills all process at gpu0 but User A's training also stopped. because `torch.cuda.empty_cache()` default initialize at rank0. Reference: pytorch/pytorch#25752 (comment)
1 parent 0938894 commit dbf75d9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pix2tex/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def gpu_memory_check(model, args):
2727
except RuntimeError:
2828
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
2929
model.zero_grad()
30-
torch.cuda.empty_cache()
30+
with torch.cuda.device(args.device):torch.cuda.empty_cache()
3131
del im, seq
3232

3333

0 commit comments

Comments
 (0)