Skip to content

Commit 0487e8a

Browse files
committed
add control over the number of SMs to be used by the kernel
1 parent 80354e0 commit 0487e8a

File tree

9 files changed

+52
-16
lines changed

9 files changed

+52
-16
lines changed

csrc/all_to_all/all_to_all.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ AllToAll::AllToAll(
1414
unsigned dpSize,
1515
size_t hiddenDim,
1616
size_t hiddenDimBytes,
17-
size_t hiddenDimScaleBytes
17+
size_t hiddenDimScaleBytes,
18+
int max_sm_count
1819
)
1920
: maxNumTokens(maxNumTokens),
2021
numExperts(numExperts),
@@ -27,7 +28,7 @@ AllToAll::AllToAll(
2728
rank(rank),
2829
worldSize(worldSize),
2930
dpSize(dpSize),
30-
numSMs(get_sm_count()) {
31+
numSMs(max_sm_count > 0 ? max_sm_count : get_sm_count()) {
3132

3233
PPLX_ASSERT(hiddenDimBytes % 16 == 0, "invalid hidden dim bytes");
3334
PPLX_ASSERT(hiddenDimScaleBytes % 16 == 0, "invalid hidden dim scale bytes");

csrc/all_to_all/all_to_all.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class AllToAll {
3535
unsigned dpSize,
3636
size_t hiddenDim,
3737
size_t hiddenDimBytes,
38-
size_t hiddenDimScaleBytes
38+
size_t hiddenDimScaleBytes,
39+
int max_sm_count = 0
3940
);
4041

4142
virtual ~AllToAll();

csrc/all_to_all/internode.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ AllToAllInterNode::AllToAllInterNode(
1818
unsigned dpSize,
1919
size_t hiddenDim,
2020
size_t hiddenDimBytes,
21-
size_t hiddenDimScaleBytes
21+
size_t hiddenDimScaleBytes,
22+
int max_sm_count
2223
)
2324
: AllToAll(
2425
maxNumTokens,
@@ -29,7 +30,8 @@ AllToAllInterNode::AllToAllInterNode(
2930
dpSize,
3031
hiddenDim,
3132
hiddenDimBytes,
32-
hiddenDimScaleBytes
33+
hiddenDimScaleBytes,
34+
max_sm_count
3335
),
3436
maxBatchTokens(numLocalExperts * numDPGroups * maxNumTokens) {
3537
// Buffers for token counts.

csrc/all_to_all/internode.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ class AllToAllInterNode final : public AllToAll {
2424
unsigned dpSize,
2525
size_t hiddenDim,
2626
size_t hiddenDimBytes,
27-
size_t hiddenDimScaleBytes
27+
size_t hiddenDimScaleBytes,
28+
int max_sm_count = 0
2829
);
2930

3031
~AllToAllInterNode();

csrc/all_to_all/intranode.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ AllToAllIntraNode::AllToAllIntraNode(
1818
size_t hiddenDim,
1919
size_t hiddenDimBytes,
2020
size_t hiddenDimScaleBytes,
21-
std::shared_ptr<Distributed> distributed
21+
std::shared_ptr<Distributed> distributed,
22+
int max_sm_count
2223
)
2324
: AllToAll(
2425
maxNumTokens,
@@ -29,7 +30,8 @@ AllToAllIntraNode::AllToAllIntraNode(
2930
dpSize,
3031
hiddenDim,
3132
hiddenDimBytes,
32-
hiddenDimScaleBytes
33+
hiddenDimScaleBytes,
34+
max_sm_count
3335
) {
3436

3537
// Determine the per-token buffer size. Allocate extra storage for the index.

csrc/all_to_all/intranode.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class AllToAllIntraNode final : public AllToAll {
3030
size_t hiddenDim,
3131
size_t hiddenDimBytes,
3232
size_t hiddenDimScaleBytes,
33-
std::shared_ptr<Distributed> distributed
33+
std::shared_ptr<Distributed> distributed,
34+
int max_sm_count = 0
3435
);
3536

3637
~AllToAllIntraNode();

csrc/bindings/all_to_all_ops.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ fptr_t create_internode(
6060
int64_t dpSize,
6161
int64_t hiddenDim,
6262
int64_t hiddenDimBytes,
63-
int64_t hiddenDimScaleBytes
63+
int64_t hiddenDimScaleBytes,
64+
int64_t max_sm_count = 0
6465
) {
6566
auto *ptr = new AllToAllInterNode(
6667
maxNumTokens,
@@ -71,7 +72,8 @@ fptr_t create_internode(
7172
dpSize,
7273
hiddenDim,
7374
hiddenDimBytes,
74-
hiddenDimScaleBytes
75+
hiddenDimScaleBytes,
76+
max_sm_count
7577
);
7678
return (fptr_t)ptr;
7779
}
@@ -86,7 +88,8 @@ fptr_t create_intranode(
8688
int64_t hiddenDim,
8789
int64_t hiddenDimBytes,
8890
int64_t hiddenDimScaleBytes,
89-
const std::string &group_name
91+
const std::string &group_name,
92+
int64_t max_sm_count = 0
9093
) {
9194
auto group = c10d::resolve_process_group(group_name);
9295
std::shared_ptr<Distributed> distributed = std::make_shared<DistributedTorch>(group);
@@ -100,7 +103,8 @@ fptr_t create_intranode(
100103
hiddenDim,
101104
hiddenDimBytes,
102105
hiddenDimScaleBytes,
103-
distributed
106+
distributed,
107+
max_sm_count
104108
);
105109
return (fptr_t)ptr;
106110
}

src/pplx_kernels/all_to_all.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def intranode(
9797
hidden_dim_bytes: int,
9898
hidden_dim_scale_bytes: int,
9999
group_name: str = "default",
100+
max_sm_count: int = 0,
100101
) -> "AllToAll":
101102
assert world_size % dp_size == 0
102103
assert world_size // dp_size > 1
@@ -114,6 +115,7 @@ def intranode(
114115
hidden_dim_bytes,
115116
hidden_dim_scale_bytes,
116117
group_name,
118+
max_sm_count,
117119
)
118120
assert ptr != 0
119121

@@ -136,6 +138,7 @@ def internode(
136138
hidden_dim: int,
137139
hidden_dim_bytes: int,
138140
hidden_dim_scale_bytes: int,
141+
max_sm_count: int = 0,
139142
) -> "AllToAll":
140143
assert world_size % dp_size == 0
141144
assert world_size // dp_size > 1
@@ -152,6 +155,7 @@ def internode(
152155
hidden_dim,
153156
hidden_dim_bytes,
154157
hidden_dim_scale_bytes,
158+
max_sm_count,
155159
)
156160
assert ptr != 0
157161

tests/test_all_to_all.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@
3838
)
3939

4040

41+
def _get_number_of_gpu_sm() -> int:
42+
if not torch.cuda.is_available():
43+
raise RuntimeError("CUDA is not available")
44+
device_props = torch.cuda.get_device_properties(0)
45+
return device_props.multi_processor_count
46+
47+
4148
def _str_1d_tensor(t: torch.Tensor) -> str:
4249
sl = [f"{x:7.4f}" for x in t.tolist()]
4350
if len(sl) > 5:
@@ -48,6 +55,7 @@ def _str_1d_tensor(t: torch.Tensor) -> str:
4855
def _do_test_all_to_all(
4956
pgi: ProcessGroupInfo,
5057
dp_size: int,
58+
max_sm_count: int,
5159
moe: MoEConfig,
5260
internode: bool,
5361
use_compile: bool,
@@ -80,6 +88,7 @@ def _do_test_all_to_all(
8088
* torch.float32.itemsize
8189
)
8290
),
91+
max_sm_count=max_sm_count,
8392
)
8493
else:
8594
ata = AllToAll.intranode(
@@ -100,6 +109,7 @@ def _do_test_all_to_all(
100109
* torch.float32.itemsize
101110
)
102111
),
112+
max_sm_count=max_sm_count,
103113
)
104114

105115
# Generate the same test data on all ranks
@@ -291,6 +301,7 @@ def _worker_test_all_to_all(
291301
dp_size: int,
292302
in_dtype: str,
293303
out_dtype: str,
304+
max_sm_count: int,
294305
moe_config: MoEConfig,
295306
internode: bool,
296307
use_compile: bool = False,
@@ -305,18 +316,21 @@ def _worker_test_all_to_all(
305316
out_dtype=getattr(torch, out_dtype),
306317
)
307318

308-
_do_test_all_to_all(pgi, dp_size, moe_config, internode, use_compile)
319+
_do_test_all_to_all(pgi, dp_size, max_sm_count, moe_config, internode, use_compile)
309320

310321
nvshmem_finalize()
311322

312323

313324
@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs")
314325
@pytest.mark.parametrize("in_dtype", ["bfloat16", "float8_e4m3fn", "float16"])
315326
@pytest.mark.parametrize("out_dtype", ["float16", "bfloat16"])
327+
@pytest.mark.parametrize(
328+
"max_sm_count", [_get_number_of_gpu_sm(), _get_number_of_gpu_sm() // 2]
329+
)
316330
@pytest.mark.parametrize("internode", [True, False])
317331
@pytest.mark.parametrize("use_compile", [False, True])
318332
def test_all_to_all_4_gpu(
319-
in_dtype: str, out_dtype: str, internode: bool, use_compile: bool
333+
in_dtype: str, out_dtype: str, max_sm_count: int, internode: bool, use_compile: bool
320334
) -> None:
321335
world_size = 4
322336
dp_size = 2
@@ -326,6 +340,7 @@ def test_all_to_all_4_gpu(
326340
dp_size,
327341
in_dtype,
328342
out_dtype,
343+
max_sm_count,
329344
small_moe,
330345
internode,
331346
use_compile,
@@ -336,13 +351,15 @@ def _worker_test_all_to_all_multi_node(
336351
pgi: ProcessGroupInfo,
337352
in_dtype: str,
338353
out_dtype: str,
354+
max_sm_count: int,
339355
) -> None:
340356
dp_size = 4
341357
_worker_test_all_to_all(
342358
pgi,
343359
dp_size,
344360
in_dtype,
345361
out_dtype,
362+
max_sm_count,
346363
medium_moe,
347364
True,
348365
)
@@ -352,4 +369,7 @@ def _worker_test_all_to_all_multi_node(
352369
@pytest.mark.parametrize("in_dtype", ["bfloat16", "float8_e4m3fn", "float16"])
353370
@pytest.mark.parametrize("out_dtype", ["float16", "bfloat16"])
354371
def test_all_to_all_multi_node(in_dtype: str, out_dtype: str) -> None:
355-
parallel_launch_from_env(_worker_test_all_to_all_multi_node, in_dtype, out_dtype)
372+
max_sm_count = _get_number_of_gpu_sm()
373+
parallel_launch_from_env(
374+
_worker_test_all_to_all_multi_node, in_dtype, out_dtype, max_sm_count
375+
)

0 commit comments

Comments
 (0)