diff --git a/crates/brush-process/src/train_stream.rs b/crates/brush-process/src/train_stream.rs index 01d1ad09..69d9f220 100644 --- a/crates/brush-process/src/train_stream.rs +++ b/crates/brush-process/src/train_stream.rs @@ -13,6 +13,7 @@ use brush_render::{ use brush_rerun::visualize_tools::VisualizeTools; use brush_train::{ RandomSplatsConfig, create_random_splats, + density_control::compute_gaussian_score, eval::eval_stats, msg::RefineStats, splats_into_autodiff, to_init_splats, @@ -194,15 +195,37 @@ pub(crate) async fn train_stream( .await .unwrap(); - let train_t = - (iter as f32 / train_stream_config.train_config.total_steps as f32).clamp(0.0, 1.0); - let refine = if iter > 0 && iter.is_multiple_of(train_stream_config.train_config.refine_every) - && train_t <= 0.95 + && iter < train_stream_config.train_config.growth_stop_iter { splat_slot - .act(0, async |splats| trainer.refine(iter, splats).await) + .act(0, async |splats| { + let gaussian_scores = compute_gaussian_score( + &mut dataloader, + splats.clone(), + train_stream_config.train_config.n_views, + train_stream_config.train_config.high_error_threshold, + ) + .await; + trainer.refine(iter, splats.clone(), gaussian_scores).await + }) + .await + .unwrap() + } else if iter > train_stream_config.train_config.growth_stop_iter + && iter % train_stream_config.train_config.refine_every_final == 0 + { + splat_slot + .act(0, async |splats| { + let gaussian_scores = compute_gaussian_score( + &mut dataloader, + splats.clone(), + train_stream_config.train_config.n_views, + train_stream_config.train_config.high_error_threshold, + ) + .await; + trainer.refine_final(splats.clone(), gaussian_scores).await + }) .await .unwrap() } else { diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index 81adbd2f..9e99447a 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -279,8 +279,14 @@ where // Async readback let num_intersections = project_output.read_num_intersections().await; - let (out_img, render_aux, compact_gid_from_isect) = - >::rasterize(&project_output, num_intersections, background, true); + let (out_img, render_aux, compact_gid_from_isect) = >::rasterize( + &project_output, + num_intersections, + background, + true, + false, + None, + ); let wrapped_render_aux = RenderAux::> { num_visible: render_aux.num_visible.clone(), @@ -288,6 +294,7 @@ where visible: as AutodiffBackend>::from_inner(render_aux.visible.clone()), tile_offsets: render_aux.tile_offsets.clone(), img_size: render_aux.img_size, + high_error_count: render_aux.high_error_count, }; let sh_degree = sh_degree_from_coeffs(sh_coeffs_dims[1] as u32); diff --git a/crates/brush-render/src/burn_glue.rs b/crates/brush-render/src/burn_glue.rs index fce4399e..3544e18a 100644 --- a/crates/brush-render/src/burn_glue.rs +++ b/crates/brush-render/src/burn_glue.rs @@ -1,5 +1,5 @@ use burn::tensor::{ - DType, Shape, + DType, Shape, Tensor, ops::{FloatTensor, IntTensor}, }; use burn_cubecl::{BoolElement, fusion::FusionCubeRuntime}; @@ -164,6 +164,8 @@ impl SplatOps for Fusion { num_intersections: u32, background: Vec3, bwd_info: bool, + high_error_info: bool, + high_error_mask: Option<&FloatTensor>, ) -> (FloatTensor, RenderAux, IntTensor) { #[derive(Debug)] struct CustomOp { @@ -171,6 +173,7 @@ impl SplatOps for Fusion { num_intersections: u32, background: Vec3, bwd_info: bool, + high_error_info: bool, project_uniforms: ProjectUniforms, desc: CustomOpIr, } @@ -187,8 +190,15 @@ impl SplatOps for Fusion { num_visible, global_from_compact_gid, cum_tiles_hit, + high_error_mask, ] = inputs; - let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs; + let [ + out_img, + tile_offsets, + compact_gid_from_isect, + visible, + high_error_count, + ] = outputs; let inner_output = ProjectOutput:: { projected_splats: h.get_float_tensor::(projected_splats), @@ -200,11 +210,19 @@ impl SplatOps for Fusion { img_size: self.img_size, }; + let high_error_mask = if self.high_error_info { + Some(&h.get_float_tensor::(high_error_mask)) + } else { + None + }; + let (img, aux, compact_gid) = MainBackendBase::rasterize( &inner_output, self.num_intersections, self.background, self.bwd_info, + self.high_error_info, + high_error_mask, ); // Register outputs @@ -212,10 +230,15 @@ impl SplatOps for Fusion { h.register_int_tensor::(&tile_offsets.id, aux.tile_offsets); h.register_int_tensor::(&compact_gid_from_isect.id, compact_gid); h.register_float_tensor::(&visible.id, aux.visible); + h.register_int_tensor::( + &high_error_count.id, + aux.high_error_count, + ); } } let client = project_output.projected_splats.client.clone(); + let device = project_output.projected_splats.client.device(); let img_size = project_output.img_size; let tile_bounds = calc_tile_bounds(img_size); @@ -248,23 +271,52 @@ impl SplatOps for Fusion { }; let visible = TensorIr::uninit(client.create_empty_handle(), visible_shape, DType::F32); + let high_error_count_shape = if bwd_info && high_error_info { + Shape::new([num_points]) + } else { + Shape::new([1]) + }; + let high_error_count = TensorIr::uninit( + client.create_empty_handle(), + high_error_count_shape, + DType::U32, + ); + + let high_error_mask = if bwd_info && high_error_info { + high_error_mask + .expect("Provide high error mask if high error info is required") + .clone() + } else { + Tensor::::zeros([1, 1], device) + .into_primitive() + .tensor() + }; + let input_tensors = [ project_output.projected_splats.clone(), project_output.num_visible.clone(), project_output.global_from_compact_gid.clone(), project_output.cum_tiles_hit.clone(), + high_error_mask, ]; let stream = OperationStreams::with_inputs(&input_tensors); let desc = CustomOpIr::new( "rasterize", &input_tensors.map(|t| t.into_ir()), - &[out_img, tile_offsets, compact_gid_from_isect, visible], + &[ + out_img, + tile_offsets, + compact_gid_from_isect, + visible, + high_error_count, + ], ); let op = CustomOp { img_size, num_intersections, background, bwd_info, + high_error_info, project_uniforms: project_output.project_uniforms, desc: desc.clone(), }; @@ -273,7 +325,13 @@ impl SplatOps for Fusion { .register(stream, OperationIr::Custom(desc), op) .outputs(); - let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs; + let [ + out_img, + tile_offsets, + compact_gid_from_isect, + visible, + high_error_count, + ] = outputs; ( out_img, @@ -283,6 +341,7 @@ impl SplatOps for Fusion { visible, tile_offsets, img_size, + high_error_count, }, compact_gid_from_isect, ) diff --git a/crates/brush-render/src/gaussian_splats.rs b/crates/brush-render/src/gaussian_splats.rs index e9fa41c8..ac73cb1d 100644 --- a/crates/brush-render/src/gaussian_splats.rs +++ b/crates/brush-render/src/gaussian_splats.rs @@ -275,8 +275,14 @@ pub async fn render_splats>( let num_intersections = project_output.read_num_intersections().await; let use_float = matches!(texture_mode, TextureMode::Float); - let (out_img, render_aux, _) = - B::rasterize(&project_output, num_intersections, background, use_float); + let (out_img, render_aux, _) = B::rasterize( + &project_output, + num_intersections, + background, + use_float, + false, + None, + ); render_aux.validate(); diff --git a/crates/brush-render/src/lib.rs b/crates/brush-render/src/lib.rs index 7eda03da..d5be03e4 100644 --- a/crates/brush-render/src/lib.rs +++ b/crates/brush-render/src/lib.rs @@ -61,6 +61,8 @@ pub trait SplatOps { num_intersections: u32, background: Vec3, bwd_info: bool, + high_error_info: bool, + high_error_mask: Option<&FloatTensor>, ) -> (FloatTensor, RenderAux, IntTensor); } diff --git a/crates/brush-render/src/render.rs b/crates/brush-render/src/render.rs index b4894e0c..0ecc6088 100644 --- a/crates/brush-render/src/render.rs +++ b/crates/brush-render/src/render.rs @@ -182,6 +182,8 @@ impl SplatOps for MainBackendBase { num_intersections: u32, background: Vec3, bwd_info: bool, + high_error_info: bool, + high_error_mask: Option<&FloatTensor>, ) -> (FloatTensor, RenderAux, IntTensor) { let _span = tracing::trace_span!("rasterize").entered(); @@ -281,23 +283,47 @@ impl SplatOps for MainBackendBase { // Get total_splats from the shape of projected_splats let total_splats = project_output.projected_splats.shape.dims[0]; - let (bindings, visible) = if bwd_info { + let (bindings, visible, high_error_count) = if bwd_info { let visible = Self::float_zeros([total_splats].into(), device, FloatDType::F32); - let bindings = Bindings::new() - .with_buffers(vec![ - compact_gid_from_isect.handle.clone().binding(), - tile_offsets.handle.clone().binding(), - project_output.projected_splats.handle.clone().binding(), - out_img.handle.clone().binding(), - project_output - .global_from_compact_gid - .handle - .clone() - .binding(), - visible.handle.clone().binding(), - ]) - .with_metadata(create_meta_binding(rasterize_uniforms)); - (bindings, visible) + if high_error_info { + let high_error_count = + Self::int_zeros([total_splats].into(), device, IntDType::U32); + let high_error_mask = high_error_mask + .expect("Provide high error mask if high error info is required"); + let bindings = Bindings::new() + .with_buffers(vec![ + compact_gid_from_isect.handle.clone().binding(), + tile_offsets.handle.clone().binding(), + project_output.projected_splats.handle.clone().binding(), + out_img.handle.clone().binding(), + project_output + .global_from_compact_gid + .handle + .clone() + .binding(), + visible.handle.clone().binding(), + high_error_mask.handle.clone().binding(), + high_error_count.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(rasterize_uniforms)); + (bindings, visible, high_error_count) + } else { + let bindings = Bindings::new() + .with_buffers(vec![ + compact_gid_from_isect.handle.clone().binding(), + tile_offsets.handle.clone().binding(), + project_output.projected_splats.handle.clone().binding(), + out_img.handle.clone().binding(), + project_output + .global_from_compact_gid + .handle + .clone() + .binding(), + visible.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(rasterize_uniforms)); + (bindings, visible, create_tensor([1], device, DType::U32)) + } } else { let bindings = Bindings::new() .with_buffers(vec![ @@ -307,10 +333,14 @@ impl SplatOps for MainBackendBase { out_img.handle.clone().binding(), ]) .with_metadata(create_meta_binding(rasterize_uniforms)); - (bindings, create_tensor([1], device, DType::F32)) + ( + bindings, + create_tensor([1], device, DType::F32), + create_tensor([1], device, DType::U32), + ) }; - let raster_task = Rasterize::task(bwd_info); + let raster_task = Rasterize::task(bwd_info, high_error_info); // SAFETY: Kernel checked to have no OOB, bounded loops. unsafe { @@ -331,6 +361,7 @@ impl SplatOps for MainBackendBase { visible, tile_offsets, img_size: project_output.img_size, + high_error_count, }, compact_gid_from_isect, ) diff --git a/crates/brush-render/src/render_aux.rs b/crates/brush-render/src/render_aux.rs index 4c879326..f36f65a6 100644 --- a/crates/brush-render/src/render_aux.rs +++ b/crates/brush-render/src/render_aux.rs @@ -92,6 +92,7 @@ pub struct RenderAux { pub visible: FloatTensor, pub tile_offsets: IntTensor, pub img_size: glam::UVec2, + pub high_error_count: IntTensor, } impl RenderAux { diff --git a/crates/brush-render/src/shaders.rs b/crates/brush-render/src/shaders.rs index 54757a0b..9445bcc3 100644 --- a/crates/brush-render/src/shaders.rs +++ b/crates/brush-render/src/shaders.rs @@ -18,6 +18,7 @@ pub struct MapGaussiansToIntersect; #[wgsl_kernel(source = "src/shaders/rasterize.wgsl")] pub struct Rasterize { pub bwd_info: bool, + pub high_error_info: bool, } // Re-export helper types and constants from the kernel modules that use them diff --git a/crates/brush-render/src/shaders/rasterize.wgsl b/crates/brush-render/src/shaders/rasterize.wgsl index 8883feaf..acff4af2 100644 --- a/crates/brush-render/src/shaders/rasterize.wgsl +++ b/crates/brush-render/src/shaders/rasterize.wgsl @@ -8,7 +8,13 @@ @group(0) @binding(3) var out_img: array; @group(0) @binding(4) var global_from_compact_gid: array; @group(0) @binding(5) var visible: array; - @group(0) @binding(6) var uniforms: helpers::RasterizeUniforms; + #ifdef HIGH_ERROR_INFO + @group(0) @binding(6) var high_error_mask: array; + @group(0) @binding(7) var high_error_count: array>; + @group(0) @binding(8) var uniforms: helpers::RasterizeUniforms; + #else + @group(0) @binding(6) var uniforms: helpers::RasterizeUniforms; + #endif #else @group(0) @binding(3) var out_img: array; @group(0) @binding(4) var uniforms: helpers::RasterizeUniforms; @@ -30,6 +36,7 @@ var local_batch: array; fn main( @builtin(global_invocation_id) global_id: vec3u, @builtin(local_invocation_index) local_idx: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32 ) { let pix_loc = helpers::map_1d_to_2d(global_id.x, uniforms.tile_bounds.x); let pix_id = pix_loc.x + pix_loc.y * uniforms.img_size.x; @@ -83,6 +90,8 @@ fn main( let sigma = 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + conic.y * delta.x * delta.y; let alpha = min(0.999f, color.a * exp(-sigma)); + var err_val = 0u; + if sigma >= 0.0f && alpha >= 1.0f / 255.0f { let next_T = T * (1.0 - alpha); @@ -94,12 +103,28 @@ fn main( #ifdef BWD_INFO // Count visible if contribution is at least somewhat significant. visible[load_gid[t]] = 1.0; + #ifdef HIGH_ERROR_INFO + if (high_error_mask[pix_id] > 0.f) { + err_val = 1u; + } + #endif #endif let vis = alpha * T; pix_out += max(color.rgb, vec3f(0.0)) * vis; T = next_T; } + + #ifdef BWD_INFO + #ifdef HIGH_ERROR_INFO + let subgroup_err_sum = subgroupAdd(err_val); + + if (subgroup_err_sum > 0u && subgroup_invocation_id == 0u) { + atomicAdd(&high_error_count[load_gid[t]], subgroup_err_sum); + } + #endif + #endif + } } diff --git a/crates/brush-train/src/config.rs b/crates/brush-train/src/config.rs index b0adcf47..7ddf5d00 100644 --- a/crates/brush-train/src/config.rs +++ b/crates/brush-train/src/config.rs @@ -52,20 +52,19 @@ pub struct TrainConfig { #[arg(long, help_heading = "Refine options", default_value = "10000000")] pub max_splats: u32, - /// Frequency of 'refinement' where gaussians are replaced and densified. This should + /// Frequency of 'refinement' where gaussians densified. This should /// roughly be the number of images it takes to properly "cover" your scene. #[arg(long, help_heading = "Refine options", default_value = "200")] pub refine_every: u32, + // Refine final frequency + #[arg(long, help_heading = "Refine options", default_value = "3000")] + pub refine_every_final: u32, + /// Threshold to control splat growth. Lower means faster growth. #[arg(long, help_heading = "Refine options", default_value = "0.003")] pub growth_grad_threshold: f32, - /// What fraction of splats that are deemed as needing to grow do actually grow. - /// Increase this to make splats grow more aggressively. - #[arg(long, help_heading = "Refine options", default_value = "0.2")] - pub growth_select_fraction: f32, - /// Period after which splat growth stops. #[arg(long, help_heading = "Refine options", default_value = "15000")] pub growth_stop_iter: u32, @@ -92,6 +91,26 @@ pub struct TrainConfig { #[arg(long, help_heading = "Refine options", default_value = "0.0")] pub lpips_loss_weight: f32, + + // N views used in multi view densification + #[arg(long, help_heading = "Refine options", default_value = "10")] + pub n_views: i32, + + // High error threshold used in multi view densification + #[arg(long, help_heading = "Refine options", default_value = "0.1")] + pub high_error_threshold: f32, + + // Min importance score used in multi view densification + #[arg(long, help_heading = "Refine options", default_value = "5.0")] + pub min_importance_score: f32, + + // Final min opacity + #[arg(long, help_heading = "Refine options", default_value = "0.005")] + pub final_min_opacity: f32, + + // Final max pruning score + #[arg(long, help_heading = "Refine options", default_value = "0.95")] + pub final_max_pruning_score: f32, } impl Default for TrainConfig { diff --git a/crates/brush-train/src/density_control.rs b/crates/brush-train/src/density_control.rs new file mode 100644 index 00000000..54fd85f1 --- /dev/null +++ b/crates/brush-train/src/density_control.rs @@ -0,0 +1,148 @@ +use crate::ssim::Ssim; +use brush_dataset::scene_loader::SceneLoader; +use brush_render::{MainBackend, SplatOps, gaussian_splats::Splats}; +use burn::{ + backend::wgpu::WgpuDevice, + prelude::Backend, + tensor::{Int, Tensor, TensorPrimitive, s}, +}; +use glam::Vec3; +use tracing::{Instrument, trace_span}; + +#[derive(Debug, Clone)] +pub struct GaussianScores { + pub importance_score: Tensor, + pub pruning_score: Tensor, +} + +fn compute_losses( + pred: Tensor, + gt: Tensor, + ssim: &Ssim, +) -> (Tensor, Tensor) { + let l1_loss = (pred.clone() - gt.clone()).abs(); + + let ssim_loss = 1.0 - ssim.ssim(pred, gt); + let photometric_loss: Tensor = l1_loss.clone() * 0.8 + ssim_loss * 0.2; + + let l1_mean = l1_loss.mean_dim(2); + + let l1_min = l1_mean.clone().min().reshape([1, 1, 1]); + let l1_max = l1_mean.clone().max().reshape([1, 1, 1]); + + let l1_norm = (l1_mean - l1_min.clone()) / (l1_max - l1_min).clamp_min(1e-9); + + (l1_norm.detach(), photometric_loss.detach()) +} + +async fn compute_view_metrics( + batch_img_tensor: burn::tensor::TensorData, + camera: &brush_render::camera::Camera, + splats: Splats, + ssim: &Ssim, + device: &WgpuDevice, + error_threshold: f32, +) -> (Tensor, Tensor) { + let [img_h, img_w, _] = batch_img_tensor.shape.clone().try_into().unwrap(); + let gt_tensor = Tensor::::from_data(batch_img_tensor, device); + let background = Vec3::ZERO; + + let project_output = MainBackend::project( + camera, + glam::uvec2(img_w as u32, img_h as u32), + splats.means.val().into_primitive().tensor(), + splats.log_scales.val().into_primitive().tensor(), + splats.rotations.val().into_primitive().tensor(), + splats.sh_coeffs.val().into_primitive().tensor(), + splats.raw_opacities.val().into_primitive().tensor(), + splats.render_mode, + ); + + let num_intersections = project_output.read_num_intersections().await; + + let (out_img, _, _) = MainBackend::rasterize( + &project_output, + num_intersections, + background, + true, + false, + None, + ); + + // Compute high error mask + let pred_image = Tensor::from_primitive(TensorPrimitive::Float(out_img)); + let pred_rgb = pred_image.slice(s![.., .., 0..3]); + let gt_rgb = gt_tensor.slice(s![.., .., 0..3]); + + let (l1_rgb_norm, photometric_loss) = compute_losses(pred_rgb, gt_rgb, ssim); + + let high_error_mask = l1_rgb_norm.greater_elem(error_threshold).float(); + let high_error_mask_primitive = high_error_mask.detach().into_primitive().tensor(); + + // Compute high error count + let (_, aux, _) = MainBackend::rasterize( + &project_output, + num_intersections, + background, + true, + true, + Some(&high_error_mask_primitive), + ); + + let high_error_count = Tensor::::from_primitive(aux.high_error_count) + .float() + .detach(); + + let loss_mean = photometric_loss.mean().detach(); + + (high_error_count, loss_mean) +} + +pub async fn compute_gaussian_score( + dataloader: &mut SceneLoader, + splats: Splats, + n_views: i32, + error_threshold: f32, +) -> GaussianScores { + let device = splats.device(); + + const SSIM_WINDOW_SIZE: usize = 11; + let ssim = Ssim::new(SSIM_WINDOW_SIZE, 3, &device); + + let num_splats = splats.means.val().shape().dims[0]; + let mut accum_high_error_count = Tensor::::zeros([num_splats], &device); + let mut accum_high_error_metric = Tensor::::zeros([num_splats], &device); + + for _ in 0..n_views { + let batch = dataloader + .next_batch() + .instrument(trace_span!("Wait for next data batch")) + .await; + + let (high_error_count, photometric_loss) = compute_view_metrics( + batch.img_tensor, + &batch.camera, + splats.clone(), + &ssim, + &device, + error_threshold, + ) + .await; + + accum_high_error_count = accum_high_error_count.add(high_error_count.clone()); + + accum_high_error_metric = + accum_high_error_metric.add(high_error_count.mul(photometric_loss)); + } + + let importance_score = accum_high_error_count / (n_views as f32); + + let min = accum_high_error_metric.clone().min(); + let max = accum_high_error_metric.clone().max(); + let pruning_score = (accum_high_error_metric - min.clone()) / (max - min).clamp_min(1e-9); + + GaussianScores { + importance_score, + pruning_score, + } +} diff --git a/crates/brush-train/src/lib.rs b/crates/brush-train/src/lib.rs index 207faba7..085ddc11 100644 --- a/crates/brush-train/src/lib.rs +++ b/crates/brush-train/src/lib.rs @@ -1,12 +1,12 @@ #![recursion_limit = "256"] pub mod config; +pub mod density_control; pub mod eval; pub mod msg; pub mod train; mod adam_scaled; -mod multinomial; mod quat_vec; mod ssim; mod stats; diff --git a/crates/brush-train/src/multinomial.rs b/crates/brush-train/src/multinomial.rs deleted file mode 100644 index 54b69e04..00000000 --- a/crates/brush-train/src/multinomial.rs +++ /dev/null @@ -1,73 +0,0 @@ -pub(crate) fn multinomial_sample(weights: &[f32], n: u32) -> Vec { - let mut rng = rand::rng(); - rand::seq::index::sample_weighted( - &mut rng, - weights.len(), - |i| if weights[i].is_nan() { 0.0 } else { weights[i] }, - n as usize, - ) - .unwrap_or_else(|_| { - panic!( - "Failed to sample from weights. Counts: {} Infinities: {} NaN: {}", - weights.len(), - weights.iter().filter(|x| x.is_infinite()).count(), - weights.iter().filter(|x| x.is_nan()).count() - ) - }) - .iter() - .map(|x| x as i32) - .collect() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_multinomial_sampling() { - // Test the complete multinomial sampling workflow (samples indices without replacement) - let weights = vec![0.1, 0.3, 0.4, 0.2]; - let samples = multinomial_sample(&weights, 3); - - assert_eq!(samples.len(), 3); - for &sample in &samples { - assert!(sample >= 0 && sample < weights.len() as i32); - } - // Should not have duplicates (sampling without replacement) - let mut unique_samples = samples.clone(); - unique_samples.sort(); - unique_samples.dedup(); - assert_eq!(unique_samples.len(), samples.len()); - - // Test edge case: sampling all indices - let single_weight = vec![1.0]; - let single_samples = multinomial_sample(&single_weight, 1); - assert_eq!(single_samples.len(), 1); - assert_eq!(single_samples[0], 0); - } - - #[test] - fn test_nan_weight_handling() { - // Test that NaN weights are handled (converted to 0.0) - let weights_with_nan = vec![0.5, f32::NAN, 0.3, 0.2]; - let samples = multinomial_sample(&weights_with_nan, 2); - - assert_eq!(samples.len(), 2); - // Should never sample index 1 (NaN weight becomes 0.0) - assert!(!samples.contains(&1)); - // Should only sample from valid indices - for &sample in &samples { - assert!(sample == 0 || sample == 2 || sample == 3); - } - } - - #[test] - fn test_all_zero_weights() { - // Discovered behavior: returns empty vec when all weights are zero - let zero_weights = vec![0.0, 0.0, 0.0]; - let result = multinomial_sample(&zero_weights, 1); - - // Function returns empty vector when it cannot sample any valid indices - assert_eq!(result.len(), 0); - } -} diff --git a/crates/brush-train/src/train.rs b/crates/brush-train/src/train.rs index d6673337..45819250 100644 --- a/crates/brush-train/src/train.rs +++ b/crates/brush-train/src/train.rs @@ -1,8 +1,8 @@ use crate::{ adam_scaled::{AdamScaled, AdamScaledConfig, AdamState}, config::TrainConfig, + density_control::GaussianScores, msg::{RefineStats, TrainStepStats}, - multinomial::multinomial_sample, quat_vec::quaternion_vec_multiply, splat_init::bounds_from_pos, ssim::Ssim, @@ -281,6 +281,7 @@ impl SplatTrainer { &mut self, iter: u32, splats: Splats, + gaussian_scores: GaussianScores, ) -> (Splats, RefineStats) { let device = splats.means.device(); let client = WgpuRuntime::client(&device); @@ -292,13 +293,13 @@ impl SplatTrainer { let max_allowed_bounds = self.bounds.extent.max_element() * 100.0; - // If not refining, update splat to step with gradients applied. - // Prune dead splats. This ALWAYS happen even if we're not "refining" anymore. let mut record = self .optim .take() .expect("Can only refine after optimizer is initialized") .to_record(); + + // Remove splats that are too transparent, big, small or out of bounds. let alpha_mask = splats.opacities().lower_elem(MIN_OPACITY); let scales = splats.scales(); @@ -308,7 +309,6 @@ impl SplatTrainer { .any_dim(1) .squeeze_dim(1); - // Remove splats that are way out of bounds. let center = self.bounds.center; let bound_center = Tensor::<_, 1>::from_floats([center.x, center.y, center.z], &device).reshape([1, 3]); @@ -323,70 +323,121 @@ impl SplatTrainer { .bool_or(bound_mask); let (mut splats, refiner, pruned_count) = - prune_points(splats, &mut record, refiner, prune_mask).await; - let mut split_inds = HashSet::new(); + prune_points(splats, &mut record, refiner, prune_mask.clone()).await; - // Replace dead gaussians. - if pruned_count > 0 { - // Sample weighted by opacity from splat visible during optimization. - let resampled_weights = splats.opacities() * refiner.vis_mask().float(); - let resampled_weights = resampled_weights - .into_data_async() - .await - .expect("Failed to get weights") - .into_vec::() - .expect("Failed to read weights"); - let resampled_inds = multinomial_sample(&resampled_weights, pruned_count); - split_inds.extend(resampled_inds); - } + let keep_mask = prune_mask.bool_not().argwhere().squeeze(); + let importance_score = gaussian_scores.importance_score.select(0, keep_mask); + + // Split splats that have a high gradient and importance score + let mut split_inds = HashSet::new(); + let above_threshold = refiner.above_threshold(self.config.growth_grad_threshold); + let above_importance_score = + importance_score.greater_elem(self.config.min_importance_score); + let to_split = above_threshold.bool_and(above_importance_score); - if iter < self.config.growth_stop_iter { - let above_threshold = refiner.above_threshold(self.config.growth_grad_threshold); + let cur_splats = splats.num_splats(); + let capacity = (self.config.max_splats as i64 - cur_splats as i64).max(0) as usize; - let threshold_count = above_threshold - .clone() - .int() - .sum() - .into_scalar_async() + if capacity > 0 { + let mask_data = to_split + .into_data_async() .await - .expect("Failed to get threshold") as u32; - - let grow_count = - (threshold_count as f32 * self.config.growth_select_fraction).round() as u32; - - let sample_high_grad = grow_count.saturating_sub(pruned_count); - - // Only grow to the max nr. of splats. - let cur_splats = splats.num_splats() + split_inds.len() as u32; - let grow_count = sample_high_grad.min(self.config.max_splats - cur_splats); - - // If still growing, sample from indices which are over the threshold. - if grow_count > 0 { - let weights = above_threshold.float() * refiner.refine_weight_norm; - let weights = weights - .into_data_async() - .await - .expect("Failed to get weights") - .into_vec::() - .expect("Failed to read weights"); - let growth_inds = multinomial_sample(&weights, grow_count); - split_inds.extend(growth_inds); + .expect("Failed to download split mask") + .into_vec::() + .expect("Failed to read split mask"); + + let mut candidates: Vec = mask_data + .iter() + .enumerate() + .filter_map(|(idx, &flag)| if flag > 0 { Some(idx as i32) } else { None }) + .collect(); + + if candidates.len() > capacity { + use rand::seq::SliceRandom; + let mut rng = rand::rng(); + candidates.shuffle(&mut rng); + candidates.truncate(capacity); } + + split_inds.extend(candidates); } let refine_count = split_inds.len(); splats = self.refine_splats(&device, record, splats, split_inds, iter); + let splat_count = splats.num_splats(); // Update current bounds based on the splats. self.bounds = get_splat_bounds(splats.clone(), BOUND_PERCENTILE).await; client.memory_cleanup(); + ( + splats, + RefineStats { + num_added: refine_count as u32, + num_pruned: pruned_count, + total_splats: splat_count, + }, + ) + } + + pub async fn refine_final( + &mut self, + splats: Splats, + gaussian_scores: GaussianScores, + ) -> (Splats, RefineStats) { + let device = splats.means.device(); + + // Remove splats that are too transparent, big, small or out of bounds. + // The opacity threshold is typically stricter than when growing. + // Additionally we remove splats whose pruning score is large. + let max_allowed_bounds = self.bounds.extent.max_element() * 100.0; + let alpha_mask = splats.opacities().lower_elem(self.config.final_min_opacity); + let scales = splats.scales(); + + let scale_small = scales.clone().lower_elem(1e-10).any_dim(1).squeeze_dim(1); + let scale_big = scales + .greater_elem(max_allowed_bounds) + .any_dim(1) + .squeeze_dim(1); + + let center = self.bounds.center; + let bound_center = + Tensor::<_, 1>::from_floats([center.x, center.y, center.z], &device).reshape([1, 3]); + let splat_dists = (splats.means.val() - bound_center).abs(); + let bound_mask = splat_dists + .greater_elem(max_allowed_bounds) + .any_dim(1) + .squeeze_dim(1); + + let high_error_mask = gaussian_scores + .pruning_score + .clone() + .greater_elem(self.config.final_max_pruning_score); + + let prune_mask = alpha_mask + .clone() + .bool_or(scale_small) + .bool_or(scale_big) + .bool_or(bound_mask) + .bool_or(high_error_mask); + + let mut record = self + .optim + .take() + .expect("Can only refine after optimizer is initialized") + .to_record(); + let refiner = self + .refine_record + .take() + .expect("Can only refine if refine stats are initialized"); + let (splats, _, pruned_count) = + prune_points(splats, &mut record, refiner, prune_mask.clone()).await; let splat_count = splats.num_splats(); ( splats, RefineStats { - num_added: refine_count as u32, + num_added: 0, num_pruned: pruned_count, total_splats: splat_count, }, diff --git a/crates/brush-ui/src/settings_popup.rs b/crates/brush-ui/src/settings_popup.rs index 6254090e..fffa95c6 100644 --- a/crates/brush-ui/src/settings_popup.rs +++ b/crates/brush-ui/src/settings_popup.rs @@ -146,8 +146,10 @@ impl SettingsPopup { let tc = &mut self.args.train_config; slider(ui, &mut tc.refine_every, 50..=300, "Refinement frequency", false); slider(ui, &mut tc.growth_grad_threshold, 0.0001..=0.001, "Growth threshold", true); - slider(ui, &mut tc.growth_select_fraction, 0.01..=0.2, "Growth selection fraction", false); slider(ui, &mut tc.growth_stop_iter, 5000..=20000, "Growth stop iteration", false); + slider(ui, &mut tc.refine_every_final, 1000..=3000, "Final refinement frequency", false); + slider(ui, &mut tc.n_views, 10..=50, "Views for multiview refinement", false); + slider(ui, &mut tc.final_min_opacity, 1e-5..=1e-1, "Final min opacity", true); }); ui.collapsing("Losses", |ui| {