Skip to content

Commit ac5cc64

Browse files
Reduce Operation Support to the Executor (#484)
Co-authored-by: Binyang Li <binyli@microsoft.com>
1 parent b406246 commit ac5cc64

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

python/mscclpp/language/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class Instruction(Enum):
8080
read_reduce_copy_send = "rrcs"
8181
reduce_send = "rs"
8282
copy = "copy"
83-
reduce = "reduce"
83+
reduce = "re"
8484
copy_packet = "cpkt"
8585
transform_to_packet = "tpkt"
8686
reduce_send_packet = "rspkt"

src/include/execution_kernel.hpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,9 @@ MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t
428428
mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag);
429429
}
430430

431-
template <typename T>
431+
template <typename T, bool SendToRemote = true>
432432
MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes,
433-
T* input, uint32_t* inputOffsets,
433+
T* input, uint32_t* inputOffsets, int nSrcs,
434434
DeviceHandle<MemoryChannel>* memoryChannels, uint8_t* outputChannelIndexes,
435435
uint32_t* outputOffsets, int nOutChannels, uint32_t size) {
436436
const size_t nInt4 = size / sizeof(int4);
@@ -441,15 +441,17 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
441441
int4* input4 = (int4*)input;
442442
for (size_t idx = threadIdx.x; idx < nInt4; idx += blockDim.x) {
443443
int4 tmp = src4[srcOffset4 + idx];
444-
for (int index = 0; index < nOutChannels; ++index) {
444+
for (int index = 0; index < nSrcs; ++index) {
445445
size_t offset = inputOffsets[index] / sizeof(int4);
446446
int4 val = input4[offset + idx];
447447
tmp = add_vectors<T>(tmp, val);
448448
}
449449
dst4[dstOffset4 + idx] = tmp;
450-
for (int index = 0; index < nOutChannels; ++index) {
451-
size_t offset = outputOffsets[index] / sizeof(int4);
452-
memoryChannels[outputChannelIndexes[index]].write<int4>(offset + idx, tmp);
450+
if (SendToRemote) {
451+
for (int index = 0; index < nOutChannels; ++index) {
452+
size_t offset = outputOffsets[index] / sizeof(int4);
453+
memoryChannels[outputChannelIndexes[index]].write<int4>(offset + idx, tmp);
454+
}
453455
}
454456
}
455457
// handle rest of data
@@ -458,14 +460,16 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
458460
const size_t endIdx = (srcOffsetByBytes + size) / sizeof(T);
459461
for (size_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x) {
460462
T tmp = src[idx];
461-
for (int index = 0; index < nOutChannels; ++index) {
463+
for (int index = 0; index < nSrcs; ++index) {
462464
size_t offset = inputOffsets[index] / sizeof(T);
463465
tmp = add_elements(tmp, input[offset + idx]);
464466
}
465467
dst[idx] = tmp;
466-
for (int index = 0; index < nOutChannels; ++index) {
467-
size_t offset = outputOffsets[index] / sizeof(T);
468-
memoryChannels[outputChannelIndexes[index]].write<T>(offset + idx, tmp);
468+
if (SendToRemote) {
469+
for (int index = 0; index < nOutChannels; ++index) {
470+
size_t offset = outputOffsets[index] / sizeof(T);
471+
memoryChannels[outputChannelIndexes[index]].write<T>(offset + idx, tmp);
472+
}
469473
}
470474
}
471475
}
@@ -624,6 +628,12 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
624628
handleReduceSendPacket<T, PacketType>(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize, op.inputOffsets,
625629
op.nInputs, memoryChannels, op.outputChannelIndexes, op.outputOffsets,
626630
op.nOutputs, op.size, flag);
631+
} else if (op.type == OperationType::REDUCE) {
632+
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
633+
T* src = getBuffer(input, output, scratch, op.srcBufferType);
634+
T* tmp = getBuffer(input, output, scratch, op.inputBufferType);
635+
handleReduceSend<T, false>(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, memoryChannels,
636+
op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size);
627637
} else if (op.type == OperationType::REDUCE_PACKET) {
628638
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
629639
T* src = getBuffer(input, output, scratch, op.srcBufferType);
@@ -642,7 +652,7 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
642652
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
643653
T* src = getBuffer(input, output, scratch, op.srcBufferType);
644654
T* tmp = getBuffer(input, output, scratch, op.inputBufferType);
645-
handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, memoryChannels,
655+
handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, memoryChannels,
646656
op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size);
647657
}
648658
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900

0 commit comments

Comments
 (0)