Skip to content

Commit 55b0c1b

Browse files
author
Bernhard Kerbl
committed
This was a bad idea, undoing it
1 parent e41a365 commit 55b0c1b

File tree

7 files changed

+131
-229
lines changed

7 files changed

+131
-229
lines changed

cuda_rasterizer/rasterizer.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,10 @@
55

66
namespace CudaRasterizer
77
{
8-
struct InternalState;
9-
108
class Rasterizer
119
{
1210
public:
1311

14-
virtual InternalState* createInternalState() = 0;
15-
16-
virtual void killInternalState(InternalState*) = 0;
17-
1812
virtual void markVisible(
1913
int P,
2014
float* means3D,
@@ -39,13 +33,10 @@ namespace CudaRasterizer
3933
const float* cam_pos,
4034
const float tan_fovx, float tan_fovy,
4135
const bool prefiltered,
42-
int* radii,
43-
InternalState* state,
44-
float* out_color) = 0;
36+
float* out_color,
37+
int* radii = nullptr) = 0;
4538

4639
virtual void backward(
47-
const int* radii,
48-
const InternalState* state,
4940
const int P, int D, int M,
5041
const float* background,
5142
const int width, int height,
@@ -56,10 +47,11 @@ namespace CudaRasterizer
5647
const float scale_modifier,
5748
const float* rotations,
5849
const float* cov3D_precomp,
59-
const float* viewmatrix,
50+
const float* viewmatrix,
6051
const float* projmatrix,
6152
const float* campos,
6253
const float tan_fovx, float tan_fovy,
54+
const int* radii,
6355
const float* dL_dpix,
6456
float* dL_dmean2D,
6557
float* dL_dconic,
@@ -74,8 +66,6 @@ namespace CudaRasterizer
7466
virtual ~Rasterizer() {};
7567

7668
static Rasterizer* make(int resizeMultipliyer = 2);
77-
78-
static void kill(Rasterizer* rasterizer);
7969
};
8070
};
8171

cuda_rasterizer/rasterizer_impl.cu

Lines changed: 72 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,6 @@ CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplie
137137
return new CudaRasterizer::RasterizerImpl(resizeMultiplier);
138138
}
139139

140-
void CudaRasterizer::Rasterizer::kill(Rasterizer* rasterizer)
141-
{
142-
delete rasterizer;
143-
}
144-
145140
// Mark Gaussians as visible/invisible, based on view frustum testing
146141
void CudaRasterizer::RasterizerImpl::markVisible(
147142
int P,
@@ -176,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward(
176171
const float* cam_pos,
177172
const float tan_fovx, float tan_fovy,
178173
const bool prefiltered,
179-
int* radii,
180-
InternalState* state,
181-
float* out_color)
174+
float* out_color,
175+
int* radii)
182176
{
183177
const float focal_y = height / (2.0f * tan_fovy);
184178
const float focal_x = width / (2.0f * tan_fovx);
185179

186180
// Dynamically resize auxiliary buffers during training
187-
if (P > state->maxP)
181+
if (P > maxP)
188182
{
189-
state->maxP = resizeMultiplier * P;
190-
state->cov3D.resize(state->maxP * 6);
191-
state->rgb.resize(state->maxP * 3);
192-
state->tiles_touched.resize(state->maxP);
193-
state->point_offsets.resize(state->maxP);
194-
state->clamped.resize(3 * state->maxP);
195-
196-
state->depths.resize(state->maxP);
197-
state->means2D.resize(state->maxP);
198-
state->conic_opacity.resize(state->maxP);
199-
200-
size_t scan_size;
201-
cub::DeviceScan::InclusiveSum(nullptr,
202-
scan_size,
203-
state->tiles_touched.data().get(),
204-
state->tiles_touched.data().get(),
205-
state->maxP);
206-
state->scanning_space.resize(scan_size);
183+
maxP = resizeMultiplier * P;
184+
cov3D.resize(maxP * 6);
185+
rgb.resize(maxP * 3);
186+
tiles_touched.resize(maxP);
187+
point_offsets.resize(maxP);
188+
clamped.resize(3 * maxP);
189+
190+
depths.resize(maxP);
191+
means2D.resize(maxP);
192+
conic_opacity.resize(maxP);
193+
194+
cub::DeviceScan::InclusiveSum(nullptr, scan_size, tiles_touched.data().get(), tiles_touched.data().get(), maxP);
195+
scanning_space.resize(scan_size);
196+
}
197+
198+
if (radii == nullptr)
199+
{
200+
internal_radii.resize(maxP);
201+
radii = internal_radii.data().get();
207202
}
208203

209204
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
210205
dim3 block(BLOCK_X, BLOCK_Y, 1);
211206

212207
// Dynamically resize image-based auxiliary buffers during training
213-
if (width * height > state->maxPixels)
208+
if (width * height > maxPixels)
214209
{
215-
state->maxPixels = width * height;
216-
state->accum_alpha.resize(state->maxPixels);
217-
state->n_contrib.resize(state->maxPixels);
218-
state->ranges.resize(tile_grid.x * tile_grid.y);
210+
maxPixels = width * height;
211+
accum_alpha.resize(maxPixels);
212+
n_contrib.resize(maxPixels);
213+
ranges.resize(tile_grid.x * tile_grid.y);
219214
}
220215

221216
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
@@ -232,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward(
232227
(glm::vec4*)rotations,
233228
opacities,
234229
shs,
235-
state->clamped.data().get(),
230+
clamped.data().get(),
236231
cov3D_precomp,
237232
colors_precomp,
238233
viewmatrix, projmatrix,
@@ -241,56 +236,45 @@ void CudaRasterizer::RasterizerImpl::forward(
241236
tan_fovx, tan_fovy,
242237
focal_x, focal_y,
243238
radii,
244-
state->means2D.data().get(),
245-
state->depths.data().get(),
246-
state->cov3D.data().get(),
247-
state->rgb.data().get(),
248-
state->conic_opacity.data().get(),
239+
means2D.data().get(),
240+
depths.data().get(),
241+
cov3D.data().get(),
242+
rgb.data().get(),
243+
conic_opacity.data().get(),
249244
tile_grid,
250-
state->tiles_touched.data().get(),
245+
tiles_touched.data().get(),
251246
prefiltered
252247
);
253248

254249
// Compute prefix sum over full list of touched tile counts by Gaussians
255250
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
256-
size_t scanning_space_size = state->scanning_space.size();
257-
cub::DeviceScan::InclusiveSum(
258-
state->scanning_space.data().get(),
259-
scanning_space_size,
260-
state->tiles_touched.data().get(),
261-
state->point_offsets.data().get(),
262-
P);
251+
cub::DeviceScan::InclusiveSum(scanning_space.data().get(), scan_size,
252+
tiles_touched.data().get(), point_offsets.data().get(), P);
263253

264254
// Retrieve total number of Gaussian instances to launch and resize aux buffers
265255
int num_needed;
266-
cudaMemcpy(&num_needed, state->point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
256+
cudaMemcpy(&num_needed, point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
267257
if (num_needed > point_list_keys_unsorted.size())
268258
{
269-
int resizeNum = resizeMultiplier * num_needed;
270-
point_list_keys_unsorted.resize(resizeNum);
271-
point_list_keys.resize(resizeNum);
272-
point_list_unsorted.resize(resizeNum);
273-
size_t sorting_size;
259+
point_list_keys_unsorted.resize(2 * num_needed);
260+
point_list_keys.resize(2 * num_needed);
261+
point_list_unsorted.resize(2 * num_needed);
262+
point_list.resize(2 * num_needed);
274263
cub::DeviceRadixSort::SortPairs(
275264
nullptr, sorting_size,
276265
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
277-
point_list_unsorted.data().get(), state->point_list.data().get(),
278-
resizeNum);
266+
point_list_unsorted.data().get(), point_list.data().get(),
267+
2 * num_needed);
279268
list_sorting_space.resize(sorting_size);
280269
}
281270

282-
if (num_needed > state->point_list.size())
283-
{
284-
state->point_list.resize(resizeMultiplier * num_needed);
285-
}
286-
287271
// For each instance to be rendered, produce adequate [ tile | depth ] key
288272
// and corresponding dublicated Gaussian indices to be sorted
289273
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
290274
P,
291-
state->means2D.data().get(),
292-
state->depths.data().get(),
293-
state->point_offsets.data().get(),
275+
means2D.data().get(),
276+
depths.data().get(),
277+
point_offsets.data().get(),
294278
point_list_keys_unsorted.data().get(),
295279
point_list_unsorted.data().get(),
296280
radii,
@@ -300,45 +284,41 @@ void CudaRasterizer::RasterizerImpl::forward(
300284
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
301285

302286
// Sort complete list of (duplicated) Gaussian indices by keys
303-
size_t list_sorting_space_size = list_sorting_space.size();
304287
cub::DeviceRadixSort::SortPairs(
305288
list_sorting_space.data().get(),
306-
list_sorting_space_size,
289+
sorting_size,
307290
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
308-
point_list_unsorted.data().get(),
309-
state->point_list.data().get(),
291+
point_list_unsorted.data().get(), point_list.data().get(),
310292
num_needed, 0, 32 + bit);
311293

312-
cudaMemset(state->ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
294+
cudaMemset(ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
313295

314296
// Identify start and end of per-tile workloads in sorted list
315297
identifyTileRanges << <(num_needed + 255) / 256, 256 >> > (
316298
num_needed,
317299
point_list_keys.data().get(),
318-
state->ranges.data().get()
300+
ranges.data().get()
319301
);
320302

321303
// Let each tile blend its range of Gaussians independently in parallel
322-
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : state->rgb.data().get();
304+
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : rgb.data().get();
323305
FORWARD::render(
324306
tile_grid, block,
325-
state->ranges.data().get(),
326-
state->point_list.data().get(),
307+
ranges.data().get(),
308+
point_list.data().get(),
327309
width, height,
328-
state->means2D.data().get(),
310+
means2D.data().get(),
329311
feature_ptr,
330-
state->conic_opacity.data().get(),
331-
state->accum_alpha.data().get(),
332-
state->n_contrib.data().get(),
312+
conic_opacity.data().get(),
313+
accum_alpha.data().get(),
314+
n_contrib.data().get(),
333315
background,
334316
out_color);
335317
}
336318

337319
// Produce necessary gradients for optimization, corresponding
338320
// to forward render pass
339321
void CudaRasterizer::RasterizerImpl::backward(
340-
const int* radii,
341-
const InternalState* state,
342322
const int P, int D, int M,
343323
const float* background,
344324
const int width, int height,
@@ -353,6 +333,7 @@ void CudaRasterizer::RasterizerImpl::backward(
353333
const float* projmatrix,
354334
const float* campos,
355335
const float tan_fovx, float tan_fovy,
336+
const int* radii,
356337
const float* dL_dpix,
357338
float* dL_dmean2D,
358339
float* dL_dconic,
@@ -364,6 +345,11 @@ void CudaRasterizer::RasterizerImpl::backward(
364345
float* dL_dscale,
365346
float* dL_drot)
366347
{
348+
if (radii == nullptr)
349+
{
350+
radii = internal_radii.data().get();
351+
}
352+
367353
const float focal_y = height / (2.0f * tan_fovy);
368354
const float focal_x = width / (2.0f * tan_fovx);
369355

@@ -373,19 +359,19 @@ void CudaRasterizer::RasterizerImpl::backward(
373359
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
374360
// opacity and RGB of Gaussians from per-pixel loss gradients.
375361
// If we were given precomputed colors and not SHs, use them.
376-
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : state->rgb.data().get();
362+
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : rgb.data().get();
377363
BACKWARD::render(
378364
tile_grid,
379365
block,
380-
state->ranges.data().get(),
381-
state->point_list.data().get(),
366+
ranges.data().get(),
367+
point_list.data().get(),
382368
width, height,
383369
background,
384-
state->means2D.data().get(),
385-
state->conic_opacity.data().get(),
370+
means2D.data().get(),
371+
conic_opacity.data().get(),
386372
color_ptr,
387-
state->accum_alpha.data().get(),
388-
state->n_contrib.data().get(),
373+
accum_alpha.data().get(),
374+
n_contrib.data().get(),
389375
dL_dpix,
390376
(float3*)dL_dmean2D,
391377
(float4*)dL_dconic,
@@ -395,12 +381,12 @@ void CudaRasterizer::RasterizerImpl::backward(
395381
// Take care of the rest of preprocessing. Was the precomputed covariance
396382
// given to us or a scales/rot pair? If precomputed, pass that. If not,
397383
// use the one we computed ourselves.
398-
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : state->cov3D.data().get();
384+
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : cov3D.data().get();
399385
BACKWARD::preprocess(P, D, M,
400386
(float3*)means3D,
401387
radii,
402388
shs,
403-
state->clamped.data().get(),
389+
clamped.data().get(),
404390
(glm::vec3*)scales,
405391
(glm::vec4*)rotations,
406392
scale_modifier,

0 commit comments

Comments
 (0)