Skip to content

Commit 039ba52

Browse files
authored
MET-30: Implement Geometry Primitives (#12)
This PR adds the geometry we scoped during our 11/16/2025 and 11/18/2025 meetings. Specifically, the PR provides: 1. a `Vec3D` type for 3D vector geometry; 2. a `Rotation` type representing 3D rotations; 3. and a `Pose` type representing 3D rigid transformations. Bindings and correctness fuzz tests against `scipy` are available in `bindings.cu` and `test_geometry.py`, respectively. To run the tests simply run `pixi run test`.
1 parent cda95af commit 039ba52

File tree

7 files changed

+340
-41
lines changed

7 files changed

+340
-41
lines changed

.clang-tidy

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Checks: >
1919
-readability-avoid-const-params-in-decls,
2020
-readability-braces-around-statements,
2121
-readability-isolate-declaration,
22+
-readability-math-missing-parentheses,
2223
-cppcoreguidelines-avoid-magic-numbers,
2324
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
2425
-cppcoreguidelines-pro-bounds-pointer-arithmetic,

genmetaballs/src/cuda/bindings.cu

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,58 @@ namespace nb = nanobind;
1111

1212
NB_MODULE(_genmetaballs_bindings, m) {
1313

14-
// exposing Vec3D
15-
nb::class_<Vec3D>(m, "Vec3D")
14+
/*
15+
* Geometry module bindings
16+
*/
17+
18+
nb::module_ geometry = m.def_submodule("geometry", "Geometry helpers for GenMetaballs");
19+
20+
nb::class_<Vec3D>(geometry, "Vec3D")
1621
.def(nb::init<>())
1722
.def(nb::init<float, float, float>())
18-
.def_rw("x", &Vec3D::x)
19-
.def_rw("y", &Vec3D::y)
20-
.def_rw("z", &Vec3D::z)
23+
.def_ro("x", &Vec3D::x)
24+
.def_ro("y", &Vec3D::y)
25+
.def_ro("z", &Vec3D::z)
2126
.def(nb::self + nb::self)
2227
.def(nb::self - nb::self)
28+
.def(-nb::self)
29+
.def(nb::self * float())
30+
.def(float() * nb::self)
31+
.def(nb::self / float())
2332
.def("__repr__",
2433
[](const Vec3D& v) { return nb::str("Vec3D({}, {}, {})").format(v.x, v.y, v.z); });
2534

26-
// confidence submodule
35+
geometry.def("dot", &dot, "Dot product of two `Vec3D`s", nb::arg("a"), nb::arg("b"));
36+
geometry.def("cross", &cross, "Cross product of two `Vec3D`s", nb::arg("a"), nb::arg("b"));
37+
38+
nb::class_<Rotation>(geometry, "Rotation")
39+
.def(nb::init<>())
40+
.def_static("from_quat", &Rotation::from_quat, "Create rotation from quaternion",
41+
nb::arg("x"), nb::arg("y"), nb::arg("z"), nb::arg("w"))
42+
.def("apply", &Rotation::apply, "Apply rotation to vector", nb::arg("vec"))
43+
.def("compose", &Rotation::compose, "Compose with another rotation", nb::arg("rot"))
44+
.def("inv", &Rotation::inv, "Inverse rotation");
45+
46+
nb::class_<Pose>(geometry, "Pose")
47+
.def(nb::init<>())
48+
.def_static("from_components", &Pose::from_components,
49+
"Create rotation from a rotation and a translation", nb::arg("rot"),
50+
nb::arg("tran"))
51+
.def_prop_ro("rot", &Pose::get_rot, "get the rotation component")
52+
.def_prop_ro("tran", &Pose::get_tran, "get the translation component")
53+
.def("apply", &Pose::apply, "Apply pose to vector", nb::arg("vec"))
54+
.def("compose", &Pose::compose, "Compose with another pose", nb::arg("pose"))
55+
.def("inv", &Pose::inv, "Inverse pose");
56+
57+
nb::class_<Ray>(geometry, "Ray")
58+
.def(nb::init<>())
59+
.def_rw("start", &Ray::start)
60+
.def_rw("direction", &Ray::direction);
61+
62+
/*
63+
* Confidence module bindings
64+
*/
65+
2766
nb::module_ confidence = m.def_submodule("confidence");
2867
nb::class_<ZeroParameterConfidence>(confidence, "ZeroParameterConfidence")
2968
.def(nb::init<>())
@@ -33,7 +72,10 @@ NB_MODULE(_genmetaballs_bindings, m) {
3372
.def(nb::init<float, float>())
3473
.def("get_confidence", &TwoParameterConfidence::get_confidence);
3574

36-
// utils submodule
75+
/*
76+
* Utils module bindings
77+
*/
78+
3779
nb::module_ utils = m.def_submodule("utils");
3880
utils.def("sigmoid", sigmoid, nb::arg("x"), "Compute the sigmoid function: 1 / (1 + exp(-x))");
3981

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,39 @@
1+
#include <cmath>
2+
13
#include "geometry.cuh"
24

3-
Vec3D operator+(const Vec3D a, const Vec3D b) {
4-
return {a.x + b.x, a.y + b.y, a.z + b.z};
5+
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
6+
CUDA_CALLABLE Rotation Rotation::from_quat(float x, float y, float z, float w) {
7+
auto modulus = std::sqrt(x * x + y * y + z * z + w * w);
8+
return Rotation{{x / modulus, y / modulus, z / modulus, w / modulus}};
9+
}
10+
11+
CUDA_CALLABLE Vec3D Rotation::apply(const Vec3D vec) const {
12+
// v' = q * v * q^(-1) for unit quaternions
13+
// where q^(-1) = (-x, -y, -z, w)
14+
Vec3D q = {unit_quat_.x, unit_quat_.y, unit_quat_.z};
15+
float w = unit_quat_.w;
16+
17+
// v' = 2*(q·v)*q + (w²-|q|²)*v + 2*w*(q×v)
18+
float d = dot(q, vec);
19+
Vec3D c = cross(q, vec);
20+
21+
return 2.0f * d * q + (w * w - dot(q, q)) * vec + 2.0f * w * c;
22+
}
23+
24+
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
25+
CUDA_CALLABLE Rotation Rotation::compose(const Rotation& rot) const {
26+
// Quaternion multiplication: q1 * q2
27+
float4 q1 = unit_quat_;
28+
float4 q2 = rot.unit_quat_;
29+
30+
return Rotation{{q1.w * q2.x + q1.x * q2.w + q1.y * q2.z - q1.z * q2.y,
31+
q1.w * q2.y - q1.x * q2.z + q1.y * q2.w + q1.z * q2.x,
32+
q1.w * q2.z + q1.x * q2.y - q1.y * q2.x + q1.z * q2.w,
33+
q1.w * q2.w - q1.x * q2.x - q1.y * q2.y - q1.z * q2.z}};
534
}
635

7-
Vec3D operator-(const Vec3D a, const Vec3D b) {
8-
return {a.x - b.x, a.y - b.y, a.z - b.z};
36+
CUDA_CALLABLE Rotation Rotation::inv() const {
37+
// For unit quaternions, inverse = conjugate: (-x, -y, -z, w)
38+
return Rotation{{-unit_quat_.x, -unit_quat_.y, -unit_quat_.z, unit_quat_.w}};
939
}

genmetaballs/src/cuda/core/geometry.cuh

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,101 @@
22

33
#include <cuda_runtime.h>
44

5+
#include "utils.cuh"
6+
57
using Vec3D = float3;
68

7-
Vec3D operator+(const Vec3D a, const Vec3D b);
8-
Vec3D operator-(const Vec3D a, const Vec3D b);
9+
CUDA_CALLABLE inline Vec3D operator-(const Vec3D a) {
10+
return {-a.x, -a.y, -a.z};
11+
}
912

10-
class Rotation {
13+
CUDA_CALLABLE inline Vec3D operator+(const Vec3D a, const Vec3D b) {
14+
return {a.x + b.x, a.y + b.y, a.z + b.z};
15+
}
16+
17+
CUDA_CALLABLE inline Vec3D operator-(const Vec3D a, const Vec3D b) {
18+
return {a.x - b.x, a.y - b.y, a.z - b.z};
19+
}
20+
21+
CUDA_CALLABLE inline Vec3D operator*(const float scalar, const Vec3D a) {
22+
return {a.x * scalar, a.y * scalar, a.z * scalar};
23+
}
24+
25+
CUDA_CALLABLE inline Vec3D operator*(const Vec3D a, const float scalar) {
26+
return {a.x * scalar, a.y * scalar, a.z * scalar};
27+
}
28+
29+
CUDA_CALLABLE inline Vec3D operator/(const Vec3D a, const float scalar) {
30+
return {a.x / scalar, a.y / scalar, a.z / scalar};
31+
}
1132

33+
CUDA_CALLABLE inline float dot(const Vec3D a, const Vec3D b) {
34+
return a.x * b.x + a.y * b.y + a.z * b.z;
35+
}
36+
37+
CUDA_CALLABLE inline Vec3D cross(const Vec3D a, const Vec3D b) {
38+
return {a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x};
39+
}
40+
41+
class Rotation {
1242
private:
13-
// ...
14-
float rotmat_[9];
43+
float4 unit_quat_;
44+
45+
CUDA_CALLABLE Rotation(float4 unit_quat) : unit_quat_{unit_quat} {};
1546

1647
public:
17-
Vec3D apply(const Vec3D vec) const;
18-
Rotation compose(const Rotation& rot) const;
19-
Rotation inv() const;
48+
CUDA_CALLABLE Rotation() : unit_quat_{0.0f, 0.0f, 0.0f, 1.0f} {};
49+
50+
static CUDA_CALLABLE Rotation from_quat(float x, float y, float z, float w);
51+
52+
CUDA_CALLABLE Vec3D apply(const Vec3D vec) const;
53+
54+
CUDA_CALLABLE Rotation compose(const Rotation& rot) const;
55+
56+
CUDA_CALLABLE Rotation inv() const;
2057
};
2158

22-
struct Pose {
23-
Rotation rot;
24-
Vec3D tran;
59+
class Pose {
60+
private:
61+
Rotation rot_;
62+
Vec3D tran_;
63+
64+
CUDA_CALLABLE Pose(const Rotation rot, const Vec3D tran) : rot_{rot}, tran_{tran} {}
65+
66+
public:
67+
// these member functions are defined in class body to allow for possible inlining
68+
69+
CUDA_CALLABLE Pose() : rot_{Rotation()}, tran_{0.0f, 0.0f, 0.0f} {}
70+
71+
static CUDA_CALLABLE Pose from_components(const Rotation rot, const Vec3D tran) {
72+
return {rot, tran};
73+
}
74+
75+
CUDA_CALLABLE Rotation get_rot() const {
76+
return rot_;
77+
}
78+
79+
CUDA_CALLABLE Vec3D get_tran() const {
80+
return tran_;
81+
}
82+
83+
CUDA_CALLABLE Vec3D apply(const Vec3D vec) const {
84+
return tran_ + rot_.apply(vec);
85+
}
86+
87+
CUDA_CALLABLE Pose compose(const Pose& pose) const {
88+
/*
89+
* If $A_i$ is the matrix corresponding to pose object `p_i`, then
90+
* $A_1A_2$ is the matrix corresponding to the pose object
91+
* `p_1.compose(p2)`.
92+
*/
93+
return {rot_.compose(pose.rot_), rot_.apply(pose.tran_) + tran_};
94+
}
2595

26-
Vec3D apply(const Vec3D vec) const;
27-
Pose compose(const Pose& pose) const;
28-
Pose inv() const;
96+
CUDA_CALLABLE Pose inv() const {
97+
auto rotinv = rot_.inv();
98+
return {rotinv, -rotinv.apply(tran_)};
99+
}
29100
};
30101

31102
struct Ray {

genmetaballs/src/cuda/core/utils.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ public:
5252
__host__ __device__ constexpr auto size() const noexcept {
5353
return data_view_.size();
5454
}
55-
};
55+
}; // class Array2D

genmetaballs/src/genmetaballs/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from genmetaballs._genmetaballs_bindings import geometry
12
from genmetaballs._genmetaballs_bindings.confidence import (
23
TwoParameterConfidence,
34
ZeroParameterConfidence,
@@ -7,5 +8,6 @@
78
__all__ = [
89
"ZeroParameterConfidence",
910
"TwoParameterConfidence",
11+
"geometry",
1012
"sigmoid",
1113
]

0 commit comments

Comments
 (0)