Skip to content

Commit 5c9dead

Browse files
committed
add control over the number of SMs to be used by the kernel
1 parent ed71b87 commit 5c9dead

File tree

9 files changed

+54
-16
lines changed

9 files changed

+54
-16
lines changed

csrc/all_to_all/all_to_all.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ AllToAll::AllToAll(
1414
unsigned dpSize,
1515
size_t hiddenDim,
1616
size_t hiddenDimBytes,
17-
size_t hiddenDimScaleBytes
17+
size_t hiddenDimScaleBytes,
18+
int max_sm_count
1819
)
1920
: maxNumTokens(maxNumTokens),
2021
numExperts(numExperts),
@@ -27,7 +28,7 @@ AllToAll::AllToAll(
2728
rank(rank),
2829
worldSize(worldSize),
2930
dpSize(dpSize),
30-
numSMs(get_sm_count()) {
31+
numSMs(max_sm_count > 0 ? max_sm_count : get_sm_count()) {
3132

3233
PPLX_ASSERT(hiddenDimBytes % 16 == 0, "invalid hidden dim bytes");
3334
PPLX_ASSERT(hiddenDimScaleBytes % 16 == 0, "invalid hidden dim scale bytes");

csrc/all_to_all/all_to_all.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class AllToAll {
3535
unsigned dpSize,
3636
size_t hiddenDim,
3737
size_t hiddenDimBytes,
38-
size_t hiddenDimScaleBytes
38+
size_t hiddenDimScaleBytes,
39+
int max_sm_count = 0
3940
);
4041

4142
virtual ~AllToAll();

csrc/all_to_all/internode.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ AllToAllInterNode::AllToAllInterNode(
1818
unsigned dpSize,
1919
size_t hiddenDim,
2020
size_t hiddenDimBytes,
21-
size_t hiddenDimScaleBytes
21+
size_t hiddenDimScaleBytes,
22+
int max_sm_count
2223
)
2324
: AllToAll(
2425
maxNumTokens,
@@ -29,7 +30,8 @@ AllToAllInterNode::AllToAllInterNode(
2930
dpSize,
3031
hiddenDim,
3132
hiddenDimBytes,
32-
hiddenDimScaleBytes
33+
hiddenDimScaleBytes,
34+
max_sm_count
3335
),
3436
maxBatchTokens(numLocalExperts * numDPGroups * maxNumTokens) {
3537
// Buffers for token counts.

csrc/all_to_all/internode.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ class AllToAllInterNode final : public AllToAll {
2424
unsigned dpSize,
2525
size_t hiddenDim,
2626
size_t hiddenDimBytes,
27-
size_t hiddenDimScaleBytes
27+
size_t hiddenDimScaleBytes,
28+
int max_sm_count = 0
2829
);
2930

3031
~AllToAllInterNode();

csrc/all_to_all/intranode.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ AllToAllIntraNode::AllToAllIntraNode(
1818
size_t hiddenDim,
1919
size_t hiddenDimBytes,
2020
size_t hiddenDimScaleBytes,
21-
std::shared_ptr<Distributed> distributed
21+
std::shared_ptr<Distributed> distributed,
22+
int max_sm_count
2223
)
2324
: AllToAll(
2425
maxNumTokens,
@@ -29,7 +30,8 @@ AllToAllIntraNode::AllToAllIntraNode(
2930
dpSize,
3031
hiddenDim,
3132
hiddenDimBytes,
32-
hiddenDimScaleBytes
33+
hiddenDimScaleBytes,
34+
max_sm_count
3335
) {
3436

3537
// Determine the per-token buffer size. Allocate extra storage for the index.

csrc/all_to_all/intranode.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class AllToAllIntraNode final : public AllToAll {
3030
size_t hiddenDim,
3131
size_t hiddenDimBytes,
3232
size_t hiddenDimScaleBytes,
33-
std::shared_ptr<Distributed> distributed
33+
std::shared_ptr<Distributed> distributed,
34+
int max_sm_count = 0
3435
);
3536

3637
~AllToAllIntraNode();

csrc/bindings/all_to_all_ops.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ fptr_t create_internode(
5959
int64_t dpSize,
6060
int64_t hiddenDim,
6161
int64_t hiddenDimBytes,
62-
int64_t hiddenDimScaleBytes
62+
int64_t hiddenDimScaleBytes,
63+
int64_t max_sm_count = 0
6364
) {
6465
auto *ptr = new AllToAllInterNode(
6566
maxNumTokens,
@@ -70,7 +71,8 @@ fptr_t create_internode(
7071
dpSize,
7172
hiddenDim,
7273
hiddenDimBytes,
73-
hiddenDimScaleBytes
74+
hiddenDimScaleBytes,
75+
max_sm_count
7476
);
7577
return (fptr_t)ptr;
7678
}
@@ -85,7 +87,8 @@ fptr_t create_intranode(
8587
int64_t hiddenDim,
8688
int64_t hiddenDimBytes,
8789
int64_t hiddenDimScaleBytes,
88-
const std::string &group_name
90+
const std::string &group_name,
91+
int64_t max_sm_count = 0
8992
) {
9093
auto group = c10d::resolve_process_group(group_name);
9194
std::shared_ptr<Distributed> distributed = std::make_shared<DistributedTorch>(group);
@@ -99,7 +102,8 @@ fptr_t create_intranode(
99102
hiddenDim,
100103
hiddenDimBytes,
101104
hiddenDimScaleBytes,
102-
distributed
105+
distributed,
106+
max_sm_count
103107
);
104108
return (fptr_t)ptr;
105109
}

src/pplx_kernels/all_to_all.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def intranode(
9797
hidden_dim_bytes: int,
9898
hidden_dim_scale_bytes: int,
9999
group_name: str = "default",
100+
max_sm_count: int = 0,
100101
) -> "AllToAll":
101102
assert world_size % dp_size == 0
102103
assert world_size // dp_size > 1
@@ -114,6 +115,7 @@ def intranode(
114115
hidden_dim_bytes,
115116
hidden_dim_scale_bytes,
116117
group_name,
118+
max_sm_count,
117119
)
118120
assert ptr != 0
119121

@@ -136,6 +138,7 @@ def internode(
136138
hidden_dim: int,
137139
hidden_dim_bytes: int,
138140
hidden_dim_scale_bytes: int,
141+
max_sm_count: int = 0,
139142
) -> "AllToAll":
140143
assert world_size % dp_size == 0
141144
assert world_size // dp_size > 1
@@ -152,6 +155,7 @@ def internode(
152155
hidden_dim,
153156
hidden_dim_bytes,
154157
hidden_dim_scale_bytes,
158+
max_sm_count,
155159
)
156160
assert ptr != 0
157161

tests/test_all_to_all.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@
3838
)
3939

4040

41+
def _get_number_of_gpu_sm() -> int:
42+
if not torch.cuda.is_available():
43+
raise RuntimeError("CUDA is not available")
44+
device_props = torch.cuda.get_device_properties(0)
45+
return device_props.multi_processor_count
46+
47+
4148
def _str_1d_tensor(t: torch.Tensor) -> str:
4249
sl = [f"{x:7.4f}" for x in t.tolist()]
4350
if len(sl) > 5:
@@ -48,6 +55,7 @@ def _str_1d_tensor(t: torch.Tensor) -> str:
4855
def _do_test_all_to_all(
4956
pgi: ProcessGroupInfo,
5057
dp_size: int,
58+
max_sm_count: int,
5159
moe: MoEConfig,
5260
internode: bool,
5361
) -> None:
@@ -79,6 +87,7 @@ def _do_test_all_to_all(
7987
* torch.float32.itemsize
8088
)
8189
),
90+
max_sm_count=max_sm_count,
8291
)
8392
else:
8493
ata = AllToAll.intranode(
@@ -99,6 +108,7 @@ def _do_test_all_to_all(
99108
* torch.float32.itemsize
100109
)
101110
),
111+
max_sm_count=max_sm_count,
102112
)
103113

104114
# Generate the same test data on all ranks
@@ -283,6 +293,7 @@ def _worker_test_all_to_all(
283293
dp_size: int,
284294
in_dtype: str,
285295
out_dtype: str,
296+
max_sm_count: int,
286297
moe_config: MoEConfig,
287298
internode: bool,
288299
) -> None:
@@ -295,16 +306,21 @@ def _worker_test_all_to_all(
295306
in_dtype=getattr(torch, in_dtype),
296307
out_dtype=getattr(torch, out_dtype),
297308
)
298-
_do_test_all_to_all(pgi, dp_size, moe_config, internode)
309+
_do_test_all_to_all(pgi, dp_size, max_sm_count, moe_config, internode)
299310

300311
nvshmem_finalize()
301312

302313

303314
@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs")
304315
@pytest.mark.parametrize("in_dtype", ["bfloat16", "float8_e4m3fn", "float16"])
305316
@pytest.mark.parametrize("out_dtype", ["float16", "bfloat16"])
317+
@pytest.mark.parametrize(
318+
"max_sm_count", [_get_number_of_gpu_sm(), _get_number_of_gpu_sm() // 2]
319+
)
306320
@pytest.mark.parametrize("internode", [True, False])
307-
def test_all_to_all_4_gpu(in_dtype: str, out_dtype: str, internode: bool) -> None:
321+
def test_all_to_all_4_gpu(
322+
in_dtype: str, out_dtype: str, max_sm_count: int, internode: bool
323+
) -> None:
308324
world_size = 4
309325
dp_size = 2
310326
parallel_launch(
@@ -313,6 +329,7 @@ def test_all_to_all_4_gpu(in_dtype: str, out_dtype: str, internode: bool) -> Non
313329
dp_size,
314330
in_dtype,
315331
out_dtype,
332+
max_sm_count,
316333
small_moe,
317334
internode,
318335
)
@@ -322,13 +339,15 @@ def _worker_test_all_to_all_multi_node(
322339
pgi: ProcessGroupInfo,
323340
in_dtype: str,
324341
out_dtype: str,
342+
max_sm_count: int,
325343
) -> None:
326344
dp_size = 4
327345
_worker_test_all_to_all(
328346
pgi,
329347
dp_size,
330348
in_dtype,
331349
out_dtype,
350+
max_sm_count,
332351
medium_moe,
333352
True,
334353
)
@@ -338,4 +357,7 @@ def _worker_test_all_to_all_multi_node(
338357
@pytest.mark.parametrize("in_dtype", ["bfloat16", "float8_e4m3fn", "float16"])
339358
@pytest.mark.parametrize("out_dtype", ["float16", "bfloat16"])
340359
def test_all_to_all_multi_node(in_dtype: str, out_dtype: str) -> None:
341-
parallel_launch_from_env(_worker_test_all_to_all_multi_node, in_dtype, out_dtype)
360+
max_sm_count = _get_number_of_gpu_sm()
361+
parallel_launch_from_env(
362+
_worker_test_all_to_all_multi_node, in_dtype, out_dtype, max_sm_count
363+
)

0 commit comments

Comments
 (0)