Skip to content

Commit ddc6916

Browse files
authored
Release updates to pplx_kernels (#11)
* Intra-node dispatch/combine with NVLink * Performance improvements to inter-node dispatch/combine * FP16 changes were pulled into the internal repo, then re-released here
1 parent be55fd7 commit ddc6916

31 files changed

+1981
-520
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ cd pplx-kernels
9090
mkdir build-cmake
9191
cd build-cmake
9292

93-
TORCH_PREFIX_PATH=$(python3 -c 'import torch; print(torch.utils.cmake_prefix_path)')
93+
export TORCH_PREFIX_PATH=$(python3 -c 'import torch; print(torch.utils.cmake_prefix_path)')
9494

9595
cmake ../csrc \
9696
-GNinja \
@@ -105,12 +105,12 @@ ninja test_all_to_all bench_all_to_all
105105
To run the all-to-all tests on one node:
106106

107107
```bash
108-
NVSHMEM_REMOTE_TRANSPORT=none mpirun -np 4 ./test_all_to_all
108+
NVSHMEM_REMOTE_TRANSPORT=none mpirun -np 4 ./all_to_all/test_all_to_all
109109
```
110110

111111

112112
To run the all-to-all benchmarks on one node:
113113

114114
```bash
115-
NVSHMEM_REMOTE_TRANSPORT=none mpirun -np 4 ./bench_all_to_all
115+
NVSHMEM_REMOTE_TRANSPORT=none mpirun -np 4 ./all_to_all/bench_all_to_all
116116
```

csrc/CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cmake_minimum_required(VERSION 3.26)
1+
cmake_minimum_required(VERSION 3.22)
22
project(PPLXKernels
33
VERSION 0.0.1
44
DESCRIPTION "PPLX Kernels"
@@ -52,6 +52,7 @@ endfunction()
5252

5353
# === Library targets ===
5454
add_subdirectory(all_to_all)
55+
add_subdirectory(core)
5556

5657
# Main shared library
5758
add_library(pplx_kernels SHARED
@@ -60,13 +61,14 @@ add_library(pplx_kernels SHARED
6061
bindings/nvshmem.cpp
6162
)
6263
target_link_libraries(pplx_kernels PUBLIC
63-
all_to_all_lib
64+
all_to_all_internode_lib
65+
all_to_all_intranode_lib
66+
core_lib
6467
torch::py_limited
6568
Python::Module
6669
CUDA::cuda_driver
6770
CUDA::cudart
6871
nvshmem::nvshmem_host
69-
nvshmem::nvshmem_device
7072
)
7173
set_target_properties(pplx_kernels PROPERTIES
7274
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../src/pplx_kernels

csrc/all_to_all/CMakeLists.txt

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,46 @@
11
# All-to-All library
22

3-
add_library(all_to_all_lib STATIC
3+
add_library(all_to_all_common STATIC
44
all_to_all.cpp
5-
internode_dispatch.cu
5+
)
6+
7+
add_library(all_to_all_intranode_lib STATIC
8+
intranode_combine.cu
9+
intranode_dispatch.cu
10+
intranode.cpp
11+
)
12+
target_link_libraries(all_to_all_intranode_lib PUBLIC
13+
all_to_all_common
14+
nvshmem::nvshmem
15+
CUDA::cudart
16+
)
17+
set_cuda_compile_options(all_to_all_intranode_lib)
18+
19+
add_library(all_to_all_internode_lib STATIC
620
internode_combine.cu
21+
internode_dispatch.cu
722
internode.cpp
823
)
9-
target_link_libraries(all_to_all_lib PUBLIC
10-
nvshmem::nvshmem_host
11-
nvshmem::nvshmem_device
24+
target_link_libraries(all_to_all_internode_lib PUBLIC
25+
all_to_all_common
26+
nvshmem::nvshmem
1227
CUDA::cudart
1328
)
14-
set_cuda_compile_options(all_to_all_lib)
29+
set_cuda_compile_options(all_to_all_internode_lib)
1530

1631
if(WITH_TESTS)
1732
# All-to-All test
1833
add_executable(test_all_to_all
1934
test_all_to_all.cpp
2035
)
2136
target_link_libraries(test_all_to_all PUBLIC
22-
all_to_all_lib
37+
all_to_all_intranode_lib
38+
all_to_all_internode_lib
39+
core_lib
2340
CUDA::cudart
2441
CUDA::cuda_driver
2542
MPI::MPI_CXX
26-
nvshmem::nvshmem_host
27-
nvshmem::nvshmem_device
43+
nvshmem::nvshmem
2844
)
2945
set_cuda_compile_options(test_all_to_all)
3046
add_test(NAME AllToAllTest
@@ -37,11 +53,12 @@ if (WITH_BENCHMARKS)
3753
bench_all_to_all.cpp
3854
)
3955
target_link_libraries(bench_all_to_all PUBLIC
40-
all_to_all_lib
56+
all_to_all_intranode_lib
57+
all_to_all_internode_lib
58+
core_lib
4159
CUDA::cudart
4260
CUDA::cuda_driver
4361
MPI::MPI_CXX
44-
nvshmem::nvshmem_host
45-
nvshmem::nvshmem_device
62+
nvshmem::nvshmem
4663
)
4764
endif()

csrc/all_to_all/all_to_all.cpp

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,9 @@
11
#include "all_to_all.h"
22

3-
#include "core/cuda_utils.h"
43
#include "core/utils.h"
54

6-
#include <cuda_runtime.h>
7-
85
using namespace pplx;
96

10-
namespace {
11-
template <typename T> T *mallocZeroBuffer(size_t size) {
12-
T *ptr;
13-
CUDACHECK(cudaMalloc(&ptr, size * sizeof(T)));
14-
cudaMemset(ptr, 0, size * sizeof(T));
15-
return ptr;
16-
}
17-
} // namespace
18-
197
AllToAll::AllToAll(
208
size_t maxNumTokens,
219
size_t numExperts,
@@ -37,31 +25,16 @@ AllToAll::AllToAll(
3725
hiddenDimScaleBytes(hiddenDimScaleBytes),
3826
rank(rank),
3927
worldSize(worldSize),
40-
dpSize(dpSize),
41-
maxBatchTokens(numLocalExperts * numDPGroups * maxNumTokens) {
28+
dpSize(dpSize) {
4229

4330
ROSE_ASSERT(hiddenDimBytes % 16 == 0, "invalid hidden dim bytes");
4431
ROSE_ASSERT(hiddenDimScaleBytes % 16 == 0, "invalid hidden dim scale bytes");
4532
const size_t perTokenBytes =
4633
round_up<size_t>(hiddenDimBytes + hiddenDimScaleBytes + sizeof(uint32_t), 16);
47-
const size_t maxBatchTokens = numLocalExperts * numDPGroups * maxNumTokens;
4834

4935
ROSE_ASSERT(numLocalExperts != 0, "numLocalExperts is 0");
5036
ROSE_ASSERT(numDPGroups > 1, "at least 2 DP groups are required");
5137
ROSE_ASSERT(hiddenDimScaleBytes <= hiddenDimBytes, "invalid hidden dim bytes");
52-
53-
// Buffers for token tracking.
54-
numTokensPerDP = mallocZeroBuffer<uint32_t>(numLocalExperts * numDPGroups);
55-
sourceIndex = mallocZeroBuffer<uint32_t>(maxBatchTokens);
56-
sourceExpert = mallocZeroBuffer<uint32_t>(maxBatchTokens);
57-
sourceOffset = mallocZeroBuffer<uint32_t>(maxBatchTokens);
58-
sourceGroup = mallocZeroBuffer<uint32_t>(maxBatchTokens);
5938
}
6039

61-
AllToAll::~AllToAll() {
62-
CUDACHECK(cudaFree(numTokensPerDP));
63-
CUDACHECK(cudaFree(sourceIndex));
64-
CUDACHECK(cudaFree(sourceExpert));
65-
CUDACHECK(cudaFree(sourceOffset));
66-
CUDACHECK(cudaFree(sourceGroup));
67-
}
40+
AllToAll::~AllToAll() {}

csrc/all_to_all/all_to_all.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,6 @@ class AllToAll {
6363
const unsigned worldSize;
6464
/// The size of a DP group.
6565
const unsigned dpSize;
66-
/// The maximum number of tokens in a batch.
67-
const size_t maxBatchTokens;
68-
69-
/// @section Internal buffers communicating between dispatch and combine.
70-
uint32_t *numTokensPerDP = nullptr;
71-
uint32_t *sourceIndex = nullptr;
72-
uint32_t *sourceExpert = nullptr;
73-
uint32_t *sourceOffset = nullptr;
74-
uint32_t *sourceGroup = nullptr;
7566
};
7667

7768
} // namespace pplx

0 commit comments

Comments
 (0)