Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions genmetaballs/src/cuda/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

namespace nb = nanobind;

template <typename T, MemoryLocation location>
void bind_array2d(nb::module_& m, const char* name);

NB_MODULE(_genmetaballs_bindings, m) {

/*
Expand Down Expand Up @@ -80,23 +83,38 @@ NB_MODULE(_genmetaballs_bindings, m) {
nb::module_ utils = m.def_submodule("utils");
utils.def("sigmoid", sigmoid, nb::arg("x"), "Compute the sigmoid function: 1 / (1 + exp(-x))");

nb::class_<Array2D<float>>(utils, "FloatArray2D")
bind_array2d<float, MemoryLocation::HOST>(utils, "CPUFloatArray2D");
bind_array2d<float, MemoryLocation::DEVICE>(utils, "GPUFloatArray2D");

} // NB_MODULE(_genmetaballs_bindings)

template <typename T, MemoryLocation location>
void bind_array2d(nb::module_& m, const char* name) {
using nb_device_type =
std::conditional_t<location == MemoryLocation::HOST, nb::device::cpu, nb::device::cuda>;
nb::class_<Array2D<T, location>>(m, name)
.def_static("from_array",
[](const nb::ndarray<float, nb::ndim<2>, nb::c_contig>& array) {
return Array2D<float>(array.data(), array.shape(0), array.shape(1));
[](const nb::ndarray<T, nb::ndim<2>, nb::c_contig, nb_device_type>& array) {
return Array2D<T, location>(array.data(), array.shape(0), array.shape(1));
})
// TODO: switch to the array_api in future nanobind release
// https://nanobind.readthedocs.io/en/latest/api_extra.html#_CPPv4N8nanobind9array_apiE
.def(
"numpy",
[](const Array2D<float>& self) {
return nb::ndarray<float, nb::numpy, nb::c_contig>(
"as_numpy",
[](const Array2D<T, location>& self) {
return nb::ndarray<T, nb::numpy, nb::c_contig, nb_device_type>(
self.data(), {self.num_rows(), self.num_cols()});
},
nb::rv_policy::reference_internal)
.def_prop_ro("num_rows", &Array2D<float>::num_rows)
.def_prop_ro("num_cols", &Array2D<float>::num_cols)
.def_prop_ro("ndim", &Array2D<float>::ndim)
.def_prop_ro("size", &Array2D<float>::size);

} // NB_MODULE(_genmetaballs_bindings)
.def(
"as_jax",
[](const Array2D<T, location>& self) {
return nb::ndarray<T, nb::jax, nb::c_contig, nb_device_type>(
self.data(), {self.num_rows(), self.num_cols()});
},
nb::rv_policy::reference_internal)
.def_prop_ro("num_rows", &Array2D<T, location>::num_rows)
.def_prop_ro("num_cols", &Array2D<T, location>::num_cols)
.def_prop_ro("ndim", &Array2D<T, location>::ndim)
.def_prop_ro("size", &Array2D<T, location>::size);
}
11 changes: 3 additions & 8 deletions genmetaballs/src/cuda/core/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ CUDA_CALLABLE __forceinline__ float sigmoid(float x) {
return 1.0f / (1.0f + expf(-x));
}

enum class MemoryLocation { HOST, DEVICE };

// Non-owning 2D view into a contiguous array in either host or device memory
template <typename T>
template <typename T, MemoryLocation location>
class Array2D {
private:
cuda::std::mdspan<
Expand Down Expand Up @@ -60,10 +62,3 @@ public:
return data_view_.data_handle();
}
}; // class Array2D

// Type deduction guide
// if initialized with (Pointer, int, int), deduce T by looking at what raw_pointer_cast returns
// so we can write Array2D(array_ptr, rows, cols) instead of Array2D<Type>(array_ptr, rows, cols)
template <typename Pointer>
Array2D(Pointer, uint32_t, uint32_t)
-> Array2D<typename std::pointer_traits<Pointer>::element_type>;
20 changes: 18 additions & 2 deletions genmetaballs/src/genmetaballs/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,26 @@
TwoParameterConfidence,
ZeroParameterConfidence,
)
from genmetaballs._genmetaballs_bindings.utils import FloatArray2D, sigmoid
from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid


def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D:
"""Create a FloatArray2D on the specified device from an array.

Args:
data: A 2D array of type float32.
device: 'cpu' or 'gpu' to specify the target device.
"""
if device == "cpu":
return CPUFloatArray2D.from_array(data)
elif device == "gpu":
return GPUFloatArray2D.from_array(data)
else:
raise ValueError(f"Unsupported device type: {device}")


__all__ = [
"FloatArray2D",
"array2d_float",
"ZeroParameterConfidence",
"TwoParameterConfidence",
"geometry",
Expand Down
19 changes: 14 additions & 5 deletions tests/cpp_tests/test_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ TEST(GpuSigmoidTest, SigmoidGPUWithinBounds) {

namespace test_utils_gpu {
// CUDA kernel to fill Array2D with sequential values
__global__ void fill_array2d_kernel(Array2D<float> array2d) {
__global__ void fill_array2d_kernel(Array2D<float, MemoryLocation::DEVICE> array2d) {
uint32_t i = threadIdx.x;
uint32_t j = threadIdx.y;

Expand All @@ -120,8 +120,11 @@ TYPED_TEST(Array2DTestFixture, CreateAndAccessArray2D) {
uint32_t cols = 6;

auto data = TypeParam(rows * cols);
constexpr auto device_type = std::is_same_v<TypeParam, thrust::device_vector<float>>
? MemoryLocation::DEVICE
: MemoryLocation::HOST;
// create 2D view into the underlying data on host or device
auto array2d = Array2D(data.data(), rows, cols);
auto array2d = Array2D<float, device_type>(data.data(), rows, cols);

if constexpr (std::is_same_v<TypeParam, std::vector<float>>) {
for (auto i = 0; i < rows - 1; i++) {
Expand Down Expand Up @@ -164,7 +167,10 @@ TYPED_TEST(Array2DTestFixture, ViewModifiesUnderlyingData) {
uint32_t rows = 3;
uint32_t cols = 4;
auto data = TypeParam(rows * cols, 0.0f);
auto array2d = Array2D(data.data(), rows, cols);
constexpr auto device_type = std::is_same_v<TypeParam, thrust::device_vector<float>>
? MemoryLocation::DEVICE
: MemoryLocation::HOST;
auto array2d = Array2D<float, device_type>(data.data(), rows, cols);

// Modify through view
array2d[1][2] = 42.5f;
Expand All @@ -184,8 +190,11 @@ TYPED_TEST(Array2DTestFixture, MultipleViewsOfSameData) {
uint32_t rows = 2;
uint32_t cols = 3;
auto data = TypeParam(rows * cols, 0.0f);
auto view1 = Array2D(data.data(), rows, cols);
auto view2 = Array2D(data.data(), rows, cols);
constexpr auto device_type = std::is_same_v<TypeParam, thrust::device_vector<float>>
? MemoryLocation::DEVICE
: MemoryLocation::HOST;
auto view1 = Array2D<float, device_type>(data.data(), rows, cols);
auto view2 = Array2D<float, device_type>(data.data(), rows, cols);

// Modify through view1
view1[0][0] = 100.0f;
Expand Down
33 changes: 28 additions & 5 deletions tests/python_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from scipy.special import expit

from genmetaballs.core import FloatArray2D, sigmoid
from genmetaballs.core import array2d_float, sigmoid

NUM_RNG_SEEDS_PER_TEST = 5
NUM_N_VALUES_PER_TEST = 5
Expand Down Expand Up @@ -60,18 +62,39 @@ def test_sigmoid_edge_cases(x: float) -> None:
assert actual <= 1.0


def test_float_array2d_creation_and_view():
@pytest.mark.parametrize("device", ["cpu", "gpu"])
def test_array2d_float_creation_on_jax_devices(device: str):
"""Test creation of Array2D from a numpy array."""
rows, cols = 4, 5
data = jnp.arange(rows * cols, dtype=jnp.float32).reshape((rows, cols))
jax_device = jax.devices(device)[0]
data = jax.device_put(data, device=jax_device)
array_2d = array2d_float(data, device=device)

assert array_2d.num_rows == rows
assert array_2d.num_cols == cols
assert array_2d.ndim == 2

# then try converting back to numpy array via view
data_view = array_2d.as_jax()
assert data_view.device == data.device
assert jnp.allclose(data, data_view)

# Note: we can't test writability of shared view here since JAX arrays are immutable


def test_float_array2d_view_numpy():
"""Test creation of Array2D from a numpy array."""
rows, cols = 3, 4
data = np.arange(rows * cols, dtype=np.float32).reshape((rows, cols))
array_2d = FloatArray2D.from_array(data)
array_2d = array2d_float(data, device="cpu")

assert array_2d.num_rows == rows
assert array_2d.num_cols == cols
assert array_2d.ndim == 2

# then try converting back to numpy array via view
data_view = array_2d.numpy()
data_view = array_2d.as_numpy()
assert np.allclose(data, data_view)

# check that the view is writable and changes reflect back to original data
Expand All @@ -84,4 +107,4 @@ def test_create_invalid_array2d():
data = np.arange(12, dtype=np.float32).reshape((3, 4))

with pytest.raises(TypeError):
FloatArray2D.from_array(data.reshape((3, 4, 1))) # not 2D
array2d_float(data.reshape((3, 4, 1)), device="cpu") # not 2D