diff --git a/.azure-pipelines/templates/ut-size-alignment.yaml b/.azure-pipelines/templates/ut-size-alignment.yaml new file mode 100644 index 000000000..2e471958a --- /dev/null +++ b/.azure-pipelines/templates/ut-size-alignment.yaml @@ -0,0 +1,151 @@ +# .azure-pipelines/templates/ut-size-alignment.yaml +# ---------------------------------------- +# A step-template that runs unit tests to track the input size alignment using PyTorch and MSCCLPP. +# +# Parameters: +# subscription – Azure subscription to use for VMSS start/stop +# vmssName – Name of the VMSS to use +# sshKeySecureFile – the secureFile name for your SSH key +# pytorchImage – PyTorch Docker image to use for unit tests + +parameters: +- name: subscription + type: string +- name: vmssName + type: string +- name: sshKeySecureFile + type: string +- name: pytorchImage + type: string + default: "mcr.microsoft.com/mirror/nvcr/nvidia/pytorch:25.03-py3" + +steps: +- task: Bash@3 + name: Build + displayName: Build MSCCLPP + inputs: + targetType: inline + script: | + mkdir build && cd build + cmake -DCMAKE_BUILD_TYPE=Release -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON .. + make -j + workingDirectory: '$(System.DefaultWorkingDirectory)' + +- task: Bash@3 + name: InstallPackages + displayName: Install Packages + inputs: + targetType: inline + script: | + sudo apt-get update -y + sudo apt-get install pssh -y + curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash + +- task: DownloadSecureFile@1 + name: SshKeyFile + displayName: Download key file + inputs: + secureFile: ${{ parameters.sshKeySecureFile }} + +- task: AzureCLI@2 + name: StartVMSS + displayName: Start VMSS + inputs: + azureSubscription: ${{ parameters.subscription }} + scriptType: bash + scriptLocation: inlineScript + inlineScript: | + az vmss start --name ${{ parameters.vmssName }} --resource-group mscclpp + +- task: Bash@3 + name: DeployBenchmarkEnv + displayName: Deploy Benchmark Environment + inputs: + targetType: inline + script: | + set -e + HOSTFILE=$(System.DefaultWorkingDirectory)/test/deploy/hostfile_ci + ROOT_DIR=$(System.DefaultWorkingDirectory) + SSH_OPTION="StrictHostKeyChecking=no" + KeyFilePath=${SSHKEYFILE_SECUREFILEPATH} + + # Get the host IP + HOST_IP=$(head -1 ${HOSTFILE} | awk '{print $1}') + + # Pull the PyTorch image on the remote machine + ssh -i ${KeyFilePath} -o ${SSH_OPTION} ${HOST_IP} "sudo docker pull ${{ parameters.pytorchImage }}" + + # Stop any existing container + ssh -i ${KeyFilePath} -o ${SSH_OPTION} ${HOST_IP} "sudo docker rm -f mscclpp-benchmark || true" + + # Copy MSCCLPP build to the remote machine + scp -i ${KeyFilePath} -o ${SSH_OPTION} -r ${ROOT_DIR}/build ${HOST_IP}:/tmp/mscclpp-build + scp -i ${KeyFilePath} -o ${SSH_OPTION} -r ${ROOT_DIR}/test ${HOST_IP}:/tmp/mscclpp-test + + # Start the PyTorch container with MSCCLPP mounted + ssh -i ${KeyFilePath} -o ${SSH_OPTION} ${HOST_IP} "sudo docker run -d --name mscclpp-benchmark \ + --gpus all \ + --ipc=host \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + -v /tmp/mscclpp-build:/root/mscclpp/build \ + -v /tmp/mscclpp-test:/root/mscclpp/test \ + ${{ parameters.pytorchImage }} \ + sleep infinity" + workingDirectory: '$(System.DefaultWorkingDirectory)' + +- task: Bash@3 + name: AllReduceBenchmarkMultipleSizes + displayName: Run AllReduce Benchmark Test (Multiple Sizes to test the alignment) + inputs: + targetType: inline + script: | + set -e + HOSTFILE=$(System.DefaultWorkingDirectory)/test/deploy/hostfile_ci + SSH_OPTION="StrictHostKeyChecking=no" + KeyFilePath=${SSHKEYFILE_SECUREFILEPATH} + HOST_IP=$(head -1 ${HOSTFILE} | awk '{print $1}') + + : > azureuser@${HOST_IP} + tail -f azureuser@${HOST_IP} & + CHILD_PID=$! + + # Run with different element counts + for nelem in 10556576 10556587 10556592 10556608 1048576 9999999 12345678; do + echo "Running AllReduce benchmark with nelem=${nelem}" + ssh -i ${KeyFilePath} -o ${SSH_OPTION} ${HOST_IP} "sudo docker exec -t mscclpp-benchmark bash -c \" \ + set -e; \ + cd /root/mscclpp; \ + LD_PRELOAD=/root/mscclpp/build/apps/nccl/libmscclpp_nccl.so torchrun --nproc_per_node=8 test/torch/correctness_test.py --collective allreduce --nelem ${nelem} --dtype float\"" + done + + kill $CHILD_PID || true + workingDirectory: '$(System.DefaultWorkingDirectory)' + +- task: Bash@3 + name: Cleanup + displayName: Cleanup Benchmark Container + condition: always() + inputs: + targetType: inline + script: | + set -e + HOSTFILE=$(System.DefaultWorkingDirectory)/test/deploy/hostfile_ci + SSH_OPTION="StrictHostKeyChecking=no" + KeyFilePath=${SSHKEYFILE_SECUREFILEPATH} + HOST_IP=$(head -1 ${HOSTFILE} | awk '{print $1}') + + ssh -i ${KeyFilePath} -o ${SSH_OPTION} ${HOST_IP} "sudo docker rm -f mscclpp-benchmark || true" + ssh -i ${KeyFilePath} -o ${SSH_OPTION} ${HOST_IP} "sudo rm -rf /tmp/mscclpp-build /tmp/mscclpp-test || true" + workingDirectory: '$(System.DefaultWorkingDirectory)' + +- task: AzureCLI@2 + name: StopVMSS + displayName: Deallocate VMSS + condition: always() + inputs: + azureSubscription: ${{ parameters.subscription }} + scriptType: bash + scriptLocation: inlineScript + inlineScript: | + az vmss deallocate --name ${{ parameters.vmssName }} --resource-group mscclpp diff --git a/.azure-pipelines/ut-size-alignment.yaml b/.azure-pipelines/ut-size-alignment.yaml new file mode 100644 index 000000000..8ccd0700e --- /dev/null +++ b/.azure-pipelines/ut-size-alignment.yaml @@ -0,0 +1,42 @@ +trigger: + branches: + include: + - main + - release/* + paths: + exclude: + - .devcontainer/** + - .github/** + - docker/** + - docs/** + - '**/*.md' + +pr: + branches: + include: + - main + - release/* + drafts: false + paths: + exclude: + - .devcontainer/** + - .github/** + - docker/** + - docs/** + - '**/*.md' + +jobs: +- job: BenchmarkTestH100 + displayName: Benchmark Test H100 + pool: + name: msccl-ci-h100 + container: + image: ghcr.io/microsoft/mscclpp/mscclpp:base-dev-cuda12.4 + + steps: + - template: templates/ut-size-alignment.yaml + parameters: + subscription: mscclpp-ci-h100 + vmssName: mscclpp-h100-ci + sshKeySecureFile: mscclpp.pem + pytorchImage: mcr.microsoft.com/mirror/nvcr/nvidia/pytorch:25.03-py3 diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index 82adc323b..618f10795 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -963,8 +963,20 @@ __global__ void __launch_bounds__(1024, 1) int nBlocksForReduce = nRanksPerNode; int copyReduceRatio = nBlocksForCopy / nBlocksForReduce; size_t scratchSizePerRank = scratchBufferSize / nRanksPerNode; - size_t sizePerRank = size / nRanksPerNode; - assert(sizePerRank % alignment == 0); + + // Pad size to be divisible by (nRanksPerNode * alignment) + size_t paddingNeeded = + (nRanksPerNode * alignment - (size % (nRanksPerNode * alignment))) % (nRanksPerNode * alignment); + size_t paddedSize = size + paddingNeeded; + size_t sizePerRank = paddedSize / nRanksPerNode; + + // Calculate actual size this rank should process (without padding) + size_t actualSizeThisRank = sizePerRank; + if (rank == nRanksPerNode - 1) { + // Last rank might have less actual data due to padding + actualSizeThisRank = size - (sizePerRank * (nRanksPerNode - 1)); + } + uint32_t sizePerBlock = ((sizePerRank + (nBlocksForCopy - 1)) / nBlocksForCopy + alignment - 1) / alignment * alignment; uint32_t lastBlockSize = sizePerRank - (nBlocksForCopy - 1) * sizePerBlock; @@ -1008,7 +1020,17 @@ __global__ void __launch_bounds__(1024, 1) uint32_t scratchOffset = scratchIt * unitSize + bid * scratchSizePerBlock + i * scratchSizePerRank; char* srcData = (char*)src + blockOffset; char* dstData = (char*)scratch + scratchOffset; - mscclpp::copy(dstData, srcData, iterSize, tid, blockDim.x); + // Calculate actual copy size - don't copy beyond actual data on last rank + size_t actualCopySize = iterSize; + if (i == nRanksPerNode - 1 && blockOffset + iterSize > i * sizePerRank + actualSizeThisRank) { + // On last rank, clamp to actual data size + actualCopySize = (i * sizePerRank + actualSizeThisRank > blockOffset) + ? (i * sizePerRank + actualSizeThisRank - blockOffset) + : 0; + } + if (actualCopySize > 0) { + mscclpp::copy(dstData, srcData, actualCopySize, tid, blockDim.x); + } } __syncthreads(); if (tid < nPeers) { @@ -1067,7 +1089,16 @@ __global__ void __launch_bounds__(1024, 1) i * scratchSizePerRank; char* srcData = (char*)scratch + scratchOffset; char* dstData = (char*)dst + blockOffset; - mscclpp::copy(dstData, srcData, iterSize, tid, blockDim.x); + + size_t actualCopySize = iterSize; + if (i == nRanksPerNode - 1 && blockOffset + iterSize > i * sizePerRank + actualSizeThisRank) { + actualCopySize = (i * sizePerRank + actualSizeThisRank > blockOffset) + ? (i * sizePerRank + actualSizeThisRank - blockOffset) + : 0; + } + if (actualCopySize > 0) { + mscclpp::copy(dstData, srcData, actualCopySize, tid, blockDim.x); + } } __syncthreads(); if (tid == 0) { diff --git a/test/torch/correctness_test.py b/test/torch/correctness_test.py index ca50064e2..6fbd1f912 100644 --- a/test/torch/correctness_test.py +++ b/test/torch/correctness_test.py @@ -65,7 +65,9 @@ def _init_dist(): rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ.get("LOCAL_RANK", os.environ["RANK"])) - dist.init_process_group(backend=backend, rank=rank, world_size=world_size, device_id=local_rank) + dist.init_process_group( + backend=backend, rank=rank, world_size=world_size, device_id=torch.device(f"cuda:{local_rank}") + ) torch.cuda.set_device(local_rank)