Skip to content

Commit d7c6d1c

Browse files
authored
fix ddp rank0 memory bug (#41)
1 parent aad7a74 commit d7c6d1c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/pytorch/llm/src/llm_sft.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ class SftArguments:
111111

112112
def __post_init__(self):
113113
if is_dist():
114-
rank = get_dist_setting()[0]
114+
rank, local_rank, _, _ = get_dist_setting()
115+
torch.cuda.set_device(local_rank)
115116
self.seed += rank # Avoid the same dropout
116117
if self.ddp_backend is None:
117118
self.ddp_backend = 'nccl'

0 commit comments

Comments
 (0)