@@ -428,9 +428,9 @@ MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t
428428 mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x , blockDim.x , flag);
429429}
430430
431- template <typename T>
431+ template <typename T, bool SendToRemote = true >
432432MSCCLPP_DEVICE_INLINE void handleReduceSend (T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes,
433- T* input, uint32_t * inputOffsets,
433+ T* input, uint32_t * inputOffsets, int nSrcs,
434434 DeviceHandle<MemoryChannel>* memoryChannels, uint8_t * outputChannelIndexes,
435435 uint32_t * outputOffsets, int nOutChannels, uint32_t size) {
436436 const size_t nInt4 = size / sizeof (int4);
@@ -441,15 +441,17 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
441441 int4* input4 = (int4*)input;
442442 for (size_t idx = threadIdx.x ; idx < nInt4; idx += blockDim.x ) {
443443 int4 tmp = src4[srcOffset4 + idx];
444- for (int index = 0 ; index < nOutChannels ; ++index) {
444+ for (int index = 0 ; index < nSrcs ; ++index) {
445445 size_t offset = inputOffsets[index] / sizeof (int4);
446446 int4 val = input4[offset + idx];
447447 tmp = add_vectors<T>(tmp, val);
448448 }
449449 dst4[dstOffset4 + idx] = tmp;
450- for (int index = 0 ; index < nOutChannels; ++index) {
451- size_t offset = outputOffsets[index] / sizeof (int4);
452- memoryChannels[outputChannelIndexes[index]].write <int4>(offset + idx, tmp);
450+ if (SendToRemote) {
451+ for (int index = 0 ; index < nOutChannels; ++index) {
452+ size_t offset = outputOffsets[index] / sizeof (int4);
453+ memoryChannels[outputChannelIndexes[index]].write <int4>(offset + idx, tmp);
454+ }
453455 }
454456 }
455457 // handle rest of data
@@ -458,14 +460,16 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
458460 const size_t endIdx = (srcOffsetByBytes + size) / sizeof (T);
459461 for (size_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x ) {
460462 T tmp = src[idx];
461- for (int index = 0 ; index < nOutChannels ; ++index) {
463+ for (int index = 0 ; index < nSrcs ; ++index) {
462464 size_t offset = inputOffsets[index] / sizeof (T);
463465 tmp = add_elements (tmp, input[offset + idx]);
464466 }
465467 dst[idx] = tmp;
466- for (int index = 0 ; index < nOutChannels; ++index) {
467- size_t offset = outputOffsets[index] / sizeof (T);
468- memoryChannels[outputChannelIndexes[index]].write <T>(offset + idx, tmp);
468+ if (SendToRemote) {
469+ for (int index = 0 ; index < nOutChannels; ++index) {
470+ size_t offset = outputOffsets[index] / sizeof (T);
471+ memoryChannels[outputChannelIndexes[index]].write <T>(offset + idx, tmp);
472+ }
469473 }
470474 }
471475}
@@ -624,6 +628,12 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
624628 handleReduceSendPacket<T, PacketType>(dst, op.dstOffset , src, op.srcOffset , scratch, scratchSize, op.inputOffsets ,
625629 op.nInputs , memoryChannels, op.outputChannelIndexes , op.outputOffsets ,
626630 op.nOutputs , op.size , flag);
631+ } else if (op.type == OperationType::REDUCE) {
632+ T* dst = getBuffer (input, output, scratch, op.dstBufferType );
633+ T* src = getBuffer (input, output, scratch, op.srcBufferType );
634+ T* tmp = getBuffer (input, output, scratch, op.inputBufferType );
635+ handleReduceSend<T, false >(dst, op.dstOffset , src, op.srcOffset , tmp, op.inputOffsets , op.nInputs , memoryChannels,
636+ op.outputChannelIndexes , op.outputOffsets , op.nOutputs , op.size );
627637 } else if (op.type == OperationType::REDUCE_PACKET) {
628638 T* dst = getBuffer (input, output, scratch, op.dstBufferType );
629639 T* src = getBuffer (input, output, scratch, op.srcBufferType );
@@ -642,7 +652,7 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
642652 T* dst = getBuffer (input, output, scratch, op.dstBufferType );
643653 T* src = getBuffer (input, output, scratch, op.srcBufferType );
644654 T* tmp = getBuffer (input, output, scratch, op.inputBufferType );
645- handleReduceSend (dst, op.dstOffset , src, op.srcOffset , tmp, op.inputOffsets , memoryChannels,
655+ handleReduceSend (dst, op.dstOffset , src, op.srcOffset , tmp, op.inputOffsets , op. nInputs , memoryChannels,
646656 op.outputChannelIndexes , op.outputOffsets , op.nOutputs , op.size );
647657 }
648658#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
0 commit comments