Skip to content

Commit 88bd737

Browse files
committed
Initial commit
0 parents  commit 88bd737

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+4218
-0
lines changed

.clang-format

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
BasedOnStyle: LLVM
3+
IndentWidth: 2
4+
ColumnLimit: 100
5+
BinPackArguments: false
6+
BinPackParameters: false
7+
ExperimentalAutoDetectBinPacking: false
8+
AllowAllParametersOfDeclarationOnNextLine: false
9+
AlignAfterOpenBracket: BlockIndent
10+
BreakConstructorInitializers: BeforeColon
11+
PackConstructorInitializers: Never

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
build-cmake
2+
build
3+
pplx_kernels/*.so
4+
*.egg-info
5+
*.pyc
6+
data

LICENSE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Copyright (C) 2025 Perplexity AI
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4+
5+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6+
7+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
Perplexity MoE Kernels
2+
==========
3+
4+
Installation
5+
-----
6+
7+
```
8+
cd pplx-kernels
9+
pip install -e . -vvv
10+
```
11+
12+
Testing
13+
-----
14+
15+
To build the C++ tests and benchmarks:
16+
17+
```
18+
cd pplx-kernels
19+
mkdir build-cmake
20+
cd build-cmake
21+
22+
TORCH_PREFIX_PATH=$(python3 -c 'import torch; print(torch.utils.cmake_prefix_path)')
23+
24+
cmake ../csrc \
25+
-GNinja \
26+
-DCMAKE_PREFIX_PATH=$TORCH_PREFIX_PATH \
27+
-DTORCH_CUDA_ARCH_LIST=9.0a+PTX \
28+
-DWITH_TESTS=ON \
29+
-DWITH_BENCHMARKS=ON
30+
31+
ninja test_all_to_all bench_all_to_all
32+
```
33+
34+
To run the all-to-all tests on one node:
35+
36+
```
37+
NVSHMEM_REMOTE_TRANSPORT=None mpirun -np 4 ./test_all_to_all
38+
```
39+
40+
41+
To run the all-to-all benchmarks on one node:
42+
43+
```
44+
NVSHMEM_REMOTE_TRANSPORT=None mpirun -np 4 ./bench_all_to_all
45+
```
46+
47+
48+
Inter-Node Benchmarks
49+
-----
50+
51+
To test on a 32-device cluster spread across 4 nodes, run the following command on all nodes, alternating the rank from 0 to 3 and setting the master address to point to one of the nodes:
52+
53+
```
54+
cd pplx-kernels
55+
pip install -e . -vvv
56+
NVSHMEM_IB_ENABLE_IBGDA=1 NODE_RANK=<rank> WORLD_SIZE=32 WORLD_LOCAL_SIZE=8 MASTER_ADDR=<master-address> MASTER_PORT=29500 python3 -m tests.bench_all_to_all
57+
```

csrc/CMakeLists.txt

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
cmake_minimum_required(VERSION 3.26)
2+
project(PPLXKernels
3+
VERSION 0.0.1
4+
DESCRIPTION "PPLX Kernels"
5+
LANGUAGES CXX CUDA)
6+
7+
# === Configuration options ===
8+
option(WITH_TESTS "Build tests" OFF)
9+
option(WITH_BENCHMARKS "Build benchmarks" OFF)
10+
set(CMAKE_CUDA_ARCHITECTURES 90a CACHE STRING "CUDA architecture to target")
11+
12+
# === CMake configuration ===
13+
set(CMAKE_CXX_STANDARD 17)
14+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
15+
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
16+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
17+
set(CMAKE_INCLUDE_CURRENT_DIR ON)
18+
19+
# === Dependencies ===
20+
include(FetchContent)
21+
find_package(CUDAToolkit REQUIRED) # Modern replacement for find_package(CUDA)
22+
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
23+
find_package(Torch REQUIRED)
24+
find_package(NVSHMEM REQUIRED)
25+
26+
if(WITH_TESTS)
27+
enable_testing()
28+
find_package(MPI REQUIRED)
29+
find_package(PkgConfig REQUIRED)
30+
pkg_check_modules(NCCL nccl)
31+
endif()
32+
33+
# Create imported target for PyTorch
34+
add_library(torch_imported INTERFACE)
35+
add_library(torch::py_limited ALIAS torch_imported)
36+
target_include_directories(torch_imported SYSTEM INTERFACE ${TORCH_INCLUDE_DIRS})
37+
# NOTE(lequn): We don't link against all ${TORCH_LIBRARIES} because we use py_limited_api.
38+
# See: https://github.com/pytorch/pytorch/blob/9017becf1d895999a1c819c9d35b8139c090e7a9/torch/utils/cpp_extension.py#L1256-L1270
39+
target_link_libraries(torch_imported INTERFACE c10 torch torch_cpu c10_cuda torch_cuda CUDA::cudart)
40+
41+
# === Compiler and linker flags ===
42+
add_compile_options(-Wno-deprecated-declarations)
43+
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
44+
add_compile_definitions(Py_LIMITED_API=0x03090000)
45+
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
46+
47+
# CUDA-specific compile options function
48+
function(set_cuda_compile_options target)
49+
target_compile_options(${target} PRIVATE
50+
$<$<COMPILE_LANGUAGE:CUDA>:--threads=32 -O3>)
51+
endfunction()
52+
53+
# === Library targets ===
54+
add_subdirectory(all_to_all)
55+
56+
# Main shared library
57+
add_library(pplx_kernels SHARED
58+
bindings/all_to_all_ops.cpp
59+
bindings/bindings.cpp
60+
bindings/nvshmem.cpp
61+
)
62+
target_link_libraries(pplx_kernels PUBLIC
63+
all_to_all_lib
64+
torch::py_limited
65+
Python::Module
66+
CUDA::cuda_driver
67+
CUDA::cudart
68+
nvshmem::nvshmem
69+
)
70+
set_target_properties(pplx_kernels PROPERTIES
71+
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../pplx_kernels
72+
CUDA_SEPARABLE_COMPILATION ON
73+
)

csrc/all_to_all/CMakeLists.txt

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# All-to-All library
2+
3+
add_library(all_to_all_lib STATIC
4+
all_to_all.cpp
5+
internode_scatter.cu
6+
internode_gather.cu
7+
internode.cpp
8+
)
9+
target_link_libraries(all_to_all_lib PUBLIC
10+
nvshmem::nvshmem
11+
CUDA::cudart
12+
)
13+
set_cuda_compile_options(all_to_all_lib)
14+
15+
if(WITH_TESTS)
16+
# All-to-All test
17+
add_executable(test_all_to_all
18+
test_all_to_all.cpp
19+
)
20+
target_link_libraries(test_all_to_all PUBLIC
21+
all_to_all_lib
22+
CUDA::cudart
23+
CUDA::cuda_driver
24+
MPI::MPI_CXX
25+
nvshmem::nvshmem
26+
)
27+
set_cuda_compile_options(test_all_to_all)
28+
add_test(NAME AllToAllTest
29+
COMMAND ${MPIEXEC_EXECUTABLE} -np 4 $<TARGET_FILE:test_all_to_all>)
30+
set_tests_properties(AllToAllTest PROPERTIES ENVIRONMENT "NVSHMEM_REMOTE_TRANSPORT=None")
31+
endif()
32+
33+
if (WITH_BENCHMARKS)
34+
add_executable(bench_all_to_all
35+
bench_all_to_all.cpp
36+
)
37+
target_link_libraries(bench_all_to_all PUBLIC
38+
all_to_all_lib
39+
CUDA::cudart
40+
CUDA::cuda_driver
41+
MPI::MPI_CXX
42+
nvshmem::nvshmem
43+
)
44+
endif()

csrc/all_to_all/all_to_all.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#include "all_to_all.h"
2+
3+
#include "core/cuda_utils.h"
4+
#include "core/utils.h"
5+
6+
#include <cuda_runtime.h>
7+
8+
using namespace pplx;
9+
10+
namespace {
11+
template <typename T> T *mallocZeroBuffer(size_t size) {
12+
T *ptr;
13+
CUDACHECK(cudaMalloc(&ptr, size * sizeof(T)));
14+
cudaMemset(ptr, 0, size * sizeof(T));
15+
return ptr;
16+
}
17+
} // namespace
18+
19+
AllToAll::AllToAll(
20+
size_t maxNumTokens,
21+
size_t numExperts,
22+
size_t expertsPerToken,
23+
unsigned rank,
24+
unsigned worldSize,
25+
unsigned dpSize,
26+
size_t hiddenDim,
27+
size_t hiddenDimBytes,
28+
size_t hiddenDimScaleBytes
29+
)
30+
: maxNumTokens(maxNumTokens),
31+
numExperts(numExperts),
32+
numLocalExperts(ceil_div<uint32_t>(numExperts, worldSize)),
33+
numDPGroups(ceil_div<uint32_t>(worldSize, dpSize)),
34+
expertsPerToken(expertsPerToken),
35+
hiddenDim(hiddenDim),
36+
hiddenDimBytes(hiddenDimBytes),
37+
hiddenDimScaleBytes(hiddenDimScaleBytes),
38+
rank(rank),
39+
worldSize(worldSize),
40+
dpSize(dpSize),
41+
maxBatchTokens(numLocalExperts * numDPGroups * maxNumTokens) {
42+
43+
ROSE_ASSERT(hiddenDimBytes % 16 == 0, "invalid hidden dim bytes");
44+
ROSE_ASSERT(hiddenDimScaleBytes % 16 == 0, "invalid hidden dim scale bytes");
45+
const size_t perTokenBytes =
46+
round_up<size_t>(hiddenDimBytes + hiddenDimScaleBytes + sizeof(uint32_t), 16);
47+
const size_t maxBatchTokens = numLocalExperts * numDPGroups * maxNumTokens;
48+
49+
ROSE_ASSERT(numLocalExperts != 0, "numLocalExperts is 0");
50+
ROSE_ASSERT(numDPGroups > 1, "at least 2 DP groups are required");
51+
ROSE_ASSERT(hiddenDimScaleBytes <= hiddenDimBytes, "invalid hidden dim bytes");
52+
53+
// Buffers for token tracking.
54+
numTokensPerDP = mallocZeroBuffer<uint32_t>(numLocalExperts * numDPGroups);
55+
sourceIndex = mallocZeroBuffer<uint32_t>(maxBatchTokens);
56+
sourceExpert = mallocZeroBuffer<uint32_t>(maxBatchTokens);
57+
sourceOffset = mallocZeroBuffer<uint32_t>(maxBatchTokens);
58+
sourceGroup = mallocZeroBuffer<uint32_t>(maxBatchTokens);
59+
}
60+
61+
AllToAll::~AllToAll() {
62+
CUDACHECK(cudaFree(numTokensPerDP));
63+
CUDACHECK(cudaFree(sourceIndex));
64+
CUDACHECK(cudaFree(sourceExpert));
65+
CUDACHECK(cudaFree(sourceOffset));
66+
CUDACHECK(cudaFree(sourceGroup));
67+
}

csrc/all_to_all/all_to_all.h

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <cstdlib>
5+
6+
namespace pplx {
7+
8+
/// Specifies which part of a send-and-recv kernel to launch.
9+
enum class SplitMode {
10+
NONE,
11+
SEND,
12+
RECV,
13+
};
14+
15+
/// Base class for all-to-all broadcast kernels.
16+
class AllToAll {
17+
public:
18+
/// @brief Initializes the all-to-all broadcast kernel.
19+
///
20+
/// @param maxNumTokens The maximum number of tokens per DP group.
21+
/// @param numExperts The total number of experts spread across all ranks.
22+
/// @param expertsPerToken The number of experts per token.
23+
/// @param rank The rank of the current process.
24+
/// @param worldSize The number of processes in the world.
25+
/// @param dpSize The size of a DP group.
26+
/// @param hiddenDimBytes The hidden dimension of X, in bytes.
27+
/// @param hiddenDimScaleBytes The hidden dimension of the scale of X, in
28+
/// bytes.
29+
AllToAll(
30+
size_t maxNumTokens,
31+
size_t numExperts,
32+
size_t expertsPerToken,
33+
unsigned rank,
34+
unsigned worldSize,
35+
unsigned dpSize,
36+
size_t hiddenDim,
37+
size_t hiddenDimBytes,
38+
size_t hiddenDimScaleBytes
39+
);
40+
41+
virtual ~AllToAll();
42+
43+
protected:
44+
/// The maximum number of tokens per DP group.
45+
const size_t maxNumTokens;
46+
/// The total number of experts spread across all ranks.
47+
const size_t numExperts;
48+
/// The number of local experts.
49+
const size_t numLocalExperts;
50+
/// The number of DP groups.
51+
const size_t numDPGroups;
52+
/// The number of experts per token.
53+
const size_t expertsPerToken;
54+
/// The hidden dimension of X, in elements.
55+
const size_t hiddenDim;
56+
/// The hidden dimension of X, in bytes.
57+
const size_t hiddenDimBytes;
58+
/// The hidden dimension scale of X, in bytes.
59+
const size_t hiddenDimScaleBytes;
60+
/// The rank of the current process.
61+
const unsigned rank;
62+
/// The number of processes in the world.
63+
const unsigned worldSize;
64+
/// The size of a DP group.
65+
const unsigned dpSize;
66+
/// The maximum number of tokens in a batch.
67+
const size_t maxBatchTokens;
68+
69+
/// @section Internal buffers communicating between scatter and gather.
70+
uint32_t *numTokensPerDP = nullptr;
71+
uint32_t *sourceIndex = nullptr;
72+
uint32_t *sourceExpert = nullptr;
73+
uint32_t *sourceOffset = nullptr;
74+
uint32_t *sourceGroup = nullptr;
75+
};
76+
77+
} // namespace pplx

0 commit comments

Comments
 (0)