diff --git a/genmetaballs/src/cuda/bindings.cu b/genmetaballs/src/cuda/bindings.cu index 2a94a2a..7494660 100644 --- a/genmetaballs/src/cuda/bindings.cu +++ b/genmetaballs/src/cuda/bindings.cu @@ -10,6 +10,9 @@ namespace nb = nanobind; +template +void bind_array2d(nb::module_& m, const char* name); + NB_MODULE(_genmetaballs_bindings, m) { /* @@ -80,23 +83,38 @@ NB_MODULE(_genmetaballs_bindings, m) { nb::module_ utils = m.def_submodule("utils"); utils.def("sigmoid", sigmoid, nb::arg("x"), "Compute the sigmoid function: 1 / (1 + exp(-x))"); - nb::class_>(utils, "FloatArray2D") + bind_array2d(utils, "CPUFloatArray2D"); + bind_array2d(utils, "GPUFloatArray2D"); + +} // NB_MODULE(_genmetaballs_bindings) + +template +void bind_array2d(nb::module_& m, const char* name) { + using nb_device_type = + std::conditional_t; + nb::class_>(m, name) .def_static("from_array", - [](const nb::ndarray, nb::c_contig>& array) { - return Array2D(array.data(), array.shape(0), array.shape(1)); + [](const nb::ndarray, nb::c_contig, nb_device_type>& array) { + return Array2D(array.data(), array.shape(0), array.shape(1)); }) // TODO: switch to the array_api in future nanobind release // https://nanobind.readthedocs.io/en/latest/api_extra.html#_CPPv4N8nanobind9array_apiE .def( - "numpy", - [](const Array2D& self) { - return nb::ndarray( + "as_numpy", + [](const Array2D& self) { + return nb::ndarray( self.data(), {self.num_rows(), self.num_cols()}); }, nb::rv_policy::reference_internal) - .def_prop_ro("num_rows", &Array2D::num_rows) - .def_prop_ro("num_cols", &Array2D::num_cols) - .def_prop_ro("ndim", &Array2D::ndim) - .def_prop_ro("size", &Array2D::size); - -} // NB_MODULE(_genmetaballs_bindings) + .def( + "as_jax", + [](const Array2D& self) { + return nb::ndarray( + self.data(), {self.num_rows(), self.num_cols()}); + }, + nb::rv_policy::reference_internal) + .def_prop_ro("num_rows", &Array2D::num_rows) + .def_prop_ro("num_cols", &Array2D::num_cols) + .def_prop_ro("ndim", &Array2D::ndim) + .def_prop_ro("size", &Array2D::size); +} diff --git a/genmetaballs/src/cuda/core/utils.cuh b/genmetaballs/src/cuda/core/utils.cuh index 3f6aeef..1a567df 100644 --- a/genmetaballs/src/cuda/core/utils.cuh +++ b/genmetaballs/src/cuda/core/utils.cuh @@ -21,8 +21,10 @@ CUDA_CALLABLE __forceinline__ float sigmoid(float x) { return 1.0f / (1.0f + expf(-x)); } +enum class MemoryLocation { HOST, DEVICE }; + // Non-owning 2D view into a contiguous array in either host or device memory -template +template class Array2D { private: cuda::std::mdspan< @@ -60,10 +62,3 @@ public: return data_view_.data_handle(); } }; // class Array2D - -// Type deduction guide -// if initialized with (Pointer, int, int), deduce T by looking at what raw_pointer_cast returns -// so we can write Array2D(array_ptr, rows, cols) instead of Array2D(array_ptr, rows, cols) -template -Array2D(Pointer, uint32_t, uint32_t) - -> Array2D::element_type>; diff --git a/genmetaballs/src/genmetaballs/core/__init__.py b/genmetaballs/src/genmetaballs/core/__init__.py index a09878d..e2c7b76 100644 --- a/genmetaballs/src/genmetaballs/core/__init__.py +++ b/genmetaballs/src/genmetaballs/core/__init__.py @@ -3,10 +3,26 @@ TwoParameterConfidence, ZeroParameterConfidence, ) -from genmetaballs._genmetaballs_bindings.utils import FloatArray2D, sigmoid +from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid + + +def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D: + """Create a FloatArray2D on the specified device from an array. + + Args: + data: A 2D array of type float32. + device: 'cpu' or 'gpu' to specify the target device. + """ + if device == "cpu": + return CPUFloatArray2D.from_array(data) + elif device == "gpu": + return GPUFloatArray2D.from_array(data) + else: + raise ValueError(f"Unsupported device type: {device}") + __all__ = [ - "FloatArray2D", + "array2d_float", "ZeroParameterConfidence", "TwoParameterConfidence", "geometry", diff --git a/tests/cpp_tests/test_utils.cu b/tests/cpp_tests/test_utils.cu index a4ebfd3..32c476f 100644 --- a/tests/cpp_tests/test_utils.cu +++ b/tests/cpp_tests/test_utils.cu @@ -94,7 +94,7 @@ TEST(GpuSigmoidTest, SigmoidGPUWithinBounds) { namespace test_utils_gpu { // CUDA kernel to fill Array2D with sequential values -__global__ void fill_array2d_kernel(Array2D array2d) { +__global__ void fill_array2d_kernel(Array2D array2d) { uint32_t i = threadIdx.x; uint32_t j = threadIdx.y; @@ -120,8 +120,11 @@ TYPED_TEST(Array2DTestFixture, CreateAndAccessArray2D) { uint32_t cols = 6; auto data = TypeParam(rows * cols); + constexpr auto device_type = std::is_same_v> + ? MemoryLocation::DEVICE + : MemoryLocation::HOST; // create 2D view into the underlying data on host or device - auto array2d = Array2D(data.data(), rows, cols); + auto array2d = Array2D(data.data(), rows, cols); if constexpr (std::is_same_v>) { for (auto i = 0; i < rows - 1; i++) { @@ -164,7 +167,10 @@ TYPED_TEST(Array2DTestFixture, ViewModifiesUnderlyingData) { uint32_t rows = 3; uint32_t cols = 4; auto data = TypeParam(rows * cols, 0.0f); - auto array2d = Array2D(data.data(), rows, cols); + constexpr auto device_type = std::is_same_v> + ? MemoryLocation::DEVICE + : MemoryLocation::HOST; + auto array2d = Array2D(data.data(), rows, cols); // Modify through view array2d[1][2] = 42.5f; @@ -184,8 +190,11 @@ TYPED_TEST(Array2DTestFixture, MultipleViewsOfSameData) { uint32_t rows = 2; uint32_t cols = 3; auto data = TypeParam(rows * cols, 0.0f); - auto view1 = Array2D(data.data(), rows, cols); - auto view2 = Array2D(data.data(), rows, cols); + constexpr auto device_type = std::is_same_v> + ? MemoryLocation::DEVICE + : MemoryLocation::HOST; + auto view1 = Array2D(data.data(), rows, cols); + auto view2 = Array2D(data.data(), rows, cols); // Modify through view1 view1[0][0] = 100.0f; diff --git a/tests/python_tests/test_utils.py b/tests/python_tests/test_utils.py index bdfbca4..9836fea 100644 --- a/tests/python_tests/test_utils.py +++ b/tests/python_tests/test_utils.py @@ -1,8 +1,10 @@ +import jax +import jax.numpy as jnp import numpy as np import pytest from scipy.special import expit -from genmetaballs.core import FloatArray2D, sigmoid +from genmetaballs.core import array2d_float, sigmoid NUM_RNG_SEEDS_PER_TEST = 5 NUM_N_VALUES_PER_TEST = 5 @@ -60,18 +62,39 @@ def test_sigmoid_edge_cases(x: float) -> None: assert actual <= 1.0 -def test_float_array2d_creation_and_view(): +@pytest.mark.parametrize("device", ["cpu", "gpu"]) +def test_array2d_float_creation_on_jax_devices(device: str): """Test creation of Array2D from a numpy array.""" rows, cols = 4, 5 + data = jnp.arange(rows * cols, dtype=jnp.float32).reshape((rows, cols)) + jax_device = jax.devices(device)[0] + data = jax.device_put(data, device=jax_device) + array_2d = array2d_float(data, device=device) + + assert array_2d.num_rows == rows + assert array_2d.num_cols == cols + assert array_2d.ndim == 2 + + # then try converting back to numpy array via view + data_view = array_2d.as_jax() + assert data_view.device == data.device + assert jnp.allclose(data, data_view) + + # Note: we can't test writability of shared view here since JAX arrays are immutable + + +def test_float_array2d_view_numpy(): + """Test creation of Array2D from a numpy array.""" + rows, cols = 3, 4 data = np.arange(rows * cols, dtype=np.float32).reshape((rows, cols)) - array_2d = FloatArray2D.from_array(data) + array_2d = array2d_float(data, device="cpu") assert array_2d.num_rows == rows assert array_2d.num_cols == cols assert array_2d.ndim == 2 # then try converting back to numpy array via view - data_view = array_2d.numpy() + data_view = array_2d.as_numpy() assert np.allclose(data, data_view) # check that the view is writable and changes reflect back to original data @@ -84,4 +107,4 @@ def test_create_invalid_array2d(): data = np.arange(12, dtype=np.float32).reshape((3, 4)) with pytest.raises(TypeError): - FloatArray2D.from_array(data.reshape((3, 4, 1))) # not 2D + array2d_float(data.reshape((3, 4, 1)), device="cpu") # not 2D