Skip to content

Commit 2b87ea5

Browse files
authored
Add DeviceType to Array2D type annotation (#23)
Recreating #21 because I accidentally closes it and GitHub doesn't let me reopen the PR.. ---- (Closes MET-14) Note that this PR is curently set to base on `xiaoyan/array2d-nanobind` instead of `master` so that only incremental diffs are displayed. ## Summary of Changes Similar to raw pointers, our current `Array2D` type does not record *where* the underlying data comes from, as the memory accessing pattern are roughly the same on CPU & GPU. However, mentally tacking the location of the memory can be error-prone. In addition, many other Python array/tensor libraries store the device type explicitly, and we won't be able easily convert to them without knowing where our memory is. As such, I'm adding a `DeviceType` template parameter to our `Array2D` to keep track of the location of the memory. With this change, we are finally able to take CPU/GPU buffers from Python side and return them correctly without running into segfault. ## Test Plans You can find examples of creating `Array2D` from CPU/GPU memory buffers with numpy and JAX ih the included `test_utils.py`. As always, to run all the tests: ```bash pixi run test ```
1 parent d224feb commit 2b87ea5

File tree

5 files changed

+93
-32
lines changed

5 files changed

+93
-32
lines changed

genmetaballs/src/cuda/bindings.cu

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
namespace nb = nanobind;
1212

13+
template <typename T, MemoryLocation location>
14+
void bind_array2d(nb::module_& m, const char* name);
15+
1316
NB_MODULE(_genmetaballs_bindings, m) {
1417

1518
/*
@@ -80,23 +83,38 @@ NB_MODULE(_genmetaballs_bindings, m) {
8083
nb::module_ utils = m.def_submodule("utils");
8184
utils.def("sigmoid", sigmoid, nb::arg("x"), "Compute the sigmoid function: 1 / (1 + exp(-x))");
8285

83-
nb::class_<Array2D<float>>(utils, "FloatArray2D")
86+
bind_array2d<float, MemoryLocation::HOST>(utils, "CPUFloatArray2D");
87+
bind_array2d<float, MemoryLocation::DEVICE>(utils, "GPUFloatArray2D");
88+
89+
} // NB_MODULE(_genmetaballs_bindings)
90+
91+
template <typename T, MemoryLocation location>
92+
void bind_array2d(nb::module_& m, const char* name) {
93+
using nb_device_type =
94+
std::conditional_t<location == MemoryLocation::HOST, nb::device::cpu, nb::device::cuda>;
95+
nb::class_<Array2D<T, location>>(m, name)
8496
.def_static("from_array",
85-
[](const nb::ndarray<float, nb::ndim<2>, nb::c_contig>& array) {
86-
return Array2D<float>(array.data(), array.shape(0), array.shape(1));
97+
[](const nb::ndarray<T, nb::ndim<2>, nb::c_contig, nb_device_type>& array) {
98+
return Array2D<T, location>(array.data(), array.shape(0), array.shape(1));
8799
})
88100
// TODO: switch to the array_api in future nanobind release
89101
// https://nanobind.readthedocs.io/en/latest/api_extra.html#_CPPv4N8nanobind9array_apiE
90102
.def(
91-
"numpy",
92-
[](const Array2D<float>& self) {
93-
return nb::ndarray<float, nb::numpy, nb::c_contig>(
103+
"as_numpy",
104+
[](const Array2D<T, location>& self) {
105+
return nb::ndarray<T, nb::numpy, nb::c_contig, nb_device_type>(
94106
self.data(), {self.num_rows(), self.num_cols()});
95107
},
96108
nb::rv_policy::reference_internal)
97-
.def_prop_ro("num_rows", &Array2D<float>::num_rows)
98-
.def_prop_ro("num_cols", &Array2D<float>::num_cols)
99-
.def_prop_ro("ndim", &Array2D<float>::ndim)
100-
.def_prop_ro("size", &Array2D<float>::size);
101-
102-
} // NB_MODULE(_genmetaballs_bindings)
109+
.def(
110+
"as_jax",
111+
[](const Array2D<T, location>& self) {
112+
return nb::ndarray<T, nb::jax, nb::c_contig, nb_device_type>(
113+
self.data(), {self.num_rows(), self.num_cols()});
114+
},
115+
nb::rv_policy::reference_internal)
116+
.def_prop_ro("num_rows", &Array2D<T, location>::num_rows)
117+
.def_prop_ro("num_cols", &Array2D<T, location>::num_cols)
118+
.def_prop_ro("ndim", &Array2D<T, location>::ndim)
119+
.def_prop_ro("size", &Array2D<T, location>::size);
120+
}

genmetaballs/src/cuda/core/utils.cuh

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ CUDA_CALLABLE __forceinline__ float sigmoid(float x) {
2121
return 1.0f / (1.0f + expf(-x));
2222
}
2323

24+
enum class MemoryLocation { HOST, DEVICE };
25+
2426
// Non-owning 2D view into a contiguous array in either host or device memory
25-
template <typename T>
27+
template <typename T, MemoryLocation location>
2628
class Array2D {
2729
private:
2830
cuda::std::mdspan<
@@ -60,10 +62,3 @@ public:
6062
return data_view_.data_handle();
6163
}
6264
}; // class Array2D
63-
64-
// Type deduction guide
65-
// if initialized with (Pointer, int, int), deduce T by looking at what raw_pointer_cast returns
66-
// so we can write Array2D(array_ptr, rows, cols) instead of Array2D<Type>(array_ptr, rows, cols)
67-
template <typename Pointer>
68-
Array2D(Pointer, uint32_t, uint32_t)
69-
-> Array2D<typename std::pointer_traits<Pointer>::element_type>;

genmetaballs/src/genmetaballs/core/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,26 @@
33
TwoParameterConfidence,
44
ZeroParameterConfidence,
55
)
6-
from genmetaballs._genmetaballs_bindings.utils import FloatArray2D, sigmoid
6+
from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid
7+
8+
9+
def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D:
10+
"""Create a FloatArray2D on the specified device from an array.
11+
12+
Args:
13+
data: A 2D array of type float32.
14+
device: 'cpu' or 'gpu' to specify the target device.
15+
"""
16+
if device == "cpu":
17+
return CPUFloatArray2D.from_array(data)
18+
elif device == "gpu":
19+
return GPUFloatArray2D.from_array(data)
20+
else:
21+
raise ValueError(f"Unsupported device type: {device}")
22+
723

824
__all__ = [
9-
"FloatArray2D",
25+
"array2d_float",
1026
"ZeroParameterConfidence",
1127
"TwoParameterConfidence",
1228
"geometry",

tests/cpp_tests/test_utils.cu

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ TEST(GpuSigmoidTest, SigmoidGPUWithinBounds) {
9494

9595
namespace test_utils_gpu {
9696
// CUDA kernel to fill Array2D with sequential values
97-
__global__ void fill_array2d_kernel(Array2D<float> array2d) {
97+
__global__ void fill_array2d_kernel(Array2D<float, MemoryLocation::DEVICE> array2d) {
9898
uint32_t i = threadIdx.x;
9999
uint32_t j = threadIdx.y;
100100

@@ -120,8 +120,11 @@ TYPED_TEST(Array2DTestFixture, CreateAndAccessArray2D) {
120120
uint32_t cols = 6;
121121

122122
auto data = TypeParam(rows * cols);
123+
constexpr auto device_type = std::is_same_v<TypeParam, thrust::device_vector<float>>
124+
? MemoryLocation::DEVICE
125+
: MemoryLocation::HOST;
123126
// create 2D view into the underlying data on host or device
124-
auto array2d = Array2D(data.data(), rows, cols);
127+
auto array2d = Array2D<float, device_type>(data.data(), rows, cols);
125128

126129
if constexpr (std::is_same_v<TypeParam, std::vector<float>>) {
127130
for (auto i = 0; i < rows - 1; i++) {
@@ -164,7 +167,10 @@ TYPED_TEST(Array2DTestFixture, ViewModifiesUnderlyingData) {
164167
uint32_t rows = 3;
165168
uint32_t cols = 4;
166169
auto data = TypeParam(rows * cols, 0.0f);
167-
auto array2d = Array2D(data.data(), rows, cols);
170+
constexpr auto device_type = std::is_same_v<TypeParam, thrust::device_vector<float>>
171+
? MemoryLocation::DEVICE
172+
: MemoryLocation::HOST;
173+
auto array2d = Array2D<float, device_type>(data.data(), rows, cols);
168174

169175
// Modify through view
170176
array2d[1][2] = 42.5f;
@@ -184,8 +190,11 @@ TYPED_TEST(Array2DTestFixture, MultipleViewsOfSameData) {
184190
uint32_t rows = 2;
185191
uint32_t cols = 3;
186192
auto data = TypeParam(rows * cols, 0.0f);
187-
auto view1 = Array2D(data.data(), rows, cols);
188-
auto view2 = Array2D(data.data(), rows, cols);
193+
constexpr auto device_type = std::is_same_v<TypeParam, thrust::device_vector<float>>
194+
? MemoryLocation::DEVICE
195+
: MemoryLocation::HOST;
196+
auto view1 = Array2D<float, device_type>(data.data(), rows, cols);
197+
auto view2 = Array2D<float, device_type>(data.data(), rows, cols);
189198

190199
// Modify through view1
191200
view1[0][0] = 100.0f;

tests/python_tests/test_utils.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import jax
2+
import jax.numpy as jnp
13
import numpy as np
24
import pytest
35
from scipy.special import expit
46

5-
from genmetaballs.core import FloatArray2D, sigmoid
7+
from genmetaballs.core import array2d_float, sigmoid
68

79
NUM_RNG_SEEDS_PER_TEST = 5
810
NUM_N_VALUES_PER_TEST = 5
@@ -60,18 +62,39 @@ def test_sigmoid_edge_cases(x: float) -> None:
6062
assert actual <= 1.0
6163

6264

63-
def test_float_array2d_creation_and_view():
65+
@pytest.mark.parametrize("device", ["cpu", "gpu"])
66+
def test_array2d_float_creation_on_jax_devices(device: str):
6467
"""Test creation of Array2D from a numpy array."""
6568
rows, cols = 4, 5
69+
data = jnp.arange(rows * cols, dtype=jnp.float32).reshape((rows, cols))
70+
jax_device = jax.devices(device)[0]
71+
data = jax.device_put(data, device=jax_device)
72+
array_2d = array2d_float(data, device=device)
73+
74+
assert array_2d.num_rows == rows
75+
assert array_2d.num_cols == cols
76+
assert array_2d.ndim == 2
77+
78+
# then try converting back to numpy array via view
79+
data_view = array_2d.as_jax()
80+
assert data_view.device == data.device
81+
assert jnp.allclose(data, data_view)
82+
83+
# Note: we can't test writability of shared view here since JAX arrays are immutable
84+
85+
86+
def test_float_array2d_view_numpy():
87+
"""Test creation of Array2D from a numpy array."""
88+
rows, cols = 3, 4
6689
data = np.arange(rows * cols, dtype=np.float32).reshape((rows, cols))
67-
array_2d = FloatArray2D.from_array(data)
90+
array_2d = array2d_float(data, device="cpu")
6891

6992
assert array_2d.num_rows == rows
7093
assert array_2d.num_cols == cols
7194
assert array_2d.ndim == 2
7295

7396
# then try converting back to numpy array via view
74-
data_view = array_2d.numpy()
97+
data_view = array_2d.as_numpy()
7598
assert np.allclose(data, data_view)
7699

77100
# check that the view is writable and changes reflect back to original data
@@ -84,4 +107,4 @@ def test_create_invalid_array2d():
84107
data = np.arange(12, dtype=np.float32).reshape((3, 4))
85108

86109
with pytest.raises(TypeError):
87-
FloatArray2D.from_array(data.reshape((3, 4, 1))) # not 2D
110+
array2d_float(data.reshape((3, 4, 1)), device="cpu") # not 2D

0 commit comments

Comments
 (0)