Skip to content

add control over the number of SMs to be used by the kernel #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions csrc/all_to_all/all_to_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ AllToAll::AllToAll(
unsigned dpSize,
size_t hiddenDim,
size_t hiddenDimBytes,
size_t hiddenDimScaleBytes
size_t hiddenDimScaleBytes,
int max_sm_count
)
: maxNumTokens(maxNumTokens),
numExperts(numExperts),
Expand All @@ -27,7 +28,7 @@ AllToAll::AllToAll(
rank(rank),
worldSize(worldSize),
dpSize(dpSize),
numSMs(get_sm_count()) {
numSMs(max_sm_count > 0 ? max_sm_count : get_sm_count()) {

PPLX_ASSERT(hiddenDimBytes % 16 == 0, "invalid hidden dim bytes");
PPLX_ASSERT(hiddenDimScaleBytes % 16 == 0, "invalid hidden dim scale bytes");
Expand Down
3 changes: 2 additions & 1 deletion csrc/all_to_all/all_to_all.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class AllToAll {
unsigned dpSize,
size_t hiddenDim,
size_t hiddenDimBytes,
size_t hiddenDimScaleBytes
size_t hiddenDimScaleBytes,
int max_sm_count = 0
);

virtual ~AllToAll();
Expand Down
6 changes: 4 additions & 2 deletions csrc/all_to_all/internode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ AllToAllInterNode::AllToAllInterNode(
unsigned dpSize,
size_t hiddenDim,
size_t hiddenDimBytes,
size_t hiddenDimScaleBytes
size_t hiddenDimScaleBytes,
int max_sm_count
)
: AllToAll(
maxNumTokens,
Expand All @@ -29,7 +30,8 @@ AllToAllInterNode::AllToAllInterNode(
dpSize,
hiddenDim,
hiddenDimBytes,
hiddenDimScaleBytes
hiddenDimScaleBytes,
max_sm_count
),
maxBatchTokens(numLocalExperts * numDPGroups * maxNumTokens) {
// Buffers for token counts.
Expand Down
3 changes: 2 additions & 1 deletion csrc/all_to_all/internode.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class AllToAllInterNode final : public AllToAll {
unsigned dpSize,
size_t hiddenDim,
size_t hiddenDimBytes,
size_t hiddenDimScaleBytes
size_t hiddenDimScaleBytes,
int max_sm_count = 0
);

~AllToAllInterNode();
Expand Down
6 changes: 4 additions & 2 deletions csrc/all_to_all/intranode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ AllToAllIntraNode::AllToAllIntraNode(
size_t hiddenDim,
size_t hiddenDimBytes,
size_t hiddenDimScaleBytes,
std::shared_ptr<Distributed> distributed
std::shared_ptr<Distributed> distributed,
int max_sm_count
)
: AllToAll(
maxNumTokens,
Expand All @@ -29,7 +30,8 @@ AllToAllIntraNode::AllToAllIntraNode(
dpSize,
hiddenDim,
hiddenDimBytes,
hiddenDimScaleBytes
hiddenDimScaleBytes,
max_sm_count
) {

// Determine the per-token buffer size. Allocate extra storage for the index.
Expand Down
3 changes: 2 additions & 1 deletion csrc/all_to_all/intranode.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class AllToAllIntraNode final : public AllToAll {
size_t hiddenDim,
size_t hiddenDimBytes,
size_t hiddenDimScaleBytes,
std::shared_ptr<Distributed> distributed
std::shared_ptr<Distributed> distributed,
int max_sm_count = 0
);

~AllToAllIntraNode();
Expand Down
12 changes: 8 additions & 4 deletions csrc/bindings/all_to_all_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ fptr_t create_internode(
int64_t dpSize,
int64_t hiddenDim,
int64_t hiddenDimBytes,
int64_t hiddenDimScaleBytes
int64_t hiddenDimScaleBytes,
int64_t max_sm_count = 0
) {
auto *ptr = new AllToAllInterNode(
maxNumTokens,
Expand All @@ -71,7 +72,8 @@ fptr_t create_internode(
dpSize,
hiddenDim,
hiddenDimBytes,
hiddenDimScaleBytes
hiddenDimScaleBytes,
max_sm_count
);
return (fptr_t)ptr;
}
Expand All @@ -86,7 +88,8 @@ fptr_t create_intranode(
int64_t hiddenDim,
int64_t hiddenDimBytes,
int64_t hiddenDimScaleBytes,
const std::string &group_name
const std::string &group_name,
int64_t max_sm_count = 0
) {
auto group = c10d::resolve_process_group(group_name);
std::shared_ptr<Distributed> distributed = std::make_shared<DistributedTorch>(group);
Expand All @@ -100,7 +103,8 @@ fptr_t create_intranode(
hiddenDim,
hiddenDimBytes,
hiddenDimScaleBytes,
distributed
distributed,
max_sm_count
);
return (fptr_t)ptr;
}
Expand Down
4 changes: 4 additions & 0 deletions src/pplx_kernels/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def intranode(
hidden_dim_bytes: int,
hidden_dim_scale_bytes: int,
group_name: str = "default",
max_sm_count: int = 0,
) -> "AllToAll":
assert world_size % dp_size == 0
assert world_size // dp_size > 1
Expand All @@ -114,6 +115,7 @@ def intranode(
hidden_dim_bytes,
hidden_dim_scale_bytes,
group_name,
max_sm_count,
)
assert ptr != 0

Expand All @@ -136,6 +138,7 @@ def internode(
hidden_dim: int,
hidden_dim_bytes: int,
hidden_dim_scale_bytes: int,
max_sm_count: int = 0,
) -> "AllToAll":
assert world_size % dp_size == 0
assert world_size // dp_size > 1
Expand All @@ -152,6 +155,7 @@ def internode(
hidden_dim,
hidden_dim_bytes,
hidden_dim_scale_bytes,
max_sm_count,
)
assert ptr != 0

Expand Down
26 changes: 23 additions & 3 deletions tests/test_all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
)


def _get_number_of_gpu_sm() -> int:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device_props = torch.cuda.get_device_properties(0)
return device_props.multi_processor_count


def _str_1d_tensor(t: torch.Tensor) -> str:
sl = [f"{x:7.4f}" for x in t.tolist()]
if len(sl) > 5:
Expand All @@ -48,6 +55,7 @@ def _str_1d_tensor(t: torch.Tensor) -> str:
def _do_test_all_to_all(
pgi: ProcessGroupInfo,
dp_size: int,
max_sm_count: int,
moe: MoEConfig,
internode: bool,
use_compile: bool,
Expand Down Expand Up @@ -80,6 +88,7 @@ def _do_test_all_to_all(
* torch.float32.itemsize
)
),
max_sm_count=max_sm_count,
)
else:
ata = AllToAll.intranode(
Expand All @@ -100,6 +109,7 @@ def _do_test_all_to_all(
* torch.float32.itemsize
)
),
max_sm_count=max_sm_count,
)

# Generate the same test data on all ranks
Expand Down Expand Up @@ -291,6 +301,7 @@ def _worker_test_all_to_all(
dp_size: int,
in_dtype: str,
out_dtype: str,
max_sm_count: int,
moe_config: MoEConfig,
internode: bool,
use_compile: bool = False,
Expand All @@ -305,18 +316,21 @@ def _worker_test_all_to_all(
out_dtype=getattr(torch, out_dtype),
)

_do_test_all_to_all(pgi, dp_size, moe_config, internode, use_compile)
_do_test_all_to_all(pgi, dp_size, max_sm_count, moe_config, internode, use_compile)

nvshmem_finalize()


@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs")
@pytest.mark.parametrize("in_dtype", ["bfloat16", "float8_e4m3fn", "float16"])
@pytest.mark.parametrize("out_dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize(
"max_sm_count", [_get_number_of_gpu_sm(), _get_number_of_gpu_sm() // 2]
)
@pytest.mark.parametrize("internode", [True, False])
@pytest.mark.parametrize("use_compile", [False, True])
def test_all_to_all_4_gpu(
in_dtype: str, out_dtype: str, internode: bool, use_compile: bool
in_dtype: str, out_dtype: str, max_sm_count: int, internode: bool, use_compile: bool
) -> None:
world_size = 4
dp_size = 2
Expand All @@ -326,6 +340,7 @@ def test_all_to_all_4_gpu(
dp_size,
in_dtype,
out_dtype,
max_sm_count,
small_moe,
internode,
use_compile,
Expand All @@ -336,13 +351,15 @@ def _worker_test_all_to_all_multi_node(
pgi: ProcessGroupInfo,
in_dtype: str,
out_dtype: str,
max_sm_count: int,
) -> None:
dp_size = 4
_worker_test_all_to_all(
pgi,
dp_size,
in_dtype,
out_dtype,
max_sm_count,
medium_moe,
True,
)
Expand All @@ -352,4 +369,7 @@ def _worker_test_all_to_all_multi_node(
@pytest.mark.parametrize("in_dtype", ["bfloat16", "float8_e4m3fn", "float16"])
@pytest.mark.parametrize("out_dtype", ["float16", "bfloat16"])
def test_all_to_all_multi_node(in_dtype: str, out_dtype: str) -> None:
parallel_launch_from_env(_worker_test_all_to_all_multi_node, in_dtype, out_dtype)
max_sm_count = _get_number_of_gpu_sm()
parallel_launch_from_env(
_worker_test_all_to_all_multi_node, in_dtype, out_dtype, max_sm_count
)