Skip to content

Commit bccbb2e

Browse files
author
Bernhard Kerbl
committed
InternalState explicit now
1 parent ffee75d commit bccbb2e

File tree

6 files changed

+170
-110
lines changed

6 files changed

+170
-110
lines changed

cuda_rasterizer/rasterizer.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55

66
namespace CudaRasterizer
77
{
8+
struct InternalState;
9+
810
class Rasterizer
911
{
1012
public:
1113

14+
virtual InternalState* createInternalState() = 0;
15+
16+
virtual void killInternalState(InternalState*) = 0;
17+
1218
virtual void markVisible(
1319
int P,
1420
float* means3D,
@@ -33,10 +39,13 @@ namespace CudaRasterizer
3339
const float* cam_pos,
3440
const float tan_fovx, float tan_fovy,
3541
const bool prefiltered,
36-
float* out_color,
37-
int* radii = nullptr) = 0;
42+
int* radii,
43+
InternalState* state,
44+
float* out_color) = 0;
3845

3946
virtual void backward(
47+
const int* radii,
48+
const InternalState* state,
4049
const int P, int D, int M,
4150
const float* background,
4251
const int width, int height,
@@ -47,11 +56,10 @@ namespace CudaRasterizer
4756
const float scale_modifier,
4857
const float* rotations,
4958
const float* cov3D_precomp,
50-
const float* viewmatrix,
59+
const float* viewmatrix,
5160
const float* projmatrix,
5261
const float* campos,
5362
const float tan_fovx, float tan_fovy,
54-
const int* radii,
5563
const float* dL_dpix,
5664
float* dL_dmean2D,
5765
float* dL_dconic,

cuda_rasterizer/rasterizer_impl.cu

Lines changed: 81 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -171,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward(
171171
const float* cam_pos,
172172
const float tan_fovx, float tan_fovy,
173173
const bool prefiltered,
174-
float* out_color,
175-
int* radii)
174+
int* radii,
175+
InternalState* state,
176+
float* out_color)
176177
{
177178
const float focal_y = height / (2.0f * tan_fovy);
178179
const float focal_x = width / (2.0f * tan_fovx);
179180

180181
// Dynamically resize auxiliary buffers during training
181-
if (P > maxP)
182-
{
183-
maxP = resizeMultiplier * P;
184-
cov3D.resize(maxP * 6);
185-
rgb.resize(maxP * 3);
186-
tiles_touched.resize(maxP);
187-
point_offsets.resize(maxP);
188-
clamped.resize(3 * maxP);
189-
190-
depths.resize(maxP);
191-
means2D.resize(maxP);
192-
conic_opacity.resize(maxP);
193-
194-
cub::DeviceScan::InclusiveSum(nullptr, scan_size, tiles_touched.data().get(), tiles_touched.data().get(), maxP);
195-
scanning_space.resize(scan_size);
196-
}
197-
198-
if (radii == nullptr)
182+
if (P > state->maxP)
199183
{
200-
internal_radii.resize(maxP);
201-
radii = internal_radii.data().get();
184+
state->maxP = resizeMultiplier * P;
185+
state->cov3D.resize(state->maxP * 6);
186+
state->rgb.resize(state->maxP * 3);
187+
state->tiles_touched.resize(state->maxP);
188+
state->point_offsets.resize(state->maxP);
189+
state->clamped.resize(3 * state->maxP);
190+
191+
state->depths.resize(state->maxP);
192+
state->means2D.resize(state->maxP);
193+
state->conic_opacity.resize(state->maxP);
194+
195+
size_t scan_size;
196+
cub::DeviceScan::InclusiveSum(nullptr,
197+
scan_size,
198+
state->tiles_touched.data().get(),
199+
state->tiles_touched.data().get(),
200+
state->maxP);
201+
state->scanning_space.resize(scan_size);
202202
}
203203

204204
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
205205
dim3 block(BLOCK_X, BLOCK_Y, 1);
206206

207207
// Dynamically resize image-based auxiliary buffers during training
208-
if (width * height > maxPixels)
208+
if (width * height > state->maxPixels)
209209
{
210-
maxPixels = width * height;
211-
accum_alpha.resize(maxPixels);
212-
n_contrib.resize(maxPixels);
213-
ranges.resize(tile_grid.x * tile_grid.y);
210+
state->maxPixels = width * height;
211+
state->accum_alpha.resize(state->maxPixels);
212+
state->n_contrib.resize(state->maxPixels);
213+
state->ranges.resize(tile_grid.x * tile_grid.y);
214214
}
215215

216216
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
@@ -227,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward(
227227
(glm::vec4*)rotations,
228228
opacities,
229229
shs,
230-
clamped.data().get(),
230+
state->clamped.data().get(),
231231
cov3D_precomp,
232232
colors_precomp,
233233
viewmatrix, projmatrix,
@@ -236,45 +236,56 @@ void CudaRasterizer::RasterizerImpl::forward(
236236
tan_fovx, tan_fovy,
237237
focal_x, focal_y,
238238
radii,
239-
means2D.data().get(),
240-
depths.data().get(),
241-
cov3D.data().get(),
242-
rgb.data().get(),
243-
conic_opacity.data().get(),
239+
state->means2D.data().get(),
240+
state->depths.data().get(),
241+
state->cov3D.data().get(),
242+
state->rgb.data().get(),
243+
state->conic_opacity.data().get(),
244244
tile_grid,
245-
tiles_touched.data().get(),
245+
state->tiles_touched.data().get(),
246246
prefiltered
247247
);
248248

249249
// Compute prefix sum over full list of touched tile counts by Gaussians
250250
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
251-
cub::DeviceScan::InclusiveSum(scanning_space.data().get(), scan_size,
252-
tiles_touched.data().get(), point_offsets.data().get(), P);
251+
size_t scanning_space_size = state->scanning_space.size();
252+
cub::DeviceScan::InclusiveSum(
253+
state->scanning_space.data().get(),
254+
scanning_space_size,
255+
state->tiles_touched.data().get(),
256+
state->point_offsets.data().get(),
257+
P);
253258

254259
// Retrieve total number of Gaussian instances to launch and resize aux buffers
255260
int num_needed;
256-
cudaMemcpy(&num_needed, point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
261+
cudaMemcpy(&num_needed, state->point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
257262
if (num_needed > point_list_keys_unsorted.size())
258263
{
259-
point_list_keys_unsorted.resize(2 * num_needed);
260-
point_list_keys.resize(2 * num_needed);
261-
point_list_unsorted.resize(2 * num_needed);
262-
point_list.resize(2 * num_needed);
264+
int resizeNum = resizeMultiplier * num_needed;
265+
point_list_keys_unsorted.resize(resizeNum);
266+
point_list_keys.resize(resizeNum);
267+
point_list_unsorted.resize(resizeNum);
268+
size_t sorting_size;
263269
cub::DeviceRadixSort::SortPairs(
264270
nullptr, sorting_size,
265271
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
266-
point_list_unsorted.data().get(), point_list.data().get(),
267-
2 * num_needed);
272+
point_list_unsorted.data().get(), state->point_list.data().get(),
273+
resizeNum);
268274
list_sorting_space.resize(sorting_size);
269275
}
270276

277+
if (num_needed > state->point_list.size())
278+
{
279+
state->point_list.resize(resizeMultiplier * num_needed);
280+
}
281+
271282
// For each instance to be rendered, produce adequate [ tile | depth ] key
272283
// and corresponding dublicated Gaussian indices to be sorted
273284
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
274285
P,
275-
means2D.data().get(),
276-
depths.data().get(),
277-
point_offsets.data().get(),
286+
state->means2D.data().get(),
287+
state->depths.data().get(),
288+
state->point_offsets.data().get(),
278289
point_list_keys_unsorted.data().get(),
279290
point_list_unsorted.data().get(),
280291
radii,
@@ -284,41 +295,45 @@ void CudaRasterizer::RasterizerImpl::forward(
284295
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
285296

286297
// Sort complete list of (duplicated) Gaussian indices by keys
298+
size_t list_sorting_space_size = list_sorting_space.size();
287299
cub::DeviceRadixSort::SortPairs(
288300
list_sorting_space.data().get(),
289-
sorting_size,
301+
list_sorting_space_size,
290302
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
291-
point_list_unsorted.data().get(), point_list.data().get(),
303+
point_list_unsorted.data().get(),
304+
state->point_list.data().get(),
292305
num_needed, 0, 32 + bit);
293306

294-
cudaMemset(ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
307+
cudaMemset(state->ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
295308

296309
// Identify start and end of per-tile workloads in sorted list
297310
identifyTileRanges << <(num_needed + 255) / 256, 256 >> > (
298311
num_needed,
299312
point_list_keys.data().get(),
300-
ranges.data().get()
313+
state->ranges.data().get()
301314
);
302315

303316
// Let each tile blend its range of Gaussians independently in parallel
304-
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : rgb.data().get();
317+
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : state->rgb.data().get();
305318
FORWARD::render(
306319
tile_grid, block,
307-
ranges.data().get(),
308-
point_list.data().get(),
320+
state->ranges.data().get(),
321+
state->point_list.data().get(),
309322
width, height,
310-
means2D.data().get(),
323+
state->means2D.data().get(),
311324
feature_ptr,
312-
conic_opacity.data().get(),
313-
accum_alpha.data().get(),
314-
n_contrib.data().get(),
325+
state->conic_opacity.data().get(),
326+
state->accum_alpha.data().get(),
327+
state->n_contrib.data().get(),
315328
background,
316329
out_color);
317330
}
318331

319332
// Produce necessary gradients for optimization, corresponding
320333
// to forward render pass
321334
void CudaRasterizer::RasterizerImpl::backward(
335+
const int* radii,
336+
const InternalState* state,
322337
const int P, int D, int M,
323338
const float* background,
324339
const int width, int height,
@@ -333,7 +348,6 @@ void CudaRasterizer::RasterizerImpl::backward(
333348
const float* projmatrix,
334349
const float* campos,
335350
const float tan_fovx, float tan_fovy,
336-
const int* radii,
337351
const float* dL_dpix,
338352
float* dL_dmean2D,
339353
float* dL_dconic,
@@ -345,11 +359,6 @@ void CudaRasterizer::RasterizerImpl::backward(
345359
float* dL_dscale,
346360
float* dL_drot)
347361
{
348-
if (radii == nullptr)
349-
{
350-
radii = internal_radii.data().get();
351-
}
352-
353362
const float focal_y = height / (2.0f * tan_fovy);
354363
const float focal_x = width / (2.0f * tan_fovx);
355364

@@ -359,19 +368,19 @@ void CudaRasterizer::RasterizerImpl::backward(
359368
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
360369
// opacity and RGB of Gaussians from per-pixel loss gradients.
361370
// If we were given precomputed colors and not SHs, use them.
362-
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : rgb.data().get();
371+
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : state->rgb.data().get();
363372
BACKWARD::render(
364373
tile_grid,
365374
block,
366-
ranges.data().get(),
367-
point_list.data().get(),
375+
state->ranges.data().get(),
376+
state->point_list.data().get(),
368377
width, height,
369378
background,
370-
means2D.data().get(),
371-
conic_opacity.data().get(),
379+
state->means2D.data().get(),
380+
state->conic_opacity.data().get(),
372381
color_ptr,
373-
accum_alpha.data().get(),
374-
n_contrib.data().get(),
382+
state->accum_alpha.data().get(),
383+
state->n_contrib.data().get(),
375384
dL_dpix,
376385
(float3*)dL_dmean2D,
377386
(float4*)dL_dconic,
@@ -381,12 +390,12 @@ void CudaRasterizer::RasterizerImpl::backward(
381390
// Take care of the rest of preprocessing. Was the precomputed covariance
382391
// given to us or a scales/rot pair? If precomputed, pass that. If not,
383392
// use the one we computed ourselves.
384-
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : cov3D.data().get();
393+
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : state->cov3D.data().get();
385394
BACKWARD::preprocess(P, D, M,
386395
(float3*)means3D,
387396
radii,
388397
shs,
389-
clamped.data().get(),
398+
state->clamped.data().get(),
390399
(glm::vec3*)scales,
391400
(glm::vec4*)rotations,
392401
scale_modifier,

0 commit comments

Comments
 (0)