Skip to content

Commit 443f61a

Browse files
committed
Reduce mem usage of torch nccl all-reduce
To lower mem requirement in tensor parallel cases See https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
1 parent d1df979 commit 443f61a

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

launcher/src/main.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,9 @@ fn shard_manager(
499499
env.push(("MASTER_PORT".into(), master_port.to_string().into()));
500500
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
501501

502+
// See https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
503+
env.push(("TORCH_NCCL_AVOID_RECORD_STREAMS".into(), "1".into()));
504+
502505
// Safetensors load fast
503506
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
504507

0 commit comments

Comments
 (0)