@@ -171,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward(
171171 const float * cam_pos,
172172 const float tan_fovx, float tan_fovy,
173173 const bool prefiltered,
174- float * out_color,
175- int * radii)
174+ int * radii,
175+ InternalState* state,
176+ float * out_color)
176177{
177178 const float focal_y = height / (2 .0f * tan_fovy);
178179 const float focal_x = width / (2 .0f * tan_fovx);
179180
180181 // Dynamically resize auxiliary buffers during training
181- if (P > maxP)
182- {
183- maxP = resizeMultiplier * P;
184- cov3D.resize (maxP * 6 );
185- rgb.resize (maxP * 3 );
186- tiles_touched.resize (maxP);
187- point_offsets.resize (maxP);
188- clamped.resize (3 * maxP);
189-
190- depths.resize (maxP);
191- means2D.resize (maxP);
192- conic_opacity.resize (maxP);
193-
194- cub::DeviceScan::InclusiveSum (nullptr , scan_size, tiles_touched.data ().get (), tiles_touched.data ().get (), maxP);
195- scanning_space.resize (scan_size);
196- }
197-
198- if (radii == nullptr )
182+ if (P > state->maxP )
199183 {
200- internal_radii.resize (maxP);
201- radii = internal_radii.data ().get ();
184+ state->maxP = resizeMultiplier * P;
185+ state->cov3D .resize (state->maxP * 6 );
186+ state->rgb .resize (state->maxP * 3 );
187+ state->tiles_touched .resize (state->maxP );
188+ state->point_offsets .resize (state->maxP );
189+ state->clamped .resize (3 * state->maxP );
190+
191+ state->depths .resize (state->maxP );
192+ state->means2D .resize (state->maxP );
193+ state->conic_opacity .resize (state->maxP );
194+
195+ size_t scan_size;
196+ cub::DeviceScan::InclusiveSum (nullptr ,
197+ scan_size,
198+ state->tiles_touched .data ().get (),
199+ state->tiles_touched .data ().get (),
200+ state->maxP );
201+ state->scanning_space .resize (scan_size);
202202 }
203203
204204 dim3 tile_grid ((width + BLOCK_X - 1 ) / BLOCK_X, (height + BLOCK_Y - 1 ) / BLOCK_Y, 1 );
205205 dim3 block (BLOCK_X, BLOCK_Y, 1 );
206206
207207 // Dynamically resize image-based auxiliary buffers during training
208- if (width * height > maxPixels)
208+ if (width * height > state-> maxPixels )
209209 {
210- maxPixels = width * height;
211- accum_alpha.resize (maxPixels);
212- n_contrib.resize (maxPixels);
213- ranges.resize (tile_grid.x * tile_grid.y );
210+ state-> maxPixels = width * height;
211+ state-> accum_alpha .resize (state-> maxPixels );
212+ state-> n_contrib .resize (state-> maxPixels );
213+ state-> ranges .resize (tile_grid.x * tile_grid.y );
214214 }
215215
216216 if (NUM_CHANNELS != 3 && colors_precomp == nullptr )
@@ -227,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward(
227227 (glm::vec4*)rotations,
228228 opacities,
229229 shs,
230- clamped.data ().get (),
230+ state-> clamped .data ().get (),
231231 cov3D_precomp,
232232 colors_precomp,
233233 viewmatrix, projmatrix,
@@ -236,45 +236,56 @@ void CudaRasterizer::RasterizerImpl::forward(
236236 tan_fovx, tan_fovy,
237237 focal_x, focal_y,
238238 radii,
239- means2D.data ().get (),
240- depths.data ().get (),
241- cov3D.data ().get (),
242- rgb.data ().get (),
243- conic_opacity.data ().get (),
239+ state-> means2D .data ().get (),
240+ state-> depths .data ().get (),
241+ state-> cov3D .data ().get (),
242+ state-> rgb .data ().get (),
243+ state-> conic_opacity .data ().get (),
244244 tile_grid,
245- tiles_touched.data ().get (),
245+ state-> tiles_touched .data ().get (),
246246 prefiltered
247247 );
248248
249249 // Compute prefix sum over full list of touched tile counts by Gaussians
250250 // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
251- cub::DeviceScan::InclusiveSum (scanning_space.data ().get (), scan_size,
252- tiles_touched.data ().get (), point_offsets.data ().get (), P);
251+ size_t scanning_space_size = state->scanning_space .size ();
252+ cub::DeviceScan::InclusiveSum (
253+ state->scanning_space .data ().get (),
254+ scanning_space_size,
255+ state->tiles_touched .data ().get (),
256+ state->point_offsets .data ().get (),
257+ P);
253258
254259 // Retrieve total number of Gaussian instances to launch and resize aux buffers
255260 int num_needed;
256- cudaMemcpy (&num_needed, point_offsets.data ().get () + P - 1 , sizeof (int ), cudaMemcpyDeviceToHost);
261+ cudaMemcpy (&num_needed, state-> point_offsets .data ().get () + P - 1 , sizeof (int ), cudaMemcpyDeviceToHost);
257262 if (num_needed > point_list_keys_unsorted.size ())
258263 {
259- point_list_keys_unsorted.resize (2 * num_needed);
260- point_list_keys.resize (2 * num_needed);
261- point_list_unsorted.resize (2 * num_needed);
262- point_list.resize (2 * num_needed);
264+ int resizeNum = resizeMultiplier * num_needed;
265+ point_list_keys_unsorted.resize (resizeNum);
266+ point_list_keys.resize (resizeNum);
267+ point_list_unsorted.resize (resizeNum);
268+ size_t sorting_size;
263269 cub::DeviceRadixSort::SortPairs (
264270 nullptr , sorting_size,
265271 point_list_keys_unsorted.data ().get (), point_list_keys.data ().get (),
266- point_list_unsorted.data ().get (), point_list.data ().get (),
267- 2 * num_needed );
272+ point_list_unsorted.data ().get (), state-> point_list .data ().get (),
273+ resizeNum );
268274 list_sorting_space.resize (sorting_size);
269275 }
270276
277+ if (num_needed > state->point_list .size ())
278+ {
279+ state->point_list .resize (resizeMultiplier * num_needed);
280+ }
281+
271282 // For each instance to be rendered, produce adequate [ tile | depth ] key
272283 // and corresponding dublicated Gaussian indices to be sorted
273284 duplicateWithKeys << <(P + 255 ) / 256 , 256 >> > (
274285 P,
275- means2D.data ().get (),
276- depths.data ().get (),
277- point_offsets.data ().get (),
286+ state-> means2D .data ().get (),
287+ state-> depths .data ().get (),
288+ state-> point_offsets .data ().get (),
278289 point_list_keys_unsorted.data ().get (),
279290 point_list_unsorted.data ().get (),
280291 radii,
@@ -284,41 +295,45 @@ void CudaRasterizer::RasterizerImpl::forward(
284295 int bit = getHigherMsb (tile_grid.x * tile_grid.y );
285296
286297 // Sort complete list of (duplicated) Gaussian indices by keys
298+ size_t list_sorting_space_size = list_sorting_space.size ();
287299 cub::DeviceRadixSort::SortPairs (
288300 list_sorting_space.data ().get (),
289- sorting_size ,
301+ list_sorting_space_size ,
290302 point_list_keys_unsorted.data ().get (), point_list_keys.data ().get (),
291- point_list_unsorted.data ().get (), point_list.data ().get (),
303+ point_list_unsorted.data ().get (),
304+ state->point_list .data ().get (),
292305 num_needed, 0 , 32 + bit);
293306
294- cudaMemset (ranges.data ().get (), 0 , tile_grid.x * tile_grid.y * sizeof (uint2 ));
307+ cudaMemset (state-> ranges .data ().get (), 0 , tile_grid.x * tile_grid.y * sizeof (uint2 ));
295308
296309 // Identify start and end of per-tile workloads in sorted list
297310 identifyTileRanges << <(num_needed + 255 ) / 256 , 256 >> > (
298311 num_needed,
299312 point_list_keys.data ().get (),
300- ranges.data ().get ()
313+ state-> ranges .data ().get ()
301314 );
302315
303316 // Let each tile blend its range of Gaussians independently in parallel
304- const float * feature_ptr = colors_precomp != nullptr ? colors_precomp : rgb.data ().get ();
317+ const float * feature_ptr = colors_precomp != nullptr ? colors_precomp : state-> rgb .data ().get ();
305318 FORWARD::render (
306319 tile_grid, block,
307- ranges.data ().get (),
308- point_list.data ().get (),
320+ state-> ranges .data ().get (),
321+ state-> point_list .data ().get (),
309322 width, height,
310- means2D.data ().get (),
323+ state-> means2D .data ().get (),
311324 feature_ptr,
312- conic_opacity.data ().get (),
313- accum_alpha.data ().get (),
314- n_contrib.data ().get (),
325+ state-> conic_opacity .data ().get (),
326+ state-> accum_alpha .data ().get (),
327+ state-> n_contrib .data ().get (),
315328 background,
316329 out_color);
317330}
318331
319332// Produce necessary gradients for optimization, corresponding
320333// to forward render pass
321334void CudaRasterizer::RasterizerImpl::backward (
335+ const int * radii,
336+ const InternalState* state,
322337 const int P, int D, int M,
323338 const float * background,
324339 const int width, int height,
@@ -333,7 +348,6 @@ void CudaRasterizer::RasterizerImpl::backward(
333348 const float * projmatrix,
334349 const float * campos,
335350 const float tan_fovx, float tan_fovy,
336- const int * radii,
337351 const float * dL_dpix,
338352 float * dL_dmean2D,
339353 float * dL_dconic,
@@ -345,11 +359,6 @@ void CudaRasterizer::RasterizerImpl::backward(
345359 float * dL_dscale,
346360 float * dL_drot)
347361{
348- if (radii == nullptr )
349- {
350- radii = internal_radii.data ().get ();
351- }
352-
353362 const float focal_y = height / (2 .0f * tan_fovy);
354363 const float focal_x = width / (2 .0f * tan_fovx);
355364
@@ -359,19 +368,19 @@ void CudaRasterizer::RasterizerImpl::backward(
359368 // Compute loss gradients w.r.t. 2D mean position, conic matrix,
360369 // opacity and RGB of Gaussians from per-pixel loss gradients.
361370 // If we were given precomputed colors and not SHs, use them.
362- const float * color_ptr = (colors_precomp != nullptr ) ? colors_precomp : rgb.data ().get ();
371+ const float * color_ptr = (colors_precomp != nullptr ) ? colors_precomp : state-> rgb .data ().get ();
363372 BACKWARD::render (
364373 tile_grid,
365374 block,
366- ranges.data ().get (),
367- point_list.data ().get (),
375+ state-> ranges .data ().get (),
376+ state-> point_list .data ().get (),
368377 width, height,
369378 background,
370- means2D.data ().get (),
371- conic_opacity.data ().get (),
379+ state-> means2D .data ().get (),
380+ state-> conic_opacity .data ().get (),
372381 color_ptr,
373- accum_alpha.data ().get (),
374- n_contrib.data ().get (),
382+ state-> accum_alpha .data ().get (),
383+ state-> n_contrib .data ().get (),
375384 dL_dpix,
376385 (float3 *)dL_dmean2D,
377386 (float4 *)dL_dconic,
@@ -381,12 +390,12 @@ void CudaRasterizer::RasterizerImpl::backward(
381390 // Take care of the rest of preprocessing. Was the precomputed covariance
382391 // given to us or a scales/rot pair? If precomputed, pass that. If not,
383392 // use the one we computed ourselves.
384- const float * cov3D_ptr = (cov3D_precomp != nullptr ) ? cov3D_precomp : cov3D.data ().get ();
393+ const float * cov3D_ptr = (cov3D_precomp != nullptr ) ? cov3D_precomp : state-> cov3D .data ().get ();
385394 BACKWARD::preprocess (P, D, M,
386395 (float3 *)means3D,
387396 radii,
388397 shs,
389- clamped.data ().get (),
398+ state-> clamped .data ().get (),
390399 (glm::vec3*)scales,
391400 (glm::vec4*)rotations,
392401 scale_modifier,
0 commit comments