Skip to content

Commit 9ecf27c

Browse files
committed
Use thread-safe gpu-lite
1 parent 82302f2 commit 9ecf27c

File tree

5 files changed

+174
-4
lines changed

5 files changed

+174
-4
lines changed

tox.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ deps = cmake
125125
commands =
126126
cmake -B {envtmpdir} -S vesin -DVESIN_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug
127127
cmake --build {envtmpdir} --config Debug
128-
ctest --test-dir {envtmpdir} --build-config Debug
128+
ctest --test-dir {envtmpdir} --build-config Debug --output-on-failure
129129

130130

131131
[testenv:fortran-tests]
@@ -137,7 +137,7 @@ deps = cmake
137137
commands =
138138
cmake -B {envtmpdir} -S fortran -DVESIN_FORTRAN_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug
139139
cmake --build {envtmpdir} --config Debug
140-
ctest --test-dir {envtmpdir} --build-config Debug
140+
ctest --test-dir {envtmpdir} --build-config Debug --output-on-failure
141141

142142

143143
[testenv:lint]

vesin/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ FetchContent_Declare(
8181
# GIT_REPOSITORY https://github.com/rubber-duck-debug/gpu-lite.git
8282
# GIT_TAG 78b4bad091e329b332def47ee9692e367a28ea85 # v1.0.0
8383
GIT_REPOSITORY https://github.com/Luthaf/gpu-lite.git
84-
GIT_TAG f59550e6eafdf51a3fd266a50ec080f640c0991c
84+
GIT_TAG e50589e9f78a154425917ebcfee56471f1d067fc
8585
EXCLUDE_FROM_ALL
8686
)
8787
FetchContent_MakeAvailable(gpulite)

vesin/src/vesin_cuda.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ void vesin::cuda::neighbors(
520520
auto* d_overflow_flag = extras->overflow_flag;
521521
size_t max_pairs = extras->max_pairs;
522522

523-
auto& factory = KernelFactory::instance();
523+
auto& factory = KernelFactory::instance(device_id);
524524

525525
if (extras->box_diag == nullptr) {
526526
CUDART_SAFE_CALL(CUDART_INSTANCE.cudaMalloc((void**)&extras->box_diag, sizeof(double) * 3));

vesin/tests/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ else()
2323
set(TEST_COMMAND "")
2424
endif()
2525

26+
find_package(CUDAToolkit)
27+
2628

2729
file(GLOB ALL_TESTS *.cpp)
2830
foreach(_file_ ${ALL_TESTS})
@@ -34,4 +36,9 @@ foreach(_file_ ${ALL_TESTS})
3436
NAME ${_name_}
3537
COMMAND ${TEST_COMMAND} $<TARGET_FILE:${_name_}>
3638
)
39+
40+
if (CUDAToolkit_FOUND)
41+
target_compile_definitions(${_name_} PRIVATE VESIN_TESTS_WITH_CUDA)
42+
target_link_libraries(${_name_} CUDA::cudart)
43+
endif()
3744
endforeach()

vesin/tests/cuda.cpp

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#include <catch2/catch_test_macros.hpp>
2+
3+
#ifdef VESIN_TESTS_WITH_CUDA
4+
5+
#include <cmath>
6+
#include <thread>
7+
8+
#include <cuda_runtime.h>
9+
10+
#include <vesin.h>
11+
12+
void check_cuda(cudaError_t status) {
13+
if (status != cudaSuccess) {
14+
const char* message = cudaGetErrorString(status);
15+
FAIL(message);
16+
}
17+
}
18+
19+
void run_cuda_test(int device_id) {
20+
check_cuda(cudaSetDevice(device_id));
21+
22+
double points[][3] = {
23+
{0.0, 0.0, 0.0},
24+
{1.0, 1.0, 1.0},
25+
{2.0, 2.0, 2.0},
26+
};
27+
size_t n_points = 3;
28+
double (*d_points)[3] = nullptr;
29+
check_cuda(cudaMalloc(&d_points, sizeof(double) * n_points * 3));
30+
check_cuda(cudaMemcpy(d_points, points, sizeof(double) * n_points * 3, cudaMemcpyHostToDevice));
31+
32+
double box[3][3] = {
33+
{0.0, 3.0, 3.0},
34+
{3.0, 0.0, 3.0},
35+
{3.0, 3.0, 0.0},
36+
};
37+
double (*d_box)[3] = nullptr;
38+
check_cuda(cudaMalloc(&d_box, sizeof(double) * 9));
39+
check_cuda(cudaMemcpy(d_box, box, sizeof(double) * 9, cudaMemcpyHostToDevice));
40+
41+
bool periodic[3] = {true, true, true};
42+
bool* d_periodic = nullptr;
43+
check_cuda(cudaMalloc(&d_periodic, sizeof(bool) * 3));
44+
check_cuda(cudaMemcpy(d_periodic, periodic, sizeof(bool) * 3, cudaMemcpyHostToDevice));
45+
46+
VesinNeighborList neighbors;
47+
48+
auto options = VesinOptions();
49+
options.cutoff = 3.0;
50+
options.full = false;
51+
options.sorted = false;
52+
options.algorithm = VesinAutoAlgorithm;
53+
options.return_shifts = true;
54+
options.return_distances = true;
55+
options.return_vectors = true;
56+
57+
const char* error_message = nullptr;
58+
auto status = vesin_neighbors(
59+
d_points,
60+
n_points,
61+
d_box,
62+
d_periodic,
63+
{VesinDeviceKind::VesinCUDA, device_id},
64+
options,
65+
&neighbors,
66+
&error_message
67+
);
68+
69+
REQUIRE(error_message == nullptr);
70+
REQUIRE(status == EXIT_SUCCESS);
71+
72+
CHECK(neighbors.length == 5);
73+
CHECK(neighbors.pairs != nullptr);
74+
CHECK(neighbors.shifts != nullptr);
75+
CHECK(neighbors.distances != nullptr);
76+
CHECK(neighbors.vectors != nullptr);
77+
78+
auto* h_pairs = static_cast<size_t (*)[2]>(malloc(sizeof(size_t) * neighbors.length * 2));
79+
check_cuda(cudaMemcpy(h_pairs, neighbors.pairs, sizeof(size_t) * neighbors.length * 2, cudaMemcpyDeviceToHost));
80+
81+
auto* h_shifts = static_cast<int32_t (*)[3]>(malloc(sizeof(int32_t) * neighbors.length * 3));
82+
check_cuda(cudaMemcpy(h_shifts, neighbors.shifts, sizeof(int32_t) * neighbors.length * 3, cudaMemcpyDeviceToHost));
83+
84+
auto* h_distances = static_cast<double*>(malloc(sizeof(double) * neighbors.length));
85+
check_cuda(cudaMemcpy(h_distances, neighbors.distances, sizeof(double) * neighbors.length, cudaMemcpyDeviceToHost));
86+
87+
auto* h_vectors = static_cast<double (*)[3]>(malloc(sizeof(double) * neighbors.length * 3));
88+
check_cuda(cudaMemcpy(h_vectors, neighbors.vectors, sizeof(double) * neighbors.length * 3, cudaMemcpyDeviceToHost));
89+
90+
for (size_t i = 0; i < neighbors.length; ++i) {
91+
if (h_pairs[i][0] == 0 && h_pairs[i][1] == 2) {
92+
// we have three pairs between 0 and 2 with shifts (-1, 0, 0),
93+
// (0, -1, 0), and (0, 0, -1)
94+
CHECK(h_distances[i] == std::sqrt(6.0));
95+
96+
if (h_shifts[i][0] == -1 && h_shifts[i][1] == 0 && h_shifts[i][2] == 0) {
97+
CHECK(h_vectors[i][0] == 2.0);
98+
CHECK(h_vectors[i][1] == -1.0);
99+
CHECK(h_vectors[i][2] == -1.0);
100+
} else if (h_shifts[i][0] == 0 && h_shifts[i][1] == -1 && h_shifts[i][2] == 0) {
101+
CHECK(h_vectors[i][0] == -1.0);
102+
CHECK(h_vectors[i][1] == 2.0);
103+
CHECK(h_vectors[i][2] == -1.0);
104+
} else if (h_shifts[i][0] == 0 && h_shifts[i][1] == 0 && h_shifts[i][2] == -1) {
105+
CHECK(h_vectors[i][0] == -1.0);
106+
CHECK(h_vectors[i][1] == -1.0);
107+
CHECK(h_vectors[i][2] == 2.0);
108+
} else {
109+
FAIL("Unexpected shift for pair (0, 2): (" + std::to_string(h_shifts[i][0]) + ", " + std::to_string(h_shifts[i][1]) + ", " + std::to_string(h_shifts[i][2]) + ")");
110+
}
111+
112+
} else if ((h_pairs[i][0] == 0 && h_pairs[i][1] == 1) || (h_pairs[i][0] == 1 && h_pairs[i][1] == 2)) {
113+
// pairs between 0-1 or 1-2 should have zero shifts, distance
114+
// sqrt(3), and vector (1, 1, 1)
115+
CHECK(h_shifts[i][0] == 0);
116+
CHECK(h_shifts[i][1] == 0);
117+
CHECK(h_shifts[i][2] == 0);
118+
119+
CHECK(h_distances[i] == std::sqrt(3.0));
120+
CHECK(h_vectors[i][0] == 1.0);
121+
CHECK(h_vectors[i][1] == 1.0);
122+
CHECK(h_vectors[i][2] == 1.0);
123+
} else {
124+
FAIL("Unexpected pair: (" + std::to_string(h_pairs[i][0]) + ", " + std::to_string(h_pairs[i][1]) + ")");
125+
}
126+
}
127+
128+
// Clean up
129+
vesin_free(&neighbors);
130+
131+
free(h_pairs);
132+
free(h_shifts);
133+
free(h_distances);
134+
free(h_vectors);
135+
136+
check_cuda(cudaFree(d_points));
137+
check_cuda(cudaFree(d_box));
138+
check_cuda(cudaFree(d_periodic));
139+
}
140+
141+
TEST_CASE("Test CUDA") {
142+
// get the number of CUDA devices
143+
int n_devices = 0;
144+
check_cuda(cudaGetDeviceCount(&n_devices));
145+
REQUIRE(n_devices > 0);
146+
147+
// start multiple threads to test concurrent execution
148+
auto threads = std::vector<std::thread>();
149+
for (int thread_id = 0; thread_id < 10; ++thread_id) {
150+
std::thread t(run_cuda_test, thread_id % n_devices);
151+
threads.push_back(std::move(t));
152+
}
153+
154+
for (auto& t : threads) {
155+
t.join();
156+
}
157+
}
158+
159+
#else
160+
161+
TEST_CASE("CUDA tests are disabled") {}
162+
163+
#endif

0 commit comments

Comments
 (0)