Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions genmetaballs/src/cuda/bindings.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cstdint>
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <nanobind/operators.h>
#include <nanobind/stl/vector.h>

Expand Down Expand Up @@ -79,4 +80,23 @@ NB_MODULE(_genmetaballs_bindings, m) {
nb::module_ utils = m.def_submodule("utils");
utils.def("sigmoid", sigmoid, nb::arg("x"), "Compute the sigmoid function: 1 / (1 + exp(-x))");

nb::class_<Array2D<float>>(utils, "FloatArray2D")
.def_static("from_array",
[](const nb::ndarray<float, nb::ndim<2>, nb::c_contig>& array) {
return Array2D<float>(array.data(), array.shape(0), array.shape(1));
})
// TODO: switch to the array_api in future nanobind release
// https://nanobind.readthedocs.io/en/latest/api_extra.html#_CPPv4N8nanobind9array_apiE
.def(
"numpy",
[](const Array2D<float>& self) {
return nb::ndarray<float, nb::numpy, nb::c_contig>(
self.data(), {self.num_rows(), self.num_cols()});
},
nb::rv_policy::reference_internal)
.def_prop_ro("num_rows", &Array2D<float>::num_rows)
.def_prop_ro("num_cols", &Array2D<float>::num_cols)
.def_prop_ro("ndim", &Array2D<float>::ndim)
.def_prop_ro("size", &Array2D<float>::size);

} // NB_MODULE(_genmetaballs_bindings)
3 changes: 3 additions & 0 deletions genmetaballs/src/cuda/core/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public:
CUDA_CALLABLE constexpr auto size() const noexcept {
return data_view_.size();
}
CUDA_CALLABLE constexpr T* data() const noexcept {
return data_view_.data_handle();
}
}; // class Array2D

// Type deduction guide
Expand Down
3 changes: 2 additions & 1 deletion genmetaballs/src/genmetaballs/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
TwoParameterConfidence,
ZeroParameterConfidence,
)
from genmetaballs._genmetaballs_bindings.utils import sigmoid
from genmetaballs._genmetaballs_bindings.utils import FloatArray2D, sigmoid

__all__ = [
"FloatArray2D",
"ZeroParameterConfidence",
"TwoParameterConfidence",
"geometry",
Expand Down
29 changes: 28 additions & 1 deletion tests/python_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from scipy.special import expit

from genmetaballs.core import sigmoid
from genmetaballs.core import FloatArray2D, sigmoid

NUM_RNG_SEEDS_PER_TEST = 5
NUM_N_VALUES_PER_TEST = 5
Expand Down Expand Up @@ -58,3 +58,30 @@ def test_sigmoid_edge_cases(x: float) -> None:
assert np.isclose(actual, expected, rtol=1e-5, atol=1e-6)
assert actual >= 0.0
assert actual <= 1.0


def test_float_array2d_creation_and_view():
"""Test creation of Array2D from a numpy array."""
rows, cols = 4, 5
data = np.arange(rows * cols, dtype=np.float32).reshape((rows, cols))
array_2d = FloatArray2D.from_array(data)

assert array_2d.num_rows == rows
assert array_2d.num_cols == cols
assert array_2d.ndim == 2

# then try converting back to numpy array via view
data_view = array_2d.numpy()
assert np.allclose(data, data_view)

# check that the view is writable and changes reflect back to original data
data_view[0, 0] = 999.0
assert np.isclose(data[0, 0], 999.0)


def test_create_invalid_array2d():
"""Test that creating Array2D with invalid dimensions raises errors."""
data = np.arange(12, dtype=np.float32).reshape((3, 4))

with pytest.raises(TypeError):
FloatArray2D.from_array(data.reshape((3, 4, 1))) # not 2D