Skip to content

Commit 95b50d7

Browse files
committed
Pull jik876#25 Solve imbalanced gpu memory at multi-gpu distributed training
1 parent 6d0412f commit 95b50d7

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def train(rank, a, h):
2727
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
2828

2929
torch.cuda.manual_seed(h.seed)
30+
torch.cuda.set_device(rank)
3031
device = torch.device('cuda:{:d}'.format(rank))
3132

3233
generator = Generator(h).to(device)

0 commit comments

Comments
 (0)