Skip to content

Commit d480977

Browse files
authored
discover the number of SMs using cuda runtime (#14)
we can discover the number of SMs using `cuda_runtime`
1 parent ddc6916 commit d480977

File tree

8 files changed

+23
-5
lines changed

8 files changed

+23
-5
lines changed

csrc/all_to_all/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ add_library(all_to_all_common STATIC
44
all_to_all.cpp
55
)
66

7+
target_link_libraries(all_to_all_common PUBLIC
8+
CUDA::cudart
9+
)
10+
711
add_library(all_to_all_intranode_lib STATIC
812
intranode_combine.cu
913
intranode_dispatch.cu

csrc/all_to_all/all_to_all.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "all_to_all.h"
22

3+
#include "core/cuda_utils.h"
34
#include "core/utils.h"
45

56
using namespace pplx;
@@ -25,7 +26,8 @@ AllToAll::AllToAll(
2526
hiddenDimScaleBytes(hiddenDimScaleBytes),
2627
rank(rank),
2728
worldSize(worldSize),
28-
dpSize(dpSize) {
29+
dpSize(dpSize),
30+
numSMs(get_sm_count()) {
2931

3032
ROSE_ASSERT(hiddenDimBytes % 16 == 0, "invalid hidden dim bytes");
3133
ROSE_ASSERT(hiddenDimScaleBytes % 16 == 0, "invalid hidden dim scale bytes");

csrc/all_to_all/all_to_all.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class AllToAll {
6363
const unsigned worldSize;
6464
/// The size of a DP group.
6565
const unsigned dpSize;
66+
/// The number of streaming multiprocessors (SMs) on the device.
67+
const int numSMs;
6668
};
6769

6870
} // namespace pplx

csrc/all_to_all/internode_combine.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ void AllToAllInterNode::combine(
162162
const size_t numLocalExperts = numExperts / worldSize;
163163
const size_t numDPGroups = worldSize / dpSize;
164164
const size_t batchNumTokens = numLocalExperts * numDPGroups * maxNumTokens;
165-
const size_t numBlocks = std::min(132ul, batchNumTokens);
165+
const size_t numBlocks = std::min(static_cast<size_t>(numSMs), batchNumTokens);
166166

167167
assert(hiddenDimBytes % 16 == 0);
168168

csrc/all_to_all/internode_dispatch.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ void AllToAllInterNode::dispatch(
270270
std::max(
271271
ceil_div<unsigned>(numExperts, NUM_WARPS), (unsigned)(maxNumTokens * expertsPerToken)
272272
),
273-
132u
273+
static_cast<unsigned>(numSMs)
274274
);
275275
dim3 dimGrid(numBlocks, 1, 1);
276276
dim3 dimBlock(NUM_WARPS * 32, 1, 1);

csrc/all_to_all/intranode_combine.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ void AllToAllIntraNode::combine(
178178
const size_t numLocalExperts = numExperts / worldSize;
179179
const size_t numDPGroups = worldSize / dpSize;
180180
const size_t batchNumTokens = numLocalExperts * numDPGroups * maxNumTokens;
181-
const size_t numBlocks = std::min(132ul, batchNumTokens);
181+
const size_t numBlocks = std::min(static_cast<size_t>(numSMs), batchNumTokens);
182182

183183
assert(hiddenDimBytes % 16 == 0);
184184

csrc/all_to_all/intranode_dispatch.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ void AllToAllIntraNode::dispatch(
278278
std::max(
279279
ceil_div<unsigned>(numExperts, NUM_WARPS), (unsigned)(maxNumTokens * expertsPerToken)
280280
),
281-
132u
281+
static_cast<unsigned>(numSMs)
282282
);
283283
dim3 dimGrid(numBlocks, 1, 1);
284284
dim3 dimBlock(NUM_WARPS * 32, 1, 1);

csrc/core/cuda_utils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,14 @@ template <typename T> T *mallocZeroBuffer(size_t size) {
2020
cudaMemset(ptr, 0, size * sizeof(T));
2121
return ptr;
2222
}
23+
24+
inline int get_sm_count() {
25+
int device;
26+
CUDACHECK(cudaGetDevice(&device));
27+
int numSMs;
28+
CUDACHECK(cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, device));
29+
30+
return numSMs;
31+
}
32+
2333
} // namespace pplx

0 commit comments

Comments
 (0)