diff --git a/csrc/all_to_all/all_to_all.cpp b/csrc/all_to_all/all_to_all.cpp index 75af1df..8f4a178 100644 --- a/csrc/all_to_all/all_to_all.cpp +++ b/csrc/all_to_all/all_to_all.cpp @@ -14,7 +14,8 @@ AllToAll::AllToAll( unsigned dpSize, size_t hiddenDim, size_t hiddenDimBytes, - size_t hiddenDimScaleBytes + size_t hiddenDimScaleBytes, + int max_sm_count ) : maxNumTokens(maxNumTokens), numExperts(numExperts), @@ -27,7 +28,7 @@ AllToAll::AllToAll( rank(rank), worldSize(worldSize), dpSize(dpSize), - numSMs(get_sm_count()) { + numSMs(max_sm_count > 0 ? max_sm_count : get_sm_count()) { PPLX_ASSERT(hiddenDimBytes % 16 == 0, "invalid hidden dim bytes"); PPLX_ASSERT(hiddenDimScaleBytes % 16 == 0, "invalid hidden dim scale bytes"); diff --git a/csrc/all_to_all/all_to_all.h b/csrc/all_to_all/all_to_all.h index 5d5beff..b44e7b0 100644 --- a/csrc/all_to_all/all_to_all.h +++ b/csrc/all_to_all/all_to_all.h @@ -35,7 +35,8 @@ class AllToAll { unsigned dpSize, size_t hiddenDim, size_t hiddenDimBytes, - size_t hiddenDimScaleBytes + size_t hiddenDimScaleBytes, + int max_sm_count = 0 ); virtual ~AllToAll(); diff --git a/csrc/all_to_all/internode.cpp b/csrc/all_to_all/internode.cpp index b73e547..81aabbc 100644 --- a/csrc/all_to_all/internode.cpp +++ b/csrc/all_to_all/internode.cpp @@ -18,7 +18,8 @@ AllToAllInterNode::AllToAllInterNode( unsigned dpSize, size_t hiddenDim, size_t hiddenDimBytes, - size_t hiddenDimScaleBytes + size_t hiddenDimScaleBytes, + int max_sm_count ) : AllToAll( maxNumTokens, @@ -29,7 +30,8 @@ AllToAllInterNode::AllToAllInterNode( dpSize, hiddenDim, hiddenDimBytes, - hiddenDimScaleBytes + hiddenDimScaleBytes, + max_sm_count ), maxBatchTokens(numLocalExperts * numDPGroups * maxNumTokens) { // Buffers for token counts. diff --git a/csrc/all_to_all/internode.h b/csrc/all_to_all/internode.h index 28aa939..91779f1 100644 --- a/csrc/all_to_all/internode.h +++ b/csrc/all_to_all/internode.h @@ -24,7 +24,8 @@ class AllToAllInterNode final : public AllToAll { unsigned dpSize, size_t hiddenDim, size_t hiddenDimBytes, - size_t hiddenDimScaleBytes + size_t hiddenDimScaleBytes, + int max_sm_count = 0 ); ~AllToAllInterNode(); diff --git a/csrc/all_to_all/intranode.cpp b/csrc/all_to_all/intranode.cpp index bb3d9d8..20fa073 100644 --- a/csrc/all_to_all/intranode.cpp +++ b/csrc/all_to_all/intranode.cpp @@ -18,7 +18,8 @@ AllToAllIntraNode::AllToAllIntraNode( size_t hiddenDim, size_t hiddenDimBytes, size_t hiddenDimScaleBytes, - std::shared_ptr distributed + std::shared_ptr distributed, + int max_sm_count ) : AllToAll( maxNumTokens, @@ -29,7 +30,8 @@ AllToAllIntraNode::AllToAllIntraNode( dpSize, hiddenDim, hiddenDimBytes, - hiddenDimScaleBytes + hiddenDimScaleBytes, + max_sm_count ) { // Determine the per-token buffer size. Allocate extra storage for the index. diff --git a/csrc/all_to_all/intranode.h b/csrc/all_to_all/intranode.h index 871347e..44c7f9f 100644 --- a/csrc/all_to_all/intranode.h +++ b/csrc/all_to_all/intranode.h @@ -30,7 +30,8 @@ class AllToAllIntraNode final : public AllToAll { size_t hiddenDim, size_t hiddenDimBytes, size_t hiddenDimScaleBytes, - std::shared_ptr distributed + std::shared_ptr distributed, + int max_sm_count = 0 ); ~AllToAllIntraNode(); diff --git a/csrc/bindings/all_to_all_ops.cpp b/csrc/bindings/all_to_all_ops.cpp index a96ee97..414a625 100644 --- a/csrc/bindings/all_to_all_ops.cpp +++ b/csrc/bindings/all_to_all_ops.cpp @@ -60,7 +60,8 @@ fptr_t create_internode( int64_t dpSize, int64_t hiddenDim, int64_t hiddenDimBytes, - int64_t hiddenDimScaleBytes + int64_t hiddenDimScaleBytes, + int64_t max_sm_count = 0 ) { auto *ptr = new AllToAllInterNode( maxNumTokens, @@ -71,7 +72,8 @@ fptr_t create_internode( dpSize, hiddenDim, hiddenDimBytes, - hiddenDimScaleBytes + hiddenDimScaleBytes, + max_sm_count ); return (fptr_t)ptr; } @@ -86,7 +88,8 @@ fptr_t create_intranode( int64_t hiddenDim, int64_t hiddenDimBytes, int64_t hiddenDimScaleBytes, - const std::string &group_name + const std::string &group_name, + int64_t max_sm_count = 0 ) { auto group = c10d::resolve_process_group(group_name); std::shared_ptr distributed = std::make_shared(group); @@ -100,7 +103,8 @@ fptr_t create_intranode( hiddenDim, hiddenDimBytes, hiddenDimScaleBytes, - distributed + distributed, + max_sm_count ); return (fptr_t)ptr; } diff --git a/src/pplx_kernels/all_to_all.py b/src/pplx_kernels/all_to_all.py index 61a74d6..f021951 100644 --- a/src/pplx_kernels/all_to_all.py +++ b/src/pplx_kernels/all_to_all.py @@ -97,6 +97,7 @@ def intranode( hidden_dim_bytes: int, hidden_dim_scale_bytes: int, group_name: str = "default", + max_sm_count: int = 0, ) -> "AllToAll": assert world_size % dp_size == 0 assert world_size // dp_size > 1 @@ -114,6 +115,7 @@ def intranode( hidden_dim_bytes, hidden_dim_scale_bytes, group_name, + max_sm_count, ) assert ptr != 0 @@ -136,6 +138,7 @@ def internode( hidden_dim: int, hidden_dim_bytes: int, hidden_dim_scale_bytes: int, + max_sm_count: int = 0, ) -> "AllToAll": assert world_size % dp_size == 0 assert world_size // dp_size > 1 @@ -152,6 +155,7 @@ def internode( hidden_dim, hidden_dim_bytes, hidden_dim_scale_bytes, + max_sm_count, ) assert ptr != 0 diff --git a/tests/test_all_to_all.py b/tests/test_all_to_all.py index f26701a..04b9aac 100644 --- a/tests/test_all_to_all.py +++ b/tests/test_all_to_all.py @@ -38,6 +38,13 @@ ) +def _get_number_of_gpu_sm() -> int: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device_props = torch.cuda.get_device_properties(0) + return device_props.multi_processor_count + + def _str_1d_tensor(t: torch.Tensor) -> str: sl = [f"{x:7.4f}" for x in t.tolist()] if len(sl) > 5: @@ -48,6 +55,7 @@ def _str_1d_tensor(t: torch.Tensor) -> str: def _do_test_all_to_all( pgi: ProcessGroupInfo, dp_size: int, + max_sm_count: int, moe: MoEConfig, internode: bool, use_compile: bool, @@ -80,6 +88,7 @@ def _do_test_all_to_all( * torch.float32.itemsize ) ), + max_sm_count=max_sm_count, ) else: ata = AllToAll.intranode( @@ -100,6 +109,7 @@ def _do_test_all_to_all( * torch.float32.itemsize ) ), + max_sm_count=max_sm_count, ) # Generate the same test data on all ranks @@ -291,6 +301,7 @@ def _worker_test_all_to_all( dp_size: int, in_dtype: str, out_dtype: str, + max_sm_count: int, moe_config: MoEConfig, internode: bool, use_compile: bool = False, @@ -305,7 +316,7 @@ def _worker_test_all_to_all( out_dtype=getattr(torch, out_dtype), ) - _do_test_all_to_all(pgi, dp_size, moe_config, internode, use_compile) + _do_test_all_to_all(pgi, dp_size, max_sm_count, moe_config, internode, use_compile) nvshmem_finalize() @@ -313,10 +324,13 @@ def _worker_test_all_to_all( @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs") @pytest.mark.parametrize("in_dtype", ["bfloat16", "float8_e4m3fn", "float16"]) @pytest.mark.parametrize("out_dtype", ["float16", "bfloat16"]) +@pytest.mark.parametrize( + "max_sm_count", [_get_number_of_gpu_sm(), _get_number_of_gpu_sm() // 2] +) @pytest.mark.parametrize("internode", [True, False]) @pytest.mark.parametrize("use_compile", [False, True]) def test_all_to_all_4_gpu( - in_dtype: str, out_dtype: str, internode: bool, use_compile: bool + in_dtype: str, out_dtype: str, max_sm_count: int, internode: bool, use_compile: bool ) -> None: world_size = 4 dp_size = 2 @@ -326,6 +340,7 @@ def test_all_to_all_4_gpu( dp_size, in_dtype, out_dtype, + max_sm_count, small_moe, internode, use_compile, @@ -336,6 +351,7 @@ def _worker_test_all_to_all_multi_node( pgi: ProcessGroupInfo, in_dtype: str, out_dtype: str, + max_sm_count: int, ) -> None: dp_size = 4 _worker_test_all_to_all( @@ -343,6 +359,7 @@ def _worker_test_all_to_all_multi_node( dp_size, in_dtype, out_dtype, + max_sm_count, medium_moe, True, ) @@ -352,4 +369,7 @@ def _worker_test_all_to_all_multi_node( @pytest.mark.parametrize("in_dtype", ["bfloat16", "float8_e4m3fn", "float16"]) @pytest.mark.parametrize("out_dtype", ["float16", "bfloat16"]) def test_all_to_all_multi_node(in_dtype: str, out_dtype: str) -> None: - parallel_launch_from_env(_worker_test_all_to_all_multi_node, in_dtype, out_dtype) + max_sm_count = _get_number_of_gpu_sm() + parallel_launch_from_env( + _worker_test_all_to_all_multi_node, in_dtype, out_dtype, max_sm_count + )