Skip to content

Commit 3ff5ed5

Browse files
committed
Adding launch params as argumetns to forward
1 parent bdc718d commit 3ff5ed5

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

genmetaballs/src/cuda/bindings.cu

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,17 @@ NB_MODULE(_genmetaballs_bindings, m) {
228228
bind_array2d<float, MemoryLocation::HOST>(utils, "CPUFloatArray2D");
229229
bind_array2d<float, MemoryLocation::DEVICE>(utils, "GPUFloatArray2D");
230230

231+
// bind dim3, which is used to specify the launch configuration for the kernel
232+
nb::class_<dim3>(utils, "dim3")
233+
.def(nb::init<uint32_t, uint32_t, uint32_t>(), nb::arg("x") = 1, nb::arg("y") = 1,
234+
nb::arg("z") = 1)
235+
.def_prop_ro("x", [](const dim3& self) { return self.x; })
236+
.def_prop_ro("y", [](const dim3& self) { return self.y; })
237+
.def_prop_ro("z", [](const dim3& self) { return self.z; })
238+
.def("__repr__", [](const dim3& self) {
239+
return nb::str("dim3(x={}, y={}, z={})").format(self.x, self.y, self.z);
240+
});
241+
231242
} // NB_MODULE(_genmetaballs_bindings)
232243

233244
template <typename T, MemoryLocation location>
@@ -319,5 +330,5 @@ void bind_render_fmbs(nb::module_& m, const char* name) {
319330
&render_fmbs<AllGetter<MemoryLocation::DEVICE>, LinearIntersector, Blender, Confidence>,
320331
"Render the given FMB scene into the provided image view", nb::arg("fmbs"),
321332
nb::arg("blender"), nb::arg("confidence"), nb::arg("intr"), nb::arg("extr"),
322-
nb::arg("img"));
333+
nb::arg("img"), nb::arg("grid_size"), nb::arg("block_size"));
323334
}

genmetaballs/src/cuda/core/forward.cuh

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
#include "image.cuh"
1010
#include "utils.cuh"
1111

12-
// TODO: tune this number
13-
constexpr auto NUM_BLOCKS = dim3(4, 4);
14-
constexpr auto THREADS_PER_BLOCK = dim3(16, 16);
1512

1613
CUDA_CALLABLE PixelCoordRange get_pixel_coords(const dim3 thread_idx, const dim3 block_idx,
1714
const dim3 block_dim, const dim3 grid_dim,
@@ -49,7 +46,8 @@ __global__ void render_kernel(const FMBScene<MemoryLocation::DEVICE>& fmbs, cons
4946
template <typename Getter, typename Intersector, typename Blender, typename Confidence>
5047
void render_fmbs(const FMBScene<MemoryLocation::DEVICE>& fmbs, const Blender& blender,
5148
const Confidence& confidence, const Intrinsics& intr, const Pose& extr,
52-
ImageView<MemoryLocation::DEVICE> img) {
49+
ImageView<MemoryLocation::DEVICE> img, const dim3 grid_size,
50+
const dim3 block_size) {
5351
render_kernel<Getter, Intersector, Blender, Confidence>
54-
<<<NUM_BLOCKS, THREADS_PER_BLOCK>>>(fmbs, blender, confidence, intr, extr, img);
52+
<<<grid_size, block_size>>>(fmbs, blender, confidence, intr, extr, img);
5553
}

genmetaballs/src/genmetaballs/core/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
)
1313
from genmetaballs._genmetaballs_bindings.fmb import FMB, CPUFMBScene, GPUFMBScene
1414
from genmetaballs._genmetaballs_bindings.image import CPUImage, GPUImage
15-
from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid
15+
from genmetaballs._genmetaballs_bindings.utils import (
16+
CPUFloatArray2D,
17+
GPUFloatArray2D,
18+
dim3,
19+
sigmoid,
20+
)
1621

1722
type DeviceType = Literal["cpu", "gpu"]
1823

@@ -83,6 +88,8 @@ def render_fmbs(
8388
intr: Intrinsics,
8489
extr: geometry.Pose,
8590
img: GPUImage | None = None,
91+
grid_size: dim3 = dim3(4, 4),
92+
block_size: dim3 = dim3(16, 16),
8693
) -> GPUImage:
8794
"""Render the given FMB scene into the provided image view.
8895
@@ -105,7 +112,7 @@ def render_fmbs(
105112
else:
106113
raise TypeError("Unsupported blender and confidence combination.")
107114

108-
render_func(fmbs, blender, confidence, intr, extr, img.as_view())
115+
render_func(fmbs, blender, confidence, intr, extr, img.as_view(), grid_size, block_size)
109116
return img
110117

111118

@@ -122,6 +129,7 @@ def render_fmbs(
122129
"Camera",
123130
"FourParameterBlender",
124131
"FMB",
132+
"dim3",
125133
"Intrinsics",
126134
"ThreeParameterBlender",
127135
"TwoParameterConfidence",

0 commit comments

Comments
 (0)