Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/mscclpp/language/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class Instruction(Enum):
read_reduce_copy_send = "rrcs"
reduce_send = "rs"
copy = "copy"
reduce = "reduce"
reduce = "re"
copy_packet = "cpkt"
transform_to_packet = "tpkt"
reduce_send_packet = "rspkt"
Expand Down
32 changes: 21 additions & 11 deletions src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t
mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag);
}

template <typename T>
template <typename T, bool SendToRemote = true>
MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes,
T* input, uint32_t* inputOffsets,
T* input, uint32_t* inputOffsets, int nSrcs,
DeviceHandle<MemoryChannel>* memoryChannels, uint8_t* outputChannelIndexes,
uint32_t* outputOffsets, int nOutChannels, uint32_t size) {
const size_t nInt4 = size / sizeof(int4);
Expand All @@ -441,15 +441,17 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
int4* input4 = (int4*)input;
for (size_t idx = threadIdx.x; idx < nInt4; idx += blockDim.x) {
int4 tmp = src4[srcOffset4 + idx];
for (int index = 0; index < nOutChannels; ++index) {
for (int index = 0; index < nSrcs; ++index) {
size_t offset = inputOffsets[index] / sizeof(int4);
int4 val = input4[offset + idx];
tmp = add_vectors<T>(tmp, val);
}
dst4[dstOffset4 + idx] = tmp;
for (int index = 0; index < nOutChannels; ++index) {
size_t offset = outputOffsets[index] / sizeof(int4);
memoryChannels[outputChannelIndexes[index]].write<int4>(offset + idx, tmp);
if (SendToRemote) {
for (int index = 0; index < nOutChannels; ++index) {
size_t offset = outputOffsets[index] / sizeof(int4);
memoryChannels[outputChannelIndexes[index]].write<int4>(offset + idx, tmp);
}
}
}
// handle rest of data
Expand All @@ -458,14 +460,16 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
const size_t endIdx = (srcOffsetByBytes + size) / sizeof(T);
for (size_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x) {
T tmp = src[idx];
for (int index = 0; index < nOutChannels; ++index) {
for (int index = 0; index < nSrcs; ++index) {
size_t offset = inputOffsets[index] / sizeof(T);
tmp = add_elements(tmp, input[offset + idx]);
}
dst[idx] = tmp;
for (int index = 0; index < nOutChannels; ++index) {
size_t offset = outputOffsets[index] / sizeof(T);
memoryChannels[outputChannelIndexes[index]].write<T>(offset + idx, tmp);
if (SendToRemote) {
for (int index = 0; index < nOutChannels; ++index) {
size_t offset = outputOffsets[index] / sizeof(T);
memoryChannels[outputChannelIndexes[index]].write<T>(offset + idx, tmp);
}
}
}
}
Expand Down Expand Up @@ -624,6 +628,12 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
handleReduceSendPacket<T, PacketType>(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize, op.inputOffsets,
op.nInputs, memoryChannels, op.outputChannelIndexes, op.outputOffsets,
op.nOutputs, op.size, flag);
} else if (op.type == OperationType::REDUCE) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
T* tmp = getBuffer(input, output, scratch, op.inputBufferType);
handleReduceSend<T, false>(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, memoryChannels,
op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size);
} else if (op.type == OperationType::REDUCE_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
Expand All @@ -642,7 +652,7 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
T* tmp = getBuffer(input, output, scratch, op.inputBufferType);
handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, memoryChannels,
handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, memoryChannels,
op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size);
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
Expand Down
Loading