Skip to content

Commit c830d82

Browse files
committed
Python binding for FMBScene
1 parent 5873487 commit c830d82

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

genmetaballs/src/cuda/bindings.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ template <MemoryLocation location>
2323
void bind_image(nb::module_& m, const char* name);
2424
template <MemoryLocation location>
2525
void bind_image_view(nb::module_& m, const char* name);
26+
template <MemoryLocation location>
27+
void bind_fmb_scene(nb::module_& m, const char* name);
2628

2729
NB_MODULE(_genmetaballs_bindings, m) {
2830

@@ -66,6 +68,8 @@ NB_MODULE(_genmetaballs_bindings, m) {
6668
"apply the inverse covariance matrix to the given vector", nb::arg("vec"))
6769
.def("quadratic_form", &FMB::quadratic_form,
6870
"Evaluate the associated quadratic form at the given vector", nb::arg("vec"));
71+
bind_fmb_scene<MemoryLocation::HOST>(fmb, "CPUFMBScene");
72+
bind_fmb_scene<MemoryLocation::DEVICE>(fmb, "GPUFMBScene");
6973

7074
/*
7175
* Geometry module bindings
@@ -244,3 +248,15 @@ void bind_image(nb::module_& m, const char* name) {
244248
return nb::str("{}(height={}, width={})").format(name, img.num_rows(), img.num_cols());
245249
});
246250
}
251+
252+
template <MemoryLocation location>
253+
void bind_fmb_scene(nb::module_& m, const char* name) {
254+
nb::class_<FMBScene<location>>(m, name)
255+
.def(nb::init<size_t>(), nb::arg("size"))
256+
.def_prop_ro("size", &FMBScene<location>::size)
257+
.def("__getitem__", &FMBScene<location>::get_fmb, nb::arg("idx"),
258+
"Get the (FMB, log_weight) tuple at index i")
259+
.def("__repr__", [=](const FMBScene<location>& scene) {
260+
return nb::str("{}(size={})").format(name, scene.size());
261+
});
262+
}

genmetaballs/src/genmetaballs/core/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
TwoParameterConfidence,
1111
ZeroParameterConfidence,
1212
)
13+
from genmetaballs._genmetaballs_bindings.fmb import CPUFMBScene, GPUFMBScene
1314
from genmetaballs._genmetaballs_bindings.image import CPUImage, GPUImage
1415
from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid
1516

@@ -47,6 +48,21 @@ def make_image(height: int, width: int, device: DeviceType) -> CPUImage | GPUIma
4748
raise ValueError(f"Unsupported device type: {device}")
4849

4950

51+
def make_fmb_scene(size: int, device: DeviceType) -> CPUFMBScene | GPUFMBScene:
52+
"""Create an FMBScene on the specified device.
53+
54+
Args:
55+
size: The number of FMBs in the scene.
56+
device: 'cpu' or 'gpu' to specify the target device.
57+
"""
58+
if device == "cpu":
59+
return CPUFMBScene(size)
60+
elif device == "gpu":
61+
return GPUFMBScene(size)
62+
else:
63+
raise ValueError(f"Unsupported device type: {device}")
64+
65+
5066
__all__ = [
5167
"array2d_float",
5268
"ZeroParameterConfidence",
@@ -60,4 +76,5 @@ def make_image(height: int, width: int, device: DeviceType) -> CPUImage | GPUIma
6076
"FourParameterBlender",
6177
"ThreeParameterBlender",
6278
"make_image",
79+
"make_fmb_scene",
6380
]

tests/python_tests/test_fmb.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from scipy.spatial.distance import mahalanobis
44
from scipy.spatial.transform import Rotation as Rot
55

6-
from genmetaballs.core import fmb, geometry
6+
from genmetaballs.core import fmb, geometry, make_fmb_scene
77

88
FMB = fmb.FMB
99
Pose, Vec3D, Rotation = geometry.Pose, geometry.Vec3D, geometry.Rotation
@@ -38,3 +38,13 @@ def test_fmb_quadratic_form(rng):
3838
FMB(pose, *extent).quadratic_form(Vec3D(*vec)),
3939
mahalanobis(vec, tran, np.linalg.inv(cov)) ** 2,
4040
)
41+
42+
43+
def test_fmb_scene_creation():
44+
cpu_scene = make_fmb_scene(10, device="cpu")
45+
assert isinstance(cpu_scene, fmb.CPUFMBScene)
46+
assert cpu_scene.size == 10
47+
48+
gpu_scene = make_fmb_scene(20, device="gpu")
49+
assert isinstance(gpu_scene, fmb.GPUFMBScene)
50+
assert gpu_scene.size == 20

0 commit comments

Comments
 (0)