Skip to content

Commit 215d221

Browse files
committed
Array2D bindings to support passing data to/from numpy
1 parent 4bee915 commit 215d221

File tree

4 files changed

+54
-2
lines changed

4 files changed

+54
-2
lines changed

genmetaballs/src/cuda/bindings.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cstdint>
22
#include <nanobind/nanobind.h>
3+
#include <nanobind/ndarray.h>
34
#include <nanobind/operators.h>
45
#include <nanobind/stl/vector.h>
56

@@ -79,4 +80,23 @@ NB_MODULE(_genmetaballs_bindings, m) {
7980
nb::module_ utils = m.def_submodule("utils");
8081
utils.def("sigmoid", sigmoid, nb::arg("x"), "Compute the sigmoid function: 1 / (1 + exp(-x))");
8182

83+
nb::class_<Array2D<float>>(utils, "FloatArray2D")
84+
// TODO: switch to the array_api in future nanobind release
85+
// https://nanobind.readthedocs.io/en/latest/api_extra.html#_CPPv4N8nanobind9array_apiE
86+
.def(
87+
"numpy",
88+
[](const Array2D<float>& self) {
89+
return nb::ndarray<float, nb::numpy, nb::c_contig>(
90+
self.data(), {self.num_rows(), self.num_cols()});
91+
},
92+
nb::rv_policy::reference_internal)
93+
.def_static("from_array",
94+
[](const nb::ndarray<float, nb::ndim<2>, nb::c_contig>& array) {
95+
return Array2D<float>(array.data(), array.shape(0), array.shape(1));
96+
})
97+
.def_prop_ro("num_rows", &Array2D<float>::num_rows)
98+
.def_prop_ro("num_cols", &Array2D<float>::num_cols)
99+
.def_prop_ro("ndim", &Array2D<float>::ndim)
100+
.def_prop_ro("size", &Array2D<float>::size);
101+
82102
} // NB_MODULE(_genmetaballs_bindings)

genmetaballs/src/cuda/core/utils.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ public:
5656
CUDA_CALLABLE constexpr auto size() const noexcept {
5757
return data_view_.size();
5858
}
59+
CUDA_CALLABLE constexpr T* data() const noexcept {
60+
return data_view_.data_handle();
61+
}
5962
}; // class Array2D
6063

6164
// Type deduction guide

genmetaballs/src/genmetaballs/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
TwoParameterConfidence,
44
ZeroParameterConfidence,
55
)
6-
from genmetaballs._genmetaballs_bindings.utils import sigmoid
6+
from genmetaballs._genmetaballs_bindings.utils import FloatArray2D, sigmoid
77

88
__all__ = [
9+
"FloatArray2D",
910
"ZeroParameterConfidence",
1011
"TwoParameterConfidence",
1112
"geometry",

tests/python_tests/test_utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
from scipy.special import expit
44

5-
from genmetaballs.core import sigmoid
5+
from genmetaballs.core import FloatArray2D, sigmoid
66

77
NUM_RNG_SEEDS_PER_TEST = 5
88
NUM_N_VALUES_PER_TEST = 5
@@ -58,3 +58,31 @@ def test_sigmoid_edge_cases(x: float) -> None:
5858
assert np.isclose(actual, expected, rtol=1e-5, atol=1e-6)
5959
assert actual >= 0.0
6060
assert actual <= 1.0
61+
62+
63+
def test_float_array2d_creation_and_view():
64+
"""Test creation of Array2D from a numpy array."""
65+
rows, cols = 4, 5
66+
data = np.arange(rows * cols, dtype=np.float32).reshape((rows, cols))
67+
array_2d = FloatArray2D.from_array(data)
68+
69+
assert array_2d.num_rows == rows
70+
assert array_2d.num_cols == cols
71+
assert array_2d.ndim == 2
72+
73+
# then try converting back to numpy array via view
74+
data_view = array_2d.numpy()
75+
print(type(data_view))
76+
assert np.allclose(data, data_view)
77+
78+
# check that the view is writable and changes reflect back to original data
79+
data_view[0, 0] = 999.0
80+
assert np.isclose(data[0, 0], 999.0)
81+
82+
83+
def test_create_invalid_array2d():
84+
"""Test that creating Array2D with invalid dimensions raises errors."""
85+
data = np.arange(12, dtype=np.float32).reshape((3, 4))
86+
87+
with pytest.raises(TypeError):
88+
FloatArray2D.from_array(data.reshape((3, 4, 1))) # not 2D

0 commit comments

Comments
 (0)