Skip to content

Commit c6abbab

Browse files
committed
Unit test for constructing FMBScene from value
1 parent 3e63cb8 commit c6abbab

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

genmetaballs/src/cuda/bindings.cu

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,17 @@ void bind_fmb_scene(nb::module_& m, const char* name) {
258258
"Construct FMBScene from a list of FMBs and corresponding log weights")
259259
.def_prop_ro("size", &FMBScene<location>::size)
260260
.def("__len__", &FMBScene<location>::size)
261-
.def("__getitem__", &FMBScene<location>::get_fmb, nb::arg("idx"),
262-
"Get the (FMB, log_weight) tuple at index i")
261+
.def(
262+
"__getitem__",
263+
// Convert cuda::std::tuple to std::tuple for nanobind
264+
[](const FMBScene<location>& scene, size_t idx) {
265+
const auto& [fmb, log_weight] = scene[idx];
266+
// for device data, the types would be thrust::device_reference, which cannot be
267+
// returned directly to Python. The static cast forces a copy (to host) to be made.
268+
return std::make_tuple(static_cast<const FMB&>(fmb),
269+
static_cast<const float&>(log_weight));
270+
},
271+
"Get the (FMB, log_weight) tuple at index i")
263272
.def("__repr__", [=](const FMBScene<location>& scene) {
264273
return nb::str("{}(size={})").format(name, scene.size());
265274
});

genmetaballs/src/genmetaballs/core/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def make_fmb_scene(size: int, device: DeviceType) -> CPUFMBScene | GPUFMBScene:
6565

6666
# TODO: create a wrapper class for FMBScene and turn the factory functions into
6767
# class methods
68-
def fmb_scene_from_values(
68+
def make_fmb_scene_from_values(
6969
fmbs: list[fmb.FMB], log_weights: list[float], device: DeviceType
7070
) -> CPUFMBScene | GPUFMBScene:
7171
if device == "cpu":
@@ -87,7 +87,10 @@ def fmb_scene_from_values(
8787
"intersector",
8888
"sigmoid",
8989
"FourParameterBlender",
90+
"FMB",
91+
"Intrinsics",
9092
"ThreeParameterBlender",
9193
"make_image",
9294
"make_fmb_scene",
95+
"make_fmb_scene_from_values",
9396
]

tests/python_tests/test_fmb.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from scipy.spatial.distance import mahalanobis
44
from scipy.spatial.transform import Rotation as Rot
55

6-
from genmetaballs.core import fmb, geometry, make_fmb_scene
6+
from genmetaballs.core import fmb, geometry, make_fmb_scene, make_fmb_scene_from_values
77

88
FMB = fmb.FMB
99
Pose, Vec3D, Rotation = geometry.Pose, geometry.Vec3D, geometry.Rotation
@@ -48,3 +48,33 @@ def test_fmb_scene_creation():
4848
gpu_scene = make_fmb_scene(20, device="gpu")
4949
assert isinstance(gpu_scene, fmb.GPUFMBScene)
5050
assert len(gpu_scene) == 20
51+
52+
53+
@pytest.mark.parametrize("device", ["cpu", "gpu"])
54+
def test_fmb_scene_creation_from_lists(rng, device):
55+
fmbs = []
56+
log_weights = []
57+
gt_translations = []
58+
gt_extents = []
59+
num_balls = 15
60+
for _ in range(num_balls):
61+
quat = rng.uniform(size=4).astype(np.float32)
62+
tran, extent = rng.uniform(size=(2, 3)).astype(np.float32)
63+
pose = Pose.from_components(Rotation.from_quat(*quat), Vec3D(*tran))
64+
fmbs.append(FMB(pose, *extent))
65+
log_weights.append(rng.uniform())
66+
gt_translations.append(tran)
67+
gt_extents.append(extent)
68+
69+
scene = make_fmb_scene_from_values(fmbs, log_weights, device=device)
70+
71+
assert len(scene) == num_balls
72+
# Verify that we can retrieve each FMB and log weight correctly
73+
for i in range(num_balls):
74+
fmb_i, log_weight = scene[i]
75+
translation = fmb_i.pose.tran
76+
assert np.allclose([translation.x, translation.y, translation.z], gt_translations[i])
77+
78+
fmb_extent = fmb_i.extent
79+
assert np.allclose(fmb_extent, gt_extents[i])
80+
assert np.isclose(log_weight, log_weights[i])

0 commit comments

Comments
 (0)