Skip to content

Commit 542a10f

Browse files
authored
Merge ChannelTrigger with ProxyTrigger (#601)
1 parent 2c04b1b commit 542a10f

File tree

4 files changed

+112
-119
lines changed

4 files changed

+112
-119
lines changed

include/mscclpp/fifo_device.hpp

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,82 @@
1515

1616
namespace mscclpp {
1717

18+
using TriggerType = uint64_t;
19+
constexpr TriggerType TriggerData = 0x1; // Trigger a data transfer.
20+
constexpr TriggerType TriggerFlag = 0x2; // Trigger a signaling.
21+
constexpr TriggerType TriggerSync = 0x4; // Trigger a flush.
22+
23+
constexpr unsigned int TriggerBitsSize = 32;
24+
constexpr unsigned int TriggerBitsOffset = 32;
25+
constexpr unsigned int TriggerBitsMemoryId = 9;
26+
constexpr unsigned int TriggerBitsType = 3;
27+
constexpr unsigned int TriggerBitsSemaphoreId = 10;
28+
constexpr unsigned int TriggerBitsFifoReserved = 1;
29+
1830
/// Pair of 64-bit unsigned integers used as a trigger for the proxy.
1931
/// Used as a work element in the concurrent FIFO.
2032
/// Most significant bit of snd is reserved.
21-
struct alignas(16) ProxyTrigger {
22-
uint64_t fst, snd;
33+
union alignas(16) ProxyTrigger {
34+
struct {
35+
uint64_t fst;
36+
uint64_t snd;
37+
};
38+
// The summation of number of bits must be 128 or less.
39+
struct {
40+
// First 64 bits: value[0]
41+
uint64_t size : TriggerBitsSize;
42+
uint64_t srcOffset : TriggerBitsOffset;
43+
uint64_t : (64 - TriggerBitsSize - TriggerBitsOffset); // ensure 64-bit alignment
44+
// Second 64 bits: value[1]
45+
uint64_t dstOffset : TriggerBitsOffset;
46+
uint64_t srcMemoryId : TriggerBitsMemoryId;
47+
uint64_t dstMemoryId : TriggerBitsMemoryId;
48+
uint64_t type : TriggerBitsType;
49+
uint64_t semaphoreId : TriggerBitsSemaphoreId;
50+
uint64_t : (64 - TriggerBitsOffset - TriggerBitsMemoryId - TriggerBitsMemoryId - TriggerBitsType -
51+
TriggerBitsSemaphoreId - TriggerBitsFifoReserved); // ensure 64-bit alignment
52+
uint64_t reserved : TriggerBitsFifoReserved;
53+
} fields;
54+
55+
#if defined(MSCCLPP_DEVICE_COMPILE)
56+
/// Default constructor.
57+
MSCCLPP_INLINE ProxyTrigger() = default;
58+
59+
/// Constructor.
60+
/// @param type The type of the trigger.
61+
/// @param dstId The destination ID of memory region.
62+
/// @param dstOffset The offset into the destination memory region.
63+
/// @param srcId The source ID of memory region.
64+
/// @param srcOffset The offset into the source memory region.
65+
/// @param bytes The bytes of the transfer.
66+
/// @param semaphoreId The ID of the semaphore.
67+
MSCCLPP_DEVICE_INLINE ProxyTrigger(TriggerType type, uint32_t dstId, uint64_t dstOffset, uint32_t srcId,
68+
uint64_t srcOffset, uint64_t bytes, uint32_t semaphoreId) {
69+
MSCCLPP_ASSERT_DEVICE(type < (1ULL << TriggerBitsType), "type is too large");
70+
MSCCLPP_ASSERT_DEVICE(dstId < (1ULL << TriggerBitsMemoryId), "dstId is too large");
71+
MSCCLPP_ASSERT_DEVICE(dstOffset < (1ULL << TriggerBitsOffset), "dstOffset is too large");
72+
MSCCLPP_ASSERT_DEVICE(srcId < (1ULL << TriggerBitsMemoryId), "srcId is too large");
73+
MSCCLPP_ASSERT_DEVICE(srcOffset < (1ULL << TriggerBitsOffset), "srcOffset is too large");
74+
MSCCLPP_ASSERT_DEVICE(bytes != 0, "bytes must not be zero");
75+
MSCCLPP_ASSERT_DEVICE(bytes < (1ULL << TriggerBitsSize), "bytes is too large");
76+
MSCCLPP_ASSERT_DEVICE(semaphoreId < (1ULL << TriggerBitsSemaphoreId), "semaphoreId is too large");
77+
constexpr uint64_t maskSize = (1ULL << TriggerBitsSize) - 1;
78+
constexpr uint64_t maskSrcOffset = (1ULL << TriggerBitsOffset) - 1;
79+
constexpr uint64_t maskDstOffset = (1ULL << TriggerBitsOffset) - 1;
80+
constexpr uint64_t maskSrcMemoryId = (1ULL << TriggerBitsMemoryId) - 1;
81+
constexpr uint64_t maskDstMemoryId = (1ULL << TriggerBitsMemoryId) - 1;
82+
constexpr uint64_t maskType = (1ULL << TriggerBitsType) - 1;
83+
constexpr uint64_t maskSemaphoreId = (1ULL << TriggerBitsSemaphoreId) - 1;
84+
fst = (((srcOffset & maskSrcOffset) << TriggerBitsSize) + (bytes & maskSize));
85+
snd = (((((((((semaphoreId & maskSemaphoreId) << TriggerBitsType) + ((uint64_t)type & maskType))
86+
<< TriggerBitsMemoryId) +
87+
(dstId & maskDstMemoryId))
88+
<< TriggerBitsMemoryId) +
89+
(srcId & maskSrcMemoryId))
90+
<< TriggerBitsOffset) +
91+
(dstOffset & maskDstOffset));
92+
}
93+
#endif // defined(MSCCLPP_DEVICE_COMPILE)
2394
};
2495

2596
/// Concurrent FIFO where multiple device threads (the number of threads should not exceed the FIFO size) to push
@@ -32,7 +103,7 @@ struct FifoDeviceHandle {
32103
/// @param trigger Trigger to push.
33104
/// @param maxSpinCount Max spin count before assert. Never assert if negative.
34105
/// @return Previous head of the FIFO where the trigger was pushed.
35-
MSCCLPP_DEVICE_INLINE uint64_t push(ProxyTrigger trigger, [[maybe_unused]] int64_t maxSpinCount = 1000000) {
106+
MSCCLPP_DEVICE_INLINE uint64_t push(ProxyTrigger trigger, int64_t maxSpinCount = 1000000) {
36107
uint64_t prevHead = atomicFetchAdd<uint64_t, scopeDevice>(head, 1, memoryOrderRelaxed);
37108

38109
// Flip the last bit for safe polling; host will revert.

include/mscclpp/port_channel_device.hpp

Lines changed: 29 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -17,82 +17,6 @@ using SemaphoreId = uint32_t;
1717
/// actual.
1818
using MemoryId = uint32_t;
1919

20-
using TriggerType = uint64_t;
21-
constexpr TriggerType TriggerData = 0x1; // Trigger a data transfer.
22-
constexpr TriggerType TriggerFlag = 0x2; // Trigger a signaling.
23-
constexpr TriggerType TriggerSync = 0x4; // Trigger a flush.
24-
25-
constexpr unsigned int TriggerBitsSize = 32;
26-
constexpr unsigned int TriggerBitsOffset = 32;
27-
constexpr unsigned int TriggerBitsMemoryId = 9;
28-
constexpr unsigned int TriggerBitsType = 3;
29-
constexpr unsigned int TriggerBitsSemaphoreId = 10;
30-
constexpr unsigned int TriggerBitsFifoReserved = 1;
31-
32-
/// Basic structure of each work element in the FIFO.
33-
union ChannelTrigger {
34-
ProxyTrigger value;
35-
// The summation of number of bits must be 128 or less.
36-
struct {
37-
// First 64 bits: value[0]
38-
uint64_t size : TriggerBitsSize;
39-
uint64_t srcOffset : TriggerBitsOffset;
40-
uint64_t : (64 - TriggerBitsSize - TriggerBitsOffset); // ensure 64-bit alignment
41-
// Second 64 bits: value[1]
42-
uint64_t dstOffset : TriggerBitsOffset;
43-
uint64_t srcMemoryId : TriggerBitsMemoryId;
44-
uint64_t dstMemoryId : TriggerBitsMemoryId;
45-
uint64_t type : TriggerBitsType;
46-
uint64_t semaphoreId : TriggerBitsSemaphoreId;
47-
uint64_t : (64 - TriggerBitsOffset - TriggerBitsMemoryId - TriggerBitsMemoryId - TriggerBitsType -
48-
TriggerBitsSemaphoreId - TriggerBitsFifoReserved); // ensure 64-bit alignment
49-
uint64_t reserved : TriggerBitsFifoReserved;
50-
} fields;
51-
52-
#if defined(MSCCLPP_DEVICE_COMPILE)
53-
/// Default constructor.
54-
MSCCLPP_INLINE ChannelTrigger() = default;
55-
56-
/// Copy constructor.
57-
MSCCLPP_DEVICE_INLINE ChannelTrigger(ProxyTrigger value) : value(value) {}
58-
59-
/// Constructor.
60-
/// @param type The type of the trigger.
61-
/// @param dst The destination memory region.
62-
/// @param dstOffset The offset into the destination memory region.
63-
/// @param src The source memory region.
64-
/// @param srcOffset The offset into the source memory region.
65-
/// @param bytes The bytes of the transfer.
66-
/// @param semaphoreId The ID of the semaphore.
67-
MSCCLPP_DEVICE_INLINE ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src,
68-
uint64_t srcOffset, uint64_t bytes, int semaphoreId) {
69-
MSCCLPP_ASSERT_DEVICE(type < (1ULL << TriggerBitsType), "type is too large");
70-
MSCCLPP_ASSERT_DEVICE(dst < (1ULL << TriggerBitsMemoryId), "dst is too large");
71-
MSCCLPP_ASSERT_DEVICE(dstOffset < (1ULL << TriggerBitsOffset), "dstOffset is too large");
72-
MSCCLPP_ASSERT_DEVICE(src < (1ULL << TriggerBitsMemoryId), "src is too large");
73-
MSCCLPP_ASSERT_DEVICE(srcOffset < (1ULL << TriggerBitsOffset), "srcOffset is too large");
74-
MSCCLPP_ASSERT_DEVICE(bytes != 0, "bytes must not be zero");
75-
MSCCLPP_ASSERT_DEVICE(bytes < (1ULL << TriggerBitsSize), "bytes is too large");
76-
MSCCLPP_ASSERT_DEVICE(semaphoreId < (1ULL << TriggerBitsSemaphoreId), "semaphoreId is too large");
77-
constexpr uint64_t maskSize = (1ULL << TriggerBitsSize) - 1;
78-
constexpr uint64_t maskSrcOffset = (1ULL << TriggerBitsOffset) - 1;
79-
constexpr uint64_t maskDstOffset = (1ULL << TriggerBitsOffset) - 1;
80-
constexpr uint64_t maskSrcMemoryId = (1ULL << TriggerBitsMemoryId) - 1;
81-
constexpr uint64_t maskDstMemoryId = (1ULL << TriggerBitsMemoryId) - 1;
82-
constexpr uint64_t maskType = (1ULL << TriggerBitsType) - 1;
83-
constexpr uint64_t maskSemaphoreId = (1ULL << TriggerBitsSemaphoreId) - 1;
84-
value.fst = (((srcOffset & maskSrcOffset) << TriggerBitsSize) + (bytes & maskSize));
85-
value.snd = (((((((((semaphoreId & maskSemaphoreId) << TriggerBitsType) + ((uint64_t)type & maskType))
86-
<< TriggerBitsMemoryId) +
87-
(dst & maskDstMemoryId))
88-
<< TriggerBitsMemoryId) +
89-
(src & maskSrcMemoryId))
90-
<< TriggerBitsOffset) +
91-
(dstOffset & maskDstOffset));
92-
}
93-
#endif // defined(MSCCLPP_DEVICE_COMPILE)
94-
};
95-
9620
struct BasePortChannelDeviceHandle {
9721
SemaphoreId semaphoreId_;
9822

@@ -111,77 +35,77 @@ struct BasePortChannelDeviceHandle {
11135

11236
#if defined(MSCCLPP_DEVICE_COMPILE)
11337
/// Push a TriggerData to the FIFO.
114-
/// @param dst The destination memory region.
38+
/// @param dstId The ID of destination memory region.
11539
/// @param dstOffset The offset into the destination memory region.
116-
/// @param src The source memory region.
40+
/// @param srcId The ID of source memory region.
11741
/// @param srcOffset The offset into the source memory region.
11842
/// @param size The size of the transfer.
119-
MSCCLPP_DEVICE_INLINE void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size) {
120-
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
43+
MSCCLPP_DEVICE_INLINE void put(MemoryId dstId, uint64_t dstOffset, MemoryId srcId, uint64_t srcOffset,
44+
uint64_t size) {
45+
fifo_.push({TriggerData, dstId, dstOffset, srcId, srcOffset, size, semaphoreId_});
12146
}
12247

12348
/// Push a TriggerData to the FIFO.
124-
/// @param dst The destination memory region.
125-
/// @param src The source memory region.
49+
/// @param dstId The ID of destination memory region.
50+
/// @param srcId The ID of source memory region.
12651
/// @param offset The common offset into the destination and source memory regions.
12752
/// @param size The size of the transfer.
128-
MSCCLPP_DEVICE_INLINE void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
129-
put(dst, offset, src, offset, size);
53+
MSCCLPP_DEVICE_INLINE void put(MemoryId dstId, MemoryId srcId, uint64_t offset, uint64_t size) {
54+
put(dstId, offset, srcId, offset, size);
13055
}
13156

13257
/// Push a TriggerFlag to the FIFO.
133-
MSCCLPP_DEVICE_INLINE void signal() { fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); }
58+
MSCCLPP_DEVICE_INLINE void signal() { fifo_.push({TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_}); }
13459

13560
/// Push a TriggerData and a TriggerFlag at the same time to the FIFO.
136-
/// @param dst The destination memory region.
61+
/// @param dstId The ID of destination memory region.
13762
/// @param dstOffset The offset into the destination memory region.
138-
/// @param src The source memory region.
63+
/// @param srcId The ID of source memory region.
13964
/// @param srcOffset The offset into the source memory region.
14065
/// @param size The size of the transfer.
141-
MSCCLPP_DEVICE_INLINE void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
66+
MSCCLPP_DEVICE_INLINE void putWithSignal(MemoryId dstId, uint64_t dstOffset, MemoryId srcId, uint64_t srcOffset,
14267
uint64_t size) {
143-
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
68+
fifo_.push({TriggerData | TriggerFlag, dstId, dstOffset, srcId, srcOffset, size, semaphoreId_});
14469
}
14570

14671
/// Push a TriggerData and a TriggerFlag at the same time to the FIFO.
147-
/// @param dst The destination memory region.
148-
/// @param src The source memory region.
72+
/// @param dstId The ID of destination memory region.
73+
/// @param srcId The ID of source memory region.
14974
/// @param offset The common offset into the destination and source memory regions.
15075
/// @param size The size of the transfer.
151-
MSCCLPP_DEVICE_INLINE void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
152-
putWithSignal(dst, offset, src, offset, size);
76+
MSCCLPP_DEVICE_INLINE void putWithSignal(MemoryId dstId, MemoryId srcId, uint64_t offset, uint64_t size) {
77+
putWithSignal(dstId, offset, srcId, offset, size);
15378
}
15479

15580
/// Push a TriggerData, a TriggerFlag, and a TriggerSync at the same time to the FIFO.
156-
/// @param dst The destination memory region.
81+
/// @param dstId The ID of destination memory region.
15782
/// @param dstOffset The offset into the destination memory region.
158-
/// @param src The source memory region.
83+
/// @param srcId The ID of source memory region.
15984
/// @param srcOffset The offset into the source memory region.
16085
/// @param size The size of the transfer.
16186
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
162-
MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
163-
uint64_t size, int64_t maxSpinCount = 1000000) {
164-
uint64_t curFifoHead = fifo_.push(
165-
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, semaphoreId_)
166-
.value);
87+
MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(MemoryId dstId, uint64_t dstOffset, MemoryId srcId,
88+
uint64_t srcOffset, uint64_t size, int64_t maxSpinCount = 1000000) {
89+
uint64_t curFifoHead =
90+
fifo_.push({TriggerData | TriggerFlag | TriggerSync, dstId, dstOffset, srcId, srcOffset, size, semaphoreId_});
16791
fifo_.sync(curFifoHead, maxSpinCount);
16892
}
16993

17094
/// Push a TriggerData, a TriggerFlag, and a TriggerSync at the same time to the FIFO.
171-
/// @param dst The destination memory region.
172-
/// @param src The source memory region.
95+
/// @param dstId The ID of destination memory region.
96+
/// @param srcId The ID of source memory region.
17397
/// @param offset The common offset into the destination and source memory regions.
17498
/// @param size The size of the transfer.
17599
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
176-
MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size,
100+
MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(MemoryId dstId, MemoryId srcId, uint64_t offset, uint64_t size,
177101
int64_t maxSpinCount = 1000000) {
178-
putWithSignalAndFlush(dst, offset, src, offset, size, maxSpinCount);
102+
putWithSignalAndFlush(dstId, offset, srcId, offset, size, maxSpinCount);
179103
}
180104

181105
/// Push a TriggerSync to the FIFO.
182106
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
183107
MSCCLPP_DEVICE_INLINE void flush(int64_t maxSpinCount = 1000000) {
184-
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, semaphoreId_).value);
108+
uint64_t curFifoHead = fifo_.push({TriggerSync, 0, 0, 0, 0, 1, semaphoreId_});
185109
fifo_.sync(curFifoHead, maxSpinCount);
186110
}
187111

python/mscclpp/fifo_py.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace nb = nanobind;
99
using namespace mscclpp;
1010

1111
void register_fifo(nb::module_& m) {
12-
nb::class_<ProxyTrigger>(m, "ProxyTrigger").def_rw("fst", &ProxyTrigger::fst).def_rw("snd", &ProxyTrigger::snd);
12+
nb::class_<ProxyTrigger>(m, "ProxyTrigger");
1313

1414
nb::class_<FifoDeviceHandle>(m, "FifoDeviceHandle")
1515
.def_rw("triggers", &FifoDeviceHandle::triggers)

src/port_channel.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,25 @@ MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_->start(); }
7878

7979
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); }
8080

81-
ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
82-
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
83-
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger->fields.semaphoreId];
81+
ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
82+
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger.fields.semaphoreId];
8483

8584
int maxWriteQueueSize = semaphore->connection()->getMaxWriteQueueSize();
8685
auto& numRequests = inflightRequests_[semaphore->connection()];
8786

88-
if (trigger->fields.type & TriggerData) {
89-
RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId];
90-
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
91-
semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
92-
trigger->fields.size);
87+
if (trigger.fields.type & TriggerData) {
88+
RegisteredMemory& dst = memories_[trigger.fields.dstMemoryId];
89+
RegisteredMemory& src = memories_[trigger.fields.srcMemoryId];
90+
semaphore->connection()->write(dst, trigger.fields.dstOffset, src, trigger.fields.srcOffset, trigger.fields.size);
9391
numRequests++;
9492
}
9593

96-
if (trigger->fields.type & TriggerFlag) {
94+
if (trigger.fields.type & TriggerFlag) {
9795
semaphore->signal();
9896
numRequests++;
9997
}
10098

101-
if (((trigger->fields.type & TriggerSync) && numRequests > 0) ||
99+
if (((trigger.fields.type & TriggerSync) && numRequests > 0) ||
102100
(maxWriteQueueSize != -1 && numRequests > maxWriteQueueSize)) {
103101
semaphore->connection()->flush();
104102
numRequests = 0;

0 commit comments

Comments
 (0)