@@ -137,11 +137,6 @@ CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplie
137137 return new CudaRasterizer::RasterizerImpl (resizeMultiplier);
138138}
139139
140- void CudaRasterizer::Rasterizer::kill (Rasterizer* rasterizer)
141- {
142- delete rasterizer;
143- }
144-
145140// Mark Gaussians as visible/invisible, based on view frustum testing
146141void CudaRasterizer::RasterizerImpl::markVisible (
147142 int P,
@@ -176,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward(
176171 const float * cam_pos,
177172 const float tan_fovx, float tan_fovy,
178173 const bool prefiltered,
179- int * radii,
180- InternalState* state,
181- float * out_color)
174+ float * out_color,
175+ int * radii)
182176{
183177 const float focal_y = height / (2 .0f * tan_fovy);
184178 const float focal_x = width / (2 .0f * tan_fovx);
185179
186180 // Dynamically resize auxiliary buffers during training
187- if (P > state-> maxP )
181+ if (P > maxP)
188182 {
189- state->maxP = resizeMultiplier * P;
190- state->cov3D .resize (state->maxP * 6 );
191- state->rgb .resize (state->maxP * 3 );
192- state->tiles_touched .resize (state->maxP );
193- state->point_offsets .resize (state->maxP );
194- state->clamped .resize (3 * state->maxP );
195-
196- state->depths .resize (state->maxP );
197- state->means2D .resize (state->maxP );
198- state->conic_opacity .resize (state->maxP );
199-
200- size_t scan_size;
201- cub::DeviceScan::InclusiveSum (nullptr ,
202- scan_size,
203- state->tiles_touched .data ().get (),
204- state->tiles_touched .data ().get (),
205- state->maxP );
206- state->scanning_space .resize (scan_size);
183+ maxP = resizeMultiplier * P;
184+ cov3D.resize (maxP * 6 );
185+ rgb.resize (maxP * 3 );
186+ tiles_touched.resize (maxP);
187+ point_offsets.resize (maxP);
188+ clamped.resize (3 * maxP);
189+
190+ depths.resize (maxP);
191+ means2D.resize (maxP);
192+ conic_opacity.resize (maxP);
193+
194+ cub::DeviceScan::InclusiveSum (nullptr , scan_size, tiles_touched.data ().get (), tiles_touched.data ().get (), maxP);
195+ scanning_space.resize (scan_size);
196+ }
197+
198+ if (radii == nullptr )
199+ {
200+ internal_radii.resize (maxP);
201+ radii = internal_radii.data ().get ();
207202 }
208203
209204 dim3 tile_grid ((width + BLOCK_X - 1 ) / BLOCK_X, (height + BLOCK_Y - 1 ) / BLOCK_Y, 1 );
210205 dim3 block (BLOCK_X, BLOCK_Y, 1 );
211206
212207 // Dynamically resize image-based auxiliary buffers during training
213- if (width * height > state-> maxPixels )
208+ if (width * height > maxPixels)
214209 {
215- state-> maxPixels = width * height;
216- state-> accum_alpha .resize (state-> maxPixels );
217- state-> n_contrib .resize (state-> maxPixels );
218- state-> ranges .resize (tile_grid.x * tile_grid.y );
210+ maxPixels = width * height;
211+ accum_alpha.resize (maxPixels);
212+ n_contrib.resize (maxPixels);
213+ ranges.resize (tile_grid.x * tile_grid.y );
219214 }
220215
221216 if (NUM_CHANNELS != 3 && colors_precomp == nullptr )
@@ -232,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward(
232227 (glm::vec4*)rotations,
233228 opacities,
234229 shs,
235- state-> clamped .data ().get (),
230+ clamped.data ().get (),
236231 cov3D_precomp,
237232 colors_precomp,
238233 viewmatrix, projmatrix,
@@ -241,56 +236,45 @@ void CudaRasterizer::RasterizerImpl::forward(
241236 tan_fovx, tan_fovy,
242237 focal_x, focal_y,
243238 radii,
244- state-> means2D .data ().get (),
245- state-> depths .data ().get (),
246- state-> cov3D .data ().get (),
247- state-> rgb .data ().get (),
248- state-> conic_opacity .data ().get (),
239+ means2D.data ().get (),
240+ depths.data ().get (),
241+ cov3D.data ().get (),
242+ rgb.data ().get (),
243+ conic_opacity.data ().get (),
249244 tile_grid,
250- state-> tiles_touched .data ().get (),
245+ tiles_touched.data ().get (),
251246 prefiltered
252247 );
253248
254249 // Compute prefix sum over full list of touched tile counts by Gaussians
255250 // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
256- size_t scanning_space_size = state->scanning_space .size ();
257- cub::DeviceScan::InclusiveSum (
258- state->scanning_space .data ().get (),
259- scanning_space_size,
260- state->tiles_touched .data ().get (),
261- state->point_offsets .data ().get (),
262- P);
251+ cub::DeviceScan::InclusiveSum (scanning_space.data ().get (), scan_size,
252+ tiles_touched.data ().get (), point_offsets.data ().get (), P);
263253
264254 // Retrieve total number of Gaussian instances to launch and resize aux buffers
265255 int num_needed;
266- cudaMemcpy (&num_needed, state-> point_offsets .data ().get () + P - 1 , sizeof (int ), cudaMemcpyDeviceToHost);
256+ cudaMemcpy (&num_needed, point_offsets.data ().get () + P - 1 , sizeof (int ), cudaMemcpyDeviceToHost);
267257 if (num_needed > point_list_keys_unsorted.size ())
268258 {
269- int resizeNum = resizeMultiplier * num_needed;
270- point_list_keys_unsorted.resize (resizeNum);
271- point_list_keys.resize (resizeNum);
272- point_list_unsorted.resize (resizeNum);
273- size_t sorting_size;
259+ point_list_keys_unsorted.resize (2 * num_needed);
260+ point_list_keys.resize (2 * num_needed);
261+ point_list_unsorted.resize (2 * num_needed);
262+ point_list.resize (2 * num_needed);
274263 cub::DeviceRadixSort::SortPairs (
275264 nullptr , sorting_size,
276265 point_list_keys_unsorted.data ().get (), point_list_keys.data ().get (),
277- point_list_unsorted.data ().get (), state-> point_list .data ().get (),
278- resizeNum );
266+ point_list_unsorted.data ().get (), point_list.data ().get (),
267+ 2 * num_needed );
279268 list_sorting_space.resize (sorting_size);
280269 }
281270
282- if (num_needed > state->point_list .size ())
283- {
284- state->point_list .resize (resizeMultiplier * num_needed);
285- }
286-
287271 // For each instance to be rendered, produce adequate [ tile | depth ] key
288272 // and corresponding dublicated Gaussian indices to be sorted
289273 duplicateWithKeys << <(P + 255 ) / 256 , 256 >> > (
290274 P,
291- state-> means2D .data ().get (),
292- state-> depths .data ().get (),
293- state-> point_offsets .data ().get (),
275+ means2D.data ().get (),
276+ depths.data ().get (),
277+ point_offsets.data ().get (),
294278 point_list_keys_unsorted.data ().get (),
295279 point_list_unsorted.data ().get (),
296280 radii,
@@ -300,45 +284,41 @@ void CudaRasterizer::RasterizerImpl::forward(
300284 int bit = getHigherMsb (tile_grid.x * tile_grid.y );
301285
302286 // Sort complete list of (duplicated) Gaussian indices by keys
303- size_t list_sorting_space_size = list_sorting_space.size ();
304287 cub::DeviceRadixSort::SortPairs (
305288 list_sorting_space.data ().get (),
306- list_sorting_space_size ,
289+ sorting_size ,
307290 point_list_keys_unsorted.data ().get (), point_list_keys.data ().get (),
308- point_list_unsorted.data ().get (),
309- state->point_list .data ().get (),
291+ point_list_unsorted.data ().get (), point_list.data ().get (),
310292 num_needed, 0 , 32 + bit);
311293
312- cudaMemset (state-> ranges .data ().get (), 0 , tile_grid.x * tile_grid.y * sizeof (uint2 ));
294+ cudaMemset (ranges.data ().get (), 0 , tile_grid.x * tile_grid.y * sizeof (uint2 ));
313295
314296 // Identify start and end of per-tile workloads in sorted list
315297 identifyTileRanges << <(num_needed + 255 ) / 256 , 256 >> > (
316298 num_needed,
317299 point_list_keys.data ().get (),
318- state-> ranges .data ().get ()
300+ ranges.data ().get ()
319301 );
320302
321303 // Let each tile blend its range of Gaussians independently in parallel
322- const float * feature_ptr = colors_precomp != nullptr ? colors_precomp : state-> rgb .data ().get ();
304+ const float * feature_ptr = colors_precomp != nullptr ? colors_precomp : rgb.data ().get ();
323305 FORWARD::render (
324306 tile_grid, block,
325- state-> ranges .data ().get (),
326- state-> point_list .data ().get (),
307+ ranges.data ().get (),
308+ point_list.data ().get (),
327309 width, height,
328- state-> means2D .data ().get (),
310+ means2D.data ().get (),
329311 feature_ptr,
330- state-> conic_opacity .data ().get (),
331- state-> accum_alpha .data ().get (),
332- state-> n_contrib .data ().get (),
312+ conic_opacity.data ().get (),
313+ accum_alpha.data ().get (),
314+ n_contrib.data ().get (),
333315 background,
334316 out_color);
335317}
336318
337319// Produce necessary gradients for optimization, corresponding
338320// to forward render pass
339321void CudaRasterizer::RasterizerImpl::backward (
340- const int * radii,
341- const InternalState* state,
342322 const int P, int D, int M,
343323 const float * background,
344324 const int width, int height,
@@ -353,6 +333,7 @@ void CudaRasterizer::RasterizerImpl::backward(
353333 const float * projmatrix,
354334 const float * campos,
355335 const float tan_fovx, float tan_fovy,
336+ const int * radii,
356337 const float * dL_dpix,
357338 float * dL_dmean2D,
358339 float * dL_dconic,
@@ -364,6 +345,11 @@ void CudaRasterizer::RasterizerImpl::backward(
364345 float * dL_dscale,
365346 float * dL_drot)
366347{
348+ if (radii == nullptr )
349+ {
350+ radii = internal_radii.data ().get ();
351+ }
352+
367353 const float focal_y = height / (2 .0f * tan_fovy);
368354 const float focal_x = width / (2 .0f * tan_fovx);
369355
@@ -373,19 +359,19 @@ void CudaRasterizer::RasterizerImpl::backward(
373359 // Compute loss gradients w.r.t. 2D mean position, conic matrix,
374360 // opacity and RGB of Gaussians from per-pixel loss gradients.
375361 // If we were given precomputed colors and not SHs, use them.
376- const float * color_ptr = (colors_precomp != nullptr ) ? colors_precomp : state-> rgb .data ().get ();
362+ const float * color_ptr = (colors_precomp != nullptr ) ? colors_precomp : rgb.data ().get ();
377363 BACKWARD::render (
378364 tile_grid,
379365 block,
380- state-> ranges .data ().get (),
381- state-> point_list .data ().get (),
366+ ranges.data ().get (),
367+ point_list.data ().get (),
382368 width, height,
383369 background,
384- state-> means2D .data ().get (),
385- state-> conic_opacity .data ().get (),
370+ means2D.data ().get (),
371+ conic_opacity.data ().get (),
386372 color_ptr,
387- state-> accum_alpha .data ().get (),
388- state-> n_contrib .data ().get (),
373+ accum_alpha.data ().get (),
374+ n_contrib.data ().get (),
389375 dL_dpix,
390376 (float3 *)dL_dmean2D,
391377 (float4 *)dL_dconic,
@@ -395,12 +381,12 @@ void CudaRasterizer::RasterizerImpl::backward(
395381 // Take care of the rest of preprocessing. Was the precomputed covariance
396382 // given to us or a scales/rot pair? If precomputed, pass that. If not,
397383 // use the one we computed ourselves.
398- const float * cov3D_ptr = (cov3D_precomp != nullptr ) ? cov3D_precomp : state-> cov3D .data ().get ();
384+ const float * cov3D_ptr = (cov3D_precomp != nullptr ) ? cov3D_precomp : cov3D.data ().get ();
399385 BACKWARD::preprocess (P, D, M,
400386 (float3 *)means3D,
401387 radii,
402388 shs,
403- state-> clamped .data ().get (),
389+ clamped.data ().get (),
404390 (glm::vec3*)scales,
405391 (glm::vec4*)rotations,
406392 scale_modifier,
0 commit comments