From 626ad0fb2b3dc1ae0a66fdc33a57d71a1e7c6f0d Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 19 Jan 2026 15:22:05 +0100 Subject: [PATCH 01/29] Readback for intersections --- crates/brush-render-bwd/src/burn_glue.rs | 1 - crates/brush-render/src/burn_glue.rs | 30 ++++--------- crates/brush-render/src/get_tile_offset.rs | 1 + crates/brush-render/src/lib.rs | 7 --- crates/brush-render/src/render.rs | 45 ++++++++++++------- crates/brush-render/src/render_aux.rs | 36 ++++----------- crates/brush-render/src/shaders/helpers.wgsl | 3 +- .../shaders/map_gaussian_to_intersects.wgsl | 7 --- crates/brush-rerun/src/visualize_tools.rs | 11 ----- crates/brush-train/src/msg.rs | 1 - crates/brush-train/src/train.rs | 2 - 11 files changed, 49 insertions(+), 95 deletions(-) diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index 72b88723..9f7c661a 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -198,7 +198,6 @@ impl + SplatForward, C: CheckpointStrategy> let wrapped_aux = RenderAux:: { projected_splats: ::from_inner(aux.projected_splats.clone()), - num_intersections: aux.num_intersections, tile_offsets: aux.tile_offsets.clone(), compact_gid_from_isect: aux.compact_gid_from_isect.clone(), global_from_compact_gid: aux.global_from_compact_gid.clone(), diff --git a/crates/brush-render/src/burn_glue.rs b/crates/brush-render/src/burn_glue.rs index 62bfb1a3..3a6f360e 100644 --- a/crates/brush-render/src/burn_glue.rs +++ b/crates/brush-render/src/burn_glue.rs @@ -9,12 +9,8 @@ use burn_wgpu::WgpuRuntime; use glam::Vec3; use crate::{ - MainBackendBase, SplatForward, - camera::Camera, - gaussian_splats::SplatRenderMode, - render::{calc_tile_bounds, max_intersections}, - render_aux::RenderAux, - shaders, + MainBackendBase, SplatForward, camera::Camera, gaussian_splats::SplatRenderMode, + render::calc_tile_bounds, render_aux::RenderAux, shaders, }; impl SplatForward for Fusion { @@ -54,7 +50,6 @@ impl SplatForward for Fusion { // Aux projected_splats, uniforms_buffer, - num_intersections, tile_offsets, compact_gid_from_isect, global_from_compact_gid, @@ -81,10 +76,6 @@ impl SplatForward for Fusion { aux.projected_splats, ); h.register_int_tensor::(&uniforms_buffer.id, aux.uniforms_buffer); - h.register_int_tensor::( - &num_intersections.id, - aux.num_intersections, - ); h.register_int_tensor::(&tile_offsets.id, aux.tile_offsets); h.register_int_tensor::( &compact_gid_from_isect.id, @@ -106,7 +97,6 @@ impl SplatForward for Fusion { let proj_size = size_of::() / 4; let uniforms_size = size_of::() / 4; let tile_bounds = calc_tile_bounds(img_size); - let max_intersects = max_intersections(img_size, num_points as u32); // If render_u32_buffer is true, we render a packed buffer of u32 values, otherwise // render RGBA f32 values. @@ -134,18 +124,17 @@ impl SplatForward for Fusion { Shape::new([uniforms_size]), DType::U32, ); - let num_intersections = - TensorIr::uninit(client.create_empty_handle(), Shape::new([1]), DType::U32); let tile_offsets = TensorIr::uninit( client.create_empty_handle(), Shape::new([tile_bounds.y as usize, tile_bounds.x as usize, 2]), DType::U32, ); - let compact_gid_from_isect = TensorIr::uninit( - client.create_empty_handle(), - Shape::new([max_intersects as usize]), - DType::U32, - ); + + // This is not actually size 0, but, it's dynamic. This is just a dummy handle so we just + // set a dummy size of 0. + let compact_gid_from_isect = + TensorIr::uninit(client.create_empty_handle(), Shape::new([0]), DType::U32); + let global_from_compact_gid = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points]), @@ -162,7 +151,6 @@ impl SplatForward for Fusion { out_img, projected_splats, uniforms_buffer, - num_intersections, tile_offsets, compact_gid_from_isect, global_from_compact_gid, @@ -188,7 +176,6 @@ impl SplatForward for Fusion { // Aux projected_splats, uniforms_buffer, - num_intersections, tile_offsets, compact_gid_from_isect, global_from_compact_gid, @@ -200,7 +187,6 @@ impl SplatForward for Fusion { RenderAux:: { projected_splats, uniforms_buffer, - num_intersections, tile_offsets, compact_gid_from_isect, global_from_compact_gid, diff --git a/crates/brush-render/src/get_tile_offset.rs b/crates/brush-render/src/get_tile_offset.rs index d16561bf..21e19a8c 100644 --- a/crates/brush-render/src/get_tile_offset.rs +++ b/crates/brush-render/src/get_tile_offset.rs @@ -22,6 +22,7 @@ fn check_tile_boundary( // Write the end of the last tile. tile_offsets[tid * 2 + 1] = isect_id + 1; } + if tid != prev_tid { // Write the end of the previous tile. tile_offsets[prev_tid * 2 + 1] = isect_id; diff --git a/crates/brush-render/src/lib.rs b/crates/brush-render/src/lib.rs index 7ef89d2c..915cc287 100644 --- a/crates/brush-render/src/lib.rs +++ b/crates/brush-render/src/lib.rs @@ -36,15 +36,8 @@ pub type MainBackend = Fusion; #[derive(Debug, Clone)] pub struct RenderStats { pub num_visible: u32, - pub num_intersections: u32, } -// The maximum number of intersections that can be rendered. -// -// With 2D dispatch support, we can now handle more than the original 65535 workgroup limit. -// Doubled from the original 512 * 65535 to allow higher resolution rendering. -const INTERSECTS_UPPER_BOUND: u32 = 2 * 512 * 65535; - pub trait SplatForward { /// Render splats to a buffer. /// diff --git a/crates/brush-render/src/render.rs b/crates/brush-render/src/render.rs index 97dad971..08afa686 100644 --- a/crates/brush-render/src/render.rs +++ b/crates/brush-render/src/render.rs @@ -1,5 +1,5 @@ use crate::{ - INTERSECTS_UPPER_BOUND, MainBackendBase, SplatForward, + MainBackendBase, SplatForward, camera::Camera, dim_check::DimCheck, gaussian_splats::SplatRenderMode, @@ -14,7 +14,9 @@ use brush_kernel::create_uniform_buffer; use brush_kernel::{CubeCount, calc_cube_count_1d}; use brush_prefix_sum::prefix_sum; use brush_sort::radix_argsort; -use burn::tensor::{DType, IntDType, ops::FloatTensor}; +#[cfg(not(target_family = "wasm"))] +use burn::Tensor; +use burn::tensor::{DType, IntDType, Slice, ops::FloatTensor}; use burn::tensor::{ FloatDType, ops::{FloatTensorOps, IntTensorOps}, @@ -22,8 +24,8 @@ use burn::tensor::{ use burn_cubecl::cubecl::server::Bindings; use burn_cubecl::kernel::into_contiguous; -use burn_wgpu::CubeDim; use burn_wgpu::WgpuRuntime; +use burn_wgpu::{CubeDim, CubeTensor}; use glam::{Vec3, uvec2}; use std::mem::offset_of; @@ -40,15 +42,17 @@ pub(crate) fn calc_tile_bounds(img_size: glam::UVec2) -> glam::UVec2 { // dispatch to avoid this. // Estimating the max number of intersects can be a bad hack though... The worst case scenario is so massive // that it's easy to run out of memory... How do we actually properly deal with this :/ -pub fn max_intersections(img_size: glam::UVec2, num_splats: u32) -> u32 { +pub fn max_intersections( + img_size: glam::UVec2, + num_splats: u32, + _num_intersections: CubeTensor, +) -> u32 { // Divide screen into tiles. let tile_bounds = calc_tile_bounds(img_size); - // Assume on average each splat is maximally covering half x half the screen, - // and adjust for the variance such that we're fairly certain we have enough intersections. let num_tiles = tile_bounds[0] * tile_bounds[1]; let max_possible = num_tiles.saturating_mul(num_splats); - // clamp to max nr. of dispatches. - max_possible.min(INTERSECTS_UPPER_BOUND) + // clamp to some max nr. or we run out of memory. + max_possible.min(2 * 512 * 65535) } // Implement forward functions for the inner wgpu backend. @@ -108,7 +112,6 @@ impl SplatForward for MainBackendBase { // Tile rendering setup. let sh_degree = sh_degree_from_coeffs(sh_coeffs.shape.dims[1] as u32); let total_splats = means.shape.dims[0]; - let max_intersects = max_intersections(img_size, total_splats as u32); let uniforms = shaders::helpers::RenderUniforms { viewmat: glam::Mat4::from(camera.world_to_local()).to_cols_array_2d(), @@ -119,10 +122,10 @@ impl SplatForward for MainBackendBase { tile_bounds: tile_bounds.into(), sh_degree, total_splats: total_splats as u32, - max_intersects, background: [background.x, background.y, background.z, 1.0], // Nb: Bit of a hack as these aren't _really_ uniforms but are written to by the shaders. num_visible: 0, + pad_a: 0, }; // Nb: This contains both static metadata and some dynamic data so can't pass this as metadata to execute. In the future @@ -232,17 +235,29 @@ impl SplatForward for MainBackendBase { } }); - // TODO: Only need to do this up to num_visible gaussians really. + // TODO: Only need to do this up to num_visible gaussians really. Would need a + // prefix sum with a dynamic total length, and get the num at the dynamic length. let cum_tiles_hit = tracing::trace_span!("PrefixSumGaussHits") .in_scope(|| prefix_sum(splat_intersect_counts)); + let num_intersections = Self::int_slice( + cum_tiles_hit.clone(), + &[Slice::new(cum_tiles_hit.shape[0] as isize - 1, None, 1)], + ); + #[cfg(target_family = "wasm")] + let max_intersects = + max_intersections(img_size, total_splats as u32, num_intersections.clone()); + #[cfg(not(target_family = "wasm"))] + let max_intersects = { + use burn::tensor::{ElementConversion, Int}; + let intersects: Tensor = + Tensor::from_primitive(num_intersections.clone()); + intersects.into_scalar().elem::() + }; let tile_id_from_isect = create_tensor([max_intersects as usize], device, DType::U32); let compact_gid_from_isect = create_tensor([max_intersects as usize], device, DType::U32); - // Zero this out, as the kernel _might_ not run at all if no gaussians are visible. - let num_intersections = Self::int_zeros([1].into(), device, IntDType::U32); - tracing::trace_span!("MapGaussiansToIntersect").in_scope(|| { // SAFETY: Kernel checked to have no OOB, bounded loops. unsafe { @@ -256,7 +271,6 @@ impl SplatForward for MainBackendBase { cum_tiles_hit.handle.binding(), tile_id_from_isect.handle.clone().binding(), compact_gid_from_isect.handle.clone().binding(), - num_intersections.handle.clone().binding(), ]), ) .expect("Failed to render splats"); @@ -383,7 +397,6 @@ impl SplatForward for MainBackendBase { RenderAux { uniforms_buffer, tile_offsets, - num_intersections, projected_splats, compact_gid_from_isect, global_from_compact_gid, diff --git a/crates/brush-render/src/render_aux.rs b/crates/brush-render/src/render_aux.rs index 3a9d3588..0ac3f579 100644 --- a/crates/brush-render/src/render_aux.rs +++ b/crates/brush-render/src/render_aux.rs @@ -17,7 +17,6 @@ pub struct RenderAux { /// The packed projected splat information, see `ProjectedSplat` in helpers.wgsl pub projected_splats: FloatTensor, pub uniforms_buffer: IntTensor, - pub num_intersections: IntTensor, pub tile_offsets: IntTensor, pub compact_gid_from_isect: IntTensor, pub global_from_compact_gid: IntTensor, @@ -35,10 +34,6 @@ impl RenderAux { (max - min).reshape([ty as usize, tx as usize]) } - pub fn num_intersections(&self) -> Tensor { - Tensor::from_primitive(self.num_intersections.clone()) - } - pub fn num_visible(&self) -> Tensor { let num_vis_field_offset = offset_of!(shaders::helpers::RenderUniforms, num_visible) / 4; Tensor::from_primitive(self.uniforms_buffer.clone()).slice(s![num_vis_field_offset]) @@ -49,35 +44,22 @@ impl RenderAux { { use burn::tensor::{ElementConversion, TensorPrimitive}; - use crate::{ - INTERSECTS_UPPER_BOUND, render::max_intersections, validation::validate_tensor_val, - }; + use crate::validation::validate_tensor_val; if std::env::args().any(|a| a == "--bench") { return; } - let num_intersects: Tensor = self.num_intersections(); let compact_gid_from_isect: Tensor = Tensor::from_primitive(self.compact_gid_from_isect.clone()); let num_visible: Tensor = self.num_visible(); - let num_intersections = num_intersects.into_scalar().elem::(); let num_points = compact_gid_from_isect.dims()[0] as u32; let num_visible = num_visible.into_scalar().elem::() as u32; - let img_size = self.img_size; - - let max_intersects = max_intersections(img_size, num_points); - - assert!( - num_intersections < max_intersects as i32, - "Too many intersections, estimated too low of a number. {num_intersections} / {max_intersects}" - ); - - assert!( - num_intersections < INTERSECTS_UPPER_BOUND as i32, - "Too many intersections, Brush currently can't handle this. {num_intersections} > {INTERSECTS_UPPER_BOUND}" - ); + let num_intersections: Tensor = + Tensor::from_primitive(self.tile_offsets.clone()); + let num_intersections = + num_intersections.slice(s![-1]).into_scalar().elem::() as u32; assert!( num_visible <= num_points, @@ -106,7 +88,7 @@ impl RenderAux { .expect("Failed to fetch tile offsets"); for &offsets in &tile_offsets { assert!( - offsets as i32 <= num_intersections, + offsets <= num_intersections, "Tile offsets exceed bounds. Value: {offsets}, num_intersections: {num_intersections}" ); } @@ -114,8 +96,8 @@ impl RenderAux { if num_intersections > 0 { for i in 0..(tile_offsets.len() - 1) / 2 { // Check pairs of start/end points. - let start = tile_offsets[i * 2] as i32; - let end = tile_offsets[i * 2 + 1] as i32; + let start = tile_offsets[i * 2]; + let end = tile_offsets[i * 2 + 1]; assert!( start < num_intersections && end <= num_intersections, "Invalid elements in tile offsets. Start {start} ending at {end}" @@ -125,7 +107,7 @@ impl RenderAux { "Invalid elements in tile offsets. Start {start} ending at {end}" ); assert!( - end - start <= num_visible as i32, + end - start <= num_visible, "One tile has more hits than total visible splats. Start {start} ending at {end}" ); } diff --git a/crates/brush-render/src/shaders/helpers.wgsl b/crates/brush-render/src/shaders/helpers.wgsl index 7a610942..4da56fe2 100644 --- a/crates/brush-render/src/shaders/helpers.wgsl +++ b/crates/brush-render/src/shaders/helpers.wgsl @@ -67,8 +67,9 @@ struct RenderUniforms { num_visible: u32, #endif + pad_a: u32, + total_splats: u32, - max_intersects: u32, // Nb: Alpha is ignored atm. background: vec4f, diff --git a/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl b/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl index 927f258b..c8c97400 100644 --- a/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl +++ b/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl @@ -9,7 +9,6 @@ @group(0) @binding(2) var splat_cum_hit_counts: array; @group(0) @binding(3) var tile_id_from_isect: array; @group(0) @binding(4) var compact_gid_from_isect: array; - @group(0) @binding(5) var num_intersections: array; #endif const WG_SIZE: u32 = 256u; @@ -23,12 +22,6 @@ fn main( ) { let compact_gid = helpers::get_global_id(wid, num_wgs, lid, WG_SIZE); -#ifndef PREPASS - if compact_gid == 0u { - num_intersections[0] = splat_cum_hit_counts[uniforms.num_visible]; - } -#endif - if compact_gid >= uniforms.num_visible { return; } diff --git a/crates/brush-rerun/src/visualize_tools.rs b/crates/brush-rerun/src/visualize_tools.rs index ce119049..4d0a7a58 100644 --- a/crates/brush-rerun/src/visualize_tools.rs +++ b/crates/brush-rerun/src/visualize_tools.rs @@ -245,17 +245,6 @@ mod visualize_tools_impl { .log("lr/coeffs", &rerun::Scalars::new(vec![stats.lr_coeffs]))?; self.rec .log("lr/opac", &rerun::Scalars::new(vec![stats.lr_opac]))?; - - self.rec.log( - "splats/num_intersects", - &rerun::Scalars::new(vec![ - stats - .num_intersections - .into_scalar_async() - .await? - .elem::(), - ]), - )?; self.rec.log( "splats/splats_visible", &rerun::Scalars::new(vec![ diff --git a/crates/brush-train/src/msg.rs b/crates/brush-train/src/msg.rs index 780d8c37..adac9804 100644 --- a/crates/brush-train/src/msg.rs +++ b/crates/brush-train/src/msg.rs @@ -13,7 +13,6 @@ pub struct RefineStats { pub struct TrainStepStats { pub pred_image: Tensor, - pub num_intersections: Tensor, pub num_visible: Tensor, pub loss: Tensor, diff --git a/crates/brush-train/src/train.rs b/crates/brush-train/src/train.rs index c7b80fb4..d43d940f 100644 --- a/crates/brush-train/src/train.rs +++ b/crates/brush-train/src/train.rs @@ -135,7 +135,6 @@ impl SplatTrainer { let median_scale = self.bounds.median_size(); let num_visible = aux.num_visible().inner(); - let num_intersections = aux.num_intersections().inner(); let pred_rgb = pred_image.clone().slice(s![.., .., 0..3]); let gt_rgb = gt_tensor.clone().slice(s![.., .., 0..3]); @@ -273,7 +272,6 @@ impl SplatTrainer { let stats = TrainStepStats { pred_image: pred_image.inner(), num_visible, - num_intersections, loss: loss.inner(), lr_mean, lr_rotation, From 3cd5731f0af7475aa4ead5f77ef6bd2a3a9efb25 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 19 Jan 2026 16:43:40 +0100 Subject: [PATCH 02/29] Split pass --- crates/brush-render-bwd/src/burn_glue.rs | 46 ++ crates/brush-render/src/burn_glue.rs | 280 +++++++++++- crates/brush-render/src/lib.rs | 33 +- crates/brush-render/src/render.rs | 401 +++++++++--------- crates/brush-render/src/render_aux.rs | 170 ++++++-- crates/brush-render/src/shaders.rs | 1 + .../src/shaders/project_visible.wgsl | 33 +- 7 files changed, 722 insertions(+), 242 deletions(-) diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index 9f7c661a..d094d580 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -65,6 +65,29 @@ pub trait SplatBackwardOps { ) -> SplatGrads; } +/// State from the ProjectPrepare pass needed for backward computation. +#[derive(Debug, Clone)] +pub struct ProjectBackwardState { + pub(crate) means: FloatTensor, + pub(crate) quats: FloatTensor, + pub(crate) log_scales: FloatTensor, + pub(crate) raw_opac: FloatTensor, + pub(crate) projected_splats: FloatTensor, + pub(crate) uniforms_buffer: IntTensor, + pub(crate) global_from_compact_gid: IntTensor, + pub(crate) render_mode: SplatRenderMode, + pub(crate) sh_degree: u32, +} + +/// State from the Rasterize pass needed for backward computation. +#[derive(Debug, Clone)] +pub struct RasterizeBackwardState { + pub(crate) out_img: FloatTensor, + pub(crate) compact_gid_from_isect: IntTensor, + pub(crate) tile_offsets: IntTensor, +} + +/// Combined backward state for compatibility with existing code. #[derive(Debug, Clone)] pub struct GaussianBackwardState { pub(crate) means: FloatTensor, @@ -81,6 +104,29 @@ pub struct GaussianBackwardState { pub(crate) sh_degree: u32, } +impl GaussianBackwardState { + /// Construct combined state from project and rasterize backward states. + pub fn from_parts( + project: ProjectBackwardState, + rasterize: RasterizeBackwardState, + ) -> Self { + Self { + means: project.means, + quats: project.quats, + log_scales: project.log_scales, + raw_opac: project.raw_opac, + projected_splats: project.projected_splats, + uniforms_buffer: project.uniforms_buffer, + global_from_compact_gid: project.global_from_compact_gid, + render_mode: project.render_mode, + sh_degree: project.sh_degree, + out_img: rasterize.out_img, + compact_gid_from_isect: rasterize.compact_gid_from_isect, + tile_offsets: rasterize.tile_offsets, + } + } +} + #[derive(Debug)] struct RenderBackwards; diff --git a/crates/brush-render/src/burn_glue.rs b/crates/brush-render/src/burn_glue.rs index 3a6f360e..58ab0515 100644 --- a/crates/brush-render/src/burn_glue.rs +++ b/crates/brush-render/src/burn_glue.rs @@ -9,8 +9,9 @@ use burn_wgpu::WgpuRuntime; use glam::Vec3; use crate::{ - MainBackendBase, SplatForward, camera::Camera, gaussian_splats::SplatRenderMode, - render::calc_tile_bounds, render_aux::RenderAux, shaders, + MainBackendBase, SplatForward, SplatProjectPrepare, SplatRasterize, + camera::Camera, gaussian_splats::SplatRenderMode, + render::calc_tile_bounds, render_aux::{ProjectAux, RasterizeAux, RenderAux}, shaders, }; impl SplatForward for Fusion { @@ -196,3 +197,278 @@ impl SplatForward for Fusion { ) } } + +// Fusion implementation for SplatProjectPrepare +impl SplatProjectPrepare for Fusion { + fn project_prepare( + cam: &Camera, + img_size: glam::UVec2, + means: FloatTensor, + log_scales: FloatTensor, + quats: FloatTensor, + sh_coeffs: FloatTensor, + opacity: FloatTensor, + render_mode: SplatRenderMode, + background: Vec3, + ) -> ProjectAux { + #[derive(Debug)] + struct CustomOp { + cam: Camera, + img_size: glam::UVec2, + render_mode: SplatRenderMode, + background: Vec3, + desc: CustomOpIr, + } + + impl Operation> for CustomOp { + fn execute( + &self, + h: &mut HandleContainer>>, + ) { + let (inputs, outputs) = self.desc.as_fixed(); + + let [means, log_scales, quats, sh_coeffs, opacity] = inputs; + let [ + projected_splats, + uniforms_buffer, + global_from_compact_gid, + cum_tiles_hit, + ] = outputs; + + let aux = MainBackendBase::project_prepare( + &self.cam, + self.img_size, + h.get_float_tensor::(means), + h.get_float_tensor::(log_scales), + h.get_float_tensor::(quats), + h.get_float_tensor::(sh_coeffs), + h.get_float_tensor::(opacity), + self.render_mode, + self.background, + ); + + // Register outputs + h.register_float_tensor::( + &projected_splats.id, + aux.projected_splats, + ); + h.register_int_tensor::(&uniforms_buffer.id, aux.uniforms_buffer); + h.register_int_tensor::( + &global_from_compact_gid.id, + aux.global_from_compact_gid, + ); + h.register_int_tensor::(&cum_tiles_hit.id, aux.cum_tiles_hit); + } + } + + let client = means.client.clone(); + let num_points = means.shape[0]; + + let proj_size = size_of::() / 4; + let uniforms_size = size_of::() / 4; + + let projected_splats = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_points, proj_size]), + DType::F32, + ); + let uniforms_buffer = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([uniforms_size]), + DType::U32, + ); + let global_from_compact_gid = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_points]), + DType::U32, + ); + // cum_tiles_hit has size num_points + 1 + let cum_tiles_hit = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_points + 1]), + DType::U32, + ); + + let input_tensors = [means, log_scales, quats, sh_coeffs, opacity]; + let stream = OperationStreams::with_inputs(&input_tensors); + let desc = CustomOpIr::new( + "project_prepare", + &input_tensors.map(|t| t.into_ir()), + &[ + projected_splats, + uniforms_buffer, + global_from_compact_gid, + cum_tiles_hit, + ], + ); + let op = CustomOp { + cam: cam.clone(), + img_size, + render_mode, + background, + desc: desc.clone(), + }; + + let outputs = client + .register(stream, OperationIr::Custom(desc), op) + .outputs(); + + let [ + projected_splats, + uniforms_buffer, + global_from_compact_gid, + cum_tiles_hit, + ] = outputs; + + ProjectAux:: { + projected_splats, + uniforms_buffer, + global_from_compact_gid, + cum_tiles_hit, + img_size, + } + } +} + +// Fusion implementation for SplatRasterize +impl SplatRasterize for Fusion { + fn rasterize( + project_aux: &ProjectAux, + num_intersections: u32, + background: Vec3, + bwd_info: bool, + ) -> (FloatTensor, RasterizeAux) { + #[derive(Debug)] + struct CustomOp { + img_size: glam::UVec2, + num_intersections: u32, + background: Vec3, + bwd_info: bool, + desc: CustomOpIr, + } + + impl Operation> for CustomOp { + fn execute( + &self, + h: &mut HandleContainer>>, + ) { + let (inputs, outputs) = self.desc.as_fixed(); + + let [ + projected_splats, + uniforms_buffer, + global_from_compact_gid, + cum_tiles_hit, + ] = inputs; + let [ + out_img, + tile_offsets, + compact_gid_from_isect, + visible, + ] = outputs; + + let inner_aux = ProjectAux:: { + projected_splats: h.get_float_tensor::(projected_splats), + uniforms_buffer: h.get_int_tensor::(uniforms_buffer), + global_from_compact_gid: h.get_int_tensor::(global_from_compact_gid), + cum_tiles_hit: h.get_int_tensor::(cum_tiles_hit), + img_size: self.img_size, + }; + + let (img, aux) = MainBackendBase::rasterize( + &inner_aux, + self.num_intersections, + self.background, + self.bwd_info, + ); + + // Register outputs + h.register_float_tensor::(&out_img.id, img); + h.register_int_tensor::(&tile_offsets.id, aux.tile_offsets); + h.register_int_tensor::( + &compact_gid_from_isect.id, + aux.compact_gid_from_isect, + ); + h.register_float_tensor::(&visible.id, aux.visible); + } + } + + let client = project_aux.projected_splats.client.clone(); + let img_size = project_aux.img_size; + let tile_bounds = calc_tile_bounds(img_size); + + let num_points = project_aux.projected_splats.shape[0]; + + let channels = if bwd_info { 4 } else { 1 }; + let out_img = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([img_size.y as usize, img_size.x as usize, channels]), + if bwd_info { DType::F32 } else { DType::U32 }, + ); + + let tile_offsets = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([tile_bounds.y as usize, tile_bounds.x as usize, 2]), + DType::U32, + ); + + // Use actual num_intersections for buffer size + let compact_gid_from_isect = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_intersections as usize]), + DType::U32, + ); + + let visible_shape = if bwd_info { + Shape::new([num_points]) + } else { + Shape::new([1]) + }; + let visible = TensorIr::uninit(client.create_empty_handle(), visible_shape, DType::F32); + + let input_tensors = [ + project_aux.projected_splats.clone(), + project_aux.uniforms_buffer.clone(), + project_aux.global_from_compact_gid.clone(), + project_aux.cum_tiles_hit.clone(), + ]; + let stream = OperationStreams::with_inputs(&input_tensors); + let desc = CustomOpIr::new( + "rasterize", + &input_tensors.map(|t| t.into_ir()), + &[ + out_img, + tile_offsets, + compact_gid_from_isect, + visible, + ], + ); + let op = CustomOp { + img_size, + num_intersections, + background, + bwd_info, + desc: desc.clone(), + }; + + let outputs = client + .register(stream, OperationIr::Custom(desc), op) + .outputs(); + + let [ + out_img, + tile_offsets, + compact_gid_from_isect, + visible, + ] = outputs; + + ( + out_img, + RasterizeAux:: { + tile_offsets, + compact_gid_from_isect, + visible, + }, + ) + } +} diff --git a/crates/brush-render/src/lib.rs b/crates/brush-render/src/lib.rs index 915cc287..33153598 100644 --- a/crates/brush-render/src/lib.rs +++ b/crates/brush-render/src/lib.rs @@ -8,7 +8,7 @@ use burn_wgpu::WgpuRuntime; use camera::Camera; use clap::ValueEnum; use glam::Vec3; -use render_aux::RenderAux; +use render_aux::{ProjectAux, RasterizeAux, RenderAux}; use crate::gaussian_splats::SplatRenderMode; pub use crate::gaussian_splats::render_splats; @@ -61,6 +61,37 @@ pub trait SplatForward { ) -> (FloatTensor, RenderAux); } +/// First pass of split rendering pipeline: culling, depth sort, projection, intersection counting, prefix sum. +/// +/// Returns [`ProjectAux`] which contains data needed for [`SplatRasterize::rasterize`], +/// including `cum_tiles_hit` which allows sync readback of the exact number of intersections. +pub trait SplatProjectPrepare { + fn project_prepare( + camera: &Camera, + img_size: glam::UVec2, + means: FloatTensor, + log_scales: FloatTensor, + quats: FloatTensor, + sh_coeffs: FloatTensor, + raw_opacities: FloatTensor, + render_mode: SplatRenderMode, + background: Vec3, + ) -> ProjectAux; +} + +/// Second pass of split rendering pipeline: intersection filling, tile sort, tile offsets, rasterization. +/// +/// Takes the output of [`SplatProjectPrepare::project_prepare`] along with the actual +/// `num_intersections` value from userland readback. +pub trait SplatRasterize { + fn rasterize( + project_aux: &ProjectAux, + num_intersections: u32, + background: Vec3, + bwd_info: bool, + ) -> (FloatTensor, RasterizeAux); +} + #[derive( Default, ValueEnum, Clone, Copy, Eq, PartialEq, Debug, serde::Serialize, serde::Deserialize, )] diff --git a/crates/brush-render/src/render.rs b/crates/brush-render/src/render.rs index 08afa686..43ce4a56 100644 --- a/crates/brush-render/src/render.rs +++ b/crates/brush-render/src/render.rs @@ -1,10 +1,10 @@ use crate::{ - MainBackendBase, SplatForward, + MainBackendBase, SplatForward, SplatProjectPrepare, SplatRasterize, camera::Camera, dim_check::DimCheck, gaussian_splats::SplatRenderMode, get_tile_offset::{CHECKS_PER_ITER, get_tile_offsets}, - render_aux::RenderAux, + render_aux::{ProjectAux, RasterizeAux, RenderAux}, sh::sh_degree_from_coeffs, shaders::{self, MapGaussiansToIntersect, ProjectSplats, ProjectVisible, Rasterize}, }; @@ -14,9 +14,7 @@ use brush_kernel::create_uniform_buffer; use brush_kernel::{CubeCount, calc_cube_count_1d}; use brush_prefix_sum::prefix_sum; use brush_sort::radix_argsort; -#[cfg(not(target_family = "wasm"))] -use burn::Tensor; -use burn::tensor::{DType, IntDType, Slice, ops::FloatTensor}; +use burn::tensor::{DType, IntDType, ops::FloatTensor}; use burn::tensor::{ FloatDType, ops::{FloatTensorOps, IntTensorOps}, @@ -25,7 +23,7 @@ use burn_cubecl::cubecl::server::Bindings; use burn_cubecl::kernel::into_contiguous; use burn_wgpu::WgpuRuntime; -use burn_wgpu::{CubeDim, CubeTensor}; +use burn_wgpu::CubeDim; use glam::{Vec3, uvec2}; use std::mem::offset_of; @@ -36,28 +34,9 @@ pub(crate) fn calc_tile_bounds(img_size: glam::UVec2) -> glam::UVec2 { ) } -// On wasm, we cannot do a sync readback at all. -// Instead, can just estimate a max number of intersects. All the kernels only handle the actual -// number of intersects, and spin up empty threads for the rest atm. In the future, could use indirect -// dispatch to avoid this. -// Estimating the max number of intersects can be a bad hack though... The worst case scenario is so massive -// that it's easy to run out of memory... How do we actually properly deal with this :/ -pub fn max_intersections( - img_size: glam::UVec2, - num_splats: u32, - _num_intersections: CubeTensor, -) -> u32 { - // Divide screen into tiles. - let tile_bounds = calc_tile_bounds(img_size); - let num_tiles = tile_bounds[0] * tile_bounds[1]; - let max_possible = num_tiles.saturating_mul(num_splats); - // clamp to some max nr. or we run out of memory. - max_possible.min(2 * 512 * 65535) -} - -// Implement forward functions for the inner wgpu backend. -impl SplatForward for MainBackendBase { - fn render_splats( +// Implement the first pass: ProjectPrepare +impl SplatProjectPrepare for MainBackendBase { + fn project_prepare( camera: &Camera, img_size: glam::UVec2, means: FloatTensor, @@ -67,8 +46,7 @@ impl SplatForward for MainBackendBase { raw_opacities: FloatTensor, render_mode: SplatRenderMode, background: Vec3, - bwd_info: bool, - ) -> (FloatTensor, RenderAux) { + ) -> ProjectAux { assert!( img_size[0] > 0 && img_size[1] > 0, "Can't render images with 0 size." @@ -84,7 +62,7 @@ impl SplatForward for MainBackendBase { let device = &means.device.clone(); let client = means.client.clone(); - let _span = tracing::trace_span!("render_forward").entered(); + let _span = tracing::trace_span!("project_prepare").entered(); // Check whether input dimensions are valid. DimCheck::new() @@ -97,18 +75,6 @@ impl SplatForward for MainBackendBase { // Divide screen into tiles. let tile_bounds = calc_tile_bounds(img_size); - // A note on some confusing naming that'll be used throughout this function: - // Gaussians are stored in various states of buffers, eg. at the start they're all in one big buffer, - // then we sparsely store some results, then sort gaussian based on depths, etc. - // Overall this means there's lots of indices flying all over the place, and it's hard to keep track - // what is indexing what. So, for some sanity, try to match a few "gaussian ids" (gid) variable names. - // - Global Gaussian ID - global_gid - // - Compacted Gaussian ID - compact_gid - // - Per tile intersection depth sorted ID - tiled_gid - // - Sorted by tile per tile intersection depth sorted ID - sorted_tiled_gid - // Then, various buffers map between these, which are named x_from_y_gid, eg. - // global_from_compact_gid. - // Tile rendering setup. let sh_degree = sh_degree_from_coeffs(sh_coeffs.shape.dims[1] as u32); let total_splats = means.shape.dims[0]; @@ -123,19 +89,16 @@ impl SplatForward for MainBackendBase { sh_degree, total_splats: total_splats as u32, background: [background.x, background.y, background.z, 1.0], - // Nb: Bit of a hack as these aren't _really_ uniforms but are written to by the shaders. num_visible: 0, pad_a: 0, }; - // Nb: This contains both static metadata and some dynamic data so can't pass this as metadata to execute. In the future - // should separate the two. let uniforms_buffer = create_uniform_buffer(uniforms, device, &client); let client = &means.client.clone(); - let mip_splat = matches!(render_mode, SplatRenderMode::Mip); + // Step 1: ProjectSplats - culling pass let (global_from_compact_gid, num_visible) = { let global_from_presort_gid = Self::int_zeros([total_splats].into(), device, IntDType::U32); @@ -168,29 +131,28 @@ impl SplatForward for MainBackendBase { &[(num_vis_field_offset..num_vis_field_offset + 1).into()], ); + // Step 2: DepthSort let (_, global_from_compact_gid) = tracing::trace_span!("DepthSort").in_scope(|| { - // Interpret the depth as a u32. This is fine for a radix sort, as long as the depth > 0.0, - // which we know to be the case given how we cull splats. radix_argsort(depths, global_from_presort_gid, &num_visible, 32) }); (global_from_compact_gid, num_visible) }; - // Create a buffer of 'projected' splats, that is, - // project XY, projected conic, and converted color. + // Step 3: ProjectVisible with intersection counting let proj_size = size_of::() / size_of::(); let projected_splats = create_tensor([total_splats, proj_size], device, DType::F32); + let splat_intersect_counts = + Self::int_zeros([total_splats + 1].into(), device, IntDType::U32); - tracing::trace_span!("ProjectVisible").in_scope(|| { - // Create a buffer to determine how many threads to dispatch for all visible splats. + tracing::trace_span!("ProjectVisibleWithCounting").in_scope(|| { let num_vis_wg = create_dispatch_buffer_1d(num_visible.clone(), ProjectVisible::WORKGROUP_SIZE[0]); // SAFETY: Kernel checked to have no OOB, bounded loops. unsafe { client .launch_unchecked( - ProjectVisible::task(mip_splat), + ProjectVisible::task(mip_splat, true), // count_intersections = true CubeCount::Dynamic(num_vis_wg.handle.binding()), Bindings::new().with_buffers(vec![ uniforms_buffer.clone().handle.binding(), @@ -201,152 +163,168 @@ impl SplatForward for MainBackendBase { raw_opacities.handle.binding(), global_from_compact_gid.handle.clone().binding(), projected_splats.handle.clone().binding(), + splat_intersect_counts.handle.clone().binding(), ]), ) .expect("Failed to render splats"); } }); - // Each intersection maps to a gaussian. - let (tile_offsets, compact_gid_from_isect, num_intersections) = { - let num_tiles = tile_bounds.x * tile_bounds.y; - - let splat_intersect_counts = - Self::int_zeros([total_splats + 1].into(), device, IntDType::U32); - - let num_vis_map_wg = - create_dispatch_buffer_1d(num_visible, MapGaussiansToIntersect::WORKGROUP_SIZE[0]); - - // First do a prepass to compute the tile counts, then fill in intersection counts. - tracing::trace_span!("MapGaussiansToIntersectPrepass").in_scope(|| { - // SAFETY: Kernel checked to have no OOB, bounded loops. - unsafe { - client - .launch_unchecked( - MapGaussiansToIntersect::task(true), - CubeCount::Dynamic(num_vis_map_wg.handle.clone().binding()), - Bindings::new().with_buffers(vec![ - uniforms_buffer.handle.clone().binding(), - projected_splats.handle.clone().binding(), - splat_intersect_counts.handle.clone().binding(), - ]), - ) - .expect("Failed to render splats"); - } - }); + // Step 4: PrefixSum to get cumulative tile hits + let cum_tiles_hit = tracing::trace_span!("PrefixSumGaussHits") + .in_scope(|| prefix_sum(splat_intersect_counts)); - // TODO: Only need to do this up to num_visible gaussians really. Would need a - // prefix sum with a dynamic total length, and get the num at the dynamic length. - let cum_tiles_hit = tracing::trace_span!("PrefixSumGaussHits") - .in_scope(|| prefix_sum(splat_intersect_counts)); - let num_intersections = Self::int_slice( - cum_tiles_hit.clone(), - &[Slice::new(cum_tiles_hit.shape[0] as isize - 1, None, 1)], - ); - #[cfg(target_family = "wasm")] - let max_intersects = - max_intersections(img_size, total_splats as u32, num_intersections.clone()); - #[cfg(not(target_family = "wasm"))] - let max_intersects = { - use burn::tensor::{ElementConversion, Int}; - let intersects: Tensor = - Tensor::from_primitive(num_intersections.clone()); - intersects.into_scalar().elem::() - }; - - let tile_id_from_isect = create_tensor([max_intersects as usize], device, DType::U32); - let compact_gid_from_isect = - create_tensor([max_intersects as usize], device, DType::U32); - - tracing::trace_span!("MapGaussiansToIntersect").in_scope(|| { - // SAFETY: Kernel checked to have no OOB, bounded loops. - unsafe { - client - .launch_unchecked( - MapGaussiansToIntersect::task(false), - CubeCount::Dynamic(num_vis_map_wg.handle.clone().binding()), - Bindings::new().with_buffers(vec![ - uniforms_buffer.handle.clone().binding(), - projected_splats.handle.clone().binding(), - cum_tiles_hit.handle.binding(), - tile_id_from_isect.handle.clone().binding(), - compact_gid_from_isect.handle.clone().binding(), - ]), - ) - .expect("Failed to render splats"); - } - }); + // Sanity check + assert!( + uniforms_buffer.is_contiguous(), + "Uniforms must be contiguous" + ); + assert!( + global_from_compact_gid.is_contiguous(), + "Global from compact gid must be contiguous" + ); + assert!( + projected_splats.is_contiguous(), + "Projected splats must be contiguous" + ); - // We're sorting by tile ID, but we know beforehand what the maximum value - // can be. We don't need to sort all the leading 0 bits! - let bits = u32::BITS - num_tiles.leading_zeros(); - - let (tile_id_from_isect, compact_gid_from_isect) = tracing::trace_span!("Tile sort") - .in_scope(|| { - radix_argsort( - tile_id_from_isect, - compact_gid_from_isect, - &num_intersections, - bits, - ) - }); - - let cube_dim = CubeDim::new_1d(256); - let num_vis_map_wg = - create_dispatch_buffer_1d(num_intersections.clone(), 256 * CHECKS_PER_ITER); - let cube_count = CubeCount::Dynamic(num_vis_map_wg.handle.binding()); - - // Tiles without splats will be written as having a range of [0, 0]. - let tile_offsets = Self::int_zeros( - [tile_bounds.y as usize, tile_bounds.x as usize, 2].into(), - device, - IntDType::U32, - ); + ProjectAux { + projected_splats, + uniforms_buffer, + global_from_compact_gid, + cum_tiles_hit, + img_size, + } + } +} + +// Implement the second pass: Rasterize +impl SplatRasterize for MainBackendBase { + fn rasterize( + project_aux: &ProjectAux, + num_intersections: u32, + _background: Vec3, // Background is read from uniforms_buffer + bwd_info: bool, + ) -> (FloatTensor, RasterizeAux) { + let _span = tracing::trace_span!("rasterize").entered(); + + let device = &project_aux.projected_splats.device.clone(); + let client = project_aux.projected_splats.client.clone(); + let img_size = project_aux.img_size; + + // Divide screen into tiles. + let tile_bounds = calc_tile_bounds(img_size); + let num_tiles = tile_bounds.x * tile_bounds.y; + + // Get num_visible from uniforms buffer + let num_vis_field_offset = offset_of!(shaders::helpers::RenderUniforms, num_visible) / 4; + let num_visible = Self::int_slice( + project_aux.uniforms_buffer.clone(), + &[(num_vis_field_offset..num_vis_field_offset + 1).into()], + ); + + + // Step 1: Allocate intersection buffers with exact size (minimum 1 to avoid zero-size allocation) + let buffer_size = (num_intersections as usize).max(1); + let tile_id_from_isect = create_tensor([buffer_size], device, DType::U32); + let compact_gid_from_isect = create_tensor([buffer_size], device, DType::U32); + + // Step 2: MapGaussiansToIntersect (fill pass, prepass=false) + let num_vis_map_wg = + create_dispatch_buffer_1d(num_visible.clone(), MapGaussiansToIntersect::WORKGROUP_SIZE[0]); - // SAFETY: Safe kernel. + tracing::trace_span!("MapGaussiansToIntersect").in_scope(|| { + // SAFETY: Kernel checked to have no OOB, bounded loops. unsafe { - get_tile_offsets::launch_unchecked::( - client, - cube_count, - cube_dim, - tile_id_from_isect.as_tensor_arg(1), - tile_offsets.as_tensor_arg(1), - num_intersections.as_tensor_arg(1), - ) - .expect("Failed to render splats"); + client + .launch_unchecked( + MapGaussiansToIntersect::task(false), + CubeCount::Dynamic(num_vis_map_wg.handle.clone().binding()), + Bindings::new().with_buffers(vec![ + project_aux.uniforms_buffer.handle.clone().binding(), + project_aux.projected_splats.handle.clone().binding(), + project_aux.cum_tiles_hit.handle.clone().binding(), + tile_id_from_isect.handle.clone().binding(), + compact_gid_from_isect.handle.clone().binding(), + ]), + ) + .expect("Failed to render splats"); } + }); - (tile_offsets, compact_gid_from_isect, num_intersections) - }; + // Step 3: Tile sort - use static dispatch with actual num_intersections + let bits = u32::BITS - num_tiles.leading_zeros(); + + // Create a tensor holding num_intersections for the sort + // Get the last element from cum_tiles_hit + let cum_len = project_aux.cum_tiles_hit.shape[0]; + let num_intersections_tensor = Self::int_slice( + project_aux.cum_tiles_hit.clone(), + &[(cum_len - 1..cum_len).into()], + ); - let _span = tracing::trace_span!("Rasterize").entered(); + let (tile_id_from_isect, compact_gid_from_isect) = tracing::trace_span!("Tile sort") + .in_scope(|| { + radix_argsort( + tile_id_from_isect, + compact_gid_from_isect, + &num_intersections_tensor, + bits, + ) + }); - let out_dim = if bwd_info { - 4 - } else { - // Channels are packed into 4 bytes, aka one float. - 1 - }; + // Step 4: GetTileOffsets + let cube_dim = CubeDim::new_1d(256); + let num_vis_map_wg = + create_dispatch_buffer_1d(num_intersections_tensor.clone(), 256 * CHECKS_PER_ITER); + let cube_count = CubeCount::Dynamic(num_vis_map_wg.handle.binding()); + let tile_offsets = Self::int_zeros( + [tile_bounds.y as usize, tile_bounds.x as usize, 2].into(), + device, + IntDType::U32, + ); + + // SAFETY: Safe kernel. + unsafe { + get_tile_offsets::launch_unchecked::( + &client, + cube_count, + cube_dim, + tile_id_from_isect.as_tensor_arg(1), + tile_offsets.as_tensor_arg(1), + num_intersections_tensor.as_tensor_arg(1), + ) + .expect("Failed to render splats"); + } + + // Step 5: Rasterize + let out_dim = if bwd_info { 4 } else { 1 }; let out_img = create_tensor( [img_size.y as usize, img_size.x as usize, out_dim], device, DType::F32, ); + // Update background in uniforms - we need to create a modified uniforms buffer + // For now, we'll pass the background through the existing buffer structure + // The rasterize kernel reads background from uniforms, so we need to ensure it's set + let mut bindings = Bindings::new().with_buffers(vec![ - uniforms_buffer.handle.clone().binding(), + project_aux.uniforms_buffer.handle.clone().binding(), compact_gid_from_isect.handle.clone().binding(), tile_offsets.handle.clone().binding(), - projected_splats.handle.clone().binding(), + project_aux.projected_splats.handle.clone().binding(), out_img.handle.clone().binding(), ]); + // Get total_splats from the shape of projected_splats + let total_splats = project_aux.projected_splats.shape.dims[0]; + let visible = if bwd_info { let visible = Self::float_zeros([total_splats].into(), device, FloatDType::F32); - // Add the buffer to the bindings bindings = bindings.with_buffers(vec![ - global_from_compact_gid.handle.clone().binding(), + project_aux.global_from_compact_gid.handle.clone().binding(), visible.handle.clone().binding(), ]); visible @@ -354,8 +332,6 @@ impl SplatForward for MainBackendBase { create_tensor([1], device, DType::F32) }; - // Compile the kernel, including/excluding info for backwards pass. - // see the BWD_INFO define in the rasterize shader. let raster_task = Rasterize::task(bwd_info); // SAFETY: Kernel checked to have no OOB, bounded loops. @@ -369,40 +345,69 @@ impl SplatForward for MainBackendBase { .expect("Failed to render splats"); } - // Sanity check the buffers. - assert!( - uniforms_buffer.is_contiguous(), - "Uniforms must be contiguous" - ); - assert!( - tile_offsets.is_contiguous(), - "Tile offsets must be contiguous" - ); - assert!( - global_from_compact_gid.is_contiguous(), - "Global from compact gid must be contiguous" - ); + // Sanity checks + assert!(tile_offsets.is_contiguous(), "Tile offsets must be contiguous"); assert!(visible.is_contiguous(), "Visible must be contiguous"); - assert!( - projected_splats.is_contiguous(), - "Projected splats must be contiguous" - ); - assert!( - num_intersections.is_contiguous(), - "Num intersections must be contiguous" - ); ( out_img, - RenderAux { - uniforms_buffer, + RasterizeAux { tile_offsets, - projected_splats, compact_gid_from_isect, - global_from_compact_gid, visible, - img_size, }, ) } } + +// Implement backwards-compatible render_splats using the split pipeline +impl SplatForward for MainBackendBase { + fn render_splats( + camera: &Camera, + img_size: glam::UVec2, + means: FloatTensor, + log_scales: FloatTensor, + quats: FloatTensor, + sh_coeffs: FloatTensor, + raw_opacities: FloatTensor, + render_mode: SplatRenderMode, + background: Vec3, + bwd_info: bool, + ) -> (FloatTensor, RenderAux) { + // First pass: project and prepare (includes background in uniforms) + let project_aux = Self::project_prepare( + camera, + img_size, + means, + log_scales, + quats, + sh_coeffs, + raw_opacities, + render_mode, + background, + ); + + // Sync readback of num_intersections + #[cfg(not(target_family = "wasm"))] + let num_intersections = project_aux.num_intersections(); + + #[cfg(target_family = "wasm")] + let num_intersections = { + // On wasm, estimate max intersections + let tile_bounds = calc_tile_bounds(img_size); + let num_tiles = tile_bounds[0] * tile_bounds[1]; + let total_splats = project_aux.projected_splats.shape.dims[0] as u32; + let max_possible = num_tiles.saturating_mul(total_splats); + max_possible.min(2 * 512 * 65535) + }; + + // Second pass: rasterize + let (out_img, rasterize_aux) = + Self::rasterize(&project_aux, num_intersections, background, bwd_info); + + // Combine into RenderAux for backwards compatibility + let render_aux = RenderAux::from_parts(project_aux, rasterize_aux); + + (out_img, render_aux) + } +} diff --git a/crates/brush-render/src/render_aux.rs b/crates/brush-render/src/render_aux.rs index 0ac3f579..a156a6d8 100644 --- a/crates/brush-render/src/render_aux.rs +++ b/crates/brush-render/src/render_aux.rs @@ -12,6 +12,48 @@ use burn::{ use crate::shaders::{self, helpers::TILE_WIDTH}; +/// Output of the ProjectPrepare pass. +/// +/// Contains all data needed to perform the Rasterize pass, including +/// the `cum_tiles_hit` buffer which can be used to extract the exact +/// number of intersections via sync readback. +#[derive(Debug, Clone)] +pub struct ProjectAux { + /// The packed projected splat information, see `ProjectedSplat` in helpers.wgsl + pub projected_splats: FloatTensor, + pub uniforms_buffer: IntTensor, + pub global_from_compact_gid: IntTensor, + /// Cumulative sum of tiles hit per splat. Last element contains total num_intersections. + pub cum_tiles_hit: IntTensor, + pub img_size: glam::UVec2, +} + +impl ProjectAux { + /// Extract the total number of intersections from the cum_tiles_hit buffer. + /// + /// This requires a sync readback from the GPU. + #[cfg(not(target_family = "wasm"))] + pub fn num_intersections(&self) -> u32 { + use burn::tensor::ElementConversion; + let cum: Tensor = Tensor::from_primitive(self.cum_tiles_hit.clone()); + let len = cum.dims()[0]; + cum.slice(s![len - 1..len]).into_scalar().elem::() + } + + pub fn num_visible(&self) -> Tensor { + let num_vis_field_offset = offset_of!(shaders::helpers::RenderUniforms, num_visible) / 4; + Tensor::from_primitive(self.uniforms_buffer.clone()).slice(s![num_vis_field_offset]) + } +} + +/// Output of the Rasterize pass. +#[derive(Debug, Clone)] +pub struct RasterizeAux { + pub tile_offsets: IntTensor, + pub compact_gid_from_isect: IntTensor, + pub visible: FloatTensor, +} + #[derive(Debug, Clone)] pub struct RenderAux { /// The packed projected splat information, see `ProjectedSplat` in helpers.wgsl @@ -24,6 +66,21 @@ pub struct RenderAux { pub img_size: glam::UVec2, } +impl RenderAux { + /// Combine ProjectAux and RasterizeAux into a RenderAux for backwards compatibility. + pub fn from_parts(project: ProjectAux, rasterize: RasterizeAux) -> Self { + Self { + projected_splats: project.projected_splats, + uniforms_buffer: project.uniforms_buffer, + global_from_compact_gid: project.global_from_compact_gid, + tile_offsets: rasterize.tile_offsets, + compact_gid_from_isect: rasterize.compact_gid_from_isect, + visible: rasterize.visible, + img_size: project.img_size, + } + } +} + impl RenderAux { pub fn calc_tile_depth(&self) -> Tensor { let tile_offsets: Tensor = Tensor::from_primitive(self.tile_offsets.clone()); @@ -54,16 +111,33 @@ impl RenderAux { Tensor::from_primitive(self.compact_gid_from_isect.clone()); let num_visible: Tensor = self.num_visible(); - let num_points = compact_gid_from_isect.dims()[0] as u32; + // Get num_intersections from the last element of the flattened tile_offsets + // tile_offsets has shape [ty, tx, 2] where each [i, j, :] is [start, end] for that tile + // The last element (end offset of the last tile) is the total number of intersections + let tile_offsets_3d: Tensor = Tensor::from_primitive(self.tile_offsets.clone()); + let [ty, tx, _] = tile_offsets_3d.dims(); + // Get the end offset of the last tile: tile_offsets[ty-1, tx-1, 1] + let num_intersections = tile_offsets_3d + .slice(s![ty - 1..ty, tx - 1..tx, 1..2]) + .reshape([1]) + .into_scalar() + .elem::() as u32; + + // Get total_splats from the uniforms buffer for validation + let total_splats_field_offset = + offset_of!(shaders::helpers::RenderUniforms, total_splats) / 4; + let total_splats: Tensor = + Tensor::from_primitive(self.uniforms_buffer.clone()); + let total_splats = total_splats + .slice(s![total_splats_field_offset..total_splats_field_offset + 1]) + .into_scalar() + .elem::() as u32; + let num_visible = num_visible.into_scalar().elem::() as u32; - let num_intersections: Tensor = - Tensor::from_primitive(self.tile_offsets.clone()); - let num_intersections = - num_intersections.slice(s![-1]).into_scalar().elem::() as u32; assert!( - num_visible <= num_points, - "Something went wrong when calculating the number of visible gaussians. {num_visible} > {num_points}" + num_visible <= total_splats, + "Something went wrong when calculating the number of visible gaussians. {num_visible} > {total_splats}" ); // Projected splats is only valid up to num_visible and undefined for other values. @@ -76,24 +150,27 @@ impl RenderAux { validate_tensor_val(&projected_splats, "projected_splats", None, None); } - let visible: Tensor = + let visible: Tensor = Tensor::from_primitive(TensorPrimitive::Float(self.visible.clone())); - validate_tensor_val(&visible, "visible", None, None); - - let tile_offsets: Tensor = Tensor::from_primitive(self.tile_offsets.clone()); - - let tile_offsets = tile_offsets - .into_data() - .into_vec::() - .expect("Failed to fetch tile offsets"); - for &offsets in &tile_offsets { - assert!( - offsets <= num_intersections, - "Tile offsets exceed bounds. Value: {offsets}, num_intersections: {num_intersections}" - ); - } + let visible_2d: Tensor = visible.unsqueeze_dim(1); + validate_tensor_val(&visible_2d, "visible", None, None); + // Only validate tile_offsets when there are intersections to validate if num_intersections > 0 { + let tile_offsets: Tensor = + Tensor::from_primitive(self.tile_offsets.clone()); + + let tile_offsets = tile_offsets + .into_data() + .into_vec::() + .expect("Failed to fetch tile offsets"); + for &offsets in &tile_offsets { + assert!( + offsets <= num_intersections, + "Tile offsets exceed bounds. Value: {offsets}, num_intersections: {num_intersections}" + ); + } + for i in 0..(tile_offsets.len() - 1) / 2 { // Check pairs of start/end points. let start = tile_offsets[i * 2]; @@ -113,19 +190,29 @@ impl RenderAux { } } - if num_intersections > 0 { - let compact_gid_from_isect = &compact_gid_from_isect + // Skip validation of compact_gid_from_isect when shape is 0 (fusion placeholder) + let declared_size = compact_gid_from_isect.dims()[0]; + if num_intersections > 0 && declared_size > 0 { + let data = compact_gid_from_isect .slice([0..num_intersections as usize]) - .into_data() + .into_data(); + + // Handle both I32 and U32 tensor types + let compact_gid_vec: Vec = data + .clone() .into_vec::() + .or_else(|_| { + data.into_vec::() + .map(|v| v.into_iter().map(|x| x as u32).collect()) + }) .expect("Failed to fetch compact_gid_from_isect"); - for (i, &compact_gid) in compact_gid_from_isect.iter().enumerate() { + for (i, compact_gid) in compact_gid_vec.iter().enumerate() { assert!( - compact_gid < num_visible, + *compact_gid < num_visible, "Invalid gaussian ID in intersection buffer. {compact_gid} out of {num_visible}. At {i} out of {num_intersections} intersections. \n - {compact_gid_from_isect:?} + {compact_gid_vec:?} \n\n\n" ); @@ -133,18 +220,21 @@ impl RenderAux { } // assert that every ID in global_from_compact_gid is valid. - let global_from_compact_gid: Tensor = - Tensor::from_primitive(self.global_from_compact_gid.clone()); - let global_from_compact_gid = &global_from_compact_gid - .into_data() - .into_vec::() - .expect("Failed to fetch global_from_compact_gid")[0..num_visible as usize]; - - for &global_gid in global_from_compact_gid { - assert!( - global_gid < num_points, - "Invalid gaussian ID in global_from_compact_gid buffer. {global_gid} out of {num_points}" - ); + // Only validate when there are visible splats + if num_visible > 0 && total_splats > 0 { + let global_from_compact_gid: Tensor = + Tensor::from_primitive(self.global_from_compact_gid.clone()); + let global_from_compact_gid = &global_from_compact_gid + .into_data() + .into_vec::() + .expect("Failed to fetch global_from_compact_gid")[0..num_visible as usize]; + + for &global_gid in global_from_compact_gid { + assert!( + global_gid < total_splats, + "Invalid gaussian ID in global_from_compact_gid buffer. {global_gid} out of {total_splats}" + ); + } } } } diff --git a/crates/brush-render/src/shaders.rs b/crates/brush-render/src/shaders.rs index b06e6ef9..2e3c2b71 100644 --- a/crates/brush-render/src/shaders.rs +++ b/crates/brush-render/src/shaders.rs @@ -10,6 +10,7 @@ pub struct ProjectSplats { #[wgsl_kernel(source = "src/shaders/project_visible.wgsl")] pub struct ProjectVisible { mip_splatting: bool, + count_intersections: bool, } #[wgsl_kernel(source = "src/shaders/map_gaussian_to_intersects.wgsl")] diff --git a/crates/brush-render/src/shaders/project_visible.wgsl b/crates/brush-render/src/shaders/project_visible.wgsl index 3853fed0..4d9b0e1a 100644 --- a/crates/brush-render/src/shaders/project_visible.wgsl +++ b/crates/brush-render/src/shaders/project_visible.wgsl @@ -15,6 +15,10 @@ struct IsectInfo { @group(0) @binding(6) var global_from_compact_gid: array; @group(0) @binding(7) var projected: array; +#ifdef COUNT_INTERSECTIONS + @group(0) @binding(8) var splat_intersect_counts: array; +#endif + struct ShCoeffs { b0_c0: vec3f, @@ -246,9 +250,36 @@ fn main( let viewdir = normalize(mean - uniforms.camera_position.xyz); var color = sh_coeffs_to_color(sh_degree, viewdir, sh) + vec3f(0.5); + let conic_packed = vec3f(conic[0][0], conic[0][1], conic[1][1]); + projected[compact_gid] = helpers::create_projected_splat( mean2d, - vec3f(conic[0][0], conic[0][1], conic[1][1]), + conic_packed, vec4f(color, opac) ); + +#ifdef COUNT_INTERSECTIONS + // Count intersections for this splat (merged from map_gaussian_to_intersects prepass) + let power_threshold = log(opac * 255.0); + let extent = helpers::compute_bbox_extent(cov2d, power_threshold); + let tile_bbox = helpers::get_tile_bbox(mean2d, extent, uniforms.tile_bounds); + let tile_bbox_min = tile_bbox.xy; + let tile_bbox_max = tile_bbox.zw; + + var num_tiles_hit = 0u; + let tile_bbox_width = tile_bbox_max.x - tile_bbox_min.x; + let num_tiles_bbox = (tile_bbox_max.y - tile_bbox_min.y) * tile_bbox_width; + + for (var tile_idx = 0u; tile_idx < num_tiles_bbox; tile_idx++) { + let tx = (tile_idx % tile_bbox_width) + tile_bbox_min.x; + let ty = (tile_idx / tile_bbox_width) + tile_bbox_min.y; + + let rect = helpers::tile_rect(vec2u(tx, ty)); + if helpers::will_primitive_contribute(rect, mean2d, conic_packed, power_threshold) { + num_tiles_hit += 1u; + } + } + + splat_intersect_counts[compact_gid + 1u] = num_tiles_hit; +#endif } From 84b10534c0ab50ef39107a0c79a273df86b9e1f5 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 19 Jan 2026 21:10:21 +0100 Subject: [PATCH 03/29] Cleanup split pass --- crates/brush-bench-test/src/reference.rs | 15 +- crates/brush-render-bwd/src/burn_glue.rs | 99 ++++-- crates/brush-render-bwd/src/render_bwd.rs | 67 ++-- .../src/shaders/project_backwards.wgsl | 7 +- .../src/shaders/rasterize_backwards.wgsl | 29 +- crates/brush-render-bwd/src/tests.rs | 2 - crates/brush-render/src/burn_glue.rs | 282 +++------------- crates/brush-render/src/gaussian_splats.rs | 28 +- crates/brush-render/src/get_tile_offset.rs | 43 +-- crates/brush-render/src/lib.rs | 54 +-- crates/brush-render/src/render.rs | 310 +++++++----------- crates/brush-render/src/render_aux.rs | 225 ++++++------- crates/brush-render/src/shaders.rs | 8 +- crates/brush-render/src/shaders/helpers.wgsl | 21 +- .../shaders/map_gaussian_to_intersects.wgsl | 32 +- .../src/shaders/project_forward.wgsl | 23 +- .../src/shaders/project_visible.wgsl | 16 +- .../brush-render/src/shaders/rasterize.wgsl | 19 +- crates/brush-render/src/tests/mod.rs | 77 ++++- crates/brush-sort/src/lib.rs | 90 +++-- crates/brush-train/src/eval.rs | 44 ++- crates/brush-train/src/train.rs | 8 +- crates/brush-ui/src/scene.rs | 2 +- 23 files changed, 669 insertions(+), 832 deletions(-) diff --git a/crates/brush-bench-test/src/reference.rs b/crates/brush-bench-test/src/reference.rs index d9b907c4..6b662075 100644 --- a/crates/brush-bench-test/src/reference.rs +++ b/crates/brush-bench-test/src/reference.rs @@ -133,10 +133,9 @@ async fn test_reference() -> Result<()> { Vec3::ZERO, ); - let (out, aux) = ( - Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)), - diff_out.aux, - ); + let out: Tensor = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); + let project_aux = diff_out.project_aux; + let rasterize_aux = diff_out.rasterize_aux; if let Some(rec) = rec.as_ref() { rec.set_time_sequence("test case", i as i64); @@ -148,17 +147,17 @@ async fn test_reference() -> Result<()> { )?; rec.log( "images/tile_depth", - &aux.calc_tile_depth().into_rerun().await, + &rasterize_aux.calc_tile_depth().into_rerun().await, )?; } - let num_visible: Tensor = aux.num_visible(); + let num_visible: Tensor = project_aux.get_num_visible(); let num_visible = num_visible.into_scalar_async().await.unwrap() as usize; let global_from_compact_gid: Tensor = - Tensor::from_primitive(aux.global_from_compact_gid.clone()); + Tensor::from_primitive(project_aux.global_from_compact_gid.clone()); let gs_ids = global_from_compact_gid.clone().slice([0..num_visible]); let projected_splats = - Tensor::from_primitive(TensorPrimitive::Float(aux.projected_splats.clone())); + Tensor::from_primitive(TensorPrimitive::Float(project_aux.projected_splats.clone())); let xys: Tensor = projected_splats.clone().slice([0..num_visible, 0..2]); let xys_ref = safetensor_to_burn::(&tensors.tensor("xys")?, &device); diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index d094d580..52d969b9 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -1,9 +1,10 @@ use brush_render::{ - MainBackendBase, SplatForward, + MainBackendBase, SplatOps, camera::Camera, gaussian_splats::{SplatRenderMode, Splats}, - render_aux::RenderAux, + render_aux::{ProjectAux, RasterizeAux, validate_render_output}, sh::{sh_coeffs_for_degree, sh_degree_from_coeffs}, + shaders::helpers::ProjectUniforms, }; use burn::{ backend::{ @@ -32,7 +33,7 @@ use glam::Vec3; use crate::render_bwd::SplatGrads; -/// Like [`SplatForward`], but for backends that support differentiation. +/// Like [`SplatOps`], but for backends that support differentiation. /// /// This shouldn't be a separate trait, but atm is needed because of orphan trait rules. pub trait SplatForwardDiff { @@ -65,7 +66,7 @@ pub trait SplatBackwardOps { ) -> SplatGrads; } -/// State from the ProjectPrepare pass needed for backward computation. +/// State from the `ProjectPrepare` pass needed for backward computation. #[derive(Debug, Clone)] pub struct ProjectBackwardState { pub(crate) means: FloatTensor, @@ -73,10 +74,12 @@ pub struct ProjectBackwardState { pub(crate) log_scales: FloatTensor, pub(crate) raw_opac: FloatTensor, pub(crate) projected_splats: FloatTensor, - pub(crate) uniforms_buffer: IntTensor, + pub(crate) project_uniforms: ProjectUniforms, + pub(crate) num_visible: IntTensor, pub(crate) global_from_compact_gid: IntTensor, pub(crate) render_mode: SplatRenderMode, pub(crate) sh_degree: u32, + pub(crate) background: Vec3, } /// State from the Rasterize pass needed for backward computation. @@ -96,12 +99,14 @@ pub struct GaussianBackwardState { pub(crate) raw_opac: FloatTensor, pub(crate) out_img: FloatTensor, pub(crate) projected_splats: FloatTensor, - pub(crate) uniforms_buffer: IntTensor, + pub(crate) project_uniforms: ProjectUniforms, + pub(crate) num_visible: IntTensor, pub(crate) compact_gid_from_isect: IntTensor, pub(crate) global_from_compact_gid: IntTensor, pub(crate) tile_offsets: IntTensor, pub(crate) render_mode: SplatRenderMode, pub(crate) sh_degree: u32, + pub(crate) background: Vec3, } impl GaussianBackwardState { @@ -116,10 +121,12 @@ impl GaussianBackwardState { log_scales: project.log_scales, raw_opac: project.raw_opac, projected_splats: project.projected_splats, - uniforms_buffer: project.uniforms_buffer, + project_uniforms: project.project_uniforms, + num_visible: project.num_visible, global_from_compact_gid: project.global_from_compact_gid, render_mode: project.render_mode, sh_degree: project.sh_degree, + background: project.background, out_img: rasterize.out_img, compact_gid_from_isect: rasterize.compact_gid_from_isect, tile_offsets: rasterize.tile_offsets, @@ -190,13 +197,14 @@ impl> Backward for RenderBackw pub struct SplatOutputDiff { pub img: FloatTensor, - pub aux: RenderAux, + pub project_aux: ProjectAux, + pub rasterize_aux: RasterizeAux, pub refine_weight_holder: Tensor, } // Implement -impl + SplatForward, C: CheckpointStrategy> - SplatForwardDiff for Autodiff +impl + SplatOps, C: CheckpointStrategy> SplatForwardDiff + for Autodiff { fn render_splats( camera: &Camera, @@ -228,8 +236,8 @@ impl + SplatForward, C: CheckpointStrategy> .compute_bound() .stateful(); - // Render complete forward pass. - let (out_img, aux) = >::render_splats( + // First pass: project + let project_aux = >::project( camera, img_size, means.clone().into_primitive(), @@ -238,18 +246,32 @@ impl + SplatForward, C: CheckpointStrategy> sh_coeffs.clone().into_primitive(), raw_opacity.clone().into_primitive(), render_mode, - background, - true, ); - let wrapped_aux = RenderAux:: { - projected_splats: ::from_inner(aux.projected_splats.clone()), - tile_offsets: aux.tile_offsets.clone(), - compact_gid_from_isect: aux.compact_gid_from_isect.clone(), - global_from_compact_gid: aux.global_from_compact_gid.clone(), - uniforms_buffer: aux.uniforms_buffer.clone(), - visible: ::from_inner(aux.visible), - img_size: aux.img_size, + // Sync readback of num_intersections + let num_intersections = project_aux.num_intersections(); + + // Second pass: rasterize (with bwd_info = true) + let (out_img, rasterize_aux) = + >::rasterize(&project_aux, num_intersections, background, true); + + // Create wrapped aux structs for Autodiff backend + let wrapped_project_aux = ProjectAux:: { + project_uniforms: project_aux.project_uniforms, + projected_splats: ::from_inner( + project_aux.projected_splats.clone(), + ), + num_visible: project_aux.num_visible.clone(), + global_from_compact_gid: project_aux.global_from_compact_gid.clone(), + cum_tiles_hit: project_aux.cum_tiles_hit.clone(), + img_size: project_aux.img_size, + }; + + let wrapped_rasterize_aux = RasterizeAux:: { + tile_offsets: rasterize_aux.tile_offsets.clone(), + compact_gid_from_isect: rasterize_aux.compact_gid_from_isect.clone(), + visible: ::from_inner(rasterize_aux.visible.clone()), + img_size: rasterize_aux.img_size, }; match prep_nodes { @@ -265,19 +287,22 @@ impl + SplatForward, C: CheckpointStrategy> [1] as u32, ), out_img: out_img.clone(), - projected_splats: aux.projected_splats, - uniforms_buffer: aux.uniforms_buffer, - tile_offsets: aux.tile_offsets, - compact_gid_from_isect: aux.compact_gid_from_isect, + projected_splats: project_aux.projected_splats, + project_uniforms: project_aux.project_uniforms, + num_visible: project_aux.num_visible, + tile_offsets: rasterize_aux.tile_offsets, + compact_gid_from_isect: rasterize_aux.compact_gid_from_isect, render_mode, - global_from_compact_gid: aux.global_from_compact_gid, + global_from_compact_gid: project_aux.global_from_compact_gid, + background, }; let out_img = prep.finish(state, out_img); SplatOutputDiff { img: out_img, - aux: wrapped_aux, + project_aux: wrapped_project_aux, + rasterize_aux: wrapped_rasterize_aux, refine_weight_holder, } } @@ -286,7 +311,8 @@ impl + SplatForward, C: CheckpointStrategy> // keeping any state. SplatOutputDiff { img: prep.finish(out_img), - aux: wrapped_aux, + project_aux: wrapped_project_aux, + rasterize_aux: wrapped_rasterize_aux, refine_weight_holder, } } @@ -304,6 +330,8 @@ impl SplatBackwardOps for Fusion { desc: CustomOpIr, render_mode: SplatRenderMode, sh_degree: u32, + background: Vec3, + project_uniforms: ProjectUniforms, } impl Operation> for CustomOp { @@ -321,7 +349,7 @@ impl SplatBackwardOps for Fusion { raw_opac, out_img, projected_splats, - uniforms_buffer, + num_visible, tile_offsets, compact_gid_from_isect, global_from_compact_gid, @@ -336,7 +364,8 @@ impl SplatBackwardOps for Fusion { raw_opac: h.get_float_tensor::(raw_opac), out_img: h.get_float_tensor::(out_img), projected_splats: h.get_float_tensor::(projected_splats), - uniforms_buffer: h.get_int_tensor::(uniforms_buffer), + project_uniforms: self.project_uniforms, + num_visible: h.get_int_tensor::(num_visible), tile_offsets: h.get_int_tensor::(tile_offsets), compact_gid_from_isect: h .get_int_tensor::(compact_gid_from_isect), @@ -344,6 +373,7 @@ impl SplatBackwardOps for Fusion { .get_int_tensor::(global_from_compact_gid), sh_degree: self.sh_degree, render_mode: self.render_mode, + background: self.background, }; let grads = @@ -405,7 +435,7 @@ impl SplatBackwardOps for Fusion { state.raw_opac, state.out_img, state.projected_splats, - state.uniforms_buffer, + state.num_visible, state.tile_offsets, state.compact_gid_from_isect, state.global_from_compact_gid, @@ -430,10 +460,11 @@ impl SplatBackwardOps for Fusion { stream, OperationIr::Custom(desc.clone()), CustomOp { - // state, desc, sh_degree: state.sh_degree, render_mode: state.render_mode, + background: state.background, + project_uniforms: state.project_uniforms, }, ) .outputs(); @@ -480,6 +511,6 @@ where splats.render_mode, background, ); - result.aux.validate_values(); + validate_render_output(&result.project_aux, &result.rasterize_aux); result } diff --git a/crates/brush-render-bwd/src/render_bwd.rs b/crates/brush-render-bwd/src/render_bwd.rs index c63574be..22e05d07 100644 --- a/crates/brush-render-bwd/src/render_bwd.rs +++ b/crates/brush-render-bwd/src/render_bwd.rs @@ -1,5 +1,6 @@ -use brush_kernel::{CubeCount, calc_cube_count_1d}; +use brush_kernel::{CubeCount, calc_cube_count_1d, create_meta_binding}; use brush_render::gaussian_splats::SplatRenderMode; +use brush_render::shaders::helpers::RasterizeUniforms; use brush_wgsl::wgsl_kernel; use brush_render::MainBackendBase; @@ -59,7 +60,6 @@ impl SplatBackwardOps for MainBackendBase { // We're in charge of these, SHOULD be contiguous but might as well. let projected_splats = into_contiguous(state.projected_splats); - let uniforms_buffer = into_contiguous(state.uniforms_buffer); let compact_gid_from_isect = into_contiguous(state.compact_gid_from_isect); let global_from_compact_gid = into_contiguous(state.global_from_compact_gid); let tile_offsets = into_contiguous(state.tile_offsets); @@ -101,6 +101,13 @@ impl SplatBackwardOps for MainBackendBase { .div_ceil(brush_render::shaders::helpers::TILE_WIDTH), ); + // Create RasterizeUniforms for the backward rasterize pass (passed via with_metadata) + let rasterize_uniforms = RasterizeUniforms { + tile_bounds: tile_bounds.into(), + img_size: img_size.into(), + background: [state.background.x, state.background.y, state.background.z, 1.0], + }; + let hard_floats = client .properties() .type_usage(StorageType::Atomic(ElemType::Float(FloatKind::F32))) @@ -117,18 +124,19 @@ impl SplatBackwardOps for MainBackendBase { .launch_unchecked( RasterizeBackwards::task(hard_floats, webgpu), CubeCount::Static(tile_bounds.x * tile_bounds.y, 1, 1), - Bindings::new().with_buffers(vec![ - uniforms_buffer.handle.clone().binding(), - compact_gid_from_isect.handle.binding(), - global_from_compact_gid.handle.clone().binding(), - tile_offsets.handle.binding(), - projected_splats.handle.binding(), - state.out_img.handle.binding(), - v_output.handle.binding(), - v_grads.handle.clone().binding(), - v_raw_opac.handle.clone().binding(), - v_refine_weight.handle.clone().binding(), - ]), + Bindings::new() + .with_buffers(vec![ + compact_gid_from_isect.handle.binding(), + global_from_compact_gid.handle.clone().binding(), + tile_offsets.handle.binding(), + projected_splats.handle.binding(), + state.out_img.handle.binding(), + v_output.handle.binding(), + v_grads.handle.clone().binding(), + v_raw_opac.handle.clone().binding(), + v_refine_weight.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(rasterize_uniforms)), ) .expect("Failed to bwd-diff splats"); } @@ -140,21 +148,22 @@ impl SplatBackwardOps for MainBackendBase { client.launch_unchecked( ProjectBackwards::task(mip_splat), calc_cube_count_1d(num_points as u32, ProjectBackwards::WORKGROUP_SIZE[0]), - Bindings::new().with_buffers( - vec![ - uniforms_buffer.handle.binding(), - means.handle.binding(), - log_scales.handle.binding(), - quats.handle.binding(), - raw_opac.handle.binding(), - global_from_compact_gid.handle.binding(), - v_grads.handle.binding(), - v_means.handle.clone().binding(), - v_scales.handle.clone().binding(), - v_quats.handle.clone().binding(), - v_coeffs.handle.clone().binding(), - v_raw_opac.handle.clone().binding(), - ]), + Bindings::new() + .with_buffers(vec![ + state.num_visible.handle.binding(), + means.handle.binding(), + log_scales.handle.binding(), + quats.handle.binding(), + raw_opac.handle.binding(), + global_from_compact_gid.handle.binding(), + v_grads.handle.binding(), + v_means.handle.clone().binding(), + v_scales.handle.clone().binding(), + v_quats.handle.clone().binding(), + v_coeffs.handle.clone().binding(), + v_raw_opac.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(state.project_uniforms)), ).expect("Failed to bwd-diff splats"); }); diff --git a/crates/brush-render-bwd/src/shaders/project_backwards.wgsl b/crates/brush-render-bwd/src/shaders/project_backwards.wgsl index 1c569646..8c8beb82 100644 --- a/crates/brush-render-bwd/src/shaders/project_backwards.wgsl +++ b/crates/brush-render-bwd/src/shaders/project_backwards.wgsl @@ -1,6 +1,6 @@ #import helpers; -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; +@group(0) @binding(0) var num_visible: u32; @group(0) @binding(1) var means: array; @group(0) @binding(2) var log_scales: array; @@ -17,6 +17,9 @@ @group(0) @binding(10) var v_coeffs: array; @group(0) @binding(11) var v_opacs: array; +// Uniforms via with_metadata (always last binding) +@group(0) @binding(12) var uniforms: helpers::ProjectUniforms; + const SH_C0: f32 = 0.2820947917738781f; struct ShCoeffs { @@ -299,7 +302,7 @@ fn main( @builtin(local_invocation_index) lid: u32, ) { let compact_gid = helpers::get_global_id(wid, num_wgs, lid, WG_SIZE); - if compact_gid >= uniforms.num_visible { + if compact_gid >= num_visible { return; } diff --git a/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl b/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl index 9e28b003..be337345 100644 --- a/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl +++ b/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl @@ -1,17 +1,18 @@ #import helpers; -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; -@group(0) @binding(1) var compact_gid_from_isect: array; -@group(0) @binding(2) var global_from_compact_gid: array; -@group(0) @binding(3) var tile_offsets: array; -@group(0) @binding(4) var projected: array; -@group(0) @binding(5) var output: array; -@group(0) @binding(6) var v_output: array; +@group(0) @binding(0) var compact_gid_from_isect: array; +@group(0) @binding(1) var global_from_compact_gid: array; +@group(0) @binding(2) var tile_offsets: array; +@group(0) @binding(3) var projected: array; +@group(0) @binding(4) var output: array; +@group(0) @binding(5) var v_output: array; #ifdef HARD_FLOAT - @group(0) @binding(7) var v_splats: array>; - @group(0) @binding(8) var v_opacs: array>; - @group(0) @binding(9) var v_refines: array>; + @group(0) @binding(6) var v_splats: array>; + @group(0) @binding(7) var v_opacs: array>; + @group(0) @binding(8) var v_refines: array>; + // Uniforms via with_metadata (always last binding) + @group(0) @binding(9) var uniforms: helpers::RasterizeUniforms; fn write_grads_atomic(id: u32, grads: f32) { atomicAdd(&v_splats[id], grads); @@ -23,9 +24,11 @@ atomicAdd(&v_opacs[id], grads); } #else - @group(0) @binding(7) var v_splats: array>; - @group(0) @binding(8) var v_opacs: array>; - @group(0) @binding(9) var v_refines: array>; + @group(0) @binding(6) var v_splats: array>; + @group(0) @binding(7) var v_opacs: array>; + @group(0) @binding(8) var v_refines: array>; + // Uniforms via with_metadata (always last binding) + @group(0) @binding(9) var uniforms: helpers::RasterizeUniforms; fn add_bitcast(cur: u32, add: f32) -> u32 { return bitcast(bitcast(cur) + add); diff --git a/crates/brush-render-bwd/src/tests.rs b/crates/brush-render-bwd/src/tests.rs index 655239c7..ebdb58b3 100644 --- a/crates/brush-render-bwd/src/tests.rs +++ b/crates/brush-render-bwd/src/tests.rs @@ -46,7 +46,6 @@ fn diffs_at_all() { SplatRenderMode::Default, Vec3::ZERO, ); - result.aux.validate_values(); let output: Tensor = Tensor::from_primitive(TensorPrimitive::Float(result.img)); let rgb = output.clone().slice([0..32, 0..32, 0..3]); @@ -115,5 +114,4 @@ fn diffs_many_splats() { SplatRenderMode::Default, Vec3::ZERO, ); - result.aux.validate_values(); } diff --git a/crates/brush-render/src/burn_glue.rs b/crates/brush-render/src/burn_glue.rs index 58ab0515..da02c661 100644 --- a/crates/brush-render/src/burn_glue.rs +++ b/crates/brush-render/src/burn_glue.rs @@ -9,13 +9,17 @@ use burn_wgpu::WgpuRuntime; use glam::Vec3; use crate::{ - MainBackendBase, SplatForward, SplatProjectPrepare, SplatRasterize, - camera::Camera, gaussian_splats::SplatRenderMode, - render::calc_tile_bounds, render_aux::{ProjectAux, RasterizeAux, RenderAux}, shaders, + MainBackendBase, SplatOps, + camera::Camera, + gaussian_splats::SplatRenderMode, + render::calc_tile_bounds, + render_aux::{ProjectAux, RasterizeAux}, + sh::sh_degree_from_coeffs, + shaders::{self, helpers::ProjectUniforms}, }; -impl SplatForward for Fusion { - fn render_splats( +impl SplatOps for Fusion { + fn project( cam: &Camera, img_size: glam::UVec2, means: FloatTensor, @@ -24,199 +28,12 @@ impl SplatForward for Fusion { sh_coeffs: FloatTensor, opacity: FloatTensor, render_mode: SplatRenderMode, - background: Vec3, - bwd_info: bool, - ) -> (FloatTensor, RenderAux) { - #[derive(Debug)] - struct CustomOp { - cam: Camera, - img_size: glam::UVec2, - render_mode: SplatRenderMode, - bwd_info: bool, - background: Vec3, - desc: CustomOpIr, - } - - impl Operation> for CustomOp { - fn execute( - &self, - h: &mut HandleContainer>>, - ) { - let (inputs, outputs) = self.desc.as_fixed(); - - let [means, log_scales, quats, sh_coeffs, opacity] = inputs; - let [ - // Img - out_img, - // Aux - projected_splats, - uniforms_buffer, - tile_offsets, - compact_gid_from_isect, - global_from_compact_gid, - visible, - ] = outputs; - - let (img, aux) = MainBackendBase::render_splats( - &self.cam, - self.img_size, - h.get_float_tensor::(means), - h.get_float_tensor::(log_scales), - h.get_float_tensor::(quats), - h.get_float_tensor::(sh_coeffs), - h.get_float_tensor::(opacity), - self.render_mode, - self.background, - self.bwd_info, - ); - - // Register output. - h.register_float_tensor::(&out_img.id, img); - h.register_float_tensor::( - &projected_splats.id, - aux.projected_splats, - ); - h.register_int_tensor::(&uniforms_buffer.id, aux.uniforms_buffer); - h.register_int_tensor::(&tile_offsets.id, aux.tile_offsets); - h.register_int_tensor::( - &compact_gid_from_isect.id, - aux.compact_gid_from_isect, - ); - h.register_int_tensor::( - &global_from_compact_gid.id, - aux.global_from_compact_gid, - ); - - h.register_float_tensor::(&visible.id, aux.visible); - } - } - - let client = means.client.clone(); - - let num_points = means.shape[0]; - - let proj_size = size_of::() / 4; - let uniforms_size = size_of::() / 4; - let tile_bounds = calc_tile_bounds(img_size); - - // If render_u32_buffer is true, we render a packed buffer of u32 values, otherwise - // render RGBA f32 values. - let channels = if bwd_info { 4 } else { 1 }; - - let out_img = TensorIr::uninit( - client.create_empty_handle(), - Shape::new([img_size.y as usize, img_size.x as usize, channels]), - if bwd_info { DType::F32 } else { DType::U32 }, - ); - - let visible_shape = if bwd_info { - Shape::new([num_points]) - } else { - Shape::new([1]) - }; - - let projected_splats = TensorIr::uninit( - client.create_empty_handle(), - Shape::new([num_points, proj_size]), - DType::F32, - ); - let uniforms_buffer = TensorIr::uninit( - client.create_empty_handle(), - Shape::new([uniforms_size]), - DType::U32, - ); - let tile_offsets = TensorIr::uninit( - client.create_empty_handle(), - Shape::new([tile_bounds.y as usize, tile_bounds.x as usize, 2]), - DType::U32, - ); - - // This is not actually size 0, but, it's dynamic. This is just a dummy handle so we just - // set a dummy size of 0. - let compact_gid_from_isect = - TensorIr::uninit(client.create_empty_handle(), Shape::new([0]), DType::U32); - - let global_from_compact_gid = TensorIr::uninit( - client.create_empty_handle(), - Shape::new([num_points]), - DType::U32, - ); - let visible = TensorIr::uninit(client.create_empty_handle(), visible_shape, DType::F32); - - let input_tensors = [means, log_scales, quats, sh_coeffs, opacity]; - let stream = OperationStreams::with_inputs(&input_tensors); - let desc = CustomOpIr::new( - "render_splats", - &input_tensors.map(|t| t.into_ir()), - &[ - out_img, - projected_splats, - uniforms_buffer, - tile_offsets, - compact_gid_from_isect, - global_from_compact_gid, - visible, - ], - ); - let op = CustomOp { - cam: cam.clone(), - img_size, - bwd_info, - background, - render_mode, - desc: desc.clone(), - }; - - let outputs = client - .register(stream, OperationIr::Custom(desc), op) - .outputs(); - - let [ - // Img - out_img, - // Aux - projected_splats, - uniforms_buffer, - tile_offsets, - compact_gid_from_isect, - global_from_compact_gid, - visible, - ] = outputs; - - ( - out_img, - RenderAux:: { - projected_splats, - uniforms_buffer, - tile_offsets, - compact_gid_from_isect, - global_from_compact_gid, - visible, - img_size, - }, - ) - } -} - -// Fusion implementation for SplatProjectPrepare -impl SplatProjectPrepare for Fusion { - fn project_prepare( - cam: &Camera, - img_size: glam::UVec2, - means: FloatTensor, - log_scales: FloatTensor, - quats: FloatTensor, - sh_coeffs: FloatTensor, - opacity: FloatTensor, - render_mode: SplatRenderMode, - background: Vec3, ) -> ProjectAux { #[derive(Debug)] struct CustomOp { cam: Camera, img_size: glam::UVec2, render_mode: SplatRenderMode, - background: Vec3, desc: CustomOpIr, } @@ -230,12 +47,12 @@ impl SplatProjectPrepare for Fusion { let [means, log_scales, quats, sh_coeffs, opacity] = inputs; let [ projected_splats, - uniforms_buffer, + num_visible, global_from_compact_gid, cum_tiles_hit, ] = outputs; - let aux = MainBackendBase::project_prepare( + let aux = MainBackendBase::project( &self.cam, self.img_size, h.get_float_tensor::(means), @@ -244,15 +61,14 @@ impl SplatProjectPrepare for Fusion { h.get_float_tensor::(sh_coeffs), h.get_float_tensor::(opacity), self.render_mode, - self.background, ); - // Register outputs + // Register outputs (project_uniforms is stored on ProjectAux directly) h.register_float_tensor::( &projected_splats.id, aux.projected_splats, ); - h.register_int_tensor::(&uniforms_buffer.id, aux.uniforms_buffer); + h.register_int_tensor::(&num_visible.id, aux.num_visible); h.register_int_tensor::( &global_from_compact_gid.id, aux.global_from_compact_gid, @@ -263,32 +79,43 @@ impl SplatProjectPrepare for Fusion { let client = means.client.clone(); let num_points = means.shape[0]; + let sh_degree = sh_degree_from_coeffs(sh_coeffs.shape[1] as u32); + let tile_bounds = calc_tile_bounds(img_size); let proj_size = size_of::() / 4; - let uniforms_size = size_of::() / 4; let projected_splats = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, proj_size]), DType::F32, ); - let uniforms_buffer = TensorIr::uninit( - client.create_empty_handle(), - Shape::new([uniforms_size]), - DType::U32, - ); + let num_visible = + TensorIr::uninit(client.create_empty_handle(), Shape::new([1]), DType::U32); let global_from_compact_gid = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points]), DType::U32, ); - // cum_tiles_hit has size num_points + 1 let cum_tiles_hit = TensorIr::uninit( client.create_empty_handle(), - Shape::new([num_points + 1]), + Shape::new([num_points]), DType::U32, ); + // Create project_uniforms from camera and img_size (stored on ProjectAux directly) + let project_uniforms = ProjectUniforms { + viewmat: glam::Mat4::from(cam.world_to_local()).to_cols_array_2d(), + camera_position: [cam.position.x, cam.position.y, cam.position.z, 0.0], + focal: cam.focal(img_size).into(), + pixel_center: cam.center(img_size).into(), + img_size: img_size.into(), + tile_bounds: tile_bounds.into(), + sh_degree, + total_splats: num_points as u32, + pad_a: 0, + pad_b: 0, + }; + let input_tensors = [means, log_scales, quats, sh_coeffs, opacity]; let stream = OperationStreams::with_inputs(&input_tensors); let desc = CustomOpIr::new( @@ -296,7 +123,7 @@ impl SplatProjectPrepare for Fusion { &input_tensors.map(|t| t.into_ir()), &[ projected_splats, - uniforms_buffer, + num_visible, global_from_compact_gid, cum_tiles_hit, ], @@ -305,7 +132,6 @@ impl SplatProjectPrepare for Fusion { cam: cam.clone(), img_size, render_mode, - background, desc: desc.clone(), }; @@ -315,23 +141,21 @@ impl SplatProjectPrepare for Fusion { let [ projected_splats, - uniforms_buffer, + num_visible, global_from_compact_gid, cum_tiles_hit, ] = outputs; ProjectAux:: { projected_splats, - uniforms_buffer, + project_uniforms, + num_visible, global_from_compact_gid, cum_tiles_hit, img_size, } } -} -// Fusion implementation for SplatRasterize -impl SplatRasterize for Fusion { fn rasterize( project_aux: &ProjectAux, num_intersections: u32, @@ -344,6 +168,7 @@ impl SplatRasterize for Fusion { num_intersections: u32, background: Vec3, bwd_info: bool, + project_uniforms: ProjectUniforms, desc: CustomOpIr, } @@ -356,21 +181,18 @@ impl SplatRasterize for Fusion { let [ projected_splats, - uniforms_buffer, + num_visible, global_from_compact_gid, cum_tiles_hit, ] = inputs; - let [ - out_img, - tile_offsets, - compact_gid_from_isect, - visible, - ] = outputs; + let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs; let inner_aux = ProjectAux:: { projected_splats: h.get_float_tensor::(projected_splats), - uniforms_buffer: h.get_int_tensor::(uniforms_buffer), - global_from_compact_gid: h.get_int_tensor::(global_from_compact_gid), + project_uniforms: self.project_uniforms, + num_visible: h.get_int_tensor::(num_visible), + global_from_compact_gid: h + .get_int_tensor::(global_from_compact_gid), cum_tiles_hit: h.get_int_tensor::(cum_tiles_hit), img_size: self.img_size, }; @@ -428,7 +250,7 @@ impl SplatRasterize for Fusion { let input_tensors = [ project_aux.projected_splats.clone(), - project_aux.uniforms_buffer.clone(), + project_aux.num_visible.clone(), project_aux.global_from_compact_gid.clone(), project_aux.cum_tiles_hit.clone(), ]; @@ -436,18 +258,14 @@ impl SplatRasterize for Fusion { let desc = CustomOpIr::new( "rasterize", &input_tensors.map(|t| t.into_ir()), - &[ - out_img, - tile_offsets, - compact_gid_from_isect, - visible, - ], + &[out_img, tile_offsets, compact_gid_from_isect, visible], ); let op = CustomOp { img_size, num_intersections, background, bwd_info, + project_uniforms: project_aux.project_uniforms, desc: desc.clone(), }; @@ -455,12 +273,7 @@ impl SplatRasterize for Fusion { .register(stream, OperationIr::Custom(desc), op) .outputs(); - let [ - out_img, - tile_offsets, - compact_gid_from_isect, - visible, - ] = outputs; + let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs; ( out_img, @@ -468,6 +281,7 @@ impl SplatRasterize for Fusion { tile_offsets, compact_gid_from_isect, visible, + img_size, }, ) } diff --git a/crates/brush-render/src/gaussian_splats.rs b/crates/brush-render/src/gaussian_splats.rs index 16085311..2a78592b 100644 --- a/crates/brush-render/src/gaussian_splats.rs +++ b/crates/brush-render/src/gaussian_splats.rs @@ -9,12 +9,15 @@ use glam::Vec3; use tracing::trace_span; use crate::{ - SplatForward, + SplatOps, camera::Camera, - render_aux::RenderAux, + render_aux::{ProjectAux, RasterizeAux, validate_render_output}, sh::{sh_coeffs_for_degree, sh_degree_from_coeffs}, }; +#[cfg(target_family = "wasm")] +use crate::render::calc_tile_bounds; + #[derive( Module, Clone, Copy, Debug, Eq, PartialEq, ValueEnum, serde::Serialize, serde::Deserialize, )] @@ -234,13 +237,13 @@ impl Splats { /// /// NB: This doesn't work on a differentiable backend. Use /// [`brush_render_bwd::render_splats`] for that. -pub fn render_splats>( +pub fn render_splats>( splats: &Splats, camera: &Camera, img_size: glam::UVec2, background: Vec3, splat_scale: Option, -) -> (Tensor, RenderAux) { +) -> (Tensor, ProjectAux, RasterizeAux) { splats.validate_values(); let mut scales = splats.log_scales.val(); @@ -250,7 +253,8 @@ pub fn render_splats>( scales = scales + scale.ln(); }; - let (img, aux) = B::render_splats( + // First pass: project + let project_aux = B::project( camera, img_size, splats.means.val().into_primitive().tensor(), @@ -259,12 +263,16 @@ pub fn render_splats>( splats.sh_coeffs.val().into_primitive().tensor(), splats.raw_opacities.val().into_primitive().tensor(), splats.render_mode, - background, - false, ); - let img = Tensor::from_primitive(TensorPrimitive::Float(img)); - aux.validate_values(); + let num_intersections = project_aux.num_intersections(); + + // Second pass: rasterize + let (out_img, rasterize_aux) = B::rasterize(&project_aux, num_intersections, background, false); + + let img = Tensor::from_primitive(TensorPrimitive::Float(out_img)); + + validate_render_output(&project_aux, &rasterize_aux); - (img, aux) + (img, project_aux, rasterize_aux) } diff --git a/crates/brush-render/src/get_tile_offset.rs b/crates/brush-render/src/get_tile_offset.rs index 21e19a8c..6223f277 100644 --- a/crates/brush-render/src/get_tile_offset.rs +++ b/crates/brush-render/src/get_tile_offset.rs @@ -7,31 +7,6 @@ use burn_cubecl::cubecl::prelude::{ pub(crate) const CHECKS_PER_ITER: u32 = 8; -#[cube] -fn check_tile_boundary( - tile_id_from_isect: &Tensor, - tile_offsets: &mut Tensor, - isect_id: u32, - inter: u32, -) { - if isect_id < inter { - let prev_tid = tile_id_from_isect[isect_id - 1]; - let tid = tile_id_from_isect[isect_id]; - - if isect_id == inter - 1 { - // Write the end of the last tile. - tile_offsets[tid * 2 + 1] = isect_id + 1; - } - - if tid != prev_tid { - // Write the end of the previous tile. - tile_offsets[prev_tid * 2 + 1] = isect_id; - // Write start of this tile. - tile_offsets[tid * 2] = isect_id; - } - } -} - #[cube(launch_unchecked)] pub fn get_tile_offsets( tile_id_from_isect: &Tensor, @@ -46,6 +21,22 @@ pub fn get_tile_offsets( #[unroll] for i in 0..CHECKS_PER_ITER { - check_tile_boundary(tile_id_from_isect, tile_offsets, base_id + i, inter); + let isect_id = base_id + i; + if isect_id < inter { + let prev_tid = tile_id_from_isect[isect_id - 1]; + let tid = tile_id_from_isect[isect_id]; + + if isect_id == inter - 1 { + // Write the end of the last tile. + tile_offsets[tid * 2 + 1] = isect_id + 1; + } + + if tid != prev_tid { + // Write the end of the previous tile. + tile_offsets[prev_tid * 2 + 1] = isect_id; + // Write start of this tile. + tile_offsets[tid * 2] = isect_id; + } + } } } diff --git a/crates/brush-render/src/lib.rs b/crates/brush-render/src/lib.rs index 33153598..9f196faa 100644 --- a/crates/brush-render/src/lib.rs +++ b/crates/brush-render/src/lib.rs @@ -8,7 +8,7 @@ use burn_wgpu::WgpuRuntime; use camera::Camera; use clap::ValueEnum; use glam::Vec3; -use render_aux::{ProjectAux, RasterizeAux, RenderAux}; +use render_aux::{ProjectAux, RasterizeAux}; use crate::gaussian_splats::SplatRenderMode; pub use crate::gaussian_splats::render_splats; @@ -38,35 +38,20 @@ pub struct RenderStats { pub num_visible: u32, } -pub trait SplatForward { - /// Render splats to a buffer. - /// - /// This projects the gaussians, sorts them, and rasterizes them to a buffer, in a - /// differentiable way. - /// The arguments are all passed as raw tensors. See [`Splats`] for a convenient Module that wraps this fun - /// The [`xy_grad_dummy`] variable is only used to carry screenspace xy gradients. - /// This function can optionally render a "u32" buffer, which is a packed RGBA (8 bits per channel) - /// buffer. This is useful when the results need to be displayed immediately. - fn render_splats( - camera: &Camera, - img_size: glam::UVec2, - means: FloatTensor, - log_scales: FloatTensor, - quats: FloatTensor, - sh_coeffs: FloatTensor, - raw_opacities: FloatTensor, - render_mode: SplatRenderMode, - background: Vec3, - bwd_info: bool, - ) -> (FloatTensor, RenderAux); -} - -/// First pass of split rendering pipeline: culling, depth sort, projection, intersection counting, prefix sum. +/// Trait for the split gaussian splatting rendering pipeline. /// -/// Returns [`ProjectAux`] which contains data needed for [`SplatRasterize::rasterize`], -/// including `cum_tiles_hit` which allows sync readback of the exact number of intersections. -pub trait SplatProjectPrepare { - fn project_prepare( +/// This trait provides two passes: +/// 1. `project`: Culling, depth sort, projection, intersection counting, prefix sum. +/// 2. `rasterize`: Intersection filling, tile sort, tile offsets, rasterization. +/// +/// The split allows for an explicit GPU sync point between passes to read back +/// the exact number of intersections needed for buffer allocation. +pub trait SplatOps { + /// First pass: project gaussians and count intersections. + /// + /// Returns [`ProjectAux`] containing projected splat data and `num_intersections` + /// tensor for explicit readback. + fn project( camera: &Camera, img_size: glam::UVec2, means: FloatTensor, @@ -75,15 +60,12 @@ pub trait SplatProjectPrepare { sh_coeffs: FloatTensor, raw_opacities: FloatTensor, render_mode: SplatRenderMode, - background: Vec3, ) -> ProjectAux; -} -/// Second pass of split rendering pipeline: intersection filling, tile sort, tile offsets, rasterization. -/// -/// Takes the output of [`SplatProjectPrepare::project_prepare`] along with the actual -/// `num_intersections` value from userland readback. -pub trait SplatRasterize { + /// Second pass: rasterize using projection data. + /// + /// Takes the output of [`Self::project`] along with the actual + /// `num_intersections` value from sync readback. fn rasterize( project_aux: &ProjectAux, num_intersections: u32, diff --git a/crates/brush-render/src/render.rs b/crates/brush-render/src/render.rs index 43ce4a56..c85118c4 100644 --- a/crates/brush-render/src/render.rs +++ b/crates/brush-render/src/render.rs @@ -1,42 +1,39 @@ use crate::{ - MainBackendBase, SplatForward, SplatProjectPrepare, SplatRasterize, + MainBackendBase, SplatOps, camera::Camera, dim_check::DimCheck, gaussian_splats::SplatRenderMode, get_tile_offset::{CHECKS_PER_ITER, get_tile_offsets}, - render_aux::{ProjectAux, RasterizeAux, RenderAux}, + render_aux::{ProjectAux, RasterizeAux}, sh::sh_degree_from_coeffs, shaders::{self, MapGaussiansToIntersect, ProjectSplats, ProjectVisible, Rasterize}, }; +use brush_kernel::bytemuck; use brush_kernel::create_dispatch_buffer_1d; +use brush_kernel::create_meta_binding; use brush_kernel::create_tensor; -use brush_kernel::create_uniform_buffer; use brush_kernel::{CubeCount, calc_cube_count_1d}; use brush_prefix_sum::prefix_sum; use brush_sort::radix_argsort; -use burn::tensor::{DType, IntDType, ops::FloatTensor}; +use burn::tensor::{DType, IntDType, Shape, ops::FloatTensor}; use burn::tensor::{ FloatDType, ops::{FloatTensorOps, IntTensorOps}, }; use burn_cubecl::cubecl::server::Bindings; - use burn_cubecl::kernel::into_contiguous; -use burn_wgpu::WgpuRuntime; -use burn_wgpu::CubeDim; +use burn_wgpu::{CubeDim, CubeTensor, WgpuRuntime}; use glam::{Vec3, uvec2}; -use std::mem::offset_of; -pub(crate) fn calc_tile_bounds(img_size: glam::UVec2) -> glam::UVec2 { +pub fn calc_tile_bounds(img_size: glam::UVec2) -> glam::UVec2 { uvec2( img_size.x.div_ceil(shaders::helpers::TILE_WIDTH), img_size.y.div_ceil(shaders::helpers::TILE_WIDTH), ) } -// Implement the first pass: ProjectPrepare -impl SplatProjectPrepare for MainBackendBase { - fn project_prepare( +impl SplatOps for MainBackendBase { + fn project( camera: &Camera, img_size: glam::UVec2, means: FloatTensor, @@ -45,7 +42,6 @@ impl SplatProjectPrepare for MainBackendBase { sh_coeffs: FloatTensor, raw_opacities: FloatTensor, render_mode: SplatRenderMode, - background: Vec3, ) -> ProjectAux { assert!( img_size[0] > 0 && img_size[1] > 0, @@ -60,7 +56,6 @@ impl SplatProjectPrepare for MainBackendBase { let raw_opacities = into_contiguous(raw_opacities); let device = &means.device.clone(); - let client = means.client.clone(); let _span = tracing::trace_span!("project_prepare").entered(); @@ -79,7 +74,7 @@ impl SplatProjectPrepare for MainBackendBase { let sh_degree = sh_degree_from_coeffs(sh_coeffs.shape.dims[1] as u32); let total_splats = means.shape.dims[0]; - let uniforms = shaders::helpers::RenderUniforms { + let project_uniforms = shaders::helpers::ProjectUniforms { viewmat: glam::Mat4::from(camera.world_to_local()).to_cols_array_2d(), camera_position: [camera.position.x, camera.position.y, camera.position.z, 0.0], focal: camera.focal(img_size).into(), @@ -88,18 +83,18 @@ impl SplatProjectPrepare for MainBackendBase { tile_bounds: tile_bounds.into(), sh_degree, total_splats: total_splats as u32, - background: [background.x, background.y, background.z, 1.0], - num_visible: 0, pad_a: 0, + pad_b: 0, }; - let uniforms_buffer = create_uniform_buffer(uniforms, device, &client); + // Separate buffer for num_visible (written atomically by ProjectSplats) + let num_visible_buffer = Self::int_zeros([1].into(), device, IntDType::U32); let client = &means.client.clone(); let mip_splat = matches!(render_mode, SplatRenderMode::Mip); // Step 1: ProjectSplats - culling pass - let (global_from_compact_gid, num_visible) = { + let global_from_compact_gid = { let global_from_presort_gid = Self::int_zeros([total_splats].into(), device, IntDType::U32); let depths = create_tensor([total_splats], device, DType::F32); @@ -110,61 +105,62 @@ impl SplatProjectPrepare for MainBackendBase { client.launch_unchecked( ProjectSplats::task(mip_splat), calc_cube_count_1d(total_splats as u32, ProjectSplats::WORKGROUP_SIZE[0]), - Bindings::new().with_buffers( - vec![ - uniforms_buffer.handle.clone().binding(), - means.handle.clone().binding(), - quats.handle.clone().binding(), - log_scales.handle.clone().binding(), - raw_opacities.handle.clone().binding(), - global_from_presort_gid.handle.clone().binding(), - depths.handle.clone().binding(), - ]), + Bindings::new() + .with_buffers(vec![ + means.handle.clone().binding(), + quats.handle.clone().binding(), + log_scales.handle.clone().binding(), + raw_opacities.handle.clone().binding(), + global_from_presort_gid.handle.clone().binding(), + depths.handle.clone().binding(), + num_visible_buffer.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(project_uniforms)), ).expect("Failed to render splats"); }); - // Get just the number of visible splats from the uniforms buffer. - let num_vis_field_offset = - offset_of!(shaders::helpers::RenderUniforms, num_visible) / 4; - let num_visible = Self::int_slice( - uniforms_buffer.clone(), - &[(num_vis_field_offset..num_vis_field_offset + 1).into()], - ); - - // Step 2: DepthSort + // Step 2: DepthSort - use dynamic count to sort only up to num_visible let (_, global_from_compact_gid) = tracing::trace_span!("DepthSort").in_scope(|| { - radix_argsort(depths, global_from_presort_gid, &num_visible, 32) + radix_argsort( + depths, + global_from_presort_gid, + 32, + Some(num_visible_buffer.clone()), + ) }); - (global_from_compact_gid, num_visible) + global_from_compact_gid }; // Step 3: ProjectVisible with intersection counting let proj_size = size_of::() / size_of::(); let projected_splats = create_tensor([total_splats, proj_size], device, DType::F32); - let splat_intersect_counts = - Self::int_zeros([total_splats + 1].into(), device, IntDType::U32); + let splat_intersect_counts = Self::int_zeros([total_splats].into(), device, IntDType::U32); tracing::trace_span!("ProjectVisibleWithCounting").in_scope(|| { - let num_vis_wg = - create_dispatch_buffer_1d(num_visible.clone(), ProjectVisible::WORKGROUP_SIZE[0]); + let num_vis_wg = create_dispatch_buffer_1d( + num_visible_buffer.clone(), + ProjectVisible::WORKGROUP_SIZE[0], + ); // SAFETY: Kernel checked to have no OOB, bounded loops. unsafe { client .launch_unchecked( - ProjectVisible::task(mip_splat, true), // count_intersections = true + ProjectVisible::task(mip_splat), CubeCount::Dynamic(num_vis_wg.handle.binding()), - Bindings::new().with_buffers(vec![ - uniforms_buffer.clone().handle.binding(), - means.handle.binding(), - log_scales.handle.binding(), - quats.handle.binding(), - sh_coeffs.handle.binding(), - raw_opacities.handle.binding(), - global_from_compact_gid.handle.clone().binding(), - projected_splats.handle.clone().binding(), - splat_intersect_counts.handle.clone().binding(), - ]), + Bindings::new() + .with_buffers(vec![ + num_visible_buffer.handle.clone().binding(), + means.handle.binding(), + log_scales.handle.binding(), + quats.handle.binding(), + sh_coeffs.handle.binding(), + raw_opacities.handle.binding(), + global_from_compact_gid.handle.clone().binding(), + projected_splats.handle.clone().binding(), + splat_intersect_counts.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(project_uniforms)), ) .expect("Failed to render splats"); } @@ -174,36 +170,20 @@ impl SplatProjectPrepare for MainBackendBase { let cum_tiles_hit = tracing::trace_span!("PrefixSumGaussHits") .in_scope(|| prefix_sum(splat_intersect_counts)); - // Sanity check - assert!( - uniforms_buffer.is_contiguous(), - "Uniforms must be contiguous" - ); - assert!( - global_from_compact_gid.is_contiguous(), - "Global from compact gid must be contiguous" - ); - assert!( - projected_splats.is_contiguous(), - "Projected splats must be contiguous" - ); - ProjectAux { projected_splats, - uniforms_buffer, + project_uniforms, + num_visible: num_visible_buffer, global_from_compact_gid, cum_tiles_hit, img_size, } } -} -// Implement the second pass: Rasterize -impl SplatRasterize for MainBackendBase { fn rasterize( project_aux: &ProjectAux, num_intersections: u32, - _background: Vec3, // Background is read from uniforms_buffer + background: Vec3, bwd_info: bool, ) -> (FloatTensor, RasterizeAux) { let _span = tracing::trace_span!("rasterize").entered(); @@ -216,68 +196,58 @@ impl SplatRasterize for MainBackendBase { let tile_bounds = calc_tile_bounds(img_size); let num_tiles = tile_bounds.x * tile_bounds.y; - // Get num_visible from uniforms buffer - let num_vis_field_offset = offset_of!(shaders::helpers::RenderUniforms, num_visible) / 4; - let num_visible = Self::int_slice( - project_aux.uniforms_buffer.clone(), - &[(num_vis_field_offset..num_vis_field_offset + 1).into()], - ); - + // Create rasterize uniforms (passed via with_metadata) + let rasterize_uniforms = shaders::helpers::RasterizeUniforms { + tile_bounds: tile_bounds.into(), + img_size: img_size.into(), + background: [background.x, background.y, background.z, 1.0], + }; // Step 1: Allocate intersection buffers with exact size (minimum 1 to avoid zero-size allocation) let buffer_size = (num_intersections as usize).max(1); let tile_id_from_isect = create_tensor([buffer_size], device, DType::U32); let compact_gid_from_isect = create_tensor([buffer_size], device, DType::U32); - // Step 2: MapGaussiansToIntersect (fill pass, prepass=false) - let num_vis_map_wg = - create_dispatch_buffer_1d(num_visible.clone(), MapGaussiansToIntersect::WORKGROUP_SIZE[0]); + // Step 2: MapGaussiansToIntersect (fill pass) + let num_vis_map_wg = create_dispatch_buffer_1d( + project_aux.num_visible.clone(), + MapGaussiansToIntersect::WORKGROUP_SIZE[0], + ); + + // Uniforms for map_gaussian_to_intersects (passed via with_metadata) + let map_uniforms = shaders::map_gaussians_to_intersect::Uniforms { + tile_bounds: tile_bounds.into(), + }; tracing::trace_span!("MapGaussiansToIntersect").in_scope(|| { // SAFETY: Kernel checked to have no OOB, bounded loops. unsafe { client .launch_unchecked( - MapGaussiansToIntersect::task(false), + MapGaussiansToIntersect::task(), CubeCount::Dynamic(num_vis_map_wg.handle.clone().binding()), - Bindings::new().with_buffers(vec![ - project_aux.uniforms_buffer.handle.clone().binding(), - project_aux.projected_splats.handle.clone().binding(), - project_aux.cum_tiles_hit.handle.clone().binding(), - tile_id_from_isect.handle.clone().binding(), - compact_gid_from_isect.handle.clone().binding(), - ]), + Bindings::new() + .with_buffers(vec![ + project_aux.num_visible.handle.clone().binding(), + project_aux.projected_splats.handle.clone().binding(), + project_aux.cum_tiles_hit.handle.clone().binding(), + tile_id_from_isect.handle.clone().binding(), + compact_gid_from_isect.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(map_uniforms)), ) .expect("Failed to render splats"); } }); - // Step 3: Tile sort - use static dispatch with actual num_intersections + // Step 3: Tile sort - use static dispatch with full buffer (num_intersections) let bits = u32::BITS - num_tiles.leading_zeros(); - // Create a tensor holding num_intersections for the sort - // Get the last element from cum_tiles_hit - let cum_len = project_aux.cum_tiles_hit.shape[0]; - let num_intersections_tensor = Self::int_slice( - project_aux.cum_tiles_hit.clone(), - &[(cum_len - 1..cum_len).into()], - ); - let (tile_id_from_isect, compact_gid_from_isect) = tracing::trace_span!("Tile sort") - .in_scope(|| { - radix_argsort( - tile_id_from_isect, - compact_gid_from_isect, - &num_intersections_tensor, - bits, - ) - }); + .in_scope(|| radix_argsort(tile_id_from_isect, compact_gid_from_isect, bits, None)); // Step 4: GetTileOffsets let cube_dim = CubeDim::new_1d(256); - let num_vis_map_wg = - create_dispatch_buffer_1d(num_intersections_tensor.clone(), 256 * CHECKS_PER_ITER); - let cube_count = CubeCount::Dynamic(num_vis_map_wg.handle.binding()); let tile_offsets = Self::int_zeros( [tile_bounds.y as usize, tile_bounds.x as usize, 2].into(), @@ -285,15 +255,27 @@ impl SplatRasterize for MainBackendBase { IntDType::U32, ); + // Create a tensor for num_intersections + let num_inter_tensor = { + let data: [u32; 1] = [num_intersections]; + CubeTensor::new_contiguous( + client.clone(), + device.clone(), + Shape::new([1]), + client.create_from_slice(bytemuck::cast_slice(&data)), + DType::U32, + ) + }; + // SAFETY: Safe kernel. unsafe { get_tile_offsets::launch_unchecked::( &client, - cube_count, + calc_cube_count_1d(num_intersections, cube_dim.x * CHECKS_PER_ITER), cube_dim, tile_id_from_isect.as_tensor_arg(1), tile_offsets.as_tensor_arg(1), - num_intersections_tensor.as_tensor_arg(1), + num_inter_tensor.as_tensor_arg(1), ) .expect("Failed to render splats"); } @@ -306,30 +288,32 @@ impl SplatRasterize for MainBackendBase { DType::F32, ); - // Update background in uniforms - we need to create a modified uniforms buffer - // For now, we'll pass the background through the existing buffer structure - // The rasterize kernel reads background from uniforms, so we need to ensure it's set - - let mut bindings = Bindings::new().with_buffers(vec![ - project_aux.uniforms_buffer.handle.clone().binding(), - compact_gid_from_isect.handle.clone().binding(), - tile_offsets.handle.clone().binding(), - project_aux.projected_splats.handle.clone().binding(), - out_img.handle.clone().binding(), - ]); - // Get total_splats from the shape of projected_splats let total_splats = project_aux.projected_splats.shape.dims[0]; - let visible = if bwd_info { + let (bindings, visible) = if bwd_info { let visible = Self::float_zeros([total_splats].into(), device, FloatDType::F32); - bindings = bindings.with_buffers(vec![ - project_aux.global_from_compact_gid.handle.clone().binding(), - visible.handle.clone().binding(), - ]); - visible + let bindings = Bindings::new() + .with_buffers(vec![ + compact_gid_from_isect.handle.clone().binding(), + tile_offsets.handle.clone().binding(), + project_aux.projected_splats.handle.clone().binding(), + out_img.handle.clone().binding(), + project_aux.global_from_compact_gid.handle.clone().binding(), + visible.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(rasterize_uniforms)); + (bindings, visible) } else { - create_tensor([1], device, DType::F32) + let bindings = Bindings::new() + .with_buffers(vec![ + compact_gid_from_isect.handle.clone().binding(), + tile_offsets.handle.clone().binding(), + project_aux.projected_splats.handle.clone().binding(), + out_img.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(rasterize_uniforms)); + (bindings, create_tensor([1], device, DType::F32)) }; let raster_task = Rasterize::task(bwd_info); @@ -346,7 +330,10 @@ impl SplatRasterize for MainBackendBase { } // Sanity checks - assert!(tile_offsets.is_contiguous(), "Tile offsets must be contiguous"); + assert!( + tile_offsets.is_contiguous(), + "Tile offsets must be contiguous" + ); assert!(visible.is_contiguous(), "Visible must be contiguous"); ( @@ -355,59 +342,8 @@ impl SplatRasterize for MainBackendBase { tile_offsets, compact_gid_from_isect, visible, + img_size: project_aux.img_size, }, ) } } - -// Implement backwards-compatible render_splats using the split pipeline -impl SplatForward for MainBackendBase { - fn render_splats( - camera: &Camera, - img_size: glam::UVec2, - means: FloatTensor, - log_scales: FloatTensor, - quats: FloatTensor, - sh_coeffs: FloatTensor, - raw_opacities: FloatTensor, - render_mode: SplatRenderMode, - background: Vec3, - bwd_info: bool, - ) -> (FloatTensor, RenderAux) { - // First pass: project and prepare (includes background in uniforms) - let project_aux = Self::project_prepare( - camera, - img_size, - means, - log_scales, - quats, - sh_coeffs, - raw_opacities, - render_mode, - background, - ); - - // Sync readback of num_intersections - #[cfg(not(target_family = "wasm"))] - let num_intersections = project_aux.num_intersections(); - - #[cfg(target_family = "wasm")] - let num_intersections = { - // On wasm, estimate max intersections - let tile_bounds = calc_tile_bounds(img_size); - let num_tiles = tile_bounds[0] * tile_bounds[1]; - let total_splats = project_aux.projected_splats.shape.dims[0] as u32; - let max_possible = num_tiles.saturating_mul(total_splats); - max_possible.min(2 * 512 * 65535) - }; - - // Second pass: rasterize - let (out_img, rasterize_aux) = - Self::rasterize(&project_aux, num_intersections, background, bwd_info); - - // Combine into RenderAux for backwards compatibility - let render_aux = RenderAux::from_parts(project_aux, rasterize_aux); - - (out_img, render_aux) - } -} diff --git a/crates/brush-render/src/render_aux.rs b/crates/brush-render/src/render_aux.rs index a156a6d8..84c25f2c 100644 --- a/crates/brush-render/src/render_aux.rs +++ b/crates/brush-render/src/render_aux.rs @@ -1,48 +1,117 @@ -use std::mem::offset_of; - use burn::{ Tensor, prelude::Backend, tensor::{ Int, ops::{FloatTensor, IntTensor}, - s, }, }; -use crate::shaders::{self, helpers::TILE_WIDTH}; +use crate::shaders::helpers::ProjectUniforms; + +/// Validate both ProjectAux and RasterizeAux outputs together. +/// This is a no-op in release builds unless `debug-validation` feature is enabled. +/// Also skipped when running benchmarks (detected via `--bench` arg). +pub fn validate_render_output(project_aux: &ProjectAux, rasterize_aux: &RasterizeAux) { + #[cfg(any(test, feature = "debug-validation"))] + { + if std::env::args().any(|a| a == "--bench") { + return; + } + project_aux.validate_values(); + use burn::tensor::ElementConversion; + let num_visible = project_aux.get_num_visible().into_scalar().elem::(); + rasterize_aux.validate_values(num_visible); + } + #[cfg(not(any(test, feature = "debug-validation")))] + { + let _ = project_aux; + let _ = rasterize_aux; + } +} -/// Output of the ProjectPrepare pass. -/// -/// Contains all data needed to perform the Rasterize pass, including -/// the `cum_tiles_hit` buffer which can be used to extract the exact -/// number of intersections via sync readback. #[derive(Debug, Clone)] pub struct ProjectAux { + pub project_uniforms: ProjectUniforms, + /// The packed projected splat information, see `ProjectedSplat` in helpers.wgsl pub projected_splats: FloatTensor, - pub uniforms_buffer: IntTensor, + pub num_visible: IntTensor, pub global_from_compact_gid: IntTensor, - /// Cumulative sum of tiles hit per splat. Last element contains total num_intersections. pub cum_tiles_hit: IntTensor, pub img_size: glam::UVec2, } impl ProjectAux { - /// Extract the total number of intersections from the cum_tiles_hit buffer. + /// Extract the total number of intersections. /// /// This requires a sync readback from the GPU. - #[cfg(not(target_family = "wasm"))] pub fn num_intersections(&self) -> u32 { use burn::tensor::ElementConversion; - let cum: Tensor = Tensor::from_primitive(self.cum_tiles_hit.clone()); - let len = cum.dims()[0]; - cum.slice(s![len - 1..len]).into_scalar().elem::() + let cum_tiles_hit: Tensor = Tensor::from_primitive(self.cum_tiles_hit.clone()); + let total = self.project_uniforms.total_splats as usize; + // The prefix sum is inclusive, so the last element is the total number of intersections + if total > 0 { + cum_tiles_hit + .slice([total - 1..total]) + .into_scalar() + .elem::() + } else { + 0 + } } - pub fn num_visible(&self) -> Tensor { - let num_vis_field_offset = offset_of!(shaders::helpers::RenderUniforms, num_visible) / 4; - Tensor::from_primitive(self.uniforms_buffer.clone()).slice(s![num_vis_field_offset]) + pub fn get_num_visible(&self) -> Tensor { + Tensor::from_primitive(self.num_visible.clone()) + } + + pub fn validate_values(&self) { + #[cfg(any(test, feature = "debug-validation"))] + { + use burn::tensor::{ElementConversion, TensorPrimitive, s}; + + use crate::validation::validate_tensor_val; + + if std::env::args().any(|a| a == "--bench") { + return; + } + + let num_visible_tensor: Tensor = self.get_num_visible(); + let total_splats = self.project_uniforms.total_splats; + let num_visible = num_visible_tensor.into_scalar().elem::() as u32; + + assert!( + num_visible <= total_splats, + "Something went wrong when calculating the number of visible gaussians. {num_visible} > {total_splats}" + ); + + // Projected splats is only valid up to num_visible and undefined for other values. + if num_visible > 0 { + let projected_splats: Tensor = + Tensor::from_primitive(TensorPrimitive::Float(self.projected_splats.clone())); + let projected_splats = projected_splats.slice(s![0..num_visible, ..]); + validate_tensor_val(&projected_splats, "projected_splats", None, None); + } + + // assert that every ID in global_from_compact_gid is valid. + // Only validate when there are visible splats + if num_visible > 0 && total_splats > 0 { + let global_from_compact_gid: Tensor = + Tensor::from_primitive(self.global_from_compact_gid.clone()); + let global_from_compact_gid = &global_from_compact_gid + .into_data() + .into_vec::() + .expect("Failed to fetch global_from_compact_gid") + [0..num_visible as usize]; + + for &global_gid in global_from_compact_gid { + assert!( + global_gid < total_splats, + "Invalid gaussian ID in global_from_compact_gid buffer. {global_gid} out of {total_splats}" + ); + } + } + } } } @@ -52,37 +121,14 @@ pub struct RasterizeAux { pub tile_offsets: IntTensor, pub compact_gid_from_isect: IntTensor, pub visible: FloatTensor, -} - -#[derive(Debug, Clone)] -pub struct RenderAux { - /// The packed projected splat information, see `ProjectedSplat` in helpers.wgsl - pub projected_splats: FloatTensor, - pub uniforms_buffer: IntTensor, - pub tile_offsets: IntTensor, - pub compact_gid_from_isect: IntTensor, - pub global_from_compact_gid: IntTensor, - pub visible: FloatTensor, pub img_size: glam::UVec2, } -impl RenderAux { - /// Combine ProjectAux and RasterizeAux into a RenderAux for backwards compatibility. - pub fn from_parts(project: ProjectAux, rasterize: RasterizeAux) -> Self { - Self { - projected_splats: project.projected_splats, - uniforms_buffer: project.uniforms_buffer, - global_from_compact_gid: project.global_from_compact_gid, - tile_offsets: rasterize.tile_offsets, - compact_gid_from_isect: rasterize.compact_gid_from_isect, - visible: rasterize.visible, - img_size: project.img_size, - } - } -} - -impl RenderAux { +impl RasterizeAux { pub fn calc_tile_depth(&self) -> Tensor { + use crate::shaders::helpers::TILE_WIDTH; + use burn::tensor::s; + let tile_offsets: Tensor = Tensor::from_primitive(self.tile_offsets.clone()); let max = tile_offsets.clone().slice(s![.., .., 1]); let min = tile_offsets.slice(s![.., .., 0]); @@ -91,15 +137,10 @@ impl RenderAux { (max - min).reshape([ty as usize, tx as usize]) } - pub fn num_visible(&self) -> Tensor { - let num_vis_field_offset = offset_of!(shaders::helpers::RenderUniforms, num_visible) / 4; - Tensor::from_primitive(self.uniforms_buffer.clone()).slice(s![num_vis_field_offset]) - } - - pub fn validate_values(&self) { + pub fn validate_values(&self, #[allow(unused)] num_visible: u32) { #[cfg(any(test, feature = "debug-validation"))] { - use burn::tensor::{ElementConversion, TensorPrimitive}; + use burn::tensor::TensorPrimitive; use crate::validation::validate_tensor_val; @@ -109,46 +150,8 @@ impl RenderAux { let compact_gid_from_isect: Tensor = Tensor::from_primitive(self.compact_gid_from_isect.clone()); - let num_visible: Tensor = self.num_visible(); - - // Get num_intersections from the last element of the flattened tile_offsets - // tile_offsets has shape [ty, tx, 2] where each [i, j, :] is [start, end] for that tile - // The last element (end offset of the last tile) is the total number of intersections - let tile_offsets_3d: Tensor = Tensor::from_primitive(self.tile_offsets.clone()); - let [ty, tx, _] = tile_offsets_3d.dims(); - // Get the end offset of the last tile: tile_offsets[ty-1, tx-1, 1] - let num_intersections = tile_offsets_3d - .slice(s![ty - 1..ty, tx - 1..tx, 1..2]) - .reshape([1]) - .into_scalar() - .elem::() as u32; - - // Get total_splats from the uniforms buffer for validation - let total_splats_field_offset = - offset_of!(shaders::helpers::RenderUniforms, total_splats) / 4; - let total_splats: Tensor = - Tensor::from_primitive(self.uniforms_buffer.clone()); - let total_splats = total_splats - .slice(s![total_splats_field_offset..total_splats_field_offset + 1]) - .into_scalar() - .elem::() as u32; - - let num_visible = num_visible.into_scalar().elem::() as u32; - - assert!( - num_visible <= total_splats, - "Something went wrong when calculating the number of visible gaussians. {num_visible} > {total_splats}" - ); - - // Projected splats is only valid up to num_visible and undefined for other values. - if num_visible > 0 { - use crate::validation::validate_tensor_val; - let projected_splats: Tensor = - Tensor::from_primitive(TensorPrimitive::Float(self.projected_splats.clone())); - let projected_splats = projected_splats.slice(s![0..num_visible]); - validate_tensor_val(&projected_splats, "projected_splats", None, None); - } + let num_intersections = compact_gid_from_isect.shape()[0] as u32; let visible: Tensor = Tensor::from_primitive(TensorPrimitive::Float(self.visible.clone())); @@ -190,12 +193,10 @@ impl RenderAux { } } - // Skip validation of compact_gid_from_isect when shape is 0 (fusion placeholder) - let declared_size = compact_gid_from_isect.dims()[0]; - if num_intersections > 0 && declared_size > 0 { - let data = compact_gid_from_isect - .slice([0..num_intersections as usize]) - .into_data(); + // Validate compact_gid_from_isect + // Only validate when there are visible splats (if num_visible=0, no valid intersections) + if num_visible > 0 { + let data = compact_gid_from_isect.into_data(); // Handle both I32 and U32 tensor types let compact_gid_vec: Vec = data @@ -207,32 +208,10 @@ impl RenderAux { }) .expect("Failed to fetch compact_gid_from_isect"); - for (i, compact_gid) in compact_gid_vec.iter().enumerate() { - assert!( - *compact_gid < num_visible, - "Invalid gaussian ID in intersection buffer. {compact_gid} out of {num_visible}. At {i} out of {num_intersections} intersections. \n - - {compact_gid_vec:?} - - \n\n\n" - ); - } - } - - // assert that every ID in global_from_compact_gid is valid. - // Only validate when there are visible splats - if num_visible > 0 && total_splats > 0 { - let global_from_compact_gid: Tensor = - Tensor::from_primitive(self.global_from_compact_gid.clone()); - let global_from_compact_gid = &global_from_compact_gid - .into_data() - .into_vec::() - .expect("Failed to fetch global_from_compact_gid")[0..num_visible as usize]; - - for &global_gid in global_from_compact_gid { + for compact_gid in &compact_gid_vec { assert!( - global_gid < total_splats, - "Invalid gaussian ID in global_from_compact_gid buffer. {global_gid} out of {total_splats}" + *compact_gid < num_visible, + "Invalid gaussian ID in intersection buffer. {compact_gid} out of {num_visible}." ); } } diff --git a/crates/brush-render/src/shaders.rs b/crates/brush-render/src/shaders.rs index 2e3c2b71..0d6d1c15 100644 --- a/crates/brush-render/src/shaders.rs +++ b/crates/brush-render/src/shaders.rs @@ -10,13 +10,10 @@ pub struct ProjectSplats { #[wgsl_kernel(source = "src/shaders/project_visible.wgsl")] pub struct ProjectVisible { mip_splatting: bool, - count_intersections: bool, } #[wgsl_kernel(source = "src/shaders/map_gaussian_to_intersects.wgsl")] -pub struct MapGaussiansToIntersect { - pub prepass: bool, -} +pub struct MapGaussiansToIntersect; #[wgsl_kernel(source = "src/shaders/rasterize.wgsl")] pub struct Rasterize { @@ -28,7 +25,8 @@ pub mod helpers { // Types used by multiple shaders - available from project_visible pub use super::project_visible::PackedVec3; pub use super::project_visible::ProjectedSplat; - pub use super::project_visible::RenderUniforms; + pub use super::project_visible::ProjectUniforms; + pub use super::rasterize::RasterizeUniforms; // Constants are now associated with the kernel structs pub const COV_BLUR: f32 = super::ProjectVisible::COV_BLUR; diff --git a/crates/brush-render/src/shaders/helpers.wgsl b/crates/brush-render/src/shaders/helpers.wgsl index 4da56fe2..d840fcec 100644 --- a/crates/brush-render/src/shaders/helpers.wgsl +++ b/crates/brush-render/src/shaders/helpers.wgsl @@ -38,7 +38,8 @@ fn map_1d_to_2d(id: u32, tiles_per_row: u32) -> vec2 { return vec2u(tile_x * TILE_WIDTH, tile_y * TILE_WIDTH) + decode_morton_2d(within_tile_id); } -struct RenderUniforms { +// Uniforms for projection passes. +struct ProjectUniforms { // View matrix transform world to view position. viewmat: mat4x4f, @@ -57,20 +58,16 @@ struct RenderUniforms { // Degree of sh coefficients used. sh_degree: u32, -#ifdef UNIFORM_WRITE - // Number of visible gaussians, written by project_forward. - // This needs to be non-atomic for other kernels as you can't have - // read-only atomic data. - num_visible: atomic, -#else - // Number of visible gaussians. - num_visible: u32, -#endif + total_splats: u32, pad_a: u32, + pad_b: u32, +} - total_splats: u32, - +// Uniforms for rasterize pass. +struct RasterizeUniforms { + tile_bounds: vec2u, + img_size: vec2u, // Nb: Alpha is ignored atm. background: vec4f, } diff --git a/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl b/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl index c8c97400..557e2a5e 100644 --- a/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl +++ b/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl @@ -1,15 +1,16 @@ #import helpers; -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; +@group(0) @binding(0) var num_visible: u32; @group(0) @binding(1) var projected: array; +@group(0) @binding(2) var splat_cum_hit_counts: array; +@group(0) @binding(3) var tile_id_from_isect: array; +@group(0) @binding(4) var compact_gid_from_isect: array; -#ifdef PREPASS - @group(0) @binding(2) var splat_intersect_counts: array; -#else - @group(0) @binding(2) var splat_cum_hit_counts: array; - @group(0) @binding(3) var tile_id_from_isect: array; - @group(0) @binding(4) var compact_gid_from_isect: array; -#endif +// Uniforms passed via with_metadata (always last binding) +struct Uniforms { + tile_bounds: vec2u, +} +@group(0) @binding(5) var uniforms: Uniforms; const WG_SIZE: u32 = 256u; @@ -22,7 +23,7 @@ fn main( ) { let compact_gid = helpers::get_global_id(wid, num_wgs, lid, WG_SIZE); - if compact_gid >= uniforms.num_visible { + if compact_gid >= num_visible { return; } @@ -38,11 +39,10 @@ fn main( let tile_bbox_min = tile_bbox.xy; let tile_bbox_max = tile_bbox.zw; - var num_tiles_hit = 0u; + // With inclusive prefix sum, use cum[compact_gid - 1] as base (or 0 for first element) + let base_isect_id = select(splat_cum_hit_counts[compact_gid - 1u], 0u, compact_gid == 0u); - #ifndef PREPASS - let base_isect_id = splat_cum_hit_counts[compact_gid]; - #endif + var num_tiles_hit = 0u; // Nb: It's really really important here the two dispatches // of this kernel arrive at the exact same num_tiles_hit count. Otherwise @@ -60,20 +60,14 @@ fn main( if helpers::will_primitive_contribute(rect, mean2d, conic, power_threshold) { let tile_id = tx + ty * uniforms.tile_bounds.x; - #ifndef PREPASS let isect_id = base_isect_id + num_tiles_hit; // Nb: isect_id MIGHT be out of bounds here for degenerate cases. // These kernels should be launched with bounds checking, so that these // writes are ignored. This will skip these intersections. tile_id_from_isect[isect_id] = tile_id; compact_gid_from_isect[isect_id] = compact_gid; - #endif num_tiles_hit += 1u; } } - - #ifdef PREPASS - splat_intersect_counts[compact_gid + 1u] = num_tiles_hit; - #endif } diff --git a/crates/brush-render/src/shaders/project_forward.wgsl b/crates/brush-render/src/shaders/project_forward.wgsl index be604f56..34369d1a 100644 --- a/crates/brush-render/src/shaders/project_forward.wgsl +++ b/crates/brush-render/src/shaders/project_forward.wgsl @@ -1,17 +1,14 @@ -#define UNIFORM_WRITE - #import helpers; -// Unfiroms contains the splat count which we're writing to. -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; - -@group(0) @binding(1) var means: array; -@group(0) @binding(2) var quats: array; -@group(0) @binding(3) var log_scales: array; -@group(0) @binding(4) var raw_opacities: array; - -@group(0) @binding(5) var global_from_compact_gid: array; -@group(0) @binding(6) var depths: array; +@group(0) @binding(0) var means: array; +@group(0) @binding(1) var quats: array; +@group(0) @binding(2) var log_scales: array; +@group(0) @binding(3) var raw_opacities: array; +@group(0) @binding(4) var global_from_compact_gid: array; +@group(0) @binding(5) var depths: array; +@group(0) @binding(6) var num_visible: atomic; +// Uniforms via with_metadata (always last binding) +@group(0) @binding(7) var uniforms: helpers::ProjectUniforms; const WG_SIZE: u32 = 256u; @@ -77,7 +74,7 @@ fn main( return; } // Now write all the data to the buffers. - let write_id = atomicAdd(&uniforms.num_visible, 1u); + let write_id = atomicAdd(&num_visible, 1u); global_from_compact_gid[write_id] = global_gid; depths[write_id] = mean_c.z; } diff --git a/crates/brush-render/src/shaders/project_visible.wgsl b/crates/brush-render/src/shaders/project_visible.wgsl index 4d9b0e1a..4c694536 100644 --- a/crates/brush-render/src/shaders/project_visible.wgsl +++ b/crates/brush-render/src/shaders/project_visible.wgsl @@ -5,8 +5,7 @@ struct IsectInfo { tile_id: u32, } -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; - +@group(0) @binding(0) var num_visible: u32; @group(0) @binding(1) var means: array; @group(0) @binding(2) var log_scales: array; @group(0) @binding(3) var quats: array; @@ -14,10 +13,9 @@ struct IsectInfo { @group(0) @binding(5) var raw_opacities: array; @group(0) @binding(6) var global_from_compact_gid: array; @group(0) @binding(7) var projected: array; - -#ifdef COUNT_INTERSECTIONS - @group(0) @binding(8) var splat_intersect_counts: array; -#endif +@group(0) @binding(8) var splat_intersect_counts: array; +// Uniforms via with_metadata (always last binding) +@group(0) @binding(9) var uniforms: helpers::ProjectUniforms; struct ShCoeffs { b0_c0: vec3f, @@ -175,7 +173,7 @@ fn main( ) { let compact_gid = helpers::get_global_id(wid, num_wgs, lid, WG_SIZE); - if compact_gid >= uniforms.num_visible { + if compact_gid >= num_visible { return; } @@ -258,7 +256,6 @@ fn main( vec4f(color, opac) ); -#ifdef COUNT_INTERSECTIONS // Count intersections for this splat (merged from map_gaussian_to_intersects prepass) let power_threshold = log(opac * 255.0); let extent = helpers::compute_bbox_extent(cov2d, power_threshold); @@ -280,6 +277,5 @@ fn main( } } - splat_intersect_counts[compact_gid + 1u] = num_tiles_hit; -#endif + splat_intersect_counts[compact_gid] = num_tiles_hit; } diff --git a/crates/brush-render/src/shaders/rasterize.wgsl b/crates/brush-render/src/shaders/rasterize.wgsl index 29c3d437..b6bc9a41 100644 --- a/crates/brush-render/src/shaders/rasterize.wgsl +++ b/crates/brush-render/src/shaders/rasterize.wgsl @@ -1,16 +1,19 @@ #import helpers -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; -@group(0) @binding(1) var compact_gid_from_isect: array; -@group(0) @binding(2) var tile_offsets: array; -@group(0) @binding(3) var projected: array; +@group(0) @binding(0) var compact_gid_from_isect: array; +@group(0) @binding(1) var tile_offsets: array; +@group(0) @binding(2) var projected: array; #ifdef BWD_INFO - @group(0) @binding(4) var out_img: array; - @group(0) @binding(5) var global_from_compact_gid: array; - @group(0) @binding(6) var visible: array; + @group(0) @binding(3) var out_img: array; + @group(0) @binding(4) var global_from_compact_gid: array; + @group(0) @binding(5) var visible: array; + // Uniforms via with_metadata (always last binding) + @group(0) @binding(6) var uniforms: helpers::RasterizeUniforms; #else - @group(0) @binding(4) var out_img: array; + @group(0) @binding(3) var out_img: array; + // Uniforms via with_metadata (always last binding) + @group(0) @binding(4) var uniforms: helpers::RasterizeUniforms; #endif var range_uniform: vec2u; diff --git a/crates/brush-render/src/tests/mod.rs b/crates/brush-render/src/tests/mod.rs index 8fc579a9..48241c0a 100644 --- a/crates/brush-render/src/tests/mod.rs +++ b/crates/brush-render/src/tests/mod.rs @@ -1,9 +1,57 @@ -use crate::{MainBackend, SplatForward, camera::Camera, gaussian_splats::SplatRenderMode}; +use crate::{ + MainBackend, SplatOps, + camera::Camera, + gaussian_splats::SplatRenderMode, + render_aux::{ProjectAux, RasterizeAux, validate_render_output}, +}; use assert_approx_eq::assert_approx_eq; use burn::tensor::{Distribution, Tensor, TensorPrimitive}; use burn_wgpu::WgpuDevice; use glam::Vec3; +/// Helper to run project + readback + rasterize for tests. +fn render_splats_test( + cam: &Camera, + img_size: glam::UVec2, + means: Tensor, + log_scales: Tensor, + quats: Tensor, + sh_coeffs: Tensor, + raw_opacity: Tensor, + render_mode: SplatRenderMode, + background: Vec3, + bwd_info: bool, +) -> ( + Tensor, + ProjectAux, + RasterizeAux, +) { + // First pass: project + let project_aux = MainBackend::project( + cam, + img_size, + means.into_primitive().tensor(), + log_scales.into_primitive().tensor(), + quats.into_primitive().tensor(), + sh_coeffs.into_primitive().tensor(), + raw_opacity.into_primitive().tensor(), + render_mode, + ); + + // Sync readback of num_intersections + let num_intersections = project_aux.num_intersections(); + + // Second pass: rasterize + let (out_img, rasterize_aux) = + MainBackend::rasterize(&project_aux, num_intersections, background, bwd_info); + + let img = Tensor::from_primitive(TensorPrimitive::Float(out_img)); + + validate_render_output(&project_aux, &rasterize_aux); + + (img, project_aux, rasterize_aux) +} + #[test] fn renders_at_all() { // Check if rendering doesn't hard crash or anything. @@ -27,21 +75,19 @@ fn renders_at_all() { .repeat_dim(0, num_points); let sh_coeffs = Tensor::::ones([num_points, 1, 3], &device); let raw_opacity = Tensor::::zeros([num_points], &device); - let (output, aux) = >::render_splats( + let (output, _project_aux, _rasterize_aux) = render_splats_test( &cam, img_size, - means.into_primitive().tensor(), - log_scales.into_primitive().tensor(), - quats.into_primitive().tensor(), - sh_coeffs.into_primitive().tensor(), - raw_opacity.into_primitive().tensor(), + means, + log_scales, + quats, + sh_coeffs, + raw_opacity, SplatRenderMode::Default, Vec3::ZERO, true, ); - aux.validate_values(); - let output: Tensor = Tensor::from_primitive(TensorPrimitive::Float(output)); let rgb = output.clone().slice([0..32, 0..32, 0..3]); let alpha = output.slice([0..32, 0..32, 3..4]); let rgb_mean = rgb.mean().to_data().as_slice::().expect("Wrong type")[0]; @@ -97,17 +143,16 @@ fn renders_many_splats() { let raw_opacity = Tensor::::random([num_splats], Distribution::Uniform(-2.0, 2.0), &device); - let (_output, aux) = >::render_splats( + let (_output, _project_aux, _rasterize_aux) = render_splats_test( &cam, img_size, - means.into_primitive().tensor(), - log_scales.into_primitive().tensor(), - quats.into_primitive().tensor(), - sh_coeffs.into_primitive().tensor(), - raw_opacity.into_primitive().tensor(), + means, + log_scales, + quats, + sh_coeffs, + raw_opacity, SplatRenderMode::Default, Vec3::ZERO, true, ); - aux.validate_values(); } diff --git a/crates/brush-sort/src/lib.rs b/crates/brush-sort/src/lib.rs index 4ef0e57c..ec14d24b 100644 --- a/crates/brush-sort/src/lib.rs +++ b/crates/brush-sort/src/lib.rs @@ -1,4 +1,5 @@ use brush_kernel::CubeCount; +use brush_kernel::calc_cube_count_1d; use brush_kernel::create_dispatch_buffer_1d; use brush_kernel::create_tensor; use brush_kernel::create_uniform_buffer; @@ -33,17 +34,21 @@ use sort_count::Uniforms; const BLOCK_SIZE: u32 = SortCount::WG * SortCount::ELEMENTS_PER_THREAD; +/// Perform a radix argsort on the input keys and values. +/// +/// If `dynamic_count` is `Some(count_buffer)`, use that buffer as the actual number +/// of keys to sort (uses dynamic GPU dispatch). If `None`, use the full buffer length +/// with static CPU dispatch. pub fn radix_argsort( input_keys: CubeTensor, input_values: CubeTensor, - n_sort: &CubeTensor, sorting_bits: u32, + dynamic_count: Option>, ) -> (CubeTensor, CubeTensor) { assert_eq!( input_keys.shape.dims[0], input_values.shape.dims[0], "Input keys and values must have the same number of elements" ); - assert_eq!(n_sort.shape.dims[0], 1, "Sort count must have one element"); assert!(sorting_bits <= 32, "Can only sort up to 32 bits"); assert!( input_keys.is_contiguous(), @@ -64,11 +69,47 @@ pub fn radix_argsort( let max_needed_wgs = max_n.div_ceil(BLOCK_SIZE); - let num_wgs = create_dispatch_buffer_1d(n_sort.clone(), BLOCK_SIZE); - let num_reduce_wgs: Tensor, 1, Int> = - Tensor::from_primitive(create_dispatch_buffer_1d(num_wgs.clone(), BLOCK_SIZE)) - * Tensor::from_ints([SortCount::BIN_COUNT, 1, 1], device); - let num_reduce_wgs: CubeTensor = num_reduce_wgs.into_primitive(); + // Handle dynamic vs static dispatch + let (num_keys_buf, num_wgs, num_reduce_wgs) = if let Some(count_buf) = dynamic_count { + // Dynamic dispatch: compute workgroup counts on GPU + // num_wgs = ceil(count / BLOCK_SIZE) + let num_wgs = create_dispatch_buffer_1d(count_buf.clone(), BLOCK_SIZE); + + // The reduce shader expects: num_reduce_wgs = BIN_COUNT * ceil(num_wgs / BLOCK_SIZE) + // This is NOT the same as ceil(num_wgs * BIN_COUNT / BLOCK_SIZE) due to ceiling! + // We need: first compute ceil(num_wgs_x / BLOCK_SIZE), then multiply by BIN_COUNT. + type Backend = CubeBackend; + let num_wgs_tensor: Tensor = Tensor::from_primitive(num_wgs.clone()); + let num_wgs_x = num_wgs_tensor.slice([0..1]); // Get just the X component (scalar) + + // num_reduce_wg_per_bin = ceil(num_wgs_x / BLOCK_SIZE) + let num_reduce_wg_per_bin_buf = + create_dispatch_buffer_1d(num_wgs_x.into_primitive(), BLOCK_SIZE); + let num_reduce_wg_per_bin: Tensor = + Tensor::from_primitive(num_reduce_wg_per_bin_buf); + let num_reduce_wg_per_bin_x = num_reduce_wg_per_bin.slice([0..1]); + + // num_reduce_wgs_total = num_reduce_wg_per_bin * BIN_COUNT + let num_reduce_total: Tensor = + num_reduce_wg_per_bin_x * (SortCount::BIN_COUNT as i32); + + // Create dispatch buffer for the total (uses 2D tiling if > 65535) + let num_reduce_wgs = create_dispatch_buffer_1d(num_reduce_total.into_primitive(), 1); + + (count_buf, CubeCount::Dynamic(num_wgs.handle.binding()), CubeCount::Dynamic(num_reduce_wgs.handle.binding())) + } else { + // Static dispatch: use full buffer size + let num_keys_buf = { + type Backend = CubeBackend; + Tensor::::from_ints([max_n as i32], device).into_primitive() + }; + // Calculate dispatch counts matching the original formula + let num_wgs_count = max_n.div_ceil(BLOCK_SIZE); + let num_reduce_wgs_count = num_wgs_count.div_ceil(BLOCK_SIZE) * SortCount::BIN_COUNT; + let num_wgs = calc_cube_count_1d(max_n, BLOCK_SIZE); + let num_reduce_wgs = calc_cube_count_1d(num_reduce_wgs_count, 1); + (num_keys_buf, num_wgs, num_reduce_wgs) + }; let mut cur_keys = input_keys; let mut cur_vals = input_values; @@ -79,14 +120,14 @@ pub fn radix_argsort( let count_buf = create_tensor([(max_needed_wgs as usize) * 16], device, DType::I32); - // use safe distpatch as dynamic work count isn't verified. + // use safe dispatch as dynamic work count isn't verified. client .launch( SortCount::task(), - CubeCount::Dynamic(num_wgs.clone().handle.binding()), + num_wgs.clone(), Bindings::new().with_buffers(vec![ uniforms_buffer.handle.clone().binding(), - n_sort.handle.clone().binding(), + num_keys_buf.handle.clone().binding(), cur_keys.handle.clone().binding(), count_buf.handle.clone().binding(), ]), @@ -99,9 +140,9 @@ pub fn radix_argsort( client .launch( SortReduce::task(), - CubeCount::Dynamic(num_reduce_wgs.handle.clone().binding()), + num_reduce_wgs.clone(), Bindings::new().with_buffers(vec![ - n_sort.handle.clone().binding(), + num_keys_buf.handle.clone().binding(), count_buf.handle.clone().binding(), reduced_buf.handle.clone().binding(), ]), @@ -115,7 +156,7 @@ pub fn radix_argsort( SortScan::task(), CubeCount::Static(1, 1, 1), Bindings::new().with_buffers(vec![ - n_sort.handle.clone().binding(), + num_keys_buf.handle.clone().binding(), reduced_buf.handle.clone().binding(), ]), ) @@ -125,9 +166,9 @@ pub fn radix_argsort( client .launch( SortScanAdd::task(), - CubeCount::Dynamic(num_reduce_wgs.handle.clone().binding()), + num_reduce_wgs.clone(), Bindings::new().with_buffers(vec![ - n_sort.handle.clone().binding(), + num_keys_buf.handle.clone().binding(), reduced_buf.handle.clone().binding(), count_buf.handle.clone().binding(), ]), @@ -141,10 +182,10 @@ pub fn radix_argsort( client .launch( SortScatter::task(), - CubeCount::Dynamic(num_wgs.handle.clone().binding()), + num_wgs.clone(), Bindings::new().with_buffers(vec![ uniforms_buffer.handle.clone().binding(), - n_sort.handle.clone().binding(), + num_keys_buf.handle.clone().binding(), cur_keys.handle.clone().binding(), cur_vals.handle.clone().binding(), count_buf.handle.clone().binding(), @@ -200,12 +241,9 @@ mod tests { let device = Default::default(); let keys = Tensor::::from_ints(keys_inp, &device).into_primitive(); - let values = Tensor::::from_ints(values_inp.as_slice(), &device) .into_primitive(); - let num_points = Tensor::::from_ints([keys_inp.len() as i32], &device) - .into_primitive(); - let (ret_keys, ret_values) = radix_argsort(keys, values, &num_points, 32); + let (ret_keys, ret_values) = radix_argsort(keys, values, 32, None); let ret_keys = Tensor::::from_primitive(ret_keys).into_data(); @@ -253,10 +291,7 @@ mod tests { Tensor::::from_ints(keys_inp.as_slice(), &device).into_primitive(); let values = Tensor::::from_ints(values_inp.as_slice(), &device).into_primitive(); - let num_points = - Tensor::::from_ints([keys_inp.len() as i32], &device).into_primitive(); - - let (ret_keys, ret_values) = radix_argsort(keys, values, &num_points, 32); + let (ret_keys, ret_values) = radix_argsort(keys, values, 32, None); let ret_keys = Tensor::::from_primitive(ret_keys).to_data(); let ret_values = Tensor::::from_primitive(ret_values).to_data(); @@ -296,10 +331,7 @@ mod tests { Tensor::::from_ints(keys_inp.as_slice(), &device).into_primitive(); let values = Tensor::::from_ints(values_inp.as_slice(), &device).into_primitive(); - let num_points = - Tensor::::from_ints([NUM_ELEMENTS as i32], &device).into_primitive(); - - let (ret_keys, ret_values) = radix_argsort(keys, values, &num_points, 32); + let (ret_keys, ret_values) = radix_argsort(keys, values, 32, None); let ret_keys = Tensor::::from_primitive(ret_keys).to_data(); let ret_values = Tensor::::from_primitive(ret_values).to_data(); diff --git a/crates/brush-train/src/eval.rs b/crates/brush-train/src/eval.rs index c3e6c97f..52ce4b29 100644 --- a/crates/brush-train/src/eval.rs +++ b/crates/brush-train/src/eval.rs @@ -5,8 +5,11 @@ use anyhow::Result; use brush_dataset::scene::{sample_to_tensor_data, view_to_sample_image}; use brush_render::camera::Camera; use brush_render::gaussian_splats::Splats; -use brush_render::render_aux::RenderAux; -use brush_render::{AlphaMode, SplatForward}; +use brush_render::render_aux::{ProjectAux, RasterizeAux}; +use brush_render::{AlphaMode, SplatOps}; + +#[cfg(target_family = "wasm")] +use brush_render::render::calc_tile_bounds; use burn::prelude::Backend; use burn::tensor::{Tensor, TensorPrimitive, s}; use glam::Vec3; @@ -19,10 +22,11 @@ pub struct EvalSample { pub rendered: Tensor, pub psnr: Tensor, pub ssim: Tensor, - pub aux: RenderAux, + pub project_aux: ProjectAux, + pub rasterize_aux: RasterizeAux, } -pub fn eval_stats>( +pub fn eval_stats>( splats: &Splats, gt_cam: &Camera, gt_img: DynamicImage, @@ -36,9 +40,10 @@ pub fn eval_stats>( let gt_tensor = Tensor::from_data(gt_tensor, device); let gt_rgb = gt_tensor.slice(s![.., .., 0..3]); - // Render on reference black background. - let (img, aux) = { - let (img, aux) = B::render_splats( + // Render on reference black background using split pipeline. + let (img, project_aux, rasterize_aux) = { + // First pass: project + let project_aux = B::project( gt_cam, res, splats.means.val().into_primitive().tensor(), @@ -47,10 +52,26 @@ pub fn eval_stats>( splats.sh_coeffs.val().into_primitive().tensor(), splats.raw_opacities.val().into_primitive().tensor(), splats.render_mode, - Vec3::ZERO, - true, ); - (Tensor::from_primitive(TensorPrimitive::Float(img)), aux) + + // Sync readback of num_intersections + #[cfg(not(target_family = "wasm"))] + let num_intersections = project_aux.num_intersections(); + + #[cfg(target_family = "wasm")] + let num_intersections = { + use burn::tensor::ops::FloatTensorOps; + let tile_bounds = calc_tile_bounds(res); + let num_tiles = tile_bounds[0] * tile_bounds[1]; + let total_splats = splats.num_splats(); + let max_possible = num_tiles.saturating_mul(total_splats); + max_possible.min(2 * 512 * 65535) + }; + + // Second pass: rasterize (with bwd_info = true for eval) + let (out_img, rasterize_aux) = B::rasterize(&project_aux, num_intersections, Vec3::ZERO, true); + + (Tensor::from_primitive(TensorPrimitive::Float(out_img)), project_aux, rasterize_aux) }; let render_rgb = img.slice(s![.., .., 0..3]); @@ -68,7 +89,8 @@ pub fn eval_stats>( psnr, ssim, rendered: render_rgb, - aux, + project_aux, + rasterize_aux, }) } diff --git a/crates/brush-train/src/train.rs b/crates/brush-train/src/train.rs index d43d940f..eadd12cd 100644 --- a/crates/brush-train/src/train.rs +++ b/crates/brush-train/src/train.rs @@ -116,7 +116,7 @@ impl SplatTrainer { let has_alpha = batch.has_alpha(); let gt_tensor = Tensor::from_data(batch.img_tensor, &device); - let (pred_image, aux, refine_weight_holder) = trace_span!("Forward").in_scope(|| { + let (pred_image, project_aux, rasterize_aux, refine_weight_holder) = trace_span!("Forward").in_scope(|| { // Could generate a random background color, but so far // results just seem worse. let background = Vec3::ZERO; @@ -130,16 +130,16 @@ impl SplatTrainer { let img = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); - (img, diff_out.aux, diff_out.refine_weight_holder) + (img, diff_out.project_aux, diff_out.rasterize_aux, diff_out.refine_weight_holder) }); let median_scale = self.bounds.median_size(); - let num_visible = aux.num_visible().inner(); + let num_visible = project_aux.get_num_visible().inner(); let pred_rgb = pred_image.clone().slice(s![.., .., 0..3]); let gt_rgb = gt_tensor.clone().slice(s![.., .., 0..3]); let visible: Tensor, 1> = - Tensor::from_primitive(TensorPrimitive::Float(aux.visible)); + Tensor::from_primitive(TensorPrimitive::Float(rasterize_aux.visible)); let loss = trace_span!("Calculate losses").in_scope(|| { let l1_rgb = (pred_rgb.clone() - gt_rgb.clone()).abs(); diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index d207ef5e..4b095c07 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -334,7 +334,7 @@ impl ScenePanel { if pixel_size.x > 8 && pixel_size.y > 8 && dirty { let _span = trace_span!("Render splats").entered(); // Could add an option for background color. - let (img, _) = render_splats( + let (img, _, _) = render_splats( &splats, &camera, pixel_size, From fe3c8affce43c24bcceee3e33f3c857cca1dd493 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 19 Jan 2026 21:44:25 +0100 Subject: [PATCH 04/29] Pt. 2 --- crates/brush-bench-test/src/reference.rs | 25 +-- crates/brush-render-bwd/src/burn_glue.rs | 60 +++----- crates/brush-render-bwd/src/tests.rs | 2 +- crates/brush-render/src/burn_glue.rs | 51 +++---- crates/brush-render/src/gaussian_splats.rs | 23 +-- crates/brush-render/src/lib.rs | 24 ++- crates/brush-render/src/render.rs | 46 +++--- crates/brush-render/src/render_aux.rs | 169 +++++++-------------- crates/brush-render/src/tests/mod.rs | 33 ++-- crates/brush-train/src/eval.rs | 21 ++- crates/brush-train/src/train.rs | 8 +- crates/brush-ui/src/scene.rs | 4 +- examples/train-2d/examples/train-2d.rs | 2 +- 13 files changed, 188 insertions(+), 280 deletions(-) diff --git a/crates/brush-bench-test/src/reference.rs b/crates/brush-bench-test/src/reference.rs index 6b662075..11b00c63 100644 --- a/crates/brush-bench-test/src/reference.rs +++ b/crates/brush-bench-test/src/reference.rs @@ -10,7 +10,7 @@ use burn::{ Tensor, backend::{Autodiff, wgpu::WgpuDevice}, prelude::Backend, - tensor::{Float, Int, TensorPrimitive}, + tensor::TensorPrimitive, }; use anyhow::{Context, Result}; @@ -134,8 +134,7 @@ async fn test_reference() -> Result<()> { ); let out: Tensor = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); - let project_aux = diff_out.project_aux; - let rasterize_aux = diff_out.rasterize_aux; + let render_aux = diff_out.render_aux; if let Some(rec) = rec.as_ref() { rec.set_time_sequence("test case", i as i64); @@ -147,28 +146,10 @@ async fn test_reference() -> Result<()> { )?; rec.log( "images/tile_depth", - &rasterize_aux.calc_tile_depth().into_rerun().await, + &render_aux.calc_tile_depth().into_rerun().await, )?; } - let num_visible: Tensor = project_aux.get_num_visible(); - let num_visible = num_visible.into_scalar_async().await.unwrap() as usize; - let global_from_compact_gid: Tensor = - Tensor::from_primitive(project_aux.global_from_compact_gid.clone()); - let gs_ids = global_from_compact_gid.clone().slice([0..num_visible]); - let projected_splats = - Tensor::from_primitive(TensorPrimitive::Float(project_aux.projected_splats.clone())); - let xys: Tensor = - projected_splats.clone().slice([0..num_visible, 0..2]); - let xys_ref = safetensor_to_burn::(&tensors.tensor("xys")?, &device); - let xys_ref = xys_ref.select(0, gs_ids.clone()); - compare("xy", xys, xys_ref, 1e-5, 2e-5); - let conics: Tensor = - projected_splats.clone().slice([0..num_visible, 2..5]); - let conics_ref = safetensor_to_burn::(&tensors.tensor("conics")?, &device); - let conics_ref = conics_ref.select(0, gs_ids.clone()); - compare("conics", conics, conics_ref, 1e-6, 2e-5); - // Check if images match. compare("img", out.clone(), img_ref, 1e-5, 1e-5); diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index 52d969b9..b50ab05c 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -1,8 +1,7 @@ use brush_render::{ - MainBackendBase, SplatOps, + MainBackendBase, RenderAux, SplatOps, camera::Camera, gaussian_splats::{SplatRenderMode, Splats}, - render_aux::{ProjectAux, RasterizeAux, validate_render_output}, sh::{sh_coeffs_for_degree, sh_degree_from_coeffs}, shaders::helpers::ProjectUniforms, }; @@ -197,8 +196,7 @@ impl> Backward for RenderBackw pub struct SplatOutputDiff { pub img: FloatTensor, - pub project_aux: ProjectAux, - pub rasterize_aux: RasterizeAux, + pub render_aux: RenderAux, pub refine_weight_holder: Tensor, } @@ -237,7 +235,7 @@ impl + SplatOps, C: CheckpointStrategy> Spla .stateful(); // First pass: project - let project_aux = >::project( + let project_output = >::project( camera, img_size, means.clone().into_primitive(), @@ -249,29 +247,19 @@ impl + SplatOps, C: CheckpointStrategy> Spla ); // Sync readback of num_intersections - let num_intersections = project_aux.num_intersections(); + let num_intersections = project_output.num_intersections(); // Second pass: rasterize (with bwd_info = true) - let (out_img, rasterize_aux) = - >::rasterize(&project_aux, num_intersections, background, true); - - // Create wrapped aux structs for Autodiff backend - let wrapped_project_aux = ProjectAux:: { - project_uniforms: project_aux.project_uniforms, - projected_splats: ::from_inner( - project_aux.projected_splats.clone(), - ), - num_visible: project_aux.num_visible.clone(), - global_from_compact_gid: project_aux.global_from_compact_gid.clone(), - cum_tiles_hit: project_aux.cum_tiles_hit.clone(), - img_size: project_aux.img_size, - }; - - let wrapped_rasterize_aux = RasterizeAux:: { - tile_offsets: rasterize_aux.tile_offsets.clone(), - compact_gid_from_isect: rasterize_aux.compact_gid_from_isect.clone(), - visible: ::from_inner(rasterize_aux.visible.clone()), - img_size: rasterize_aux.img_size, + let (out_img, render_aux, compact_gid_from_isect) = + >::rasterize(&project_output, num_intersections, background, true); + + // Create wrapped render_aux for Autodiff backend + let wrapped_render_aux = RenderAux:: { + num_visible: render_aux.num_visible.clone(), + num_intersections: render_aux.num_intersections, + visible: ::from_inner(render_aux.visible.clone()), + tile_offsets: render_aux.tile_offsets.clone(), + img_size: render_aux.img_size, }; match prep_nodes { @@ -287,13 +275,13 @@ impl + SplatOps, C: CheckpointStrategy> Spla [1] as u32, ), out_img: out_img.clone(), - projected_splats: project_aux.projected_splats, - project_uniforms: project_aux.project_uniforms, - num_visible: project_aux.num_visible, - tile_offsets: rasterize_aux.tile_offsets, - compact_gid_from_isect: rasterize_aux.compact_gid_from_isect, + projected_splats: project_output.projected_splats, + project_uniforms: project_output.project_uniforms, + num_visible: project_output.num_visible, + tile_offsets: render_aux.tile_offsets, + compact_gid_from_isect, render_mode, - global_from_compact_gid: project_aux.global_from_compact_gid, + global_from_compact_gid: project_output.global_from_compact_gid, background, }; @@ -301,8 +289,7 @@ impl + SplatOps, C: CheckpointStrategy> Spla SplatOutputDiff { img: out_img, - project_aux: wrapped_project_aux, - rasterize_aux: wrapped_rasterize_aux, + render_aux: wrapped_render_aux, refine_weight_holder, } } @@ -311,8 +298,7 @@ impl + SplatOps, C: CheckpointStrategy> Spla // keeping any state. SplatOutputDiff { img: prep.finish(out_img), - project_aux: wrapped_project_aux, - rasterize_aux: wrapped_rasterize_aux, + render_aux: wrapped_render_aux, refine_weight_holder, } } @@ -511,6 +497,6 @@ where splats.render_mode, background, ); - validate_render_output(&result.project_aux, &result.rasterize_aux); + result.render_aux.validate(); result } diff --git a/crates/brush-render-bwd/src/tests.rs b/crates/brush-render-bwd/src/tests.rs index ebdb58b3..46cb7818 100644 --- a/crates/brush-render-bwd/src/tests.rs +++ b/crates/brush-render-bwd/src/tests.rs @@ -103,7 +103,7 @@ fn diffs_many_splats() { let raw_opacity = Tensor::::random([num_points], Distribution::Uniform(-2.0, 2.0), &device); - let result = >::render_splats( + >::render_splats( &cam, img_size, means.into_primitive().tensor(), diff --git a/crates/brush-render/src/burn_glue.rs b/crates/brush-render/src/burn_glue.rs index da02c661..2a79ac8e 100644 --- a/crates/brush-render/src/burn_glue.rs +++ b/crates/brush-render/src/burn_glue.rs @@ -1,4 +1,4 @@ -use burn::tensor::{DType, Shape, ops::FloatTensor}; +use burn::tensor::{DType, Shape, ops::{FloatTensor, IntTensor}}; use burn_cubecl::{BoolElement, fusion::FusionCubeRuntime}; use burn_fusion::{ Fusion, FusionHandle, @@ -9,11 +9,11 @@ use burn_wgpu::WgpuRuntime; use glam::Vec3; use crate::{ - MainBackendBase, SplatOps, + MainBackendBase, RenderAux, SplatOps, camera::Camera, gaussian_splats::SplatRenderMode, render::calc_tile_bounds, - render_aux::{ProjectAux, RasterizeAux}, + render_aux::ProjectOutput, sh::sh_degree_from_coeffs, shaders::{self, helpers::ProjectUniforms}, }; @@ -28,7 +28,7 @@ impl SplatOps for Fusion { sh_coeffs: FloatTensor, opacity: FloatTensor, render_mode: SplatRenderMode, - ) -> ProjectAux { + ) -> ProjectOutput { #[derive(Debug)] struct CustomOp { cam: Camera, @@ -146,7 +146,7 @@ impl SplatOps for Fusion { cum_tiles_hit, ] = outputs; - ProjectAux:: { + ProjectOutput:: { projected_splats, project_uniforms, num_visible, @@ -157,11 +157,11 @@ impl SplatOps for Fusion { } fn rasterize( - project_aux: &ProjectAux, + project_output: &ProjectOutput, num_intersections: u32, background: Vec3, bwd_info: bool, - ) -> (FloatTensor, RasterizeAux) { + ) -> (FloatTensor, RenderAux, IntTensor) { #[derive(Debug)] struct CustomOp { img_size: glam::UVec2, @@ -187,7 +187,7 @@ impl SplatOps for Fusion { ] = inputs; let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs; - let inner_aux = ProjectAux:: { + let inner_output = ProjectOutput:: { projected_splats: h.get_float_tensor::(projected_splats), project_uniforms: self.project_uniforms, num_visible: h.get_int_tensor::(num_visible), @@ -197,8 +197,8 @@ impl SplatOps for Fusion { img_size: self.img_size, }; - let (img, aux) = MainBackendBase::rasterize( - &inner_aux, + let (img, aux, compact_gid) = MainBackendBase::rasterize( + &inner_output, self.num_intersections, self.background, self.bwd_info, @@ -207,19 +207,16 @@ impl SplatOps for Fusion { // Register outputs h.register_float_tensor::(&out_img.id, img); h.register_int_tensor::(&tile_offsets.id, aux.tile_offsets); - h.register_int_tensor::( - &compact_gid_from_isect.id, - aux.compact_gid_from_isect, - ); + h.register_int_tensor::(&compact_gid_from_isect.id, compact_gid); h.register_float_tensor::(&visible.id, aux.visible); } } - let client = project_aux.projected_splats.client.clone(); - let img_size = project_aux.img_size; + let client = project_output.projected_splats.client.clone(); + let img_size = project_output.img_size; let tile_bounds = calc_tile_bounds(img_size); - let num_points = project_aux.projected_splats.shape[0]; + let num_points = project_output.projected_splats.shape[0]; let channels = if bwd_info { 4 } else { 1 }; let out_img = TensorIr::uninit( @@ -237,7 +234,7 @@ impl SplatOps for Fusion { // Use actual num_intersections for buffer size let compact_gid_from_isect = TensorIr::uninit( client.create_empty_handle(), - Shape::new([num_intersections as usize]), + Shape::new([num_intersections.max(1) as usize]), DType::U32, ); @@ -249,10 +246,10 @@ impl SplatOps for Fusion { let visible = TensorIr::uninit(client.create_empty_handle(), visible_shape, DType::F32); let input_tensors = [ - project_aux.projected_splats.clone(), - project_aux.num_visible.clone(), - project_aux.global_from_compact_gid.clone(), - project_aux.cum_tiles_hit.clone(), + project_output.projected_splats.clone(), + project_output.num_visible.clone(), + project_output.global_from_compact_gid.clone(), + project_output.cum_tiles_hit.clone(), ]; let stream = OperationStreams::with_inputs(&input_tensors); let desc = CustomOpIr::new( @@ -265,7 +262,7 @@ impl SplatOps for Fusion { num_intersections, background, bwd_info, - project_uniforms: project_aux.project_uniforms, + project_uniforms: project_output.project_uniforms, desc: desc.clone(), }; @@ -277,12 +274,14 @@ impl SplatOps for Fusion { ( out_img, - RasterizeAux:: { - tile_offsets, - compact_gid_from_isect, + RenderAux:: { + num_visible: project_output.num_visible.clone(), + num_intersections, visible, + tile_offsets, img_size, }, + compact_gid_from_isect, ) } } diff --git a/crates/brush-render/src/gaussian_splats.rs b/crates/brush-render/src/gaussian_splats.rs index 2a78592b..0fb22b86 100644 --- a/crates/brush-render/src/gaussian_splats.rs +++ b/crates/brush-render/src/gaussian_splats.rs @@ -9,9 +9,8 @@ use glam::Vec3; use tracing::trace_span; use crate::{ - SplatOps, + RenderAux, SplatOps, camera::Camera, - render_aux::{ProjectAux, RasterizeAux, validate_render_output}, sh::{sh_coeffs_for_degree, sh_degree_from_coeffs}, }; @@ -243,7 +242,7 @@ pub fn render_splats>( img_size: glam::UVec2, background: Vec3, splat_scale: Option, -) -> (Tensor, ProjectAux, RasterizeAux) { +) -> (Tensor, RenderAux) { splats.validate_values(); let mut scales = splats.log_scales.val(); @@ -254,7 +253,7 @@ pub fn render_splats>( }; // First pass: project - let project_aux = B::project( + let project_output = B::project( camera, img_size, splats.means.val().into_primitive().tensor(), @@ -265,14 +264,18 @@ pub fn render_splats>( splats.render_mode, ); - let num_intersections = project_aux.num_intersections(); + // Validate before readback + project_output.validate(); - // Second pass: rasterize - let (out_img, rasterize_aux) = B::rasterize(&project_aux, num_intersections, background, false); + let num_intersections = project_output.num_intersections(); - let img = Tensor::from_primitive(TensorPrimitive::Float(out_img)); + // Second pass: rasterize (drop compact_gid_from_isect - only needed for backward) + let (out_img, render_aux, _) = B::rasterize(&project_output, num_intersections, background, false); + + // Validate rasterize outputs + render_aux.validate(); - validate_render_output(&project_aux, &rasterize_aux); + let img = Tensor::from_primitive(TensorPrimitive::Float(out_img)); - (img, project_aux, rasterize_aux) + (img, render_aux) } diff --git a/crates/brush-render/src/lib.rs b/crates/brush-render/src/lib.rs index 9f196faa..dd3bdb94 100644 --- a/crates/brush-render/src/lib.rs +++ b/crates/brush-render/src/lib.rs @@ -1,17 +1,18 @@ #![recursion_limit = "256"] use burn::prelude::Backend; -use burn::tensor::ops::FloatTensor; +use burn::tensor::ops::{FloatTensor, IntTensor}; use burn_cubecl::CubeBackend; use burn_fusion::Fusion; use burn_wgpu::WgpuRuntime; use camera::Camera; use clap::ValueEnum; use glam::Vec3; -use render_aux::{ProjectAux, RasterizeAux}; +use render_aux::ProjectOutput; use crate::gaussian_splats::SplatRenderMode; pub use crate::gaussian_splats::render_splats; +pub use crate::render_aux::RenderAux; mod burn_glue; mod dim_check; @@ -33,11 +34,6 @@ pub mod validation; pub type MainBackendBase = CubeBackend; pub type MainBackend = Fusion; -#[derive(Debug, Clone)] -pub struct RenderStats { - pub num_visible: u32, -} - /// Trait for the split gaussian splatting rendering pipeline. /// /// This trait provides two passes: @@ -46,11 +42,10 @@ pub struct RenderStats { /// /// The split allows for an explicit GPU sync point between passes to read back /// the exact number of intersections needed for buffer allocation. +/// +/// Users should typically use [`render_splats`] instead of this trait directly. pub trait SplatOps { /// First pass: project gaussians and count intersections. - /// - /// Returns [`ProjectAux`] containing projected splat data and `num_intersections` - /// tensor for explicit readback. fn project( camera: &Camera, img_size: glam::UVec2, @@ -60,18 +55,21 @@ pub trait SplatOps { sh_coeffs: FloatTensor, raw_opacities: FloatTensor, render_mode: SplatRenderMode, - ) -> ProjectAux; + ) -> ProjectOutput; /// Second pass: rasterize using projection data. /// /// Takes the output of [`Self::project`] along with the actual /// `num_intersections` value from sync readback. + /// + /// Returns `(image, render_aux, compact_gid_from_isect)` where the last + /// value is only needed for backward pass and can be dropped for forward-only. fn rasterize( - project_aux: &ProjectAux, + project_output: &ProjectOutput, num_intersections: u32, background: Vec3, bwd_info: bool, - ) -> (FloatTensor, RasterizeAux); + ) -> (FloatTensor, RenderAux, IntTensor); } #[derive( diff --git a/crates/brush-render/src/render.rs b/crates/brush-render/src/render.rs index c85118c4..f78fa036 100644 --- a/crates/brush-render/src/render.rs +++ b/crates/brush-render/src/render.rs @@ -1,10 +1,10 @@ use crate::{ - MainBackendBase, SplatOps, + MainBackendBase, RenderAux, SplatOps, camera::Camera, dim_check::DimCheck, gaussian_splats::SplatRenderMode, get_tile_offset::{CHECKS_PER_ITER, get_tile_offsets}, - render_aux::{ProjectAux, RasterizeAux}, + render_aux::ProjectOutput, sh::sh_degree_from_coeffs, shaders::{self, MapGaussiansToIntersect, ProjectSplats, ProjectVisible, Rasterize}, }; @@ -18,7 +18,7 @@ use brush_sort::radix_argsort; use burn::tensor::{DType, IntDType, Shape, ops::FloatTensor}; use burn::tensor::{ FloatDType, - ops::{FloatTensorOps, IntTensorOps}, + ops::{FloatTensorOps, IntTensor, IntTensorOps}, }; use burn_cubecl::cubecl::server::Bindings; use burn_cubecl::kernel::into_contiguous; @@ -42,7 +42,7 @@ impl SplatOps for MainBackendBase { sh_coeffs: FloatTensor, raw_opacities: FloatTensor, render_mode: SplatRenderMode, - ) -> ProjectAux { + ) -> ProjectOutput { assert!( img_size[0] > 0 && img_size[1] > 0, "Can't render images with 0 size." @@ -170,7 +170,7 @@ impl SplatOps for MainBackendBase { let cum_tiles_hit = tracing::trace_span!("PrefixSumGaussHits") .in_scope(|| prefix_sum(splat_intersect_counts)); - ProjectAux { + ProjectOutput { projected_splats, project_uniforms, num_visible: num_visible_buffer, @@ -181,16 +181,16 @@ impl SplatOps for MainBackendBase { } fn rasterize( - project_aux: &ProjectAux, + project_output: &ProjectOutput, num_intersections: u32, background: Vec3, bwd_info: bool, - ) -> (FloatTensor, RasterizeAux) { + ) -> (FloatTensor, RenderAux, IntTensor) { let _span = tracing::trace_span!("rasterize").entered(); - let device = &project_aux.projected_splats.device.clone(); - let client = project_aux.projected_splats.client.clone(); - let img_size = project_aux.img_size; + let device = &project_output.projected_splats.device.clone(); + let client = project_output.projected_splats.client.clone(); + let img_size = project_output.img_size; // Divide screen into tiles. let tile_bounds = calc_tile_bounds(img_size); @@ -210,7 +210,7 @@ impl SplatOps for MainBackendBase { // Step 2: MapGaussiansToIntersect (fill pass) let num_vis_map_wg = create_dispatch_buffer_1d( - project_aux.num_visible.clone(), + project_output.num_visible.clone(), MapGaussiansToIntersect::WORKGROUP_SIZE[0], ); @@ -228,9 +228,9 @@ impl SplatOps for MainBackendBase { CubeCount::Dynamic(num_vis_map_wg.handle.clone().binding()), Bindings::new() .with_buffers(vec![ - project_aux.num_visible.handle.clone().binding(), - project_aux.projected_splats.handle.clone().binding(), - project_aux.cum_tiles_hit.handle.clone().binding(), + project_output.num_visible.handle.clone().binding(), + project_output.projected_splats.handle.clone().binding(), + project_output.cum_tiles_hit.handle.clone().binding(), tile_id_from_isect.handle.clone().binding(), compact_gid_from_isect.handle.clone().binding(), ]) @@ -289,7 +289,7 @@ impl SplatOps for MainBackendBase { ); // Get total_splats from the shape of projected_splats - let total_splats = project_aux.projected_splats.shape.dims[0]; + let total_splats = project_output.projected_splats.shape.dims[0]; let (bindings, visible) = if bwd_info { let visible = Self::float_zeros([total_splats].into(), device, FloatDType::F32); @@ -297,9 +297,9 @@ impl SplatOps for MainBackendBase { .with_buffers(vec![ compact_gid_from_isect.handle.clone().binding(), tile_offsets.handle.clone().binding(), - project_aux.projected_splats.handle.clone().binding(), + project_output.projected_splats.handle.clone().binding(), out_img.handle.clone().binding(), - project_aux.global_from_compact_gid.handle.clone().binding(), + project_output.global_from_compact_gid.handle.clone().binding(), visible.handle.clone().binding(), ]) .with_metadata(create_meta_binding(rasterize_uniforms)); @@ -309,7 +309,7 @@ impl SplatOps for MainBackendBase { .with_buffers(vec![ compact_gid_from_isect.handle.clone().binding(), tile_offsets.handle.clone().binding(), - project_aux.projected_splats.handle.clone().binding(), + project_output.projected_splats.handle.clone().binding(), out_img.handle.clone().binding(), ]) .with_metadata(create_meta_binding(rasterize_uniforms)); @@ -338,12 +338,14 @@ impl SplatOps for MainBackendBase { ( out_img, - RasterizeAux { - tile_offsets, - compact_gid_from_isect, + RenderAux { + num_visible: project_output.num_visible.clone(), + num_intersections, visible, - img_size: project_aux.img_size, + tile_offsets, + img_size: project_output.img_size, }, + compact_gid_from_isect, ) } } diff --git a/crates/brush-render/src/render_aux.rs b/crates/brush-render/src/render_aux.rs index 84c25f2c..a2c0bbb0 100644 --- a/crates/brush-render/src/render_aux.rs +++ b/crates/brush-render/src/render_aux.rs @@ -9,32 +9,10 @@ use burn::{ use crate::shaders::helpers::ProjectUniforms; -/// Validate both ProjectAux and RasterizeAux outputs together. -/// This is a no-op in release builds unless `debug-validation` feature is enabled. -/// Also skipped when running benchmarks (detected via `--bench` arg). -pub fn validate_render_output(project_aux: &ProjectAux, rasterize_aux: &RasterizeAux) { - #[cfg(any(test, feature = "debug-validation"))] - { - if std::env::args().any(|a| a == "--bench") { - return; - } - project_aux.validate_values(); - use burn::tensor::ElementConversion; - let num_visible = project_aux.get_num_visible().into_scalar().elem::(); - rasterize_aux.validate_values(num_visible); - } - #[cfg(not(any(test, feature = "debug-validation")))] - { - let _ = project_aux; - let _ = rasterize_aux; - } -} - +/// Output from the project pass. Consumed by rasterize. #[derive(Debug, Clone)] -pub struct ProjectAux { +pub struct ProjectOutput { pub project_uniforms: ProjectUniforms, - - /// The packed projected splat information, see `ProjectedSplat` in helpers.wgsl pub projected_splats: FloatTensor, pub num_visible: IntTensor, pub global_from_compact_gid: IntTensor, @@ -42,15 +20,12 @@ pub struct ProjectAux { pub img_size: glam::UVec2, } -impl ProjectAux { - /// Extract the total number of intersections. - /// - /// This requires a sync readback from the GPU. +impl ProjectOutput { + /// Extract the total number of intersections (sync readback). pub fn num_intersections(&self) -> u32 { use burn::tensor::ElementConversion; let cum_tiles_hit: Tensor = Tensor::from_primitive(self.cum_tiles_hit.clone()); let total = self.project_uniforms.total_splats as usize; - // The prefix sum is inclusive, so the last element is the total number of intersections if total > 0 { cum_tiles_hit .slice([total - 1..total]) @@ -61,31 +36,27 @@ impl ProjectAux { } } - pub fn get_num_visible(&self) -> Tensor { - Tensor::from_primitive(self.num_visible.clone()) - } - - pub fn validate_values(&self) { + /// Validate project outputs. Call before consuming. + pub fn validate(&self) { #[cfg(any(test, feature = "debug-validation"))] { - use burn::tensor::{ElementConversion, TensorPrimitive, s}; - - use crate::validation::validate_tensor_val; - if std::env::args().any(|a| a == "--bench") { return; } - let num_visible_tensor: Tensor = self.get_num_visible(); + use crate::validation::validate_tensor_val; + use burn::tensor::{ElementConversion, TensorPrimitive, s}; + + let num_visible_tensor: Tensor = + Tensor::from_primitive(self.num_visible.clone()); let total_splats = self.project_uniforms.total_splats; let num_visible = num_visible_tensor.into_scalar().elem::() as u32; assert!( num_visible <= total_splats, - "Something went wrong when calculating the number of visible gaussians. {num_visible} > {total_splats}" + "num_visible ({num_visible}) > total_splats ({total_splats})" ); - // Projected splats is only valid up to num_visible and undefined for other values. if num_visible > 0 { let projected_splats: Tensor = Tensor::from_primitive(TensorPrimitive::Float(self.projected_splats.clone())); @@ -93,8 +64,6 @@ impl ProjectAux { validate_tensor_val(&projected_splats, "projected_splats", None, None); } - // assert that every ID in global_from_compact_gid is valid. - // Only validate when there are visible splats if num_visible > 0 && total_splats > 0 { let global_from_compact_gid: Tensor = Tensor::from_primitive(self.global_from_compact_gid.clone()); @@ -107,7 +76,7 @@ impl ProjectAux { for &global_gid in global_from_compact_gid { assert!( global_gid < total_splats, - "Invalid gaussian ID in global_from_compact_gid buffer. {global_gid} out of {total_splats}" + "Invalid gaussian ID in global_from_compact_gid: {global_gid} >= {total_splats}" ); } } @@ -115,16 +84,28 @@ impl ProjectAux { } } -/// Output of the Rasterize pass. +/// Minimal output from rendering. Contains only what callers typically need. #[derive(Debug, Clone)] -pub struct RasterizeAux { - pub tile_offsets: IntTensor, - pub compact_gid_from_isect: IntTensor, +pub struct RenderAux { + /// Number of visible splats (for stats/logging) + pub num_visible: IntTensor, + /// Total number of tile-splat intersections (for stats/logging) + pub num_intersections: u32, + /// Visibility weights per splat (for training densification) pub visible: FloatTensor, + /// Tile offsets [ty, tx, 2] with (start, end) per tile (for visualization) + pub tile_offsets: IntTensor, + /// Image size pub img_size: glam::UVec2, } -impl RasterizeAux { +impl RenderAux { + /// Get `num_visible` as a tensor. + pub fn get_num_visible(&self) -> Tensor { + Tensor::from_primitive(self.num_visible.clone()) + } + + /// Calculate tile depth map for visualization. pub fn calc_tile_depth(&self) -> Tensor { use crate::shaders::helpers::TILE_WIDTH; use burn::tensor::s; @@ -137,83 +118,45 @@ impl RasterizeAux { (max - min).reshape([ty as usize, tx as usize]) } - pub fn validate_values(&self, #[allow(unused)] num_visible: u32) { + /// Validate rasterize outputs. + pub fn validate(&self) { #[cfg(any(test, feature = "debug-validation"))] { - use burn::tensor::TensorPrimitive; - - use crate::validation::validate_tensor_val; - if std::env::args().any(|a| a == "--bench") { return; } - let compact_gid_from_isect: Tensor = - Tensor::from_primitive(self.compact_gid_from_isect.clone()); + use crate::validation::validate_tensor_val; + use burn::tensor::{ElementConversion, TensorPrimitive}; - let num_intersections = compact_gid_from_isect.shape()[0] as u32; + let num_visible = Tensor::::from_primitive(self.num_visible.clone()) + .into_scalar() + .elem::(); let visible: Tensor = Tensor::from_primitive(TensorPrimitive::Float(self.visible.clone())); let visible_2d: Tensor = visible.unsqueeze_dim(1); validate_tensor_val(&visible_2d, "visible", None, None); - // Only validate tile_offsets when there are intersections to validate - if num_intersections > 0 { - let tile_offsets: Tensor = - Tensor::from_primitive(self.tile_offsets.clone()); - - let tile_offsets = tile_offsets - .into_data() - .into_vec::() - .expect("Failed to fetch tile offsets"); - for &offsets in &tile_offsets { - assert!( - offsets <= num_intersections, - "Tile offsets exceed bounds. Value: {offsets}, num_intersections: {num_intersections}" - ); - } - - for i in 0..(tile_offsets.len() - 1) / 2 { - // Check pairs of start/end points. - let start = tile_offsets[i * 2]; - let end = tile_offsets[i * 2 + 1]; - assert!( - start < num_intersections && end <= num_intersections, - "Invalid elements in tile offsets. Start {start} ending at {end}" - ); - assert!( - end >= start, - "Invalid elements in tile offsets. Start {start} ending at {end}" - ); - assert!( - end - start <= num_visible, - "One tile has more hits than total visible splats. Start {start} ending at {end}" - ); - } - } - - // Validate compact_gid_from_isect - // Only validate when there are visible splats (if num_visible=0, no valid intersections) - if num_visible > 0 { - let data = compact_gid_from_isect.into_data(); - - // Handle both I32 and U32 tensor types - let compact_gid_vec: Vec = data - .clone() - .into_vec::() - .or_else(|_| { - data.into_vec::() - .map(|v| v.into_iter().map(|x| x as u32).collect()) - }) - .expect("Failed to fetch compact_gid_from_isect"); - - for compact_gid in &compact_gid_vec { - assert!( - *compact_gid < num_visible, - "Invalid gaussian ID in intersection buffer. {compact_gid} out of {num_visible}." - ); - } + // Validate tile_offsets + let tile_offsets: Tensor = Tensor::from_primitive(self.tile_offsets.clone()); + let tile_offsets_data = tile_offsets + .into_data() + .into_vec::() + .expect("Failed to fetch tile offsets"); + + for i in 0..(tile_offsets_data.len() / 2) { + let start = tile_offsets_data[i * 2]; + let end = tile_offsets_data[i * 2 + 1]; + assert!( + end >= start, + "Invalid tile offsets: start {start} > end {end}" + ); + assert!( + end - start <= num_visible, + "Tile has more hits ({}) than visible splats ({num_visible})", + end - start + ); } } } diff --git a/crates/brush-render/src/tests/mod.rs b/crates/brush-render/src/tests/mod.rs index 48241c0a..faa66a63 100644 --- a/crates/brush-render/src/tests/mod.rs +++ b/crates/brush-render/src/tests/mod.rs @@ -1,8 +1,7 @@ use crate::{ - MainBackend, SplatOps, + MainBackend, RenderAux, SplatOps, camera::Camera, gaussian_splats::SplatRenderMode, - render_aux::{ProjectAux, RasterizeAux, validate_render_output}, }; use assert_approx_eq::assert_approx_eq; use burn::tensor::{Distribution, Tensor, TensorPrimitive}; @@ -21,13 +20,9 @@ fn render_splats_test( render_mode: SplatRenderMode, background: Vec3, bwd_info: bool, -) -> ( - Tensor, - ProjectAux, - RasterizeAux, -) { +) -> (Tensor, RenderAux) { // First pass: project - let project_aux = MainBackend::project( + let project_output = MainBackend::project( cam, img_size, means.into_primitive().tensor(), @@ -38,18 +33,22 @@ fn render_splats_test( render_mode, ); + // Validate project output + project_output.validate(); + // Sync readback of num_intersections - let num_intersections = project_aux.num_intersections(); + let num_intersections = project_output.num_intersections(); - // Second pass: rasterize - let (out_img, rasterize_aux) = - MainBackend::rasterize(&project_aux, num_intersections, background, bwd_info); + // Second pass: rasterize (drop compact_gid_from_isect - only needed for backward) + let (out_img, render_aux, _) = + MainBackend::rasterize(&project_output, num_intersections, background, bwd_info); - let img = Tensor::from_primitive(TensorPrimitive::Float(out_img)); + // Validate render aux + render_aux.validate(); - validate_render_output(&project_aux, &rasterize_aux); + let img = Tensor::from_primitive(TensorPrimitive::Float(out_img)); - (img, project_aux, rasterize_aux) + (img, render_aux) } #[test] @@ -75,7 +74,7 @@ fn renders_at_all() { .repeat_dim(0, num_points); let sh_coeffs = Tensor::::ones([num_points, 1, 3], &device); let raw_opacity = Tensor::::zeros([num_points], &device); - let (output, _project_aux, _rasterize_aux) = render_splats_test( + let (output, _render_aux) = render_splats_test( &cam, img_size, means, @@ -143,7 +142,7 @@ fn renders_many_splats() { let raw_opacity = Tensor::::random([num_splats], Distribution::Uniform(-2.0, 2.0), &device); - let (_output, _project_aux, _rasterize_aux) = render_splats_test( + let (_output, _render_aux) = render_splats_test( &cam, img_size, means, diff --git a/crates/brush-train/src/eval.rs b/crates/brush-train/src/eval.rs index 52ce4b29..a8cc17f1 100644 --- a/crates/brush-train/src/eval.rs +++ b/crates/brush-train/src/eval.rs @@ -5,8 +5,7 @@ use anyhow::Result; use brush_dataset::scene::{sample_to_tensor_data, view_to_sample_image}; use brush_render::camera::Camera; use brush_render::gaussian_splats::Splats; -use brush_render::render_aux::{ProjectAux, RasterizeAux}; -use brush_render::{AlphaMode, SplatOps}; +use brush_render::{AlphaMode, RenderAux, SplatOps}; #[cfg(target_family = "wasm")] use brush_render::render::calc_tile_bounds; @@ -22,8 +21,7 @@ pub struct EvalSample { pub rendered: Tensor, pub psnr: Tensor, pub ssim: Tensor, - pub project_aux: ProjectAux, - pub rasterize_aux: RasterizeAux, + pub render_aux: RenderAux, } pub fn eval_stats>( @@ -41,9 +39,9 @@ pub fn eval_stats>( let gt_rgb = gt_tensor.slice(s![.., .., 0..3]); // Render on reference black background using split pipeline. - let (img, project_aux, rasterize_aux) = { + let (img, render_aux) = { // First pass: project - let project_aux = B::project( + let project_output = B::project( gt_cam, res, splats.means.val().into_primitive().tensor(), @@ -56,7 +54,7 @@ pub fn eval_stats>( // Sync readback of num_intersections #[cfg(not(target_family = "wasm"))] - let num_intersections = project_aux.num_intersections(); + let num_intersections = project_output.num_intersections(); #[cfg(target_family = "wasm")] let num_intersections = { @@ -68,10 +66,10 @@ pub fn eval_stats>( max_possible.min(2 * 512 * 65535) }; - // Second pass: rasterize (with bwd_info = true for eval) - let (out_img, rasterize_aux) = B::rasterize(&project_aux, num_intersections, Vec3::ZERO, true); + // Second pass: rasterize (with bwd_info = true for eval, drop compact_gid) + let (out_img, render_aux, _) = B::rasterize(&project_output, num_intersections, Vec3::ZERO, true); - (Tensor::from_primitive(TensorPrimitive::Float(out_img)), project_aux, rasterize_aux) + (Tensor::from_primitive(TensorPrimitive::Float(out_img)), render_aux) }; let render_rgb = img.slice(s![.., .., 0..3]); @@ -89,8 +87,7 @@ pub fn eval_stats>( psnr, ssim, rendered: render_rgb, - project_aux, - rasterize_aux, + render_aux, }) } diff --git a/crates/brush-train/src/train.rs b/crates/brush-train/src/train.rs index eadd12cd..82156b6e 100644 --- a/crates/brush-train/src/train.rs +++ b/crates/brush-train/src/train.rs @@ -116,7 +116,7 @@ impl SplatTrainer { let has_alpha = batch.has_alpha(); let gt_tensor = Tensor::from_data(batch.img_tensor, &device); - let (pred_image, project_aux, rasterize_aux, refine_weight_holder) = trace_span!("Forward").in_scope(|| { + let (pred_image, render_aux, refine_weight_holder) = trace_span!("Forward").in_scope(|| { // Could generate a random background color, but so far // results just seem worse. let background = Vec3::ZERO; @@ -130,16 +130,16 @@ impl SplatTrainer { let img = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); - (img, diff_out.project_aux, diff_out.rasterize_aux, diff_out.refine_weight_holder) + (img, diff_out.render_aux, diff_out.refine_weight_holder) }); let median_scale = self.bounds.median_size(); - let num_visible = project_aux.get_num_visible().inner(); + let num_visible = render_aux.get_num_visible().inner(); let pred_rgb = pred_image.clone().slice(s![.., .., 0..3]); let gt_rgb = gt_tensor.clone().slice(s![.., .., 0..3]); let visible: Tensor, 1> = - Tensor::from_primitive(TensorPrimitive::Float(rasterize_aux.visible)); + Tensor::from_primitive(TensorPrimitive::Float(render_aux.visible)); let loss = trace_span!("Calculate losses").in_scope(|| { let l1_rgb = (pred_rgb.clone() - gt_rgb.clone()).abs(); diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 4b095c07..b7d046af 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -250,7 +250,7 @@ impl ScenePanel { load_option } - fn start_loading(&self, source: DataSource, process: &UiProcess) { + fn start_loading(#[allow(clippy::unused_self)] &self, source: DataSource, process: &UiProcess) { process.connect_to_process(create_process( source, #[cfg(feature = "training")] @@ -334,7 +334,7 @@ impl ScenePanel { if pixel_size.x > 8 && pixel_size.y > 8 && dirty { let _span = trace_span!("Render splats").entered(); // Could add an option for background color. - let (img, _, _) = render_splats( + let (img, _render_aux) = render_splats( &splats, &camera, pixel_size, diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index ed0ae44f..7799df19 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -172,7 +172,7 @@ impl eframe::App for App { return; }; - let (img, _) = render_splats( + let (img, _render_aux) = render_splats( &msg.splats, &self.camera, glam::uvec2(self.image.width(), self.image.height()), From 727ef08b4f5dbfdb7091ebff5fb50bd692201c16 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 19 Jan 2026 22:22:14 +0100 Subject: [PATCH 05/29] More cleanup --- crates/brush-render-bwd/src/burn_glue.rs | 647 +++++++++++------- crates/brush-render-bwd/src/lib.rs | 8 +- crates/brush-render-bwd/src/render_bwd.rs | 180 ++--- .../src/shaders/project_backwards.wgsl | 2 - .../src/shaders/rasterize_backwards.wgsl | 2 - crates/brush-render-bwd/src/tests.rs | 117 ---- crates/brush-render/src/gaussian_splats.rs | 5 +- crates/brush-render/src/lib.rs | 10 +- crates/brush-render/src/render.rs | 18 +- crates/brush-render/src/render_aux.rs | 6 +- .../src/shaders/project_forward.wgsl | 1 - .../src/shaders/project_visible.wgsl | 1 - .../brush-render/src/shaders/rasterize.wgsl | 2 - crates/brush-render/src/tests/mod.rs | 8 +- crates/brush-sort/src/lib.rs | 34 +- crates/brush-train/src/eval.rs | 42 +- 16 files changed, 523 insertions(+), 560 deletions(-) delete mode 100644 crates/brush-render-bwd/src/tests.rs diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index b50ab05c..42d13859 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -30,107 +30,100 @@ use burn_fusion::{ use burn_ir::{CustomOpIr, HandleContainer, OperationIr, OperationOutput, TensorIr}; use glam::Vec3; -use crate::render_bwd::SplatGrads; - -/// Like [`SplatOps`], but for backends that support differentiation. -/// -/// This shouldn't be a separate trait, but atm is needed because of orphan trait rules. -pub trait SplatForwardDiff { - /// Render splats to a buffer. - /// - /// This projects the gaussians, sorts them, and rasterizes them to a buffer, in a - /// differentiable way. - #[allow(clippy::too_many_arguments)] - fn render_splats( - camera: &Camera, - img_size: glam::UVec2, - means: FloatTensor, - log_scales: FloatTensor, - quats: FloatTensor, - sh_coeffs: FloatTensor, - raw_opacity: FloatTensor, - render_mode: SplatRenderMode, - background: Vec3, - ) -> SplatOutputDiff; +/// State needed for the rasterize backward pass. +#[derive(Debug, Clone)] +pub struct RasterizeBwdState { + pub out_img: FloatTensor, + pub projected_splats: FloatTensor, + pub global_from_compact_gid: IntTensor, + pub compact_gid_from_isect: IntTensor, + pub tile_offsets: IntTensor, + pub background: Vec3, + pub img_size: glam::UVec2, } -pub trait SplatBackwardOps { - /// Backward pass for `render_splats`. - /// - /// Do not use directly, `render_splats` will use this to calculate gradients. - #[allow(unused_variables)] - fn render_splats_bwd( - state: GaussianBackwardState, - v_output: FloatTensor, - ) -> SplatGrads; +/// State needed for the project backward pass. +#[derive(Debug, Clone)] +pub struct ProjectBwdState { + pub means: FloatTensor, + pub log_scales: FloatTensor, + pub quats: FloatTensor, + pub raw_opac: FloatTensor, + pub num_visible: IntTensor, + pub global_from_compact_gid: IntTensor, + pub project_uniforms: ProjectUniforms, + pub sh_degree: u32, + pub render_mode: SplatRenderMode, } -/// State from the `ProjectPrepare` pass needed for backward computation. +/// Intermediate gradients from the rasterize backward pass. #[derive(Debug, Clone)] -pub struct ProjectBackwardState { - pub(crate) means: FloatTensor, - pub(crate) quats: FloatTensor, - pub(crate) log_scales: FloatTensor, - pub(crate) raw_opac: FloatTensor, - pub(crate) projected_splats: FloatTensor, - pub(crate) project_uniforms: ProjectUniforms, - pub(crate) num_visible: IntTensor, - pub(crate) global_from_compact_gid: IntTensor, - pub(crate) render_mode: SplatRenderMode, - pub(crate) sh_degree: u32, - pub(crate) background: Vec3, +pub struct RasterizeGrads { + /// Gradients w.r.t. projected splat data [`num_points`, 8]. + pub v_projected_splats: FloatTensor, + /// Gradients w.r.t. raw opacity from rasterization [`num_points`]. + pub v_raw_opac: FloatTensor, + /// Refinement weights for densification [`num_points`]. + pub v_refine_weight: FloatTensor, } -/// State from the Rasterize pass needed for backward computation. +/// Final gradients w.r.t. splat inputs from the project backward pass. #[derive(Debug, Clone)] -pub struct RasterizeBackwardState { - pub(crate) out_img: FloatTensor, - pub(crate) compact_gid_from_isect: IntTensor, - pub(crate) tile_offsets: IntTensor, +pub struct SplatGrads { + pub v_means: FloatTensor, + pub v_quats: FloatTensor, + pub v_scales: FloatTensor, + pub v_coeffs: FloatTensor, + pub v_raw_opac: FloatTensor, + pub v_refine_weight: FloatTensor, } -/// Combined backward state for compatibility with existing code. -#[derive(Debug, Clone)] -pub struct GaussianBackwardState { - pub(crate) means: FloatTensor, - pub(crate) quats: FloatTensor, - pub(crate) log_scales: FloatTensor, - pub(crate) raw_opac: FloatTensor, - pub(crate) out_img: FloatTensor, - pub(crate) projected_splats: FloatTensor, - pub(crate) project_uniforms: ProjectUniforms, - pub(crate) num_visible: IntTensor, - pub(crate) compact_gid_from_isect: IntTensor, - pub(crate) global_from_compact_gid: IntTensor, - pub(crate) tile_offsets: IntTensor, - pub(crate) render_mode: SplatRenderMode, - pub(crate) sh_degree: u32, - pub(crate) background: Vec3, +/// Backward pass trait mirroring [`SplatOps`]. +/// +/// Provides backward methods for each forward pass: +/// - `rasterize_bwd`: computes gradients w.r.t. projected splats +/// - `project_bwd`: computes gradients w.r.t. original inputs +/// +/// These are called in reverse order during backpropagation. +pub trait SplatBwdOps: SplatOps { + /// Backward pass for rasterization. + /// + /// Takes the upstream gradient `v_output` and produces intermediate gradients + /// w.r.t. the projected splat representation. + fn rasterize_bwd(state: RasterizeBwdState, v_output: FloatTensor) -> RasterizeGrads; + + /// Backward pass for projection. + /// + /// Takes the intermediate gradients from `rasterize_bwd` and produces + /// the final gradients w.r.t. the original splat inputs. + fn project_bwd(state: ProjectBwdState, rasterize_grads: RasterizeGrads) -> SplatGrads; } -impl GaussianBackwardState { - /// Construct combined state from project and rasterize backward states. - pub fn from_parts( - project: ProjectBackwardState, - rasterize: RasterizeBackwardState, - ) -> Self { - Self { - means: project.means, - quats: project.quats, - log_scales: project.log_scales, - raw_opac: project.raw_opac, - projected_splats: project.projected_splats, - project_uniforms: project.project_uniforms, - num_visible: project.num_visible, - global_from_compact_gid: project.global_from_compact_gid, - render_mode: project.render_mode, - sh_degree: project.sh_degree, - background: project.background, - out_img: rasterize.out_img, - compact_gid_from_isect: rasterize.compact_gid_from_isect, - tile_offsets: rasterize.tile_offsets, - } - } +/// State saved during forward pass for backward computation. +#[derive(Debug, Clone)] +struct GaussianBackwardState { + // Original inputs (needed for project_bwd) + means: FloatTensor, + quats: FloatTensor, + log_scales: FloatTensor, + raw_opac: FloatTensor, + + // From project forward (needed for both bwd passes) + projected_splats: FloatTensor, + project_uniforms: ProjectUniforms, + num_visible: IntTensor, + global_from_compact_gid: IntTensor, + + // From rasterize forward (needed for rasterize_bwd) + out_img: FloatTensor, + compact_gid_from_isect: IntTensor, + tile_offsets: IntTensor, + + // Settings + render_mode: SplatRenderMode, + sh_degree: u32, + background: Vec3, + img_size: glam::UVec2, } #[derive(Debug)] @@ -139,7 +132,7 @@ struct RenderBackwards; const NUM_BWD_ARGS: usize = 6; // Implement gradient registration when rendering backwards. -impl> Backward for RenderBackwards { +impl> Backward for RenderBackwards { type State = GaussianBackwardState; fn backward( @@ -151,7 +144,6 @@ impl> Backward for RenderBackw let _span = tracing::trace_span!("render_gaussians backwards").entered(); let state = ops.state; - let v_output = grads.consume::(&ops.node); // Register gradients for parent nodes (This code is already skipped entirely @@ -165,31 +157,55 @@ impl> Backward for RenderBackw raw_opacity_parent, ] = ops.parents; - let v_tens = B::render_splats_bwd(state, v_output); + // Step 1: Rasterize backward + let rasterize_state = RasterizeBwdState { + out_img: state.out_img, + projected_splats: state.projected_splats, + global_from_compact_gid: state.global_from_compact_gid.clone(), + compact_gid_from_isect: state.compact_gid_from_isect, + tile_offsets: state.tile_offsets, + background: state.background, + img_size: state.img_size, + }; + let rasterize_grads = B::rasterize_bwd(rasterize_state, v_output); + + // Step 2: Project backward + let project_state = ProjectBwdState { + means: state.means, + log_scales: state.log_scales, + quats: state.quats, + raw_opac: state.raw_opac, + num_visible: state.num_visible, + global_from_compact_gid: state.global_from_compact_gid, + project_uniforms: state.project_uniforms, + sh_degree: state.sh_degree, + render_mode: state.render_mode, + }; + let splat_grads = B::project_bwd(project_state, rasterize_grads); if let Some(node) = mean_parent { - grads.register::(node.id, v_tens.v_means); + grads.register::(node.id, splat_grads.v_means); } // Register the gradients for the dummy xy input. if let Some(node) = refine_weight { - grads.register::(node.id, v_tens.v_refine_weight); + grads.register::(node.id, splat_grads.v_refine_weight); } if let Some(node) = log_scales_parent { - grads.register::(node.id, v_tens.v_scales); + grads.register::(node.id, splat_grads.v_scales); } if let Some(node) = quats_parent { - grads.register::(node.id, v_tens.v_quats); + grads.register::(node.id, splat_grads.v_quats); } if let Some(node) = coeffs_parent { - grads.register::(node.id, v_tens.v_coeffs); + grads.register::(node.id, splat_grads.v_coeffs); } if let Some(node) = raw_opacity_parent { - grads.register::(node.id, v_tens.v_raw_opac); + grads.register::(node.id, splat_grads.v_raw_opac); } } } @@ -200,123 +216,269 @@ pub struct SplatOutputDiff { pub refine_weight_holder: Tensor, } -// Implement -impl + SplatOps, C: CheckpointStrategy> SplatForwardDiff - for Autodiff +/// Render splats on a differentiable backend. +/// +/// This is the main entry point for differentiable rendering, wrapping +/// the forward pass with autodiff support. +pub fn render_splats( + splats: &Splats>, + camera: &Camera, + img_size: glam::UVec2, + background: Vec3, +) -> SplatOutputDiff> +where + B: Backend + SplatBwdOps, + C: CheckpointStrategy, { - fn render_splats( - camera: &Camera, - img_size: glam::UVec2, - means: FloatTensor, - log_scales: FloatTensor, - quats: FloatTensor, - sh_coeffs: FloatTensor, - raw_opacity: FloatTensor, - render_mode: SplatRenderMode, - background: Vec3, - ) -> SplatOutputDiff { - // Get backend tensors & dequantize if needed. Could try and support quantized inputs - // in the future. - let device = - Tensor::::from_primitive(TensorPrimitive::Float(means.clone())).device(); - let refine_weight_holder = Tensor::::zeros([1], &device).require_grad(); - - // Prepare backward pass, and check if we even need to do it. Store nodes that need gradients. - let prep_nodes = RenderBackwards - .prepare::([ - means.node.clone(), - refine_weight_holder.clone().into_primitive().tensor().node, - log_scales.node.clone(), - quats.node.clone(), - sh_coeffs.node.clone(), - raw_opacity.node.clone(), - ]) - .compute_bound() - .stateful(); - - // First pass: project - let project_output = >::project( - camera, - img_size, - means.clone().into_primitive(), - log_scales.clone().into_primitive(), - quats.clone().into_primitive(), - sh_coeffs.clone().into_primitive(), - raw_opacity.clone().into_primitive(), - render_mode, - ); + splats.validate_values(); - // Sync readback of num_intersections - let num_intersections = project_output.num_intersections(); + let device = Tensor::, 2>::from_primitive(TensorPrimitive::Float( + splats.means.val().into_primitive().tensor(), + )) + .device(); + let refine_weight_holder = Tensor::, 1>::zeros([1], &device).require_grad(); + + // Prepare backward pass, and check if we even need to do it. + let prep_nodes = RenderBackwards + .prepare::([ + splats.means.val().into_primitive().tensor().node, + refine_weight_holder.clone().into_primitive().tensor().node, + splats.log_scales.val().into_primitive().tensor().node, + splats.rotations.val().into_primitive().tensor().node, + splats.sh_coeffs.val().into_primitive().tensor().node, + splats.raw_opacities.val().into_primitive().tensor().node, + ]) + .compute_bound() + .stateful(); + + let means = splats + .means + .val() + .into_primitive() + .tensor() + .into_primitive(); + let log_scales = splats + .log_scales + .val() + .into_primitive() + .tensor() + .into_primitive(); + let quats = splats + .rotations + .val() + .into_primitive() + .tensor() + .into_primitive(); + let sh_coeffs_dims = splats.sh_coeffs.dims(); + let sh_coeffs = splats + .sh_coeffs + .val() + .into_primitive() + .tensor() + .into_primitive(); + let raw_opacity = splats + .raw_opacities + .val() + .into_primitive() + .tensor() + .into_primitive(); + + // First pass: project + let project_output = >::project( + camera, + img_size, + means.clone(), + log_scales.clone(), + quats.clone(), + sh_coeffs, + raw_opacity.clone(), + splats.render_mode, + ); - // Second pass: rasterize (with bwd_info = true) - let (out_img, render_aux, compact_gid_from_isect) = - >::rasterize(&project_output, num_intersections, background, true); + // Sync readback of num_intersections + let num_intersections = project_output.read_num_intersections(); + + // Second pass: rasterize (with bwd_info = true) + let (out_img, render_aux, compact_gid_from_isect) = + >::rasterize(&project_output, num_intersections, background, true); + + // Create wrapped render_aux for Autodiff backend + let wrapped_render_aux = RenderAux::> { + num_visible: render_aux.num_visible.clone(), + num_intersections: render_aux.num_intersections, + visible: as AutodiffBackend>::from_inner(render_aux.visible.clone()), + tile_offsets: render_aux.tile_offsets.clone(), + img_size: render_aux.img_size, + }; + + let sh_degree = sh_degree_from_coeffs(sh_coeffs_dims[1] as u32); + + match prep_nodes { + OpsKind::Tracked(prep) => { + // Save state needed for backward pass. + let state = GaussianBackwardState { + means, + log_scales, + quats, + raw_opac: raw_opacity, + sh_degree, + out_img: out_img.clone(), + projected_splats: project_output.projected_splats, + project_uniforms: project_output.project_uniforms, + num_visible: project_output.num_visible, + tile_offsets: render_aux.tile_offsets, + compact_gid_from_isect, + render_mode: splats.render_mode, + global_from_compact_gid: project_output.global_from_compact_gid, + background, + img_size, + }; + + let out_img = prep.finish(state, out_img); + + let result = SplatOutputDiff { + img: out_img, + render_aux: wrapped_render_aux, + refine_weight_holder, + }; + result.render_aux.validate(); + result + } + OpsKind::UnTracked(prep) => { + // When no node is tracked, we can just use the original operation without + // keeping any state. + let result = SplatOutputDiff { + img: prep.finish(out_img), + render_aux: wrapped_render_aux, + refine_weight_holder, + }; + result.render_aux.validate(); + result + } + } +} - // Create wrapped render_aux for Autodiff backend - let wrapped_render_aux = RenderAux:: { - num_visible: render_aux.num_visible.clone(), - num_intersections: render_aux.num_intersections, - visible: ::from_inner(render_aux.visible.clone()), - tile_offsets: render_aux.tile_offsets.clone(), - img_size: render_aux.img_size, - }; +impl SplatBwdOps for Fusion { + fn rasterize_bwd( + state: RasterizeBwdState, + v_output: FloatTensor, + ) -> RasterizeGrads { + #[derive(Debug)] + struct CustomOp { + desc: CustomOpIr, + background: Vec3, + img_size: glam::UVec2, + } - match prep_nodes { - OpsKind::Tracked(prep) => { - // Save state needed for backward pass. - let state = GaussianBackwardState { - means: means.into_primitive(), - log_scales: log_scales.into_primitive(), - quats: quats.into_primitive(), - raw_opac: raw_opacity.into_primitive(), - sh_degree: sh_degree_from_coeffs( - Tensor::::from_primitive(TensorPrimitive::Float(sh_coeffs)).dims() - [1] as u32, - ), - out_img: out_img.clone(), - projected_splats: project_output.projected_splats, - project_uniforms: project_output.project_uniforms, - num_visible: project_output.num_visible, - tile_offsets: render_aux.tile_offsets, + impl Operation> for CustomOp { + fn execute( + &self, + h: &mut HandleContainer>>, + ) { + let (inputs, outputs) = self.desc.as_fixed(); + + let [ + v_output, + out_img, + projected_splats, + global_from_compact_gid, compact_gid_from_isect, - render_mode, - global_from_compact_gid: project_output.global_from_compact_gid, - background, + tile_offsets, + ] = inputs; + + let [v_projected_splats, v_raw_opac, v_refine_weight] = outputs; + + let inner_state = RasterizeBwdState { + out_img: h.get_float_tensor::(out_img), + projected_splats: h.get_float_tensor::(projected_splats), + global_from_compact_gid: h + .get_int_tensor::(global_from_compact_gid), + compact_gid_from_isect: h + .get_int_tensor::(compact_gid_from_isect), + tile_offsets: h.get_int_tensor::(tile_offsets), + background: self.background, + img_size: self.img_size, }; - let out_img = prep.finish(state, out_img); + let grads = >::rasterize_bwd( + inner_state, + h.get_float_tensor::(v_output), + ); - SplatOutputDiff { - img: out_img, - render_aux: wrapped_render_aux, - refine_weight_holder, - } - } - OpsKind::UnTracked(prep) => { - // When no node is tracked, we can just use the original operation without - // keeping any state. - SplatOutputDiff { - img: prep.finish(out_img), - render_aux: wrapped_render_aux, - refine_weight_holder, - } + h.register_float_tensor::( + &v_projected_splats.id, + grads.v_projected_splats, + ); + h.register_float_tensor::(&v_raw_opac.id, grads.v_raw_opac); + h.register_float_tensor::( + &v_refine_weight.id, + grads.v_refine_weight, + ); } } + + let client = v_output.client.clone(); + let num_points = state.projected_splats.shape[0]; + + let v_projected_splats = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_points, 8]), + DType::F32, + ); + let v_raw_opac = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_points]), + DType::F32, + ); + let v_refine_weight = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_points]), + DType::F32, + ); + + let input_tensors = [ + v_output, + state.out_img, + state.projected_splats, + state.global_from_compact_gid, + state.compact_gid_from_isect, + state.tile_offsets, + ]; + + let stream = OperationStreams::with_inputs(&input_tensors); + let desc = CustomOpIr::new( + "rasterize_bwd", + &input_tensors.map(|t| t.into_ir()), + &[v_projected_splats, v_raw_opac, v_refine_weight], + ); + let op = CustomOp { + desc: desc.clone(), + background: state.background, + img_size: state.img_size, + }; + + let outputs = client + .register(stream, OperationIr::Custom(desc), op) + .outputs(); + + let [v_projected_splats, v_raw_opac, v_refine_weight] = outputs; + + RasterizeGrads { + v_projected_splats, + v_raw_opac, + v_refine_weight, + } } -} -impl SplatBackwardOps for Fusion { - fn render_splats_bwd( - state: GaussianBackwardState, - v_output: FloatTensor, + fn project_bwd( + state: ProjectBwdState, + rasterize_grads: RasterizeGrads, ) -> SplatGrads { #[derive(Debug)] struct CustomOp { desc: CustomOpIr, render_mode: SplatRenderMode, sh_degree: u32, - background: Vec3, project_uniforms: ProjectUniforms, } @@ -328,57 +490,63 @@ impl SplatBackwardOps for Fusion { let (inputs, outputs) = self.desc.as_fixed(); let [ - v_output, means, - quats, log_scales, + quats, raw_opac, - out_img, - projected_splats, num_visible, - tile_offsets, - compact_gid_from_isect, global_from_compact_gid, + v_projected_splats, + v_raw_opac_in, + v_refine_weight_in, ] = inputs; - let [v_means, v_quats, v_scales, v_coeffs, v_raw_opac, v_refine] = outputs; - - let inner_state = GaussianBackwardState { + let [ + v_means, + v_quats, + v_scales, + v_coeffs, + v_raw_opac, + v_refine_weight, + ] = outputs; + + let inner_state = ProjectBwdState { means: h.get_float_tensor::(means), log_scales: h.get_float_tensor::(log_scales), quats: h.get_float_tensor::(quats), raw_opac: h.get_float_tensor::(raw_opac), - out_img: h.get_float_tensor::(out_img), - projected_splats: h.get_float_tensor::(projected_splats), - project_uniforms: self.project_uniforms, num_visible: h.get_int_tensor::(num_visible), - tile_offsets: h.get_int_tensor::(tile_offsets), - compact_gid_from_isect: h - .get_int_tensor::(compact_gid_from_isect), global_from_compact_gid: h .get_int_tensor::(global_from_compact_gid), + project_uniforms: self.project_uniforms, sh_degree: self.sh_degree, render_mode: self.render_mode, - background: self.background, }; - let grads = - >::render_splats_bwd( - inner_state, - h.get_float_tensor::(v_output), - ); + let inner_rasterize_grads = RasterizeGrads { + v_projected_splats: h.get_float_tensor::(v_projected_splats), + v_raw_opac: h.get_float_tensor::(v_raw_opac_in), + v_refine_weight: h.get_float_tensor::(v_refine_weight_in), + }; + + let grads = >::project_bwd( + inner_state, + inner_rasterize_grads, + ); - // // Register output. h.register_float_tensor::(&v_means.id, grads.v_means); h.register_float_tensor::(&v_quats.id, grads.v_quats); h.register_float_tensor::(&v_scales.id, grads.v_scales); h.register_float_tensor::(&v_coeffs.id, grads.v_coeffs); h.register_float_tensor::(&v_raw_opac.id, grads.v_raw_opac); - h.register_float_tensor::(&v_refine.id, grads.v_refine_weight); + h.register_float_tensor::( + &v_refine_weight.id, + grads.v_refine_weight, + ); } } - let client = v_output.client.clone(); + let client = state.means.client.clone(); let num_points = state.means.shape[0]; let coeffs = sh_coeffs_for_degree(state.sh_degree) as usize; @@ -414,22 +582,20 @@ impl SplatBackwardOps for Fusion { ); let input_tensors = [ - v_output, state.means, - state.quats, state.log_scales, + state.quats, state.raw_opac, - state.out_img, - state.projected_splats, state.num_visible, - state.tile_offsets, - state.compact_gid_from_isect, state.global_from_compact_gid, + rasterize_grads.v_projected_splats, + rasterize_grads.v_raw_opac, + rasterize_grads.v_refine_weight, ]; let stream = OperationStreams::with_inputs(&input_tensors); let desc = CustomOpIr::new( - "render_splat_bwd", + "project_bwd", &input_tensors.map(|t| t.into_ir()), &[ v_means, @@ -449,7 +615,6 @@ impl SplatBackwardOps for Fusion { desc, sh_degree: state.sh_degree, render_mode: state.render_mode, - background: state.background, project_uniforms: state.project_uniforms, }, ) @@ -474,29 +639,3 @@ impl SplatBackwardOps for Fusion { } } } - -/// Render splats on a differentiable backend. -pub fn render_splats( - splats: &Splats, - camera: &Camera, - img_size: glam::UVec2, - background: Vec3, -) -> SplatOutputDiff -where - B: Backend + SplatForwardDiff, -{ - splats.validate_values(); - let result = B::render_splats( - camera, - img_size, - splats.means.val().into_primitive().tensor(), - splats.log_scales.val().into_primitive().tensor(), - splats.rotations.val().into_primitive().tensor(), - splats.sh_coeffs.val().into_primitive().tensor(), - splats.raw_opacities.val().into_primitive().tensor(), - splats.render_mode, - background, - ); - result.render_aux.validate(); - result -} diff --git a/crates/brush-render-bwd/src/lib.rs b/crates/brush-render-bwd/src/lib.rs index 39698ab2..9985efe4 100644 --- a/crates/brush-render-bwd/src/lib.rs +++ b/crates/brush-render-bwd/src/lib.rs @@ -1,7 +1,7 @@ pub mod burn_glue; mod render_bwd; -pub use burn_glue::render_splats; - -#[cfg(test)] -mod tests; +pub use burn_glue::{ + ProjectBwdState, RasterizeBwdState, RasterizeGrads, SplatBwdOps, SplatGrads, SplatOutputDiff, + render_splats, +}; diff --git a/crates/brush-render-bwd/src/render_bwd.rs b/crates/brush-render-bwd/src/render_bwd.rs index 22e05d07..fc4ae864 100644 --- a/crates/brush-render-bwd/src/render_bwd.rs +++ b/crates/brush-render-bwd/src/render_bwd.rs @@ -1,20 +1,19 @@ use brush_kernel::{CubeCount, calc_cube_count_1d, create_meta_binding}; use brush_render::gaussian_splats::SplatRenderMode; use brush_render::shaders::helpers::RasterizeUniforms; +use brush_render::MainBackendBase; use brush_wgsl::wgsl_kernel; -use brush_render::MainBackendBase; use brush_render::sh::sh_coeffs_for_degree; use burn::tensor::FloatDType; -use burn::tensor::ops::FloatTensorOps; -use burn::{prelude::Backend, tensor::ops::FloatTensor}; +use burn::tensor::ops::{FloatTensor, FloatTensorOps}; use burn_cubecl::cubecl::features::TypeUsage; use burn_cubecl::cubecl::ir::{ElemType, FloatKind, StorageType}; use burn_cubecl::cubecl::server::Bindings; use burn_cubecl::kernel::into_contiguous; use glam::uvec2; -use crate::burn_glue::{GaussianBackwardState, SplatBackwardOps}; +use crate::burn_glue::{ProjectBwdState, RasterizeBwdState, RasterizeGrads, SplatBwdOps, SplatGrads}; // Kernel definitions using proc macro #[wgsl_kernel( @@ -34,62 +33,33 @@ pub struct RasterizeBackwards { pub webgpu: bool, } -#[derive(Debug, Clone)] -pub struct SplatGrads { - pub v_means: FloatTensor, - pub v_quats: FloatTensor, - pub v_scales: FloatTensor, - pub v_coeffs: FloatTensor, - pub v_raw_opac: FloatTensor, - pub v_refine_weight: FloatTensor, -} - -impl SplatBackwardOps for MainBackendBase { - fn render_splats_bwd( - state: GaussianBackwardState, +impl SplatBwdOps for MainBackendBase { + fn rasterize_bwd( + state: RasterizeBwdState, v_output: FloatTensor, - ) -> SplatGrads { + ) -> RasterizeGrads { + let _span = tracing::trace_span!("rasterize_bwd").entered(); + // Comes from loss, might not be contiguous. let v_output = into_contiguous(v_output); - // Comes from params, might not be contiguous. - let means = into_contiguous(state.means); - let log_scales = into_contiguous(state.log_scales); - let quats = into_contiguous(state.quats); - let raw_opac = into_contiguous(state.raw_opac); - // We're in charge of these, SHOULD be contiguous but might as well. let projected_splats = into_contiguous(state.projected_splats); let compact_gid_from_isect = into_contiguous(state.compact_gid_from_isect); let global_from_compact_gid = into_contiguous(state.global_from_compact_gid); let tile_offsets = into_contiguous(state.tile_offsets); + let out_img = into_contiguous(state.out_img); - let device = &state.out_img.device; - let img_dimgs = state.out_img.shape.dims; - let img_size = glam::uvec2(img_dimgs[1] as u32, img_dimgs[0] as u32); - - let num_points = means.shape.dims[0]; - - let client = &means.client; + let device = &out_img.device; + let img_size = state.img_size; + let num_points = projected_splats.shape.dims[0]; - // Setup tensors. - // Nb: these are packed vec3 values, special care is taken in the kernel to respect alignment. - let v_means = Self::float_zeros([num_points, 3].into(), device, FloatDType::F32); + let client = &projected_splats.client; - let v_scales = Self::float_zeros([num_points, 3].into(), device, FloatDType::F32); - let v_quats = Self::float_zeros([num_points, 4].into(), device, FloatDType::F32); - let v_coeffs = Self::float_zeros( - [ - num_points, - sh_coeffs_for_degree(state.sh_degree) as usize, - 3, - ] - .into(), - device, - FloatDType::F32, - ); + // Setup output tensors. + let v_projected_splats = + Self::float_zeros([num_points, 8].into(), device, FloatDType::F32); let v_raw_opac = Self::float_zeros([num_points].into(), device, FloatDType::F32); - let v_grads = Self::float_zeros([num_points, 8].into(), device, FloatDType::F32); let v_refine_weight = Self::float_zeros([num_points].into(), device, FloatDType::F32); let tile_bounds = uvec2( @@ -101,7 +71,7 @@ impl SplatBackwardOps for MainBackendBase { .div_ceil(brush_render::shaders::helpers::TILE_WIDTH), ); - // Create RasterizeUniforms for the backward rasterize pass (passed via with_metadata) + // Create RasterizeUniforms for the backward rasterize pass let rasterize_uniforms = RasterizeUniforms { tile_bounds: tile_bounds.into(), img_size: img_size.into(), @@ -114,9 +84,7 @@ impl SplatBackwardOps for MainBackendBase { .contains(TypeUsage::AtomicAdd); let webgpu = cfg!(target_family = "wasm"); - let mip_splat = matches!(state.render_mode, SplatRenderMode::Mip); - // Use checked execution, as the atomic loops are potentially unbounded. tracing::trace_span!("RasterizeBackwards").in_scope(|| { // SAFETY: Kernel checked to have no OOB, bounded loops. unsafe { @@ -127,12 +95,12 @@ impl SplatBackwardOps for MainBackendBase { Bindings::new() .with_buffers(vec![ compact_gid_from_isect.handle.binding(), - global_from_compact_gid.handle.clone().binding(), + global_from_compact_gid.handle.binding(), tile_offsets.handle.binding(), projected_splats.handle.binding(), - state.out_img.handle.binding(), + out_img.handle.binding(), v_output.handle.binding(), - v_grads.handle.clone().binding(), + v_projected_splats.handle.clone().binding(), v_raw_opac.handle.clone().binding(), v_refine_weight.handle.clone().binding(), ]) @@ -142,45 +110,95 @@ impl SplatBackwardOps for MainBackendBase { } }); - tracing::trace_span!("ProjectBackwards").in_scope(|| - // SAFETY: Kernel has to contain no OOB indexing, bounded loops. - unsafe { - client.launch_unchecked( - ProjectBackwards::task(mip_splat), - calc_cube_count_1d(num_points as u32, ProjectBackwards::WORKGROUP_SIZE[0]), - Bindings::new() - .with_buffers(vec![ - state.num_visible.handle.binding(), - means.handle.binding(), - log_scales.handle.binding(), - quats.handle.binding(), - raw_opac.handle.binding(), - global_from_compact_gid.handle.binding(), - v_grads.handle.binding(), - v_means.handle.clone().binding(), - v_scales.handle.clone().binding(), - v_quats.handle.clone().binding(), - v_coeffs.handle.clone().binding(), - v_raw_opac.handle.clone().binding(), - ]) - .with_metadata(create_meta_binding(state.project_uniforms)), - ).expect("Failed to bwd-diff splats"); - }); + RasterizeGrads { + v_projected_splats, + v_raw_opac, + v_refine_weight, + } + } + + fn project_bwd( + state: ProjectBwdState, + rasterize_grads: RasterizeGrads, + ) -> SplatGrads { + let _span = tracing::trace_span!("project_bwd").entered(); + + // Comes from params, might not be contiguous. + let means = into_contiguous(state.means); + let log_scales = into_contiguous(state.log_scales); + let quats = into_contiguous(state.quats); + let raw_opac = into_contiguous(state.raw_opac); + let global_from_compact_gid = into_contiguous(state.global_from_compact_gid); + + let device = &means.device; + let num_points = means.shape.dims[0]; + let client = &means.client; + + // Setup output tensors. + let v_means = Self::float_zeros([num_points, 3].into(), device, FloatDType::F32); + let v_scales = Self::float_zeros([num_points, 3].into(), device, FloatDType::F32); + let v_quats = Self::float_zeros([num_points, 4].into(), device, FloatDType::F32); + let v_coeffs = Self::float_zeros( + [ + num_points, + sh_coeffs_for_degree(state.sh_degree) as usize, + 3, + ] + .into(), + device, + FloatDType::F32, + ); + + let mip_splat = matches!(state.render_mode, SplatRenderMode::Mip); + + tracing::trace_span!("ProjectBackwards").in_scope(|| { + // SAFETY: Kernel has to contain no OOB indexing, bounded loops. + unsafe { + client + .launch_unchecked( + ProjectBackwards::task(mip_splat), + calc_cube_count_1d(num_points as u32, ProjectBackwards::WORKGROUP_SIZE[0]), + Bindings::new() + .with_buffers(vec![ + state.num_visible.handle.binding(), + means.handle.binding(), + log_scales.handle.binding(), + quats.handle.binding(), + raw_opac.handle.binding(), + global_from_compact_gid.handle.binding(), + rasterize_grads.v_projected_splats.handle.binding(), + v_means.handle.clone().binding(), + v_scales.handle.clone().binding(), + v_quats.handle.clone().binding(), + v_coeffs.handle.clone().binding(), + rasterize_grads.v_raw_opac.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(state.project_uniforms)), + ) + .expect("Failed to bwd-diff splats"); + } + }); assert!(v_means.is_contiguous(), "Grads must be contiguous"); assert!(v_quats.is_contiguous(), "Grads must be contiguous"); assert!(v_scales.is_contiguous(), "Grads must be contiguous"); assert!(v_coeffs.is_contiguous(), "Grads must be contiguous"); - assert!(v_raw_opac.is_contiguous(), "Grads must be contiguous"); - assert!(v_refine_weight.is_contiguous(), "Grads must be contiguous"); + assert!( + rasterize_grads.v_raw_opac.is_contiguous(), + "Grads must be contiguous" + ); + assert!( + rasterize_grads.v_refine_weight.is_contiguous(), + "Grads must be contiguous" + ); SplatGrads { v_means, v_quats, v_scales, v_coeffs, - v_raw_opac, - v_refine_weight, + v_raw_opac: rasterize_grads.v_raw_opac, + v_refine_weight: rasterize_grads.v_refine_weight, } } } diff --git a/crates/brush-render-bwd/src/shaders/project_backwards.wgsl b/crates/brush-render-bwd/src/shaders/project_backwards.wgsl index 8c8beb82..6aead3cb 100644 --- a/crates/brush-render-bwd/src/shaders/project_backwards.wgsl +++ b/crates/brush-render-bwd/src/shaders/project_backwards.wgsl @@ -16,8 +16,6 @@ @group(0) @binding(9) var v_quats: array; @group(0) @binding(10) var v_coeffs: array; @group(0) @binding(11) var v_opacs: array; - -// Uniforms via with_metadata (always last binding) @group(0) @binding(12) var uniforms: helpers::ProjectUniforms; const SH_C0: f32 = 0.2820947917738781f; diff --git a/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl b/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl index be337345..1741319c 100644 --- a/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl +++ b/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl @@ -11,7 +11,6 @@ @group(0) @binding(6) var v_splats: array>; @group(0) @binding(7) var v_opacs: array>; @group(0) @binding(8) var v_refines: array>; - // Uniforms via with_metadata (always last binding) @group(0) @binding(9) var uniforms: helpers::RasterizeUniforms; fn write_grads_atomic(id: u32, grads: f32) { @@ -27,7 +26,6 @@ @group(0) @binding(6) var v_splats: array>; @group(0) @binding(7) var v_opacs: array>; @group(0) @binding(8) var v_refines: array>; - // Uniforms via with_metadata (always last binding) @group(0) @binding(9) var uniforms: helpers::RasterizeUniforms; fn add_bitcast(cur: u32, add: f32) -> u32 { diff --git a/crates/brush-render-bwd/src/tests.rs b/crates/brush-render-bwd/src/tests.rs deleted file mode 100644 index 46cb7818..00000000 --- a/crates/brush-render-bwd/src/tests.rs +++ /dev/null @@ -1,117 +0,0 @@ -use crate::burn_glue::SplatForwardDiff; -use assert_approx_eq::assert_approx_eq; -use brush_render::{camera::Camera, gaussian_splats::SplatRenderMode}; -use burn::{ - backend::Autodiff, - tensor::{Distribution, Tensor, TensorPrimitive}, -}; -use burn_wgpu::{CubeBackend, WgpuDevice, WgpuRuntime}; -use glam::Vec3; - -type InnerBackend = CubeBackend; -type TestBackend = Autodiff; - -#[test] -fn diffs_at_all() { - // Check if backward pass doesn't hard crash or anything. - // These are some zero-sized gaussians, so we know - // what the result should look like. - let cam = Camera::new( - glam::vec3(0.0, 0.0, 0.0), - glam::Quat::IDENTITY, - 0.5, - 0.5, - glam::vec2(0.5, 0.5), - ); - let img_size = glam::uvec2(32, 32); - let device = WgpuDevice::DefaultDevice; - let num_points = 8; - let means = Tensor::::zeros([num_points, 3], &device); - let log_scales = Tensor::::ones([num_points, 3], &device) * 2.0; - let quats: Tensor = - Tensor::::from_floats(glam::Quat::IDENTITY.to_array(), &device) - .unsqueeze_dim(0) - .repeat_dim(0, num_points); - let sh_coeffs = Tensor::::ones([num_points, 1, 3], &device); - let raw_opacity = Tensor::::zeros([num_points], &device); - - let result = >::render_splats( - &cam, - img_size, - means.into_primitive().tensor(), - log_scales.into_primitive().tensor(), - quats.into_primitive().tensor(), - sh_coeffs.into_primitive().tensor(), - raw_opacity.into_primitive().tensor(), - SplatRenderMode::Default, - Vec3::ZERO, - ); - - let output: Tensor = Tensor::from_primitive(TensorPrimitive::Float(result.img)); - let rgb = output.clone().slice([0..32, 0..32, 0..3]); - let alpha = output.slice([0..32, 0..32, 3..4]); - let rgb_mean = rgb.mean().to_data().as_slice::().expect("Wrong type")[0]; - let alpha_mean = alpha - .mean() - .to_data() - .as_slice::() - .expect("Wrong type")[0]; - assert_approx_eq!(rgb_mean, 0.0, 1e-5); - assert_approx_eq!(alpha_mean, 0.0); -} - -#[test] -fn diffs_many_splats() { - // Test backward pass with a ton of splats to verify 2D dispatch works correctly. - // This exceeds the 1D 65535 * 256 = 16.7M limit. - let num_points = 30_000_000; - let cam = Camera::new( - glam::vec3(0.0, 0.0, -5.0), - glam::Quat::IDENTITY, - 0.5, - 0.5, - glam::vec2(0.5, 0.5), - ); - let img_size = glam::uvec2(64, 64); - let device = WgpuDevice::DefaultDevice; - - // Create random gaussians spread in front of the camera - let means = Tensor::::random( - [num_points, 3], - Distribution::Uniform(-2.0, 2.0), - &device, - ); - // Small scales so they don't cover everything - let log_scales = Tensor::::random( - [num_points, 3], - Distribution::Uniform(-4.0, -2.0), - &device, - ); - // Random rotations (will be normalized) - let quats = Tensor::::random( - [num_points, 4], - Distribution::Uniform(-1.0, 1.0), - &device, - ); - // Simple SH coefficients (just base color) - let sh_coeffs = Tensor::::random( - [num_points, 1, 3], - Distribution::Uniform(0.0, 1.0), - &device, - ); - // Some visible, some not - let raw_opacity = - Tensor::::random([num_points], Distribution::Uniform(-2.0, 2.0), &device); - - >::render_splats( - &cam, - img_size, - means.into_primitive().tensor(), - log_scales.into_primitive().tensor(), - quats.into_primitive().tensor(), - sh_coeffs.into_primitive().tensor(), - raw_opacity.into_primitive().tensor(), - SplatRenderMode::Default, - Vec3::ZERO, - ); -} diff --git a/crates/brush-render/src/gaussian_splats.rs b/crates/brush-render/src/gaussian_splats.rs index 0fb22b86..3e0117ef 100644 --- a/crates/brush-render/src/gaussian_splats.rs +++ b/crates/brush-render/src/gaussian_splats.rs @@ -267,10 +267,11 @@ pub fn render_splats>( // Validate before readback project_output.validate(); - let num_intersections = project_output.num_intersections(); + let num_intersections = project_output.read_num_intersections(); // Second pass: rasterize (drop compact_gid_from_isect - only needed for backward) - let (out_img, render_aux, _) = B::rasterize(&project_output, num_intersections, background, false); + let (out_img, render_aux, _) = + B::rasterize(&project_output, num_intersections, background, false); // Validate rasterize outputs render_aux.validate(); diff --git a/crates/brush-render/src/lib.rs b/crates/brush-render/src/lib.rs index dd3bdb94..ec7b82ae 100644 --- a/crates/brush-render/src/lib.rs +++ b/crates/brush-render/src/lib.rs @@ -34,7 +34,7 @@ pub mod validation; pub type MainBackendBase = CubeBackend; pub type MainBackend = Fusion; -/// Trait for the split gaussian splatting rendering pipeline. +/// Trait for the the gaussian splatting rendering pipeline. /// /// This trait provides two passes: /// 1. `project`: Culling, depth sort, projection, intersection counting, prefix sum. @@ -42,8 +42,6 @@ pub type MainBackend = Fusion; /// /// The split allows for an explicit GPU sync point between passes to read back /// the exact number of intersections needed for buffer allocation. -/// -/// Users should typically use [`render_splats`] instead of this trait directly. pub trait SplatOps { /// First pass: project gaussians and count intersections. fn project( @@ -58,12 +56,6 @@ pub trait SplatOps { ) -> ProjectOutput; /// Second pass: rasterize using projection data. - /// - /// Takes the output of [`Self::project`] along with the actual - /// `num_intersections` value from sync readback. - /// - /// Returns `(image, render_aux, compact_gid_from_isect)` where the last - /// value is only needed for backward pass and can be dropped for forward-only. fn rasterize( project_output: &ProjectOutput, num_intersections: u32, diff --git a/crates/brush-render/src/render.rs b/crates/brush-render/src/render.rs index f78fa036..fb8a4f7a 100644 --- a/crates/brush-render/src/render.rs +++ b/crates/brush-render/src/render.rs @@ -119,7 +119,6 @@ impl SplatOps for MainBackendBase { ).expect("Failed to render splats"); }); - // Step 2: DepthSort - use dynamic count to sort only up to num_visible let (_, global_from_compact_gid) = tracing::trace_span!("DepthSort").in_scope(|| { radix_argsort( depths, @@ -132,7 +131,6 @@ impl SplatOps for MainBackendBase { global_from_compact_gid }; - // Step 3: ProjectVisible with intersection counting let proj_size = size_of::() / size_of::(); let projected_splats = create_tensor([total_splats, proj_size], device, DType::F32); let splat_intersect_counts = Self::int_zeros([total_splats].into(), device, IntDType::U32); @@ -166,7 +164,6 @@ impl SplatOps for MainBackendBase { } }); - // Step 4: PrefixSum to get cumulative tile hits let cum_tiles_hit = tracing::trace_span!("PrefixSumGaussHits") .in_scope(|| prefix_sum(splat_intersect_counts)); @@ -196,7 +193,6 @@ impl SplatOps for MainBackendBase { let tile_bounds = calc_tile_bounds(img_size); let num_tiles = tile_bounds.x * tile_bounds.y; - // Create rasterize uniforms (passed via with_metadata) let rasterize_uniforms = shaders::helpers::RasterizeUniforms { tile_bounds: tile_bounds.into(), img_size: img_size.into(), @@ -214,7 +210,6 @@ impl SplatOps for MainBackendBase { MapGaussiansToIntersect::WORKGROUP_SIZE[0], ); - // Uniforms for map_gaussian_to_intersects (passed via with_metadata) let map_uniforms = shaders::map_gaussians_to_intersect::Uniforms { tile_bounds: tile_bounds.into(), }; @@ -299,7 +294,11 @@ impl SplatOps for MainBackendBase { tile_offsets.handle.clone().binding(), project_output.projected_splats.handle.clone().binding(), out_img.handle.clone().binding(), - project_output.global_from_compact_gid.handle.clone().binding(), + project_output + .global_from_compact_gid + .handle + .clone() + .binding(), visible.handle.clone().binding(), ]) .with_metadata(create_meta_binding(rasterize_uniforms)); @@ -329,13 +328,6 @@ impl SplatOps for MainBackendBase { .expect("Failed to render splats"); } - // Sanity checks - assert!( - tile_offsets.is_contiguous(), - "Tile offsets must be contiguous" - ); - assert!(visible.is_contiguous(), "Visible must be contiguous"); - ( out_img, RenderAux { diff --git a/crates/brush-render/src/render_aux.rs b/crates/brush-render/src/render_aux.rs index a2c0bbb0..70cef53b 100644 --- a/crates/brush-render/src/render_aux.rs +++ b/crates/brush-render/src/render_aux.rs @@ -1,3 +1,4 @@ +use burn::tensor::ElementConversion; use burn::{ Tensor, prelude::Backend, @@ -21,9 +22,8 @@ pub struct ProjectOutput { } impl ProjectOutput { - /// Extract the total number of intersections (sync readback). - pub fn num_intersections(&self) -> u32 { - use burn::tensor::ElementConversion; + /// Get the total number of intersections (sync readback). + pub fn read_num_intersections(&self) -> u32 { let cum_tiles_hit: Tensor = Tensor::from_primitive(self.cum_tiles_hit.clone()); let total = self.project_uniforms.total_splats as usize; if total > 0 { diff --git a/crates/brush-render/src/shaders/project_forward.wgsl b/crates/brush-render/src/shaders/project_forward.wgsl index 34369d1a..0abe49b8 100644 --- a/crates/brush-render/src/shaders/project_forward.wgsl +++ b/crates/brush-render/src/shaders/project_forward.wgsl @@ -7,7 +7,6 @@ @group(0) @binding(4) var global_from_compact_gid: array; @group(0) @binding(5) var depths: array; @group(0) @binding(6) var num_visible: atomic; -// Uniforms via with_metadata (always last binding) @group(0) @binding(7) var uniforms: helpers::ProjectUniforms; const WG_SIZE: u32 = 256u; diff --git a/crates/brush-render/src/shaders/project_visible.wgsl b/crates/brush-render/src/shaders/project_visible.wgsl index 4c694536..1d78a8db 100644 --- a/crates/brush-render/src/shaders/project_visible.wgsl +++ b/crates/brush-render/src/shaders/project_visible.wgsl @@ -14,7 +14,6 @@ struct IsectInfo { @group(0) @binding(6) var global_from_compact_gid: array; @group(0) @binding(7) var projected: array; @group(0) @binding(8) var splat_intersect_counts: array; -// Uniforms via with_metadata (always last binding) @group(0) @binding(9) var uniforms: helpers::ProjectUniforms; struct ShCoeffs { diff --git a/crates/brush-render/src/shaders/rasterize.wgsl b/crates/brush-render/src/shaders/rasterize.wgsl index b6bc9a41..8883feaf 100644 --- a/crates/brush-render/src/shaders/rasterize.wgsl +++ b/crates/brush-render/src/shaders/rasterize.wgsl @@ -8,11 +8,9 @@ @group(0) @binding(3) var out_img: array; @group(0) @binding(4) var global_from_compact_gid: array; @group(0) @binding(5) var visible: array; - // Uniforms via with_metadata (always last binding) @group(0) @binding(6) var uniforms: helpers::RasterizeUniforms; #else @group(0) @binding(3) var out_img: array; - // Uniforms via with_metadata (always last binding) @group(0) @binding(4) var uniforms: helpers::RasterizeUniforms; #endif diff --git a/crates/brush-render/src/tests/mod.rs b/crates/brush-render/src/tests/mod.rs index faa66a63..77cf6330 100644 --- a/crates/brush-render/src/tests/mod.rs +++ b/crates/brush-render/src/tests/mod.rs @@ -1,8 +1,4 @@ -use crate::{ - MainBackend, RenderAux, SplatOps, - camera::Camera, - gaussian_splats::SplatRenderMode, -}; +use crate::{MainBackend, RenderAux, SplatOps, camera::Camera, gaussian_splats::SplatRenderMode}; use assert_approx_eq::assert_approx_eq; use burn::tensor::{Distribution, Tensor, TensorPrimitive}; use burn_wgpu::WgpuDevice; @@ -37,7 +33,7 @@ fn render_splats_test( project_output.validate(); // Sync readback of num_intersections - let num_intersections = project_output.num_intersections(); + let num_intersections = project_output.read_num_intersections(); // Second pass: rasterize (drop compact_gid_from_isect - only needed for backward) let (out_img, render_aux, _) = diff --git a/crates/brush-sort/src/lib.rs b/crates/brush-sort/src/lib.rs index ec14d24b..cb63849f 100644 --- a/crates/brush-sort/src/lib.rs +++ b/crates/brush-sort/src/lib.rs @@ -71,32 +71,16 @@ pub fn radix_argsort( // Handle dynamic vs static dispatch let (num_keys_buf, num_wgs, num_reduce_wgs) = if let Some(count_buf) = dynamic_count { - // Dynamic dispatch: compute workgroup counts on GPU - // num_wgs = ceil(count / BLOCK_SIZE) let num_wgs = create_dispatch_buffer_1d(count_buf.clone(), BLOCK_SIZE); - - // The reduce shader expects: num_reduce_wgs = BIN_COUNT * ceil(num_wgs / BLOCK_SIZE) - // This is NOT the same as ceil(num_wgs * BIN_COUNT / BLOCK_SIZE) due to ceiling! - // We need: first compute ceil(num_wgs_x / BLOCK_SIZE), then multiply by BIN_COUNT. - type Backend = CubeBackend; - let num_wgs_tensor: Tensor = Tensor::from_primitive(num_wgs.clone()); - let num_wgs_x = num_wgs_tensor.slice([0..1]); // Get just the X component (scalar) - - // num_reduce_wg_per_bin = ceil(num_wgs_x / BLOCK_SIZE) - let num_reduce_wg_per_bin_buf = - create_dispatch_buffer_1d(num_wgs_x.into_primitive(), BLOCK_SIZE); - let num_reduce_wg_per_bin: Tensor = - Tensor::from_primitive(num_reduce_wg_per_bin_buf); - let num_reduce_wg_per_bin_x = num_reduce_wg_per_bin.slice([0..1]); - - // num_reduce_wgs_total = num_reduce_wg_per_bin * BIN_COUNT - let num_reduce_total: Tensor = - num_reduce_wg_per_bin_x * (SortCount::BIN_COUNT as i32); - - // Create dispatch buffer for the total (uses 2D tiling if > 65535) - let num_reduce_wgs = create_dispatch_buffer_1d(num_reduce_total.into_primitive(), 1); - - (count_buf, CubeCount::Dynamic(num_wgs.handle.binding()), CubeCount::Dynamic(num_reduce_wgs.handle.binding())) + let num_reduce_wgs: Tensor, 1, Int> = + Tensor::from_primitive(create_dispatch_buffer_1d(num_wgs.clone(), BLOCK_SIZE)) + * Tensor::from_ints([SortCount::BIN_COUNT, 1, 1], device); + let num_reduce_wgs: CubeTensor = num_reduce_wgs.into_primitive(); + ( + count_buf, + CubeCount::Dynamic(num_wgs.handle.binding()), + CubeCount::Dynamic(num_reduce_wgs.handle.binding()), + ) } else { // Static dispatch: use full buffer size let num_keys_buf = { diff --git a/crates/brush-train/src/eval.rs b/crates/brush-train/src/eval.rs index a8cc17f1..bac6b59c 100644 --- a/crates/brush-train/src/eval.rs +++ b/crates/brush-train/src/eval.rs @@ -5,12 +5,9 @@ use anyhow::Result; use brush_dataset::scene::{sample_to_tensor_data, view_to_sample_image}; use brush_render::camera::Camera; use brush_render::gaussian_splats::Splats; -use brush_render::{AlphaMode, RenderAux, SplatOps}; - -#[cfg(target_family = "wasm")] -use brush_render::render::calc_tile_bounds; +use brush_render::{AlphaMode, RenderAux, SplatOps, render_splats}; use burn::prelude::Backend; -use burn::tensor::{Tensor, TensorPrimitive, s}; +use burn::tensor::{Tensor, s}; use glam::Vec3; use image::DynamicImage; @@ -38,39 +35,8 @@ pub fn eval_stats>( let gt_tensor = Tensor::from_data(gt_tensor, device); let gt_rgb = gt_tensor.slice(s![.., .., 0..3]); - // Render on reference black background using split pipeline. - let (img, render_aux) = { - // First pass: project - let project_output = B::project( - gt_cam, - res, - splats.means.val().into_primitive().tensor(), - splats.log_scales.val().into_primitive().tensor(), - splats.rotations.val().into_primitive().tensor(), - splats.sh_coeffs.val().into_primitive().tensor(), - splats.raw_opacities.val().into_primitive().tensor(), - splats.render_mode, - ); - - // Sync readback of num_intersections - #[cfg(not(target_family = "wasm"))] - let num_intersections = project_output.num_intersections(); - - #[cfg(target_family = "wasm")] - let num_intersections = { - use burn::tensor::ops::FloatTensorOps; - let tile_bounds = calc_tile_bounds(res); - let num_tiles = tile_bounds[0] * tile_bounds[1]; - let total_splats = splats.num_splats(); - let max_possible = num_tiles.saturating_mul(total_splats); - max_possible.min(2 * 512 * 65535) - }; - - // Second pass: rasterize (with bwd_info = true for eval, drop compact_gid) - let (out_img, render_aux, _) = B::rasterize(&project_output, num_intersections, Vec3::ZERO, true); - - (Tensor::from_primitive(TensorPrimitive::Float(out_img)), render_aux) - }; + // Render on reference black background + let (img, render_aux) = render_splats(splats, gt_cam, res, Vec3::ZERO, None); let render_rgb = img.slice(s![.., .., 0..3]); // Simulate an 8-bit roundtrip for fair comparison. From a5fec63d1111293f5ae9930bf2f76848a0cc1b75 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 19 Jan 2026 22:38:14 +0100 Subject: [PATCH 06/29] Cleanup --- crates/brush-render-bwd/src/burn_glue.rs | 20 ---------------- crates/brush-render-bwd/src/render_bwd.rs | 29 +++++++++-------------- 2 files changed, 11 insertions(+), 38 deletions(-) diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index 42d13859..1bb48e58 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -78,48 +78,28 @@ pub struct SplatGrads { pub v_refine_weight: FloatTensor, } -/// Backward pass trait mirroring [`SplatOps`]. -/// -/// Provides backward methods for each forward pass: -/// - `rasterize_bwd`: computes gradients w.r.t. projected splats -/// - `project_bwd`: computes gradients w.r.t. original inputs -/// -/// These are called in reverse order during backpropagation. pub trait SplatBwdOps: SplatOps { - /// Backward pass for rasterization. - /// - /// Takes the upstream gradient `v_output` and produces intermediate gradients - /// w.r.t. the projected splat representation. fn rasterize_bwd(state: RasterizeBwdState, v_output: FloatTensor) -> RasterizeGrads; - - /// Backward pass for projection. - /// - /// Takes the intermediate gradients from `rasterize_bwd` and produces - /// the final gradients w.r.t. the original splat inputs. fn project_bwd(state: ProjectBwdState, rasterize_grads: RasterizeGrads) -> SplatGrads; } /// State saved during forward pass for backward computation. #[derive(Debug, Clone)] struct GaussianBackwardState { - // Original inputs (needed for project_bwd) means: FloatTensor, quats: FloatTensor, log_scales: FloatTensor, raw_opac: FloatTensor, - // From project forward (needed for both bwd passes) projected_splats: FloatTensor, project_uniforms: ProjectUniforms, num_visible: IntTensor, global_from_compact_gid: IntTensor, - // From rasterize forward (needed for rasterize_bwd) out_img: FloatTensor, compact_gid_from_isect: IntTensor, tile_offsets: IntTensor, - // Settings render_mode: SplatRenderMode, sh_degree: u32, background: Vec3, diff --git a/crates/brush-render-bwd/src/render_bwd.rs b/crates/brush-render-bwd/src/render_bwd.rs index fc4ae864..8ac4fcbb 100644 --- a/crates/brush-render-bwd/src/render_bwd.rs +++ b/crates/brush-render-bwd/src/render_bwd.rs @@ -1,7 +1,7 @@ use brush_kernel::{CubeCount, calc_cube_count_1d, create_meta_binding}; +use brush_render::MainBackendBase; use brush_render::gaussian_splats::SplatRenderMode; use brush_render::shaders::helpers::RasterizeUniforms; -use brush_render::MainBackendBase; use brush_wgsl::wgsl_kernel; use brush_render::sh::sh_coeffs_for_degree; @@ -13,7 +13,9 @@ use burn_cubecl::cubecl::server::Bindings; use burn_cubecl::kernel::into_contiguous; use glam::uvec2; -use crate::burn_glue::{ProjectBwdState, RasterizeBwdState, RasterizeGrads, SplatBwdOps, SplatGrads}; +use crate::burn_glue::{ + ProjectBwdState, RasterizeBwdState, RasterizeGrads, SplatBwdOps, SplatGrads, +}; // Kernel definitions using proc macro #[wgsl_kernel( @@ -57,8 +59,7 @@ impl SplatBwdOps for MainBackendBase { let client = &projected_splats.client; // Setup output tensors. - let v_projected_splats = - Self::float_zeros([num_points, 8].into(), device, FloatDType::F32); + let v_projected_splats = Self::float_zeros([num_points, 8].into(), device, FloatDType::F32); let v_raw_opac = Self::float_zeros([num_points].into(), device, FloatDType::F32); let v_refine_weight = Self::float_zeros([num_points].into(), device, FloatDType::F32); @@ -75,7 +76,12 @@ impl SplatBwdOps for MainBackendBase { let rasterize_uniforms = RasterizeUniforms { tile_bounds: tile_bounds.into(), img_size: img_size.into(), - background: [state.background.x, state.background.y, state.background.z, 1.0], + background: [ + state.background.x, + state.background.y, + state.background.z, + 1.0, + ], }; let hard_floats = client @@ -179,19 +185,6 @@ impl SplatBwdOps for MainBackendBase { } }); - assert!(v_means.is_contiguous(), "Grads must be contiguous"); - assert!(v_quats.is_contiguous(), "Grads must be contiguous"); - assert!(v_scales.is_contiguous(), "Grads must be contiguous"); - assert!(v_coeffs.is_contiguous(), "Grads must be contiguous"); - assert!( - rasterize_grads.v_raw_opac.is_contiguous(), - "Grads must be contiguous" - ); - assert!( - rasterize_grads.v_refine_weight.is_contiguous(), - "Grads must be contiguous" - ); - SplatGrads { v_means, v_quats, From fadfd00d1b32085914ffce130a1db39dcf46067f Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 19 Jan 2026 22:39:00 +0100 Subject: [PATCH 07/29] Fmt --- crates/brush-render/src/burn_glue.rs | 5 ++++- crates/brush-render/src/shaders.rs | 2 +- crates/brush-train/src/train.rs | 33 ++++++++++++++-------------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/crates/brush-render/src/burn_glue.rs b/crates/brush-render/src/burn_glue.rs index 2a79ac8e..fce4399e 100644 --- a/crates/brush-render/src/burn_glue.rs +++ b/crates/brush-render/src/burn_glue.rs @@ -1,4 +1,7 @@ -use burn::tensor::{DType, Shape, ops::{FloatTensor, IntTensor}}; +use burn::tensor::{ + DType, Shape, + ops::{FloatTensor, IntTensor}, +}; use burn_cubecl::{BoolElement, fusion::FusionCubeRuntime}; use burn_fusion::{ Fusion, FusionHandle, diff --git a/crates/brush-render/src/shaders.rs b/crates/brush-render/src/shaders.rs index 0d6d1c15..54757a0b 100644 --- a/crates/brush-render/src/shaders.rs +++ b/crates/brush-render/src/shaders.rs @@ -24,8 +24,8 @@ pub struct Rasterize { pub mod helpers { // Types used by multiple shaders - available from project_visible pub use super::project_visible::PackedVec3; - pub use super::project_visible::ProjectedSplat; pub use super::project_visible::ProjectUniforms; + pub use super::project_visible::ProjectedSplat; pub use super::rasterize::RasterizeUniforms; // Constants are now associated with the kernel structs diff --git a/crates/brush-train/src/train.rs b/crates/brush-train/src/train.rs index 82156b6e..aafee5a5 100644 --- a/crates/brush-train/src/train.rs +++ b/crates/brush-train/src/train.rs @@ -116,22 +116,23 @@ impl SplatTrainer { let has_alpha = batch.has_alpha(); let gt_tensor = Tensor::from_data(batch.img_tensor, &device); - let (pred_image, render_aux, refine_weight_holder) = trace_span!("Forward").in_scope(|| { - // Could generate a random background color, but so far - // results just seem worse. - let background = Vec3::ZERO; - - let diff_out = render_splats( - &splats, - camera, - glam::uvec2(img_w as u32, img_h as u32), - background, - ); - - let img = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); - - (img, diff_out.render_aux, diff_out.refine_weight_holder) - }); + let (pred_image, render_aux, refine_weight_holder) = + trace_span!("Forward").in_scope(|| { + // Could generate a random background color, but so far + // results just seem worse. + let background = Vec3::ZERO; + + let diff_out = render_splats( + &splats, + camera, + glam::uvec2(img_w as u32, img_h as u32), + background, + ); + + let img = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); + + (img, diff_out.render_aux, diff_out.refine_weight_holder) + }); let median_scale = self.bounds.median_size(); let num_visible = render_aux.get_num_visible().inner(); From 6c81ccf1fd42e883396a459961e3713fe9354fb3 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 19 Jan 2026 22:52:15 +0100 Subject: [PATCH 08/29] Cleanup --- crates/brush-render-bwd/src/burn_glue.rs | 249 +++++++++++----------- crates/brush-render-bwd/src/lib.rs | 5 +- crates/brush-render-bwd/src/render_bwd.rs | 65 +++--- crates/brush-render/src/render_aux.rs | 4 +- crates/brush-render/src/tests/mod.rs | 1 + 5 files changed, 154 insertions(+), 170 deletions(-) diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index 1bb48e58..fc060ce2 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -30,32 +30,6 @@ use burn_fusion::{ use burn_ir::{CustomOpIr, HandleContainer, OperationIr, OperationOutput, TensorIr}; use glam::Vec3; -/// State needed for the rasterize backward pass. -#[derive(Debug, Clone)] -pub struct RasterizeBwdState { - pub out_img: FloatTensor, - pub projected_splats: FloatTensor, - pub global_from_compact_gid: IntTensor, - pub compact_gid_from_isect: IntTensor, - pub tile_offsets: IntTensor, - pub background: Vec3, - pub img_size: glam::UVec2, -} - -/// State needed for the project backward pass. -#[derive(Debug, Clone)] -pub struct ProjectBwdState { - pub means: FloatTensor, - pub log_scales: FloatTensor, - pub quats: FloatTensor, - pub raw_opac: FloatTensor, - pub num_visible: IntTensor, - pub global_from_compact_gid: IntTensor, - pub project_uniforms: ProjectUniforms, - pub sh_degree: u32, - pub render_mode: SplatRenderMode, -} - /// Intermediate gradients from the rasterize backward pass. #[derive(Debug, Clone)] pub struct RasterizeGrads { @@ -78,9 +52,35 @@ pub struct SplatGrads { pub v_refine_weight: FloatTensor, } +/// Backward pass trait mirroring [`SplatOps`]. pub trait SplatBwdOps: SplatOps { - fn rasterize_bwd(state: RasterizeBwdState, v_output: FloatTensor) -> RasterizeGrads; - fn project_bwd(state: ProjectBwdState, rasterize_grads: RasterizeGrads) -> SplatGrads; + /// Backward pass for rasterization. + #[allow(clippy::too_many_arguments)] + fn rasterize_bwd( + out_img: FloatTensor, + projected_splats: FloatTensor, + global_from_compact_gid: IntTensor, + compact_gid_from_isect: IntTensor, + tile_offsets: IntTensor, + background: Vec3, + img_size: glam::UVec2, + v_output: FloatTensor, + ) -> RasterizeGrads; + + /// Backward pass for projection. + #[allow(clippy::too_many_arguments)] + fn project_bwd( + means: FloatTensor, + log_scales: FloatTensor, + quats: FloatTensor, + raw_opac: FloatTensor, + num_visible: IntTensor, + global_from_compact_gid: IntTensor, + project_uniforms: ProjectUniforms, + sh_degree: u32, + render_mode: SplatRenderMode, + rasterize_grads: RasterizeGrads, + ) -> SplatGrads; } /// State saved during forward pass for backward computation. @@ -138,30 +138,30 @@ impl> Backward for RenderBackwards ] = ops.parents; // Step 1: Rasterize backward - let rasterize_state = RasterizeBwdState { - out_img: state.out_img, - projected_splats: state.projected_splats, - global_from_compact_gid: state.global_from_compact_gid.clone(), - compact_gid_from_isect: state.compact_gid_from_isect, - tile_offsets: state.tile_offsets, - background: state.background, - img_size: state.img_size, - }; - let rasterize_grads = B::rasterize_bwd(rasterize_state, v_output); + let rasterize_grads = B::rasterize_bwd( + state.out_img, + state.projected_splats, + state.global_from_compact_gid.clone(), + state.compact_gid_from_isect, + state.tile_offsets, + state.background, + state.img_size, + v_output, + ); // Step 2: Project backward - let project_state = ProjectBwdState { - means: state.means, - log_scales: state.log_scales, - quats: state.quats, - raw_opac: state.raw_opac, - num_visible: state.num_visible, - global_from_compact_gid: state.global_from_compact_gid, - project_uniforms: state.project_uniforms, - sh_degree: state.sh_degree, - render_mode: state.render_mode, - }; - let splat_grads = B::project_bwd(project_state, rasterize_grads); + let splat_grads = B::project_bwd( + state.means, + state.log_scales, + state.quats, + state.raw_opac, + state.num_visible, + state.global_from_compact_gid, + state.project_uniforms, + state.sh_degree, + state.render_mode, + rasterize_grads, + ); if let Some(node) = mean_parent { grads.register::(node.id, splat_grads.v_means); @@ -339,8 +339,15 @@ where } impl SplatBwdOps for Fusion { + #[allow(clippy::too_many_arguments)] fn rasterize_bwd( - state: RasterizeBwdState, + out_img: FloatTensor, + projected_splats: FloatTensor, + global_from_compact_gid: IntTensor, + compact_gid_from_isect: IntTensor, + tile_offsets: IntTensor, + background: Vec3, + img_size: glam::UVec2, v_output: FloatTensor, ) -> RasterizeGrads { #[derive(Debug)] @@ -368,20 +375,14 @@ impl SplatBwdOps for Fusion { let [v_projected_splats, v_raw_opac, v_refine_weight] = outputs; - let inner_state = RasterizeBwdState { - out_img: h.get_float_tensor::(out_img), - projected_splats: h.get_float_tensor::(projected_splats), - global_from_compact_gid: h - .get_int_tensor::(global_from_compact_gid), - compact_gid_from_isect: h - .get_int_tensor::(compact_gid_from_isect), - tile_offsets: h.get_int_tensor::(tile_offsets), - background: self.background, - img_size: self.img_size, - }; - let grads = >::rasterize_bwd( - inner_state, + h.get_float_tensor::(out_img), + h.get_float_tensor::(projected_splats), + h.get_int_tensor::(global_from_compact_gid), + h.get_int_tensor::(compact_gid_from_isect), + h.get_int_tensor::(tile_offsets), + self.background, + self.img_size, h.get_float_tensor::(v_output), ); @@ -398,9 +399,9 @@ impl SplatBwdOps for Fusion { } let client = v_output.client.clone(); - let num_points = state.projected_splats.shape[0]; + let num_points = projected_splats.shape[0]; - let v_projected_splats = TensorIr::uninit( + let v_projected_splats_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, 8]), DType::F32, @@ -418,23 +419,23 @@ impl SplatBwdOps for Fusion { let input_tensors = [ v_output, - state.out_img, - state.projected_splats, - state.global_from_compact_gid, - state.compact_gid_from_isect, - state.tile_offsets, + out_img, + projected_splats, + global_from_compact_gid, + compact_gid_from_isect, + tile_offsets, ]; let stream = OperationStreams::with_inputs(&input_tensors); let desc = CustomOpIr::new( "rasterize_bwd", &input_tensors.map(|t| t.into_ir()), - &[v_projected_splats, v_raw_opac, v_refine_weight], + &[v_projected_splats_out, v_raw_opac, v_refine_weight], ); let op = CustomOp { desc: desc.clone(), - background: state.background, - img_size: state.img_size, + background, + img_size, }; let outputs = client @@ -450,8 +451,17 @@ impl SplatBwdOps for Fusion { } } + #[allow(clippy::too_many_arguments)] fn project_bwd( - state: ProjectBwdState, + means: FloatTensor, + log_scales: FloatTensor, + quats: FloatTensor, + raw_opac: FloatTensor, + num_visible: IntTensor, + global_from_compact_gid: IntTensor, + project_uniforms: ProjectUniforms, + sh_degree: u32, + render_mode: SplatRenderMode, rasterize_grads: RasterizeGrads, ) -> SplatGrads { #[derive(Debug)] @@ -481,27 +491,7 @@ impl SplatBwdOps for Fusion { v_refine_weight_in, ] = inputs; - let [ - v_means, - v_quats, - v_scales, - v_coeffs, - v_raw_opac, - v_refine_weight, - ] = outputs; - - let inner_state = ProjectBwdState { - means: h.get_float_tensor::(means), - log_scales: h.get_float_tensor::(log_scales), - quats: h.get_float_tensor::(quats), - raw_opac: h.get_float_tensor::(raw_opac), - num_visible: h.get_int_tensor::(num_visible), - global_from_compact_gid: h - .get_int_tensor::(global_from_compact_gid), - project_uniforms: self.project_uniforms, - sh_degree: self.sh_degree, - render_mode: self.render_mode, - }; + let [v_means, v_quats, v_scales, v_coeffs, v_raw_opac, v_refine_weight] = outputs; let inner_rasterize_grads = RasterizeGrads { v_projected_splats: h.get_float_tensor::(v_projected_splats), @@ -510,7 +500,15 @@ impl SplatBwdOps for Fusion { }; let grads = >::project_bwd( - inner_state, + h.get_float_tensor::(means), + h.get_float_tensor::(log_scales), + h.get_float_tensor::(quats), + h.get_float_tensor::(raw_opac), + h.get_int_tensor::(num_visible), + h.get_int_tensor::(global_from_compact_gid), + self.project_uniforms, + self.sh_degree, + self.render_mode, inner_rasterize_grads, ); @@ -526,48 +524,48 @@ impl SplatBwdOps for Fusion { } } - let client = state.means.client.clone(); - let num_points = state.means.shape[0]; - let coeffs = sh_coeffs_for_degree(state.sh_degree) as usize; + let client = means.client.clone(); + let num_points = means.shape[0]; + let coeffs = sh_coeffs_for_degree(sh_degree) as usize; - let v_means = TensorIr::uninit( + let v_means_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, 3]), DType::F32, ); - let v_scales = TensorIr::uninit( + let v_scales_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, 3]), DType::F32, ); - let v_quats = TensorIr::uninit( + let v_quats_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, 4]), DType::F32, ); - let v_coeffs = TensorIr::uninit( + let v_coeffs_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, coeffs, 3]), DType::F32, ); - let v_raw_opac = TensorIr::uninit( + let v_raw_opac_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points]), DType::F32, ); - let v_refine_weight = TensorIr::uninit( + let v_refine_weight_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points]), DType::F32, ); let input_tensors = [ - state.means, - state.log_scales, - state.quats, - state.raw_opac, - state.num_visible, - state.global_from_compact_gid, + means, + log_scales, + quats, + raw_opac, + num_visible, + global_from_compact_gid, rasterize_grads.v_projected_splats, rasterize_grads.v_raw_opac, rasterize_grads.v_refine_weight, @@ -578,12 +576,12 @@ impl SplatBwdOps for Fusion { "project_bwd", &input_tensors.map(|t| t.into_ir()), &[ - v_means, - v_quats, - v_scales, - v_coeffs, - v_raw_opac, - v_refine_weight, + v_means_out, + v_quats_out, + v_scales_out, + v_coeffs_out, + v_raw_opac_out, + v_refine_weight_out, ], ); @@ -593,21 +591,14 @@ impl SplatBwdOps for Fusion { OperationIr::Custom(desc.clone()), CustomOp { desc, - sh_degree: state.sh_degree, - render_mode: state.render_mode, - project_uniforms: state.project_uniforms, + sh_degree, + render_mode, + project_uniforms, }, ) .outputs(); - let [ - v_means, - v_quats, - v_scales, - v_coeffs, - v_raw_opac, - v_refine_weight, - ] = outputs; + let [v_means, v_quats, v_scales, v_coeffs, v_raw_opac, v_refine_weight] = outputs; SplatGrads { v_means, diff --git a/crates/brush-render-bwd/src/lib.rs b/crates/brush-render-bwd/src/lib.rs index 9985efe4..4bbae9e4 100644 --- a/crates/brush-render-bwd/src/lib.rs +++ b/crates/brush-render-bwd/src/lib.rs @@ -1,7 +1,4 @@ pub mod burn_glue; mod render_bwd; -pub use burn_glue::{ - ProjectBwdState, RasterizeBwdState, RasterizeGrads, SplatBwdOps, SplatGrads, SplatOutputDiff, - render_splats, -}; +pub use burn_glue::{RasterizeGrads, SplatBwdOps, SplatGrads, SplatOutputDiff, render_splats}; diff --git a/crates/brush-render-bwd/src/render_bwd.rs b/crates/brush-render-bwd/src/render_bwd.rs index 8ac4fcbb..808f9464 100644 --- a/crates/brush-render-bwd/src/render_bwd.rs +++ b/crates/brush-render-bwd/src/render_bwd.rs @@ -6,16 +6,16 @@ use brush_wgsl::wgsl_kernel; use brush_render::sh::sh_coeffs_for_degree; use burn::tensor::FloatDType; +use burn::tensor::ops::IntTensor; use burn::tensor::ops::{FloatTensor, FloatTensorOps}; use burn_cubecl::cubecl::features::TypeUsage; use burn_cubecl::cubecl::ir::{ElemType, FloatKind, StorageType}; use burn_cubecl::cubecl::server::Bindings; use burn_cubecl::kernel::into_contiguous; -use glam::uvec2; +use glam::{Vec3, uvec2}; -use crate::burn_glue::{ - ProjectBwdState, RasterizeBwdState, RasterizeGrads, SplatBwdOps, SplatGrads, -}; +use crate::burn_glue::{RasterizeGrads, SplatBwdOps, SplatGrads}; +use brush_render::shaders::helpers::ProjectUniforms; // Kernel definitions using proc macro #[wgsl_kernel( @@ -36,8 +36,15 @@ pub struct RasterizeBackwards { } impl SplatBwdOps for MainBackendBase { + #[allow(clippy::too_many_arguments)] fn rasterize_bwd( - state: RasterizeBwdState, + out_img: FloatTensor, + projected_splats: FloatTensor, + global_from_compact_gid: IntTensor, + compact_gid_from_isect: IntTensor, + tile_offsets: IntTensor, + background: Vec3, + img_size: glam::UVec2, v_output: FloatTensor, ) -> RasterizeGrads { let _span = tracing::trace_span!("rasterize_bwd").entered(); @@ -45,15 +52,7 @@ impl SplatBwdOps for MainBackendBase { // Comes from loss, might not be contiguous. let v_output = into_contiguous(v_output); - // We're in charge of these, SHOULD be contiguous but might as well. - let projected_splats = into_contiguous(state.projected_splats); - let compact_gid_from_isect = into_contiguous(state.compact_gid_from_isect); - let global_from_compact_gid = into_contiguous(state.global_from_compact_gid); - let tile_offsets = into_contiguous(state.tile_offsets); - let out_img = into_contiguous(state.out_img); - let device = &out_img.device; - let img_size = state.img_size; let num_points = projected_splats.shape.dims[0]; let client = &projected_splats.client; @@ -76,12 +75,7 @@ impl SplatBwdOps for MainBackendBase { let rasterize_uniforms = RasterizeUniforms { tile_bounds: tile_bounds.into(), img_size: img_size.into(), - background: [ - state.background.x, - state.background.y, - state.background.z, - 1.0, - ], + background: [background.x, background.y, background.z, 1.0], }; let hard_floats = client @@ -123,18 +117,26 @@ impl SplatBwdOps for MainBackendBase { } } + #[allow(clippy::too_many_arguments)] fn project_bwd( - state: ProjectBwdState, + means: FloatTensor, + log_scales: FloatTensor, + quats: FloatTensor, + raw_opac: FloatTensor, + num_visible: IntTensor, + global_from_compact_gid: IntTensor, + project_uniforms: ProjectUniforms, + sh_degree: u32, + render_mode: SplatRenderMode, rasterize_grads: RasterizeGrads, ) -> SplatGrads { let _span = tracing::trace_span!("project_bwd").entered(); // Comes from params, might not be contiguous. - let means = into_contiguous(state.means); - let log_scales = into_contiguous(state.log_scales); - let quats = into_contiguous(state.quats); - let raw_opac = into_contiguous(state.raw_opac); - let global_from_compact_gid = into_contiguous(state.global_from_compact_gid); + let means = into_contiguous(means); + let log_scales = into_contiguous(log_scales); + let quats = into_contiguous(quats); + let raw_opac = into_contiguous(raw_opac); let device = &means.device; let num_points = means.shape.dims[0]; @@ -145,17 +147,12 @@ impl SplatBwdOps for MainBackendBase { let v_scales = Self::float_zeros([num_points, 3].into(), device, FloatDType::F32); let v_quats = Self::float_zeros([num_points, 4].into(), device, FloatDType::F32); let v_coeffs = Self::float_zeros( - [ - num_points, - sh_coeffs_for_degree(state.sh_degree) as usize, - 3, - ] - .into(), + [num_points, sh_coeffs_for_degree(sh_degree) as usize, 3].into(), device, FloatDType::F32, ); - let mip_splat = matches!(state.render_mode, SplatRenderMode::Mip); + let mip_splat = matches!(render_mode, SplatRenderMode::Mip); tracing::trace_span!("ProjectBackwards").in_scope(|| { // SAFETY: Kernel has to contain no OOB indexing, bounded loops. @@ -166,7 +163,7 @@ impl SplatBwdOps for MainBackendBase { calc_cube_count_1d(num_points as u32, ProjectBackwards::WORKGROUP_SIZE[0]), Bindings::new() .with_buffers(vec![ - state.num_visible.handle.binding(), + num_visible.handle.binding(), means.handle.binding(), log_scales.handle.binding(), quats.handle.binding(), @@ -179,7 +176,7 @@ impl SplatBwdOps for MainBackendBase { v_coeffs.handle.clone().binding(), rasterize_grads.v_raw_opac.handle.clone().binding(), ]) - .with_metadata(create_meta_binding(state.project_uniforms)), + .with_metadata(create_meta_binding(project_uniforms)), ) .expect("Failed to bwd-diff splats"); } diff --git a/crates/brush-render/src/render_aux.rs b/crates/brush-render/src/render_aux.rs index 70cef53b..f1d13e11 100644 --- a/crates/brush-render/src/render_aux.rs +++ b/crates/brush-render/src/render_aux.rs @@ -57,14 +57,12 @@ impl ProjectOutput { "num_visible ({num_visible}) > total_splats ({total_splats})" ); - if num_visible > 0 { + if total_splats > 0 && num_visible > 0 { let projected_splats: Tensor = Tensor::from_primitive(TensorPrimitive::Float(self.projected_splats.clone())); let projected_splats = projected_splats.slice(s![0..num_visible, ..]); validate_tensor_val(&projected_splats, "projected_splats", None, None); - } - if num_visible > 0 && total_splats > 0 { let global_from_compact_gid: Tensor = Tensor::from_primitive(self.global_from_compact_gid.clone()); let global_from_compact_gid = &global_from_compact_gid diff --git a/crates/brush-render/src/tests/mod.rs b/crates/brush-render/src/tests/mod.rs index 77cf6330..31af8d3f 100644 --- a/crates/brush-render/src/tests/mod.rs +++ b/crates/brush-render/src/tests/mod.rs @@ -5,6 +5,7 @@ use burn_wgpu::WgpuDevice; use glam::Vec3; /// Helper to run project + readback + rasterize for tests. +/// TODO: Get rid of this. fn render_splats_test( cam: &Camera, img_size: glam::UVec2, From d24dbb2123241dec320ce98bd1d92442df92dc66 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 20 Jan 2026 00:25:15 +0100 Subject: [PATCH 09/29] Cleanup --- crates/brush-bench-test/src/benches.rs | 14 ++++- crates/brush-render/src/gaussian_splats.rs | 12 +++- crates/brush-render/src/lib.rs | 2 +- crates/brush-render/src/render.rs | 2 +- crates/brush-render/src/tests/mod.rs | 71 ++++------------------ crates/brush-train/src/eval.rs | 5 +- crates/brush-ui/src/scene.rs | 3 +- examples/train-2d/examples/train-2d.rs | 3 +- 8 files changed, 43 insertions(+), 69 deletions(-) diff --git a/crates/brush-bench-test/src/benches.rs b/crates/brush-bench-test/src/benches.rs index 6a421f09..dde61648 100644 --- a/crates/brush-bench-test/src/benches.rs +++ b/crates/brush-bench-test/src/benches.rs @@ -1,6 +1,6 @@ use brush_dataset::scene::SceneBatch; use brush_render::{ - AlphaMode, MainBackend, + AlphaMode, MainBackend, TextureMode, camera::Camera, gaussian_splats::{SplatRenderMode, Splats}, render_splats, @@ -144,7 +144,7 @@ fn generate_training_batch(resolution: (u32, u32), camera_pos: Vec3) -> SceneBat mod forward_rendering { use super::{ AutodiffModule, Backend, Camera, ITERS_PER_SYNC, MainBackend, Quat, RESOLUTIONS, - SPLAT_COUNTS, Vec3, WgpuDevice, gen_splats, render_splats, + SPLAT_COUNTS, TextureMode, Vec3, WgpuDevice, gen_splats, render_splats, }; #[divan::bench(args = SPLAT_COUNTS)] @@ -161,7 +161,14 @@ mod forward_rendering { bencher.bench_local(move || { for _ in 0..ITERS_PER_SYNC { - let _ = render_splats(&splats, &camera, glam::uvec2(1920, 1080), Vec3::ZERO, None); + let _ = render_splats( + &splats, + &camera, + glam::uvec2(1920, 1080), + Vec3::ZERO, + None, + TextureMode::Float, + ); } MainBackend::sync(&device).expect("Failed to sync"); }); @@ -187,6 +194,7 @@ mod forward_rendering { glam::uvec2(width, height), Vec3::ZERO, None, + TextureMode::Float, ); } MainBackend::sync(&device).expect("Failed to sync"); diff --git a/crates/brush-render/src/gaussian_splats.rs b/crates/brush-render/src/gaussian_splats.rs index 3e0117ef..605fda34 100644 --- a/crates/brush-render/src/gaussian_splats.rs +++ b/crates/brush-render/src/gaussian_splats.rs @@ -26,6 +26,13 @@ pub enum SplatRenderMode { Mip, } +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +pub enum TextureMode { + Packed, + #[default] + Float, +} + #[derive(Module, Debug)] pub struct Splats { pub means: Param>, @@ -242,6 +249,7 @@ pub fn render_splats>( img_size: glam::UVec2, background: Vec3, splat_scale: Option, + texture_mode: TextureMode, ) -> (Tensor, RenderAux) { splats.validate_values(); @@ -269,9 +277,9 @@ pub fn render_splats>( let num_intersections = project_output.read_num_intersections(); - // Second pass: rasterize (drop compact_gid_from_isect - only needed for backward) + let use_float = matches!(texture_mode, TextureMode::Float); let (out_img, render_aux, _) = - B::rasterize(&project_output, num_intersections, background, false); + B::rasterize(&project_output, num_intersections, background, use_float); // Validate rasterize outputs render_aux.validate(); diff --git a/crates/brush-render/src/lib.rs b/crates/brush-render/src/lib.rs index ec7b82ae..7eda03da 100644 --- a/crates/brush-render/src/lib.rs +++ b/crates/brush-render/src/lib.rs @@ -11,7 +11,7 @@ use glam::Vec3; use render_aux::ProjectOutput; use crate::gaussian_splats::SplatRenderMode; -pub use crate::gaussian_splats::render_splats; +pub use crate::gaussian_splats::{TextureMode, render_splats}; pub use crate::render_aux::RenderAux; mod burn_glue; diff --git a/crates/brush-render/src/render.rs b/crates/brush-render/src/render.rs index fb8a4f7a..ccded969 100644 --- a/crates/brush-render/src/render.rs +++ b/crates/brush-render/src/render.rs @@ -25,7 +25,7 @@ use burn_cubecl::kernel::into_contiguous; use burn_wgpu::{CubeDim, CubeTensor, WgpuRuntime}; use glam::{Vec3, uvec2}; -pub fn calc_tile_bounds(img_size: glam::UVec2) -> glam::UVec2 { +pub(crate) fn calc_tile_bounds(img_size: glam::UVec2) -> glam::UVec2 { uvec2( img_size.x.div_ceil(shaders::helpers::TILE_WIDTH), img_size.y.div_ceil(shaders::helpers::TILE_WIDTH), diff --git a/crates/brush-render/src/tests/mod.rs b/crates/brush-render/src/tests/mod.rs index 31af8d3f..80b69d28 100644 --- a/crates/brush-render/src/tests/mod.rs +++ b/crates/brush-render/src/tests/mod.rs @@ -1,53 +1,13 @@ -use crate::{MainBackend, RenderAux, SplatOps, camera::Camera, gaussian_splats::SplatRenderMode}; +use crate::{ + MainBackend, TextureMode, + camera::Camera, + gaussian_splats::{SplatRenderMode, Splats, render_splats}, +}; use assert_approx_eq::assert_approx_eq; -use burn::tensor::{Distribution, Tensor, TensorPrimitive}; +use burn::tensor::{Distribution, Tensor}; use burn_wgpu::WgpuDevice; use glam::Vec3; -/// Helper to run project + readback + rasterize for tests. -/// TODO: Get rid of this. -fn render_splats_test( - cam: &Camera, - img_size: glam::UVec2, - means: Tensor, - log_scales: Tensor, - quats: Tensor, - sh_coeffs: Tensor, - raw_opacity: Tensor, - render_mode: SplatRenderMode, - background: Vec3, - bwd_info: bool, -) -> (Tensor, RenderAux) { - // First pass: project - let project_output = MainBackend::project( - cam, - img_size, - means.into_primitive().tensor(), - log_scales.into_primitive().tensor(), - quats.into_primitive().tensor(), - sh_coeffs.into_primitive().tensor(), - raw_opacity.into_primitive().tensor(), - render_mode, - ); - - // Validate project output - project_output.validate(); - - // Sync readback of num_intersections - let num_intersections = project_output.read_num_intersections(); - - // Second pass: rasterize (drop compact_gid_from_isect - only needed for backward) - let (out_img, render_aux, _) = - MainBackend::rasterize(&project_output, num_intersections, background, bwd_info); - - // Validate render aux - render_aux.validate(); - - let img = Tensor::from_primitive(TensorPrimitive::Float(out_img)); - - (img, render_aux) -} - #[test] fn renders_at_all() { // Check if rendering doesn't hard crash or anything. @@ -71,18 +31,16 @@ fn renders_at_all() { .repeat_dim(0, num_points); let sh_coeffs = Tensor::::ones([num_points, 1, 3], &device); let raw_opacity = Tensor::::zeros([num_points], &device); - let (output, _render_aux) = render_splats_test( - &cam, - img_size, + + let splats = Splats::from_tensor_data( means, - log_scales, quats, + log_scales, sh_coeffs, raw_opacity, SplatRenderMode::Default, - Vec3::ZERO, - true, ); + let (output, _render_aux) = render_splats(&splats, &cam, img_size, Vec3::ZERO, None, TextureMode::Float); let rgb = output.clone().slice([0..32, 0..32, 0..3]); let alpha = output.slice([0..32, 0..32, 3..4]); @@ -139,16 +97,13 @@ fn renders_many_splats() { let raw_opacity = Tensor::::random([num_splats], Distribution::Uniform(-2.0, 2.0), &device); - let (_output, _render_aux) = render_splats_test( - &cam, - img_size, + let splats = Splats::from_tensor_data( means, - log_scales, quats, + log_scales, sh_coeffs, raw_opacity, SplatRenderMode::Default, - Vec3::ZERO, - true, ); + let (_output, _render_aux) = render_splats(&splats, &cam, img_size, Vec3::ZERO, None, TextureMode::Float); } diff --git a/crates/brush-train/src/eval.rs b/crates/brush-train/src/eval.rs index bac6b59c..5d6dabf5 100644 --- a/crates/brush-train/src/eval.rs +++ b/crates/brush-train/src/eval.rs @@ -5,7 +5,7 @@ use anyhow::Result; use brush_dataset::scene::{sample_to_tensor_data, view_to_sample_image}; use brush_render::camera::Camera; use brush_render::gaussian_splats::Splats; -use brush_render::{AlphaMode, RenderAux, SplatOps, render_splats}; +use brush_render::{AlphaMode, RenderAux, SplatOps, TextureMode, render_splats}; use burn::prelude::Backend; use burn::tensor::{Tensor, s}; use glam::Vec3; @@ -36,7 +36,8 @@ pub fn eval_stats>( let gt_rgb = gt_tensor.slice(s![.., .., 0..3]); // Render on reference black background - let (img, render_aux) = render_splats(splats, gt_cam, res, Vec3::ZERO, None); + let (img, render_aux) = + render_splats(splats, gt_cam, res, Vec3::ZERO, None, TextureMode::Float); let render_rgb = img.slice(s![.., .., 0..3]); // Simulate an 8-bit roundtrip for fair comparison. diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index b7d046af..5ab77637 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -14,7 +14,7 @@ use egui::{ use std::sync::Arc; use brush_render::{ - MainBackend, + MainBackend, TextureMode, camera::{Camera, focal_to_fov, fov_to_focal}, gaussian_splats::Splats, render_splats, @@ -340,6 +340,7 @@ impl ScenePanel { pixel_size, settings.background.unwrap_or(Vec3::ZERO), settings.splat_scale, + TextureMode::Packed, ); if let Some(backbuffer) = &mut self.backbuffer { diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index 7799df19..2a21d5ef 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -3,7 +3,7 @@ use brush_dataset::scene::{SceneBatch, sample_to_tensor_data}; use brush_render::{ - AlphaMode, MainBackend, + AlphaMode, MainBackend, TextureMode, bounding_box::BoundingBox, camera::{Camera, focal_to_fov, fov_to_focal}, gaussian_splats::{SplatRenderMode, Splats}, @@ -178,6 +178,7 @@ impl eframe::App for App { glam::uvec2(self.image.width(), self.image.height()), Vec3::ZERO, // Just render with a black background None, + TextureMode::Packed, ); let size = egui::vec2(self.image.width() as f32, self.image.height() as f32); From bacfba490c8ed86665eacbab61c79ce450773010 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 20 Jan 2026 00:32:21 +0100 Subject: [PATCH 10/29] Fmt --- crates/brush-render-bwd/src/burn_glue.rs | 18 ++++++++++++++++-- crates/brush-render/src/tests/mod.rs | 18 ++++++++++++++++-- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index fc060ce2..dc327560 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -491,7 +491,14 @@ impl SplatBwdOps for Fusion { v_refine_weight_in, ] = inputs; - let [v_means, v_quats, v_scales, v_coeffs, v_raw_opac, v_refine_weight] = outputs; + let [ + v_means, + v_quats, + v_scales, + v_coeffs, + v_raw_opac, + v_refine_weight, + ] = outputs; let inner_rasterize_grads = RasterizeGrads { v_projected_splats: h.get_float_tensor::(v_projected_splats), @@ -598,7 +605,14 @@ impl SplatBwdOps for Fusion { ) .outputs(); - let [v_means, v_quats, v_scales, v_coeffs, v_raw_opac, v_refine_weight] = outputs; + let [ + v_means, + v_quats, + v_scales, + v_coeffs, + v_raw_opac, + v_refine_weight, + ] = outputs; SplatGrads { v_means, diff --git a/crates/brush-render/src/tests/mod.rs b/crates/brush-render/src/tests/mod.rs index 80b69d28..e82704cc 100644 --- a/crates/brush-render/src/tests/mod.rs +++ b/crates/brush-render/src/tests/mod.rs @@ -40,7 +40,14 @@ fn renders_at_all() { raw_opacity, SplatRenderMode::Default, ); - let (output, _render_aux) = render_splats(&splats, &cam, img_size, Vec3::ZERO, None, TextureMode::Float); + let (output, _render_aux) = render_splats( + &splats, + &cam, + img_size, + Vec3::ZERO, + None, + TextureMode::Float, + ); let rgb = output.clone().slice([0..32, 0..32, 0..3]); let alpha = output.slice([0..32, 0..32, 3..4]); @@ -105,5 +112,12 @@ fn renders_many_splats() { raw_opacity, SplatRenderMode::Default, ); - let (_output, _render_aux) = render_splats(&splats, &cam, img_size, Vec3::ZERO, None, TextureMode::Float); + let (_output, _render_aux) = render_splats( + &splats, + &cam, + img_size, + Vec3::ZERO, + None, + TextureMode::Float, + ); } From 3ffd71e97cddc994ca75c5ba09e356895a979209 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 20 Jan 2026 14:57:09 +0100 Subject: [PATCH 11/29] Cleanup --- crates/brush-render/src/render.rs | 5 ----- crates/brush-render/src/render_aux.rs | 5 ----- 2 files changed, 10 deletions(-) diff --git a/crates/brush-render/src/render.rs b/crates/brush-render/src/render.rs index ccded969..b4894e0c 100644 --- a/crates/brush-render/src/render.rs +++ b/crates/brush-render/src/render.rs @@ -235,15 +235,11 @@ impl SplatOps for MainBackendBase { } }); - // Step 3: Tile sort - use static dispatch with full buffer (num_intersections) let bits = u32::BITS - num_tiles.leading_zeros(); - let (tile_id_from_isect, compact_gid_from_isect) = tracing::trace_span!("Tile sort") .in_scope(|| radix_argsort(tile_id_from_isect, compact_gid_from_isect, bits, None)); - // Step 4: GetTileOffsets let cube_dim = CubeDim::new_1d(256); - let tile_offsets = Self::int_zeros( [tile_bounds.y as usize, tile_bounds.x as usize, 2].into(), device, @@ -275,7 +271,6 @@ impl SplatOps for MainBackendBase { .expect("Failed to render splats"); } - // Step 5: Rasterize let out_dim = if bwd_info { 4 } else { 1 }; let out_img = create_tensor( [img_size.y as usize, img_size.x as usize, out_dim], diff --git a/crates/brush-render/src/render_aux.rs b/crates/brush-render/src/render_aux.rs index f1d13e11..7c6d77ed 100644 --- a/crates/brush-render/src/render_aux.rs +++ b/crates/brush-render/src/render_aux.rs @@ -85,15 +85,10 @@ impl ProjectOutput { /// Minimal output from rendering. Contains only what callers typically need. #[derive(Debug, Clone)] pub struct RenderAux { - /// Number of visible splats (for stats/logging) pub num_visible: IntTensor, - /// Total number of tile-splat intersections (for stats/logging) pub num_intersections: u32, - /// Visibility weights per splat (for training densification) pub visible: FloatTensor, - /// Tile offsets [ty, tx, 2] with (start, end) per tile (for visualization) pub tile_offsets: IntTensor, - /// Image size pub img_size: glam::UVec2, } From eec9e5c69f0bf896a4a792eda1e94333d10c3377 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 20 Jan 2026 17:51:00 +0100 Subject: [PATCH 12/29] Async rendering --- Cargo.lock | 2 + crates/brush-bench-test/src/benches.rs | 135 ++++++++++------- crates/brush-bench-test/src/reference.rs | 5 +- crates/brush-bench-test/tests/integration.rs | 19 +-- crates/brush-process/src/lib.rs | 16 +- crates/brush-process/src/message.rs | 2 + crates/brush-process/src/slot.rs | 74 ++++++--- crates/brush-process/src/train_stream.rs | 136 ++++++++--------- crates/brush-render-bwd/src/burn_glue.rs | 22 ++- crates/brush-render/Cargo.toml | 3 + crates/brush-render/src/gaussian_splats.rs | 30 ++-- crates/brush-render/src/render_aux.rs | 8 +- crates/brush-render/src/tests/mod.rs | 28 +--- crates/brush-rerun/src/visualize_tools.rs | 9 +- crates/brush-train/src/eval.rs | 8 +- crates/brush-train/src/msg.rs | 1 + crates/brush-train/src/train.rs | 36 ++--- crates/brush-ui/src/async_renderer.rs | 151 +++++++++++++++++++ crates/brush-ui/src/lib.rs | 1 + crates/brush-ui/src/scene.rs | 96 ++++++------ crates/brush-ui/src/stats.rs | 20 ++- crates/brush-ui/src/training_panel.rs | 7 +- crates/brush-ui/src/ui_process.rs | 4 +- examples/train-2d/Cargo.toml | 1 + examples/train-2d/examples/train-2d.rs | 8 +- 25 files changed, 504 insertions(+), 318 deletions(-) create mode 100644 crates/brush-ui/src/async_renderer.rs diff --git a/Cargo.lock b/Cargo.lock index 1ed80a95..419d93eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1175,6 +1175,7 @@ dependencies = [ "glam", "log", "serde", + "tokio", "tracing", ] @@ -3865,6 +3866,7 @@ dependencies = [ "brush-train", "brush-ui", "burn", + "burn-cubecl", "eframe", "egui", "glam", diff --git a/crates/brush-bench-test/src/benches.rs b/crates/brush-bench-test/src/benches.rs index dde61648..079a6c85 100644 --- a/crates/brush-bench-test/src/benches.rs +++ b/crates/brush-bench-test/src/benches.rs @@ -146,6 +146,7 @@ mod forward_rendering { AutodiffModule, Backend, Camera, ITERS_PER_SYNC, MainBackend, Quat, RESOLUTIONS, SPLAT_COUNTS, TextureMode, Vec3, WgpuDevice, gen_splats, render_splats, }; + use burn_cubecl::cubecl::future::block_on; #[divan::bench(args = SPLAT_COUNTS)] fn render_1080p(bencher: divan::Bencher, splat_count: usize) { @@ -160,17 +161,20 @@ mod forward_rendering { ); bencher.bench_local(move || { - for _ in 0..ITERS_PER_SYNC { - let _ = render_splats( - &splats, - &camera, - glam::uvec2(1920, 1080), - Vec3::ZERO, - None, - TextureMode::Float, - ); - } - MainBackend::sync(&device).expect("Failed to sync"); + block_on(async { + for _ in 0..ITERS_PER_SYNC { + let _ = render_splats( + splats.clone(), + &camera, + glam::uvec2(1920, 1080), + Vec3::ZERO, + None, + TextureMode::Float, + ) + .await; + } + MainBackend::sync(&device).expect("Failed to sync"); + }); }); } @@ -187,17 +191,20 @@ mod forward_rendering { ); bencher.bench_local(move || { - for _ in 0..ITERS_PER_SYNC { - let _ = render_splats( - &splats, - &camera, - glam::uvec2(width, height), - Vec3::ZERO, - None, - TextureMode::Float, - ); - } - MainBackend::sync(&device).expect("Failed to sync"); + block_on(async { + for _ in 0..ITERS_PER_SYNC { + let _ = render_splats( + splats.clone(), + &camera, + glam::uvec2(width, height), + Vec3::ZERO, + None, + TextureMode::Float, + ) + .await; + } + MainBackend::sync(&device).expect("Failed to sync"); + }); }); } } @@ -208,6 +215,7 @@ mod backward_rendering { Backend, Camera, DiffBackend, ITERS_PER_SYNC, MainBackend, Quat, RESOLUTIONS, Tensor, TensorPrimitive, Vec3, WgpuDevice, gen_splats, render_splats_diff, }; + use burn_cubecl::cubecl::future::block_on; #[divan::bench(args = [1_000_000, 2_000_000, 5_000_000])] fn render_grad_1080p(bencher: divan::Bencher, splat_count: usize) { @@ -222,14 +230,21 @@ mod backward_rendering { ); bencher.bench_local(move || { - for _ in 0..ITERS_PER_SYNC { - let diff_out = - render_splats_diff(&splats, &camera, glam::uvec2(1920, 1080), Vec3::ZERO); - let img: Tensor = - Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); - let _ = img.mean().backward(); - } - MainBackend::sync(&device).expect("Failed to sync"); + block_on(async { + for _ in 0..ITERS_PER_SYNC { + let diff_out = render_splats_diff( + splats.clone(), + &camera, + glam::uvec2(1920, 1080), + Vec3::ZERO, + ) + .await; + let img: Tensor = + Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); + let _ = img.mean().backward(); + } + MainBackend::sync(&device).expect("Failed to sync"); + }); }); } @@ -245,14 +260,21 @@ mod backward_rendering { glam::vec2(0.5, 0.5), ); bencher.bench_local(move || { - for _ in 0..ITERS_PER_SYNC { - let diff_out = - render_splats_diff(&splats, &camera, glam::uvec2(width, height), Vec3::ZERO); - let img: Tensor = - Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); - let _ = img.mean().backward(); - } - MainBackend::sync(&device).expect("Failed to sync"); + block_on(async { + for _ in 0..ITERS_PER_SYNC { + let diff_out = render_splats_diff( + splats.clone(), + &camera, + glam::uvec2(width, height), + Vec3::ZERO, + ) + .await; + let img: Tensor = + Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); + let _ = img.mean().backward(); + } + MainBackend::sync(&device).expect("Failed to sync"); + }); }); } } @@ -260,6 +282,7 @@ mod backward_rendering { #[divan::bench_group(max_time = 4)] mod training { use brush_render::bounding_box::BoundingBox; + use burn_cubecl::cubecl::future::block_on; use crate::benches::ITERS_PER_SYNC; @@ -270,25 +293,23 @@ mod training { #[divan::bench(args = SPLAT_COUNTS)] fn train_steps(splat_count: usize) { - burn_cubecl::cubecl::future::block_on(async { - let device = WgpuDevice::default(); - let batch1 = generate_training_batch((1920, 1080), Vec3::new(0.0, 0.0, 5.0)); - let batch2 = generate_training_batch((1920, 1080), Vec3::new(2.0, 0.0, 5.0)); - let batches = [batch1, batch2]; - let config = TrainConfig::default(); - let mut splats = gen_splats(&device, splat_count); - let mut trainer = SplatTrainer::new( - &config, - &device, - BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE), - ); - for step in 0..ITERS_PER_SYNC { - let batch = batches[step as usize % batches.len()].clone(); - let (new_splats, _) = trainer.step(batch, splats); - splats = new_splats; - } - MainBackend::sync(&device).expect("Failed to sync"); - }); + let device = WgpuDevice::default(); + let batch1 = generate_training_batch((1920, 1080), Vec3::new(0.0, 0.0, 5.0)); + let batch2 = generate_training_batch((1920, 1080), Vec3::new(2.0, 0.0, 5.0)); + let batches = [batch1, batch2]; + let config = TrainConfig::default(); + let mut splats = gen_splats(&device, splat_count); + let mut trainer = SplatTrainer::new( + &config, + &device, + BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE), + ); + for step in 0..ITERS_PER_SYNC { + let batch = batches[step as usize % batches.len()].clone(); + let (new_splats, _) = block_on(trainer.step(batch, splats)); + splats = new_splats; + } + MainBackend::sync(&device).expect("Failed to sync"); } } diff --git a/crates/brush-bench-test/src/reference.rs b/crates/brush-bench-test/src/reference.rs index 11b00c63..687d3840 100644 --- a/crates/brush-bench-test/src/reference.rs +++ b/crates/brush-bench-test/src/reference.rs @@ -127,11 +127,12 @@ async fn test_reference() -> Result<()> { ); let diff_out = brush_render_bwd::render_splats( - &splats, + splats.clone(), &cam, glam::uvec2(w as u32, h as u32), Vec3::ZERO, - ); + ) + .await; let out: Tensor = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); let render_aux = diff_out.render_aux; diff --git a/crates/brush-bench-test/tests/integration.rs b/crates/brush-bench-test/tests/integration.rs index e5bcb4a8..259f5980 100644 --- a/crates/brush-bench-test/tests/integration.rs +++ b/crates/brush-bench-test/tests/integration.rs @@ -161,8 +161,8 @@ fn test_forward_rendering() { assert!(means_data.iter().all(|&x| x.is_finite())); } -#[test] -fn test_training_step() { +#[tokio::test] +async fn test_training_step() { let device = WgpuDevice::default(); let batch = generate_test_batch((64, 64)); let splats = generate_test_splats(&device, 500); @@ -172,7 +172,7 @@ fn test_training_step() { &device, BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE), ); - let (final_splats, stats) = trainer.step(batch, splats); + let (final_splats, stats) = trainer.step(batch, splats).await; assert!(final_splats.num_splats() > 0); let loss = stats.loss.into_scalar(); @@ -190,8 +190,8 @@ fn test_batch_generation() { assert!(img_data.iter().all(|&x| (0.0..=1.1).contains(&x))); } -#[test] -fn test_multi_step_training() { +#[tokio::test] +async fn test_multi_step_training() { let device = WgpuDevice::default(); let batch = generate_test_batch((64, 64)); let config = TrainConfig::default(); @@ -205,7 +205,7 @@ fn test_multi_step_training() { // Run a few training steps for _ in 0..3 { - let (new_splats, stats) = trainer.step(batch.clone(), splats); + let (new_splats, stats) = trainer.step(batch.clone(), splats).await; splats = new_splats; let loss = stats.loss.into_scalar(); @@ -216,8 +216,8 @@ fn test_multi_step_training() { assert!(splats.num_splats() > 0); } -#[test] -fn test_gradient_validation() { +#[tokio::test] +async fn test_gradient_validation() { let device = WgpuDevice::default(); let splats = generate_test_splats(&device, 100); @@ -231,7 +231,8 @@ fn test_gradient_validation() { ); let img_size = glam::uvec2(64, 64); - let result = render_splats(&splats, &camera, img_size, Vec3::ZERO); + // Clone splats since render_splats takes ownership and we need splats for gradient validation + let result = render_splats(splats.clone(), &camera, img_size, Vec3::ZERO).await; let rendered: Tensor = Tensor::from_primitive(TensorPrimitive::Float(result.img)); diff --git a/crates/brush-process/src/lib.rs b/crates/brush-process/src/lib.rs index 775073d4..03d123a1 100644 --- a/crates/brush-process/src/lib.rs +++ b/crates/brush-process/src/lib.rs @@ -172,23 +172,21 @@ pub fn create_process< // For the first frame of a new file, clear existing frames if frame == 0 { - splat_view.clear(); + splat_view.clear().await; } - // Ensure we have space up to this frame index and set it - { - let mut guard = splat_view.write(); - if guard.len() <= frame { - guard.resize(frame + 1, splats.clone()); - } - guard[frame] = splats; - } + // Capture stats before moving splats + let num_splats = splats.num_splats(); + let sh_degree = splats.sh_degree(); + splat_view.set_at(frame, splats).await; emitter .emit(ProcessMessage::SplatsUpdated { up_axis: message.meta.up_axis, frame: frame as u32, total_frames, + num_splats, + sh_degree, }) .await; } diff --git a/crates/brush-process/src/message.rs b/crates/brush-process/src/message.rs index f4ab85d0..f26b9bda 100644 --- a/crates/brush-process/src/message.rs +++ b/crates/brush-process/src/message.rs @@ -55,6 +55,8 @@ pub enum ProcessMessage { up_axis: Option, frame: u32, total_frames: u32, + num_splats: u32, + sh_degree: u32, }, #[cfg(feature = "training")] TrainMessage(TrainMessage), diff --git a/crates/brush-process/src/slot.rs b/crates/brush-process/src/slot.rs index 1bf03674..9d32fa86 100644 --- a/crates/brush-process/src/slot.rs +++ b/crates/brush-process/src/slot.rs @@ -1,37 +1,75 @@ -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::Arc; +use tokio::sync::Mutex; -/// A thread-safe slot for sharing data between the process and UI. -/// Uses Mutex because the inner type (Splats) is not Sync. +/// A thread-safe async slot for sharing data between the process and UI. +/// Uses tokio's async Mutex so locks can be held across await points. #[derive(Clone)] pub struct Slot(Arc>>); impl Slot { - pub fn write(&self) -> MutexGuard<'_, Vec> { - self.0.lock().unwrap() - } + /// Temporarily take ownership of a value at a specific index, do something with it, + /// and put the result back. The lock is held across the async operation. + /// + /// Uses swap + push to avoid needing a placeholder value. + /// The value is moved to the end during the operation, then swapped back. + pub async fn act(&self, index: usize, f: F) -> Option + where + F: AsyncFnOnce(T) -> (T, R), + { + let mut guard = self.0.lock().await; + let len = guard.len(); + if index >= len { + return None; + } + // Swap the target element to the end, then pop it + guard.swap(index, len - 1); + let value = guard.pop().unwrap(); + let (new_value, result) = f(value).await; - pub fn push(&self, value: T) { - self.0.lock().unwrap().push(value); + // Push it back and swap to original position + guard.push(new_value); + let new_len = guard.len(); + guard.swap(index, new_len - 1); + Some(result) } - pub fn clear(&self) { - self.0.lock().unwrap().clear(); + pub async fn map(&self, index: usize, f: F) -> Option + where + F: FnOnce(&T) -> R, + { + self.act(index, async move |value| { + let ret = f(&value); + (value, ret) + }) + .await } - pub fn len(&self) -> usize { - self.0.lock().unwrap().len() + /// Get a clone of the main (last) value. + pub async fn clone_main(&self) -> Option { + self.0.lock().await.last().cloned() } - pub fn is_empty(&self) -> bool { - self.0.lock().unwrap().is_empty() + /// Set the slot to contain a single value, clearing any previous contents. + pub async fn set(&self, value: T) { + let mut guard = self.0.lock().await; + guard.clear(); + guard.push(value); } - pub fn get(&self, index: usize) -> Option { - self.0.lock().unwrap().get(index).cloned() + /// Set the value at the given index, or push if index equals current length. + /// Panics if index > len (gaps not allowed). + pub async fn set_at(&self, index: usize, value: T) { + let mut guard = self.0.lock().await; + if index == guard.len() { + guard.push(value); + } else { + guard[index] = value; + } } - pub fn get_main(&self) -> Option { - self.0.lock().unwrap().last().cloned() + /// Clear the slot. + pub async fn clear(&self) { + self.0.lock().await.clear(); } } diff --git a/crates/brush-process/src/train_stream.rs b/crates/brush-process/src/train_stream.rs index c05dc23a..01d1ad09 100644 --- a/crates/brush-process/src/train_stream.rs +++ b/crates/brush-process/src/train_stream.rs @@ -10,11 +10,11 @@ use brush_render::{ MainBackend, gaussian_splats::{SplatRenderMode, Splats}, }; -use brush_rerun::{RerunConfig, visualize_tools::VisualizeTools}; +use brush_rerun::visualize_tools::VisualizeTools; use brush_train::{ RandomSplatsConfig, create_random_splats, eval::eval_stats, - msg::{RefineStats, TrainStepStats}, + msg::RefineStats, splats_into_autodiff, to_init_splats, train::{BOUND_PERCENTILE, SplatTrainer, get_splat_bounds}, }; @@ -127,12 +127,14 @@ pub(crate) async fn train_stream( // If the metadata has an up axis prefer that, otherwise estimate the up direction. let up_axis = up_axis.or(Some(estimated_up)); - *splat_slot.write() = vec![init_splats.clone()]; + splat_slot.set(init_splats.clone()).await; emitter .emit(ProcessMessage::SplatsUpdated { up_axis, frame: 0, total_frames: 1, + num_splats: init_splats.num_splats(), + sh_degree: init_splats.sh_degree(), }) .await; @@ -170,6 +172,7 @@ pub(crate) async fn train_stream( let export_path = base_path.join(&export_path_str); // Normalize path components let export_path: PathBuf = export_path.components().collect(); + let sh_degree = init_splats.sh_degree(); log::info!("Start training loop."); for iter in @@ -177,19 +180,19 @@ pub(crate) async fn train_stream( { let step_time = Instant::now(); - // Scope so we're sure we're sharing memory for the splat. - let stats = { - let batch = dataloader - .next_batch() - .instrument(trace_span!("Wait for next data batch")) - .await; + // Wait for next batch. + let batch = dataloader + .next_batch() + .instrument(trace_span!("Wait for next data batch")) + .await; - let mut splat_handle = splat_slot.write(); - let splats = splats_into_autodiff(splat_handle.remove(0)); - let (new_splats, stats) = trainer.step(batch, splats); - *splat_handle = vec![new_splats.valid()]; - stats - }; + let stats = splat_slot + .act(0, |splats| async { + let (new_splats, stats) = trainer.step(batch, splats_into_autodiff(splats)).await; + (new_splats.valid(), stats) + }) + .await + .unwrap(); let train_t = (iter as f32 / train_stream_config.train_config.total_steps as f32).clamp(0.0, 1.0); @@ -198,19 +201,19 @@ pub(crate) async fn train_stream( && iter.is_multiple_of(train_stream_config.train_config.refine_every) && train_t <= 0.95 { - let splats = splat_slot.get_main().unwrap(); - let (new_splats, refine) = trainer - .refine(iter, splats) - .instrument(trace_span!("Refine splats")) - .await; - *splat_slot.write() = vec![new_splats]; - Some(refine) + splat_slot + .act(0, async |splats| trainer.refine(iter, splats).await) + .await + .unwrap() } else { - None + let new_total: u32 = splat_slot.map(0, |s| s.num_splats()).await.unwrap(); + RefineStats { + num_added: 0, + num_pruned: 0, + total_splats: new_total, + } }; - let splats = splat_slot.get_main().unwrap(); - // We just finished iter 'iter', now starting iter + 1. let iter = iter + 1; let is_last_step = iter == train_stream_config.train_config.total_steps; @@ -232,7 +235,7 @@ pub(crate) async fn train_stream( &device, &emitter, &visualize, - splats.clone(), + splat_slot.clone_main().await.unwrap(), iter, eval_scene, save_path, @@ -248,7 +251,7 @@ pub(crate) async fn train_stream( #[cfg(not(target_family = "wasm"))] if iter % process_config.export_every == 0 || is_last_step { let res = export_checkpoint( - splats.clone(), + splat_slot.clone_main().await.unwrap(), &export_path, &process_config.export_name, iter, @@ -262,27 +265,37 @@ pub(crate) async fn train_stream( } } - let res = rerun_log( - &train_stream_config.rerun_config, - &visualize, - splats.clone(), - &stats, - iter, - is_last_step, - &device, - refine.as_ref(), - ) - .await - .context("Rerun visualization failed"); + { + let rerun_config = &train_stream_config.rerun_config; + visualize + .log_splat_stats(iter, refine.total_splats) + .unwrap(); + + if let Some(every) = rerun_config.rerun_log_splats_every + && (iter.is_multiple_of(every) || is_last_step) + { + let splats = splat_slot.clone_main().await.unwrap(); + visualize.log_splats(iter, splats).await.unwrap(); + } + + // Log out train stats. + if iter.is_multiple_of(rerun_config.rerun_log_train_stats_every) || is_last_step { + visualize + .log_train_stats(iter, stats.clone()) + .await + .unwrap(); + } - if let Err(error) = res { - emitter.emit(ProcessMessage::Warning { error }).await; + visualize.log_memory(iter, &WgpuRuntime::client(&device).memory_usage())?; + if refine.num_added > 0 { + visualize.log_refine_stats(iter, &refine).unwrap(); + } } - if refine.is_some() { + if refine.num_added > 0 { emitter .emit(ProcessMessage::TrainMessage(TrainMessage::RefineStep { - cur_splat_count: splats.num_splats(), + cur_splat_count: refine.total_splats, iter, })) .await; @@ -296,6 +309,8 @@ pub(crate) async fn train_stream( up_axis: None, frame: 0, total_frames: 1, + num_splats: refine.total_splats, + sh_degree, }) .await; @@ -335,12 +350,13 @@ async fn run_eval( let eval_img = view.image.load().await?; let sample = eval_stats( - &splats, + splats.clone(), &view.camera, eval_img, view.image.alpha_mode(), device, ) + .await .context("Failed to run eval for sample.")?; count += 1; @@ -398,35 +414,3 @@ async fn export_checkpoint( .context(format!("Failed to export ply {export_path:?}"))?; Ok(()) } - -async fn rerun_log( - rerun_config: &RerunConfig, - visualize: &VisualizeTools, - splats: Splats, - stats: &TrainStepStats, - iter: u32, - is_last_step: bool, - device: &WgpuDevice, - refine: Option<&RefineStats>, -) -> Result<(), anyhow::Error> { - visualize.log_splat_stats(iter, &splats)?; - - if let Some(every) = rerun_config.rerun_log_splats_every - && (iter.is_multiple_of(every) || is_last_step) - { - visualize.log_splats(iter, splats.clone()).await?; - } - - // Log out train stats. - if iter.is_multiple_of(rerun_config.rerun_log_train_stats_every) || is_last_step { - visualize.log_train_stats(iter, stats.clone()).await?; - } - - let client = WgpuRuntime::client(device); - visualize.log_memory(iter, &client.memory_usage())?; - // Emit some messages. Important to not count these in the training time (as this might pause). - if let Some(stats) = refine { - visualize.log_refine_stats(iter, stats)?; - } - Ok(()) -} diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index dc327560..81adbd2f 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -200,8 +200,10 @@ pub struct SplatOutputDiff { /// /// This is the main entry point for differentiable rendering, wrapping /// the forward pass with autodiff support. -pub fn render_splats( - splats: &Splats>, +/// +/// Takes ownership of the splats. Clone before calling if you need to reuse them. +pub async fn render_splats( + splats: Splats>, camera: &Camera, img_size: glam::UVec2, background: Vec3, @@ -218,7 +220,6 @@ where .device(); let refine_weight_holder = Tensor::, 1>::zeros([1], &device).require_grad(); - // Prepare backward pass, and check if we even need to do it. let prep_nodes = RenderBackwards .prepare::([ splats.means.val().into_primitive().tensor().node, @@ -262,8 +263,8 @@ where .into_primitive() .tensor() .into_primitive(); + let render_mode = splats.render_mode; - // First pass: project let project_output = >::project( camera, img_size, @@ -272,17 +273,15 @@ where quats.clone(), sh_coeffs, raw_opacity.clone(), - splats.render_mode, + render_mode, ); - // Sync readback of num_intersections - let num_intersections = project_output.read_num_intersections(); + // Async readback + let num_intersections = project_output.read_num_intersections().await; - // Second pass: rasterize (with bwd_info = true) let (out_img, render_aux, compact_gid_from_isect) = >::rasterize(&project_output, num_intersections, background, true); - // Create wrapped render_aux for Autodiff backend let wrapped_render_aux = RenderAux::> { num_visible: render_aux.num_visible.clone(), num_intersections: render_aux.num_intersections, @@ -295,7 +294,6 @@ where match prep_nodes { OpsKind::Tracked(prep) => { - // Save state needed for backward pass. let state = GaussianBackwardState { means, log_scales, @@ -308,7 +306,7 @@ where num_visible: project_output.num_visible, tile_offsets: render_aux.tile_offsets, compact_gid_from_isect, - render_mode: splats.render_mode, + render_mode, global_from_compact_gid: project_output.global_from_compact_gid, background, img_size, @@ -325,8 +323,6 @@ where result } OpsKind::UnTracked(prep) => { - // When no node is tracked, we can just use the original operation without - // keeping any state. let result = SplatOutputDiff { img: prep.finish(out_img), render_aux: wrapped_render_aux, diff --git a/crates/brush-render/Cargo.toml b/crates/brush-render/Cargo.toml index ee6691c0..fdac12f4 100644 --- a/crates/brush-render/Cargo.toml +++ b/crates/brush-render/Cargo.toml @@ -32,5 +32,8 @@ ignored = ["bytemuck"] [features] debug-validation = [] +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt"] } + [lints] workspace = true diff --git a/crates/brush-render/src/gaussian_splats.rs b/crates/brush-render/src/gaussian_splats.rs index 605fda34..f938f7a6 100644 --- a/crates/brush-render/src/gaussian_splats.rs +++ b/crates/brush-render/src/gaussian_splats.rs @@ -243,8 +243,10 @@ impl Splats { /// /// NB: This doesn't work on a differentiable backend. Use /// [`brush_render_bwd::render_splats`] for that. -pub fn render_splats>( - splats: &Splats, +/// +/// Takes ownership of the splats to avoid cloning internally. +pub async fn render_splats>( + splats: Splats, camera: &Camera, img_size: glam::UVec2, background: Vec3, @@ -253,38 +255,36 @@ pub fn render_splats>( ) -> (Tensor, RenderAux) { splats.validate_values(); - let mut scales = splats.log_scales.val(); + let mut scales = splats.log_scales.into_value(); - // Add in scaling if needed. if let Some(scale) = splat_scale { scales = scales + scale.ln(); }; - // First pass: project let project_output = B::project( camera, img_size, - splats.means.val().into_primitive().tensor(), + splats.means.into_value().into_primitive().tensor(), scales.into_primitive().tensor(), - splats.rotations.val().into_primitive().tensor(), - splats.sh_coeffs.val().into_primitive().tensor(), - splats.raw_opacities.val().into_primitive().tensor(), + splats.rotations.into_value().into_primitive().tensor(), + splats.sh_coeffs.into_value().into_primitive().tensor(), + splats.raw_opacities.into_value().into_primitive().tensor(), splats.render_mode, ); - // Validate before readback project_output.validate(); - let num_intersections = project_output.read_num_intersections(); + // Async readback + let num_intersections = project_output.read_num_intersections().await; let use_float = matches!(texture_mode, TextureMode::Float); let (out_img, render_aux, _) = B::rasterize(&project_output, num_intersections, background, use_float); - // Validate rasterize outputs render_aux.validate(); - let img = Tensor::from_primitive(TensorPrimitive::Float(out_img)); - - (img, render_aux) + ( + Tensor::from_primitive(TensorPrimitive::Float(out_img)), + render_aux, + ) } diff --git a/crates/brush-render/src/render_aux.rs b/crates/brush-render/src/render_aux.rs index 7c6d77ed..4c879326 100644 --- a/crates/brush-render/src/render_aux.rs +++ b/crates/brush-render/src/render_aux.rs @@ -22,14 +22,16 @@ pub struct ProjectOutput { } impl ProjectOutput { - /// Get the total number of intersections (sync readback). - pub fn read_num_intersections(&self) -> u32 { + /// Get the total number of intersections. + pub async fn read_num_intersections(&self) -> u32 { let cum_tiles_hit: Tensor = Tensor::from_primitive(self.cum_tiles_hit.clone()); let total = self.project_uniforms.total_splats as usize; if total > 0 { cum_tiles_hit .slice([total - 1..total]) - .into_scalar() + .into_scalar_async() + .await + .expect("Failed to read num_intersections") .elem::() } else { 0 diff --git a/crates/brush-render/src/tests/mod.rs b/crates/brush-render/src/tests/mod.rs index e82704cc..776f41bb 100644 --- a/crates/brush-render/src/tests/mod.rs +++ b/crates/brush-render/src/tests/mod.rs @@ -8,8 +8,8 @@ use burn::tensor::{Distribution, Tensor}; use burn_wgpu::WgpuDevice; use glam::Vec3; -#[test] -fn renders_at_all() { +#[tokio::test] +async fn renders_at_all() { // Check if rendering doesn't hard crash or anything. // These are some zero-sized gaussians, so we know // what the result should look like. @@ -40,14 +40,8 @@ fn renders_at_all() { raw_opacity, SplatRenderMode::Default, ); - let (output, _render_aux) = render_splats( - &splats, - &cam, - img_size, - Vec3::ZERO, - None, - TextureMode::Float, - ); + let (output, _render_aux) = + render_splats(splats, &cam, img_size, Vec3::ZERO, None, TextureMode::Float).await; let rgb = output.clone().slice([0..32, 0..32, 0..3]); let alpha = output.slice([0..32, 0..32, 3..4]); @@ -61,8 +55,8 @@ fn renders_at_all() { assert_approx_eq!(alpha_mean, 0.0); } -#[test] -fn renders_many_splats() { +#[tokio::test] +async fn renders_many_splats() { // Test rendering with a ton of gaussians to verify 2D dispatch works correctly. // This exceeds the 1D 65535 * 256 = 16.7M limit. let num_splats = 30_000_000; @@ -112,12 +106,6 @@ fn renders_many_splats() { raw_opacity, SplatRenderMode::Default, ); - let (_output, _render_aux) = render_splats( - &splats, - &cam, - img_size, - Vec3::ZERO, - None, - TextureMode::Float, - ); + let (_output, _render_aux) = + render_splats(splats, &cam, img_size, Vec3::ZERO, None, TextureMode::Float).await; } diff --git a/crates/brush-rerun/src/visualize_tools.rs b/crates/brush-rerun/src/visualize_tools.rs index 4d0a7a58..a8874f27 100644 --- a/crates/brush-rerun/src/visualize_tools.rs +++ b/crates/brush-rerun/src/visualize_tools.rs @@ -217,12 +217,13 @@ mod visualize_tools_impl { } #[allow(unused_variables)] - pub fn log_splat_stats(&self, iter: u32, splats: &Splats) -> Result<()> { + pub fn log_splat_stats(&self, iter: u32, num_splats: u32) -> Result<()> { if self.rec.is_enabled() { self.rec.set_time_sequence("iterations", iter); - let num = splats.num_splats(); - self.rec - .log("splats/num_splats", &rerun::Scalars::new(vec![num as f64]))?; + self.rec.log( + "splats/num_splats", + &rerun::Scalars::new(vec![num_splats as f64]), + )?; } Ok(()) } diff --git a/crates/brush-train/src/eval.rs b/crates/brush-train/src/eval.rs index 5d6dabf5..662b454b 100644 --- a/crates/brush-train/src/eval.rs +++ b/crates/brush-train/src/eval.rs @@ -21,8 +21,8 @@ pub struct EvalSample { pub render_aux: RenderAux, } -pub fn eval_stats>( - splats: &Splats, +pub async fn eval_stats>( + splats: Splats, gt_cam: &Camera, gt_img: DynamicImage, alpha_mode: AlphaMode, @@ -35,9 +35,9 @@ pub fn eval_stats>( let gt_tensor = Tensor::from_data(gt_tensor, device); let gt_rgb = gt_tensor.slice(s![.., .., 0..3]); - // Render on reference black background + // Render on reference black background - async readback let (img, render_aux) = - render_splats(splats, gt_cam, res, Vec3::ZERO, None, TextureMode::Float); + render_splats(splats, gt_cam, res, Vec3::ZERO, None, TextureMode::Float).await; let render_rgb = img.slice(s![.., .., 0..3]); // Simulate an 8-bit roundtrip for fair comparison. diff --git a/crates/brush-train/src/msg.rs b/crates/brush-train/src/msg.rs index adac9804..a65d5a6d 100644 --- a/crates/brush-train/src/msg.rs +++ b/crates/brush-train/src/msg.rs @@ -7,6 +7,7 @@ use burn::{ pub struct RefineStats { pub num_added: u32, pub num_pruned: u32, + pub total_splats: u32, } #[derive(Clone)] diff --git a/crates/brush-train/src/train.rs b/crates/brush-train/src/train.rs index aafee5a5..d6673337 100644 --- a/crates/brush-train/src/train.rs +++ b/crates/brush-train/src/train.rs @@ -34,7 +34,7 @@ use burn::{ use burn_cubecl::cubecl::Runtime; use glam::Vec3; use hashbrown::{HashMap, HashSet}; -use tracing::trace_span; +use tracing::{Instrument, trace_span}; pub const BOUND_PERCENTILE: f32 = 0.8; @@ -99,40 +99,33 @@ impl SplatTrainer { } } - pub fn step( + pub async fn step( &mut self, batch: SceneBatch, splats: Splats, ) -> (Splats, TrainStepStats) { - let _span = trace_span!("Train step").entered(); - let mut splats = splats; let [img_h, img_w, _] = batch.img_tensor.shape.clone().try_into().unwrap(); - let camera = &batch.camera; + let camera = batch.camera.clone(); // Upload tensor early. let device = splats.device(); let has_alpha = batch.has_alpha(); let gt_tensor = Tensor::from_data(batch.img_tensor, &device); - let (pred_image, render_aux, refine_weight_holder) = - trace_span!("Forward").in_scope(|| { - // Could generate a random background color, but so far - // results just seem worse. - let background = Vec3::ZERO; - - let diff_out = render_splats( - &splats, - camera, - glam::uvec2(img_w as u32, img_h as u32), - background, - ); + // Forward pass - render splats asynchronously. + // Clone splats to avoid holding references across the await. + let background = Vec3::ZERO; + let img_size = glam::uvec2(img_w as u32, img_h as u32); - let img = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); + let diff_out = render_splats(splats.clone(), &camera, img_size, background) + .instrument(trace_span!("Forward")) + .await; - (img, diff_out.render_aux, diff_out.refine_weight_holder) - }); + let pred_image = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); + let render_aux = diff_out.render_aux; + let refine_weight_holder = diff_out.refine_weight_holder; let median_scale = self.bounds.median_size(); let num_visible = render_aux.get_num_visible().inner(); @@ -388,11 +381,14 @@ impl SplatTrainer { self.bounds = get_splat_bounds(splats.clone(), BOUND_PERCENTILE).await; client.memory_cleanup(); + let splat_count = splats.num_splats(); + ( splats, RefineStats { num_added: refine_count as u32, num_pruned: pruned_count, + total_splats: splat_count, }, ) } diff --git a/crates/brush-ui/src/async_renderer.rs b/crates/brush-ui/src/async_renderer.rs new file mode 100644 index 00000000..140594f6 --- /dev/null +++ b/crates/brush-ui/src/async_renderer.rs @@ -0,0 +1,151 @@ +use brush_process::slot::Slot; +use brush_render::{ + MainBackend, TextureMode, camera::Camera, gaussian_splats::Splats, render_splats, +}; +use burn::tensor::Tensor; +use glam::Vec3; +use std::sync::{ + Arc, Mutex, + atomic::{AtomicBool, Ordering}, +}; +use tokio::sync::Notify; +use tokio_with_wasm::alias::task; + +pub struct RenderRequest { + pub slot: Slot>, + pub frame: usize, + pub camera: Camera, + pub img_size: glam::UVec2, + pub background: Vec3, + pub splat_scale: Option, +} + +/// Result of a render operation. +pub struct RenderResult { + pub image: Tensor, + pub camera: Camera, + pub img_size: glam::UVec2, +} + +pub struct AsyncRenderer { + /// Latest render request. + request: Arc>>, + request_notify: Arc, + + result: Arc>>, + + /// Flag to signal shutdown. + shutdown: Arc, +} + +impl AsyncRenderer { + pub fn new() -> Self { + let request = Arc::new(Mutex::new(None)); + let request_notify = Arc::new(Notify::new()); + let result = Arc::new(Mutex::new(None)); + let shutdown = Arc::new(AtomicBool::new(false)); + + task::spawn(render_loop( + Arc::clone(&request), + Arc::clone(&request_notify), + Arc::clone(&result), + Arc::clone(&shutdown), + )); + + Self { + request, + request_notify, + result, + shutdown, + } + } + + /// Submit a new render request. This will overwrite any pending request. + pub fn submit(&self, new_request: RenderRequest) { + { + let mut req = self.request.lock().unwrap(); + *req = Some(new_request); + } + self.request_notify.notify_one(); + } + + /// Check if a new render result is available and return it. + /// Returns `None` if no new result since last check. + pub fn try_get_result(&self) -> Option { + self.result.lock().unwrap().take() + } +} + +impl Default for AsyncRenderer { + fn default() -> Self { + Self::new() + } +} + +impl Drop for AsyncRenderer { + fn drop(&mut self) { + // Signal shutdown to the background task + self.shutdown.store(true, Ordering::SeqCst); + self.request_notify.notify_one(); + } +} + +/// Background render loop that processes the latest render request. +async fn render_loop( + request: Arc>>, + request_notify: Arc, + result: Arc>>, + shutdown: Arc, +) { + loop { + // Wait for a new request or shutdown + request_notify.notified().await; + + // Check for shutdown + if shutdown.load(Ordering::SeqCst) { + break; + } + + // Take the latest request (don't hold lock across await) + let req = { + let mut req_guard = request.lock().unwrap(); + req_guard.take() + }; + + if let Some(req) = req { + // Use act to hold the slot lock across the render, so training will wait. + // We clone the splats since render_splats takes ownership. + let render_result = { + let camera = req.camera.clone(); + let img_size = req.img_size; + let background = req.background; + let splat_scale = req.splat_scale; + + req.slot + .act(req.frame, async move |splats| { + let (image, _) = render_splats( + splats.clone(), + &camera, + img_size, + background, + splat_scale, + TextureMode::Packed, + ) + .await; + (splats, image) + }) + .await + }; + + // Store the result with render params for widget_3d + if let Some(image) = render_result { + let mut res_guard = result.lock().unwrap(); + *res_guard = Some(RenderResult { + image, + camera: req.camera, + img_size: req.img_size, + }); + } + } + } +} diff --git a/crates/brush-ui/src/lib.rs b/crates/brush-ui/src/lib.rs index 5b6e4ee8..d5e593dd 100644 --- a/crates/brush-ui/src/lib.rs +++ b/crates/brush-ui/src/lib.rs @@ -6,6 +6,7 @@ pub mod camera_controls; pub mod ui_process; +mod async_renderer; mod panels; mod scene; #[cfg(feature = "training")] diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 5ab77637..71515792 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -13,18 +13,14 @@ use egui::{ }; use std::sync::Arc; -use brush_render::{ - MainBackend, TextureMode, - camera::{Camera, focal_to_fov, fov_to_focal}, - gaussian_splats::Splats, - render_splats, -}; +use brush_render::camera::{Camera, focal_to_fov, fov_to_focal}; use eframe::egui_wgpu::Renderer; use egui::{Color32, Rect, Slider}; use glam::{UVec2, Vec3}; -use tracing::trace_span; use web_time::Instant; +use crate::async_renderer::{AsyncRenderer, RenderRequest}; + use serde::{Deserialize, Serialize}; /// Controls how often the viewport re-renders during training. @@ -151,6 +147,9 @@ pub struct ScenePanel { #[cfg(feature = "training")] #[serde(skip)] settings_popup: Option>>, + /// Async renderer for non-blocking splat rendering. + #[serde(skip)] + async_renderer: AsyncRenderer, } impl ScenePanel { @@ -268,7 +267,6 @@ impl ScenePanel { &mut self, ui: &mut egui::Ui, process: &UiProcess, - splats: Option>, interactive: bool, ) -> egui::Rect { let size = ui.available_size(); @@ -292,17 +290,13 @@ impl ScenePanel { let settings = process.get_cam_settings(); // Adjust FOV so that the scene view shows at least what's visible in the dataset view. - // The camera has original fov_x and fov_y from the dataset. We need to ensure - // the viewport shows at least that much in both dimensions. let camera_aspect = (camera.fov_x / 2.0).tan() / (camera.fov_y / 2.0).tan(); let viewport_aspect = size.x as f64 / size.y as f64; if viewport_aspect > camera_aspect { - // Viewport is wider than camera - keep fov_y, expand fov_x let focal_y = fov_to_focal(camera.fov_y, size.y); camera.fov_x = focal_to_fov(focal_y, size.x); } else { - // Viewport is taller than camera - keep fov_x, expand fov_y let focal_x = fov_to_focal(camera.fov_x, size.x); camera.fov_y = focal_to_fov(focal_x, size.y); } @@ -321,45 +315,48 @@ impl ScenePanel { if dirty { self.last_state = Some(state); - // Check again next frame, as there might be more to animate. ui.ctx().request_repaint(); } - if let Some(splats) = splats { - let pixel_size = glam::uvec2( - (size.x as f32 * ui.ctx().pixels_per_point().round()) as u32, - (size.y as f32 * ui.ctx().pixels_per_point().round()) as u32, - ); - // If this viewport is re-rendering. - if pixel_size.x > 8 && pixel_size.y > 8 && dirty { - let _span = trace_span!("Render splats").entered(); - // Could add an option for background color. - let (img, _render_aux) = render_splats( - &splats, - &camera, - pixel_size, - settings.background.unwrap_or(Vec3::ZERO), - settings.splat_scale, - TextureMode::Packed, - ); + let pixel_size = glam::uvec2( + (size.x as f32 * ui.ctx().pixels_per_point().round()) as u32, + (size.y as f32 * ui.ctx().pixels_per_point().round()) as u32, + ); - if let Some(backbuffer) = &mut self.backbuffer { - backbuffer.update_texture(img); - } + // Check for new render result from background task + if let Some(result) = self.async_renderer.try_get_result() { + if let Some(backbuffer) = &mut self.backbuffer { + backbuffer.update_texture(result.image); + } - if let Some(widget_3d) = &mut self.widget_3d - && let Some(backbuffer) = &self.backbuffer - && let Some(texture) = backbuffer.texture() - { - widget_3d.render_to_texture( - &camera, - process.model_local_to_world(), - pixel_size, - texture, - grid_opacity, - ); - } + // Render widget_3d in sync with new splat render (same camera/size) + if let Some(widget_3d) = &mut self.widget_3d + && let Some(backbuffer) = &self.backbuffer + && let Some(texture) = backbuffer.texture() + { + widget_3d.render_to_texture( + &result.camera, + process.model_local_to_world(), + result.img_size, + texture, + grid_opacity, + ); } + + ui.ctx().request_repaint(); + } + + // Submit new render request if dirty and we have splats + if pixel_size.x > 8 && pixel_size.y > 8 && dirty { + self.async_renderer.submit(RenderRequest { + slot: process.current_splats(), + frame: self.frame as usize, + camera, + img_size: pixel_size, + background: settings.background.unwrap_or(Vec3::ZERO), + splat_scale: settings.splat_scale, + }); + ui.ctx().request_repaint(); } ui.scope(|ui| { @@ -908,6 +905,7 @@ impl AppPane for ScenePanel { up_axis, frame, total_frames, + .. } => { self.has_splats = true; self.frame_count = *total_frames; @@ -1048,15 +1046,9 @@ impl AppPane for ScenePanel { ui.ctx().request_repaint(); } - // Get the splat for the current frame - let splats = process.current_splats().and_then(|sv| { - let frame_idx = self.frame as usize; - sv.get(frame_idx) - }); - let interactive = matches!(process.ui_mode(), UiMode::Default | UiMode::FullScreenSplat); - let rect = self.draw_splats(ui, process, splats, interactive); + let rect = self.draw_splats(ui, process, interactive); if interactive { self.draw_play_pause(ui, rect); diff --git a/crates/brush-ui/src/stats.rs b/crates/brush-ui/src/stats.rs index 2f389f51..cc6bbda8 100644 --- a/crates/brush-ui/src/stats.rs +++ b/crates/brush-ui/src/stats.rs @@ -19,6 +19,8 @@ pub struct StatsPanel { last_train_step: (Duration, u32), train_eval_views: (u32, u32), training_complete: bool, + num_splats: u32, + sh_degree: u32, } fn bytes_format(bytes: u64) -> String { @@ -103,11 +105,20 @@ impl AppPane for StatsPanel { self.last_train_step = (Duration::from_secs(0), 0); self.train_eval_views = (0, 0); self.training_complete = false; + self.num_splats = 0; + self.sh_degree = 0; } ProcessMessage::StartLoading { .. } => { self.last_eval = None; } - ProcessMessage::SplatsUpdated { .. } => {} + ProcessMessage::SplatsUpdated { + num_splats, + sh_degree, + .. + } => { + self.num_splats = *num_splats; + self.sh_degree = *sh_degree; + } ProcessMessage::TrainMessage(train) => match train { TrainMessage::TrainStep { iter, @@ -151,11 +162,8 @@ impl AppPane for StatsPanel { }); ui.separator(); - let (num_splats, sh_degree) = process - .current_splats() - .and_then(|sv| sv.get_main()) - .map_or((0, 0), |spl| (spl.num_splats(), spl.sh_degree())); - + let num_splats = self.num_splats; + let sh_degree = self.sh_degree; let frames = self.frames; stats_grid(ui, "model_stats_grid", |ui, v| { stat_row(ui, "Splats", format!("{num_splats}"), v); diff --git a/crates/brush-ui/src/training_panel.rs b/crates/brush-ui/src/training_panel.rs index 21ac1a1c..7ead7817 100644 --- a/crates/brush-ui/src/training_panel.rs +++ b/crates/brush-ui/src/training_panel.rs @@ -244,9 +244,7 @@ impl AppPane for TrainingPanel { } } - if let Some(slot) = process.current_splats() - && process.is_training() - { + if process.is_training() { // Right-align export button ui.with_layout(egui::Layout::right_to_left(egui::Align::Center), |ui| { // Make export button more prominent when training is complete @@ -278,9 +276,10 @@ impl AppPane for TrainingPanel { } let sender = self.export_channel.0.clone(); let ctx = ui.ctx().clone(); + let slot = process.current_splats(); task::spawn(async move { - let Some(splats) = slot.get_main() else { + let Some(splats) = slot.clone_main().await else { return; }; diff --git a/crates/brush-ui/src/ui_process.rs b/crates/brush-ui/src/ui_process.rs index e5f9cf8e..26cfe7d7 100644 --- a/crates/brush-ui/src/ui_process.rs +++ b/crates/brush-ui/src/ui_process.rs @@ -65,11 +65,11 @@ impl UiProcess { self.write().background_style = style; } - pub(crate) fn current_splats(&self) -> Option>> { + pub(crate) fn current_splats(&self) -> Slot> { self.read() .process_handle .as_ref() - .map(|s| s.splat_view.clone()) + .map_or(Slot::default(), |s| s.splat_view.clone()) } pub fn is_loading(&self) -> bool { diff --git a/examples/train-2d/Cargo.toml b/examples/train-2d/Cargo.toml index 5aadae07..c3f1e8c9 100644 --- a/examples/train-2d/Cargo.toml +++ b/examples/train-2d/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [dev-dependencies] burn.workspace = true +burn-cubecl.workspace = true brush-train.path = "../../crates/brush-train" brush-dataset.path = "../../crates/brush-dataset" brush-render.path = "../../crates/brush-render" diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index 2a21d5ef..57405317 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -67,7 +67,7 @@ fn spawn_train_loop( let mut iter = 0; loop { - let (new_splats, _) = trainer.step(batch.clone(), splats); + let (new_splats, _) = trainer.step(batch.clone(), splats).await; let (new_splats, _) = trainer.refine(iter, new_splats.valid()).await; splats = splats_into_autodiff(new_splats); iter += 1; @@ -172,14 +172,14 @@ impl eframe::App for App { return; }; - let (img, _render_aux) = render_splats( - &msg.splats, + let (img, _render_aux) = burn_cubecl::cubecl::future::block_on(render_splats( + msg.splats.clone(), &self.camera, glam::uvec2(self.image.width(), self.image.height()), Vec3::ZERO, // Just render with a black background None, TextureMode::Packed, - ); + )); let size = egui::vec2(self.image.width() as f32, self.image.height() as f32); From 6556683bc6f7026639f91e028cbe4db5b4214b47 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 20 Jan 2026 20:26:44 +0100 Subject: [PATCH 13/29] splat backbuffer async --- crates/brush-ui/src/async_renderer.rs | 151 -------------- crates/brush-ui/src/burn_texture.rs | 179 ----------------- crates/brush-ui/src/lib.rs | 3 +- crates/brush-ui/src/scene.rs | 55 +++--- crates/brush-ui/src/splat_backbuffer.rs | 250 ++++++++++++++++++++++++ examples/train-2d/examples/train-2d.rs | 63 +++--- 6 files changed, 311 insertions(+), 390 deletions(-) delete mode 100644 crates/brush-ui/src/async_renderer.rs delete mode 100644 crates/brush-ui/src/burn_texture.rs create mode 100644 crates/brush-ui/src/splat_backbuffer.rs diff --git a/crates/brush-ui/src/async_renderer.rs b/crates/brush-ui/src/async_renderer.rs deleted file mode 100644 index 140594f6..00000000 --- a/crates/brush-ui/src/async_renderer.rs +++ /dev/null @@ -1,151 +0,0 @@ -use brush_process::slot::Slot; -use brush_render::{ - MainBackend, TextureMode, camera::Camera, gaussian_splats::Splats, render_splats, -}; -use burn::tensor::Tensor; -use glam::Vec3; -use std::sync::{ - Arc, Mutex, - atomic::{AtomicBool, Ordering}, -}; -use tokio::sync::Notify; -use tokio_with_wasm::alias::task; - -pub struct RenderRequest { - pub slot: Slot>, - pub frame: usize, - pub camera: Camera, - pub img_size: glam::UVec2, - pub background: Vec3, - pub splat_scale: Option, -} - -/// Result of a render operation. -pub struct RenderResult { - pub image: Tensor, - pub camera: Camera, - pub img_size: glam::UVec2, -} - -pub struct AsyncRenderer { - /// Latest render request. - request: Arc>>, - request_notify: Arc, - - result: Arc>>, - - /// Flag to signal shutdown. - shutdown: Arc, -} - -impl AsyncRenderer { - pub fn new() -> Self { - let request = Arc::new(Mutex::new(None)); - let request_notify = Arc::new(Notify::new()); - let result = Arc::new(Mutex::new(None)); - let shutdown = Arc::new(AtomicBool::new(false)); - - task::spawn(render_loop( - Arc::clone(&request), - Arc::clone(&request_notify), - Arc::clone(&result), - Arc::clone(&shutdown), - )); - - Self { - request, - request_notify, - result, - shutdown, - } - } - - /// Submit a new render request. This will overwrite any pending request. - pub fn submit(&self, new_request: RenderRequest) { - { - let mut req = self.request.lock().unwrap(); - *req = Some(new_request); - } - self.request_notify.notify_one(); - } - - /// Check if a new render result is available and return it. - /// Returns `None` if no new result since last check. - pub fn try_get_result(&self) -> Option { - self.result.lock().unwrap().take() - } -} - -impl Default for AsyncRenderer { - fn default() -> Self { - Self::new() - } -} - -impl Drop for AsyncRenderer { - fn drop(&mut self) { - // Signal shutdown to the background task - self.shutdown.store(true, Ordering::SeqCst); - self.request_notify.notify_one(); - } -} - -/// Background render loop that processes the latest render request. -async fn render_loop( - request: Arc>>, - request_notify: Arc, - result: Arc>>, - shutdown: Arc, -) { - loop { - // Wait for a new request or shutdown - request_notify.notified().await; - - // Check for shutdown - if shutdown.load(Ordering::SeqCst) { - break; - } - - // Take the latest request (don't hold lock across await) - let req = { - let mut req_guard = request.lock().unwrap(); - req_guard.take() - }; - - if let Some(req) = req { - // Use act to hold the slot lock across the render, so training will wait. - // We clone the splats since render_splats takes ownership. - let render_result = { - let camera = req.camera.clone(); - let img_size = req.img_size; - let background = req.background; - let splat_scale = req.splat_scale; - - req.slot - .act(req.frame, async move |splats| { - let (image, _) = render_splats( - splats.clone(), - &camera, - img_size, - background, - splat_scale, - TextureMode::Packed, - ) - .await; - (splats, image) - }) - .await - }; - - // Store the result with render params for widget_3d - if let Some(image) = render_result { - let mut res_guard = result.lock().unwrap(); - *res_guard = Some(RenderResult { - image, - camera: req.camera, - img_size: req.img_size, - }); - } - } - } -} diff --git a/crates/brush-ui/src/burn_texture.rs b/crates/brush-ui/src/burn_texture.rs deleted file mode 100644 index 89da1eb7..00000000 --- a/crates/brush-ui/src/burn_texture.rs +++ /dev/null @@ -1,179 +0,0 @@ -use std::sync::Arc; - -use brush_render::{MainBackend, MainBackendBase}; -use burn::tensor::{Tensor, TensorPrimitive}; -use burn_cubecl::cubecl::Runtime; -use burn_wgpu::WgpuRuntime; -use eframe::egui_wgpu::Renderer; -use egui::TextureId; -use egui::epaint::mutex::RwLock as EguiRwLock; -use wgpu::{CommandEncoderDescriptor, TexelCopyBufferLayout, TextureViewDescriptor}; - -struct TextureState { - texture: wgpu::Texture, - id: TextureId, -} - -pub struct BurnTexture { - state: Option, - device: wgpu::Device, - queue: wgpu::Queue, - renderer: Arc>, -} - -fn create_texture(size: glam::UVec2, device: &wgpu::Device) -> wgpu::Texture { - device.create_texture(&wgpu::TextureDescriptor { - label: Some("Splat backbuffer"), - size: wgpu::Extent3d { - width: size.x, - height: size.y, - depth_or_array_layers: 1, - }, - mip_level_count: 1, - sample_count: 1, - dimension: wgpu::TextureDimension::D2, - format: wgpu::TextureFormat::Rgba8Unorm, - usage: wgpu::TextureUsages::TEXTURE_BINDING - | wgpu::TextureUsages::COPY_DST - | wgpu::TextureUsages::RENDER_ATTACHMENT, - view_formats: &[wgpu::TextureFormat::Rgba8Unorm], - }) -} - -impl BurnTexture { - pub fn new( - renderer: Arc>, - device: wgpu::Device, - queue: wgpu::Queue, - ) -> Self { - Self { - state: None, - device, - queue, - renderer, - } - } - - pub fn update_texture(&mut self, img: Tensor) -> TextureId { - let [h, w, c] = img.shape().dims(); - assert!(c == 1, "texture should be u8 packed RGBA"); - let size = glam::uvec2(w as u32, h as u32); - - let dirty = if let Some(s) = self.state.as_ref() { - s.texture.width() != size.x || s.texture.height() != size.y - } else { - true - }; - - if dirty { - // Resizing has some really bad memory profiles, so cleanup memory when it's detected. - let client = WgpuRuntime::client(&img.device()); - client.memory_cleanup(); - - let texture = create_texture(glam::uvec2(w as u32, h as u32), &self.device); - - if let Some(s) = self.state.as_mut() { - s.texture = texture; - - self.renderer.write().update_egui_texture_from_wgpu_texture( - &self.device, - &s.texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - s.id, - ); - } else { - let id = self.renderer.write().register_native_texture( - &self.device, - &texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - ); - self.state = Some(TextureState { texture, id }); - } - } - - let Some(s) = self.state.as_ref() else { - unreachable!("Somehow failed to initialize") - }; - let texture: &wgpu::Texture = &s.texture; - let [height, width, c] = img.dims(); - - let mut encoder = self - .device - .create_command_encoder(&CommandEncoderDescriptor { - label: Some("viewer encoder"), - }); - - let padded_shape = vec![height, width.div_ceil(64) * 64, c]; - - let img_prim = img.into_primitive().tensor(); - let fusion_client = img_prim.client.clone(); - let img = fusion_client.resolve_tensor_float::(img_prim); - let img: Tensor = Tensor::from_primitive(TensorPrimitive::Float(img)); - - // Create padded tensor if needed. The bytes_per_row needs to be divisible - // by 256 in WebGPU, so 4 bytes per pixel means width needs to be divisible by 64. - let img = if width % 64 != 0 { - let padded: Tensor = Tensor::zeros(&padded_shape, &img.device()); - padded.slice_assign([0..height, 0..width], img) - } else { - img - }; - - let img = img.into_primitive().tensor(); - - // Get a hold of the Burn resource. - let client = &img.client; - let img_res_handle = client.get_resource(img.handle.clone().binding()); - - // Now flush commands to make sure the resource is fully ready. - client.flush(); - - // Put compute passes in encoder before copying the buffer. - let bytes_per_row = Some(4 * padded_shape[1] as u32); - - // Now copy the buffer to the texture. - encoder.copy_buffer_to_texture( - wgpu::TexelCopyBufferInfo { - buffer: &img_res_handle.resource().buffer, - layout: TexelCopyBufferLayout { - offset: img_res_handle.resource().offset, - bytes_per_row, - rows_per_image: None, - }, - }, - wgpu::TexelCopyTextureInfo { - texture, - mip_level: 0, - origin: wgpu::Origin3d { x: 0, y: 0, z: 0 }, - aspect: wgpu::TextureAspect::All, - }, - wgpu::Extent3d { - width: width as u32, - height: height as u32, - depth_or_array_layers: 1, - }, - ); - - self.queue.submit([encoder.finish()]); - - s.id - } - - pub fn id(&self) -> Option { - self.state.as_ref().map(|s| s.id) - } - - pub fn reset(&mut self) { - self.state = None; - } - - /// Get the underlying texture for additional rendering - pub fn texture(&self) -> Option<&wgpu::Texture> { - self.state.as_ref().map(|s| &s.texture) - } - - /// Get device and queue for additional rendering - pub fn device_queue(&self) -> (&wgpu::Device, &wgpu::Queue) { - (&self.device, &self.queue) - } -} diff --git a/crates/brush-ui/src/lib.rs b/crates/brush-ui/src/lib.rs index d5e593dd..c3c0c949 100644 --- a/crates/brush-ui/src/lib.rs +++ b/crates/brush-ui/src/lib.rs @@ -1,14 +1,13 @@ #![recursion_limit = "256"] pub mod app; -pub mod burn_texture; pub mod camera_controls; pub mod ui_process; -mod async_renderer; mod panels; mod scene; +pub mod splat_backbuffer; #[cfg(feature = "training")] mod stats; mod widget_3d; diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 71515792..4892810a 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -19,7 +19,7 @@ use egui::{Color32, Rect, Slider}; use glam::{UVec2, Vec3}; use web_time::Instant; -use crate::async_renderer::{AsyncRenderer, RenderRequest}; +use crate::splat_backbuffer::{RenderRequest, SplatBackbuffer}; use serde::{Deserialize, Serialize}; @@ -58,7 +58,6 @@ impl RenderUpdateMode { use crate::{ UiMode, app::CameraSettings, - burn_texture::BurnTexture, draw_checkerboard, panels::AppPane, ui_process::{BackgroundStyle, UiProcess}, @@ -103,8 +102,9 @@ impl ErrorDisplay { #[derive(Default, Serialize, Deserialize)] pub struct ScenePanel { + /// Async splat renderer and texture backbuffer. #[serde(skip)] - pub(crate) backbuffer: Option, + backbuffer: Option, #[serde(skip)] pub(crate) last_draw: Option, #[serde(skip)] @@ -147,9 +147,6 @@ pub struct ScenePanel { #[cfg(feature = "training")] #[serde(skip)] settings_popup: Option>>, - /// Async renderer for non-blocking splat rendering. - #[serde(skip)] - async_renderer: AsyncRenderer, } impl ScenePanel { @@ -249,7 +246,8 @@ impl ScenePanel { load_option } - fn start_loading(#[allow(clippy::unused_self)] &self, source: DataSource, process: &UiProcess) { + #[allow(clippy::unused_self)] + fn start_loading(&self, source: DataSource, process: &UiProcess) { process.connect_to_process(create_process( source, #[cfg(feature = "training")] @@ -323,22 +321,30 @@ impl ScenePanel { (size.y as f32 * ui.ctx().pixels_per_point().round()) as u32, ); - // Check for new render result from background task - if let Some(result) = self.async_renderer.try_get_result() { - if let Some(backbuffer) = &mut self.backbuffer { - backbuffer.update_texture(result.image); - } + // Submit new render request if dirty and we have splats + if let Some(backbuffer) = &mut self.backbuffer + && pixel_size.x > 8 + && pixel_size.y > 8 + && dirty + { + backbuffer.submit(RenderRequest { + slot: process.current_splats(), + frame: self.frame as usize, + camera: camera.clone(), + img_size: pixel_size, + background: settings.background.unwrap_or(Vec3::ZERO), + splat_scale: settings.splat_scale, + }); - // Render widget_3d in sync with new splat render (same camera/size) + // Render widget_3d overlay to the same texture if let Some(widget_3d) = &mut self.widget_3d - && let Some(backbuffer) = &self.backbuffer && let Some(texture) = backbuffer.texture() { widget_3d.render_to_texture( - &result.camera, + &camera, process.model_local_to_world(), - result.img_size, - texture, + pixel_size, + &texture, grid_opacity, ); } @@ -346,19 +352,6 @@ impl ScenePanel { ui.ctx().request_repaint(); } - // Submit new render request if dirty and we have splats - if pixel_size.x > 8 && pixel_size.y > 8 && dirty { - self.async_renderer.submit(RenderRequest { - slot: process.current_splats(), - frame: self.frame as usize, - camera, - img_size: pixel_size, - background: settings.background.unwrap_or(Vec3::ZERO), - splat_scale: settings.splat_scale, - }); - ui.ctx().request_repaint(); - } - ui.scope(|ui| { // if training views have alpha, show a background checker. Masked images // should still use a black background. @@ -859,7 +852,7 @@ impl AppPane for ScenePanel { _adapter_info: wgpu::AdapterInfo, ) { self.widget_3d = Some(Widget3D::new(device.clone(), queue.clone())); - self.backbuffer = Some(BurnTexture::new(renderer, device, queue)); + self.backbuffer = Some(SplatBackbuffer::new(renderer, device, queue)); // Create the settings popup now that we have the base_path #[cfg(feature = "training")] diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs new file mode 100644 index 00000000..2e10f3d6 --- /dev/null +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -0,0 +1,250 @@ +//! Async splat rendering backbuffer. +//! +//! Background task renders splats directly to a wgpu texture. +//! wgpu handles GPU synchronization - no locks needed on results. + +use brush_process::slot::Slot; +use brush_render::{ + MainBackend, MainBackendBase, TextureMode, camera::Camera, gaussian_splats::Splats, + render_splats, +}; +use burn::tensor::{Tensor, TensorPrimitive}; +use burn_cubecl::cubecl::Runtime; +use burn_wgpu::WgpuRuntime; +use eframe::egui_wgpu::Renderer; +use egui::TextureId; +use egui::epaint::mutex::RwLock as EguiRwLock; +use glam::Vec3; +use std::sync::{Arc, Mutex}; +use tokio::sync::mpsc; +use tokio_with_wasm::alias::task; +use wgpu::{CommandEncoderDescriptor, TexelCopyBufferLayout, TextureViewDescriptor}; + +pub struct RenderRequest { + pub slot: Slot>, + pub frame: usize, + pub camera: Camera, + pub img_size: glam::UVec2, + pub background: Vec3, + pub splat_scale: Option, +} + +/// Shared texture state that the render loop writes to directly. +struct TextureState { + texture: Option, + texture_id: Option, +} + +/// Async splat rendering backbuffer. +/// +/// Background task renders splats and writes directly to a wgpu texture. +/// No synchronization needed on results - wgpu handles that. +pub struct SplatBackbuffer { + texture_state: Arc>, + request_tx: mpsc::UnboundedSender, +} + +impl SplatBackbuffer { + pub fn new( + renderer: Arc>, + device: wgpu::Device, + queue: wgpu::Queue, + ) -> Self { + let texture_state = Arc::new(Mutex::new(TextureState { + texture: None, + texture_id: None, + })); + let (request_tx, request_rx) = mpsc::unbounded_channel(); + + // Spawn the background render loop + task::spawn(render_loop( + Arc::clone(&texture_state), + renderer, + device, + queue, + request_rx, + )); + + Self { + texture_state, + request_tx, + } + } + + /// Submit a new render request. + pub fn submit(&self, request: RenderRequest) { + let _ = self.request_tx.send(request); + } + + /// Get the texture ID for display. + pub fn id(&self) -> Option { + self.texture_state.lock().unwrap().texture_id + } + + /// Get the underlying texture for additional rendering. + pub fn texture(&self) -> Option { + self.texture_state.lock().unwrap().texture.clone() + } +} + +fn create_texture(size: glam::UVec2, device: &wgpu::Device) -> wgpu::Texture { + device.create_texture(&wgpu::TextureDescriptor { + label: Some("Splat backbuffer"), + size: wgpu::Extent3d { + width: size.x, + height: size.y, + depth_or_array_layers: 1, + }, + mip_level_count: 1, + sample_count: 1, + dimension: wgpu::TextureDimension::D2, + format: wgpu::TextureFormat::Rgba8Unorm, + usage: wgpu::TextureUsages::TEXTURE_BINDING + | wgpu::TextureUsages::COPY_DST + | wgpu::TextureUsages::RENDER_ATTACHMENT, + view_formats: &[wgpu::TextureFormat::Rgba8Unorm], + }) +} + +fn copy_to_texture( + img: Tensor, + texture_state: &Arc>, + renderer: &Arc>, + device: &wgpu::Device, + queue: &wgpu::Queue, +) { + let [h, w, c] = img.shape().dims(); + assert!(c == 1, "texture should be u8 packed RGBA"); + let size = glam::uvec2(w as u32, h as u32); + + // Check if we need to resize/create texture + let needs_resize = { + let state = texture_state.lock().unwrap(); + state + .texture + .as_ref() + .is_none_or(|t| t.width() != size.x || t.height() != size.y) + }; + + if needs_resize { + // Cleanup memory when resizing + let client = WgpuRuntime::client(&img.device()); + client.memory_cleanup(); + + let texture = create_texture(size, device); + + let mut state = texture_state.lock().unwrap(); + if let Some(id) = state.texture_id { + // Update existing registration + renderer.write().update_egui_texture_from_wgpu_texture( + device, + &texture.create_view(&TextureViewDescriptor::default()), + wgpu::FilterMode::Linear, + id, + ); + } else { + // New registration + let id = renderer.write().register_native_texture( + device, + &texture.create_view(&TextureViewDescriptor::default()), + wgpu::FilterMode::Linear, + ); + state.texture_id = Some(id); + } + + state.texture = Some(texture); + } + + let texture = texture_state.lock().unwrap().texture.clone().unwrap(); + let [height, width, c] = img.dims(); + + let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor { + label: Some("splat backbuffer encoder"), + }); + + let padded_shape = vec![height, width.div_ceil(64) * 64, c]; + + let img_prim = img.into_primitive().tensor(); + let fusion_client = img_prim.client.clone(); + let img = fusion_client.resolve_tensor_float::(img_prim); + let img: Tensor = Tensor::from_primitive(TensorPrimitive::Float(img)); + + // Pad if needed (WebGPU requires bytes_per_row divisible by 256) + let img = if width % 64 != 0 { + let padded: Tensor = Tensor::zeros(&padded_shape, &img.device()); + padded.slice_assign([0..height, 0..width], img) + } else { + img + }; + + let img = img.into_primitive().tensor(); + let client = &img.client; + let img_res_handle = client.get_resource(img.handle.clone().binding()); + client.flush(); + + let bytes_per_row = Some(4 * padded_shape[1] as u32); + + encoder.copy_buffer_to_texture( + wgpu::TexelCopyBufferInfo { + buffer: &img_res_handle.resource().buffer, + layout: TexelCopyBufferLayout { + offset: img_res_handle.resource().offset, + bytes_per_row, + rows_per_image: None, + }, + }, + wgpu::TexelCopyTextureInfo { + texture: &texture, + mip_level: 0, + origin: wgpu::Origin3d { x: 0, y: 0, z: 0 }, + aspect: wgpu::TextureAspect::All, + }, + wgpu::Extent3d { + width: width as u32, + height: height as u32, + depth_or_array_layers: 1, + }, + ); + + queue.submit([encoder.finish()]); +} + +async fn render_loop( + texture_state: Arc>, + renderer: Arc>, + device: wgpu::Device, + queue: wgpu::Queue, + mut request_rx: mpsc::UnboundedReceiver, +) { + while let Some(mut req) = request_rx.recv().await { + // Drain channel to get the latest request + while let Ok(newer) = request_rx.try_recv() { + req = newer; + } + + let camera = req.camera.clone(); + let img_size = req.img_size; + let background = req.background; + let splat_scale = req.splat_scale; + + let render_result = req + .slot + .act(req.frame, async move |splats| { + let (image, _) = render_splats( + splats.clone(), + &camera, + img_size, + background, + splat_scale, + TextureMode::Packed, + ) + .await; + (splats, image) + }) + .await; + + if let Some(image) = render_result { + copy_to_texture(image, &texture_state, &renderer, &device, &queue); + } + } +} diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index 57405317..27452701 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -2,18 +2,18 @@ #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release use brush_dataset::scene::{SceneBatch, sample_to_tensor_data}; +use brush_process::slot::Slot; use brush_render::{ - AlphaMode, MainBackend, TextureMode, + AlphaMode, MainBackend, bounding_box::BoundingBox, camera::{Camera, focal_to_fov, fov_to_focal}, gaussian_splats::{SplatRenderMode, Splats}, - render_splats, }; use brush_train::{ RandomSplatsConfig, config::TrainConfig, create_random_splats, splats_into_autodiff, train::SplatTrainer, }; -use brush_ui::burn_texture::BurnTexture; +use brush_ui::splat_backbuffer::{RenderRequest, SplatBackbuffer}; use burn::{backend::wgpu::WgpuDevice, module::AutodiffModule, prelude::Backend}; use egui::{ImageSource, TextureHandle, TextureOptions, load::SizedTexture}; use glam::{Quat, Vec2, Vec3}; @@ -22,8 +22,8 @@ use rand::SeedableRng; use tokio::sync::mpsc::{Receiver, Sender}; struct TrainStep { - splats: Splats, iter: u32, + num_splats: u32, } fn spawn_train_loop( @@ -33,6 +33,7 @@ fn spawn_train_loop( device: WgpuDevice, ctx: egui::Context, sender: Sender, + slot: Slot>, ) { // Spawn a task that iterates over the training stream. tokio::spawn(async move { @@ -69,18 +70,16 @@ fn spawn_train_loop( loop { let (new_splats, _) = trainer.step(batch.clone(), splats).await; let (new_splats, _) = trainer.refine(iter, new_splats.valid()).await; + let num_splats = new_splats.num_splats(); + + // Update the slot with latest splats + slot.set(new_splats.clone()).await; + splats = splats_into_autodiff(new_splats); iter += 1; ctx.request_repaint(); - if sender - .send(TrainStep { - splats: splats.valid(), - iter, - }) - .await - .is_err() - { + if sender.send(TrainStep { iter, num_splats }).await.is_err() { break; } } @@ -91,7 +90,8 @@ struct App { image: image::DynamicImage, camera: Camera, tex_handle: TextureHandle, - backbuffer: BurnTexture, + backbuffer: SplatBackbuffer, + slot: Slot>, receiver: Receiver, last_step: Option, } @@ -133,6 +133,8 @@ impl App { cc.egui_ctx .load_texture("nearest_view_tex", color_img, TextureOptions::default()); + let slot = Slot::default(); + let config = TrainConfig::default(); spawn_train_loop( image.clone(), @@ -141,6 +143,7 @@ impl App { device, cc.egui_ctx.clone(), sender, + slot.clone(), ); let renderer = cc @@ -154,7 +157,8 @@ impl App { image, camera, tex_handle: handle, - backbuffer: BurnTexture::new(renderer, state.device.clone(), state.queue.clone()), + backbuffer: SplatBackbuffer::new(renderer, state.device.clone(), state.queue.clone()), + slot, receiver, last_step: None, } @@ -168,32 +172,37 @@ impl eframe::App for App { } egui::CentralPanel::default().show(ctx, |ui| { - let Some(msg) = self.last_step.as_ref() else { + let Some(step) = self.last_step.as_ref() else { + ui.label("Waiting for first training step..."); return; }; - let (img, _render_aux) = burn_cubecl::cubecl::future::block_on(render_splats( - msg.splats.clone(), - &self.camera, - glam::uvec2(self.image.width(), self.image.height()), - Vec3::ZERO, // Just render with a black background - None, - TextureMode::Packed, - )); + // Submit a render request + self.backbuffer.submit(RenderRequest { + slot: self.slot.clone(), + frame: 0, + camera: self.camera.clone(), + img_size: glam::uvec2(self.image.width(), self.image.height()), + background: Vec3::ZERO, + splat_scale: None, + }); let size = egui::vec2(self.image.width() as f32, self.image.height() as f32); ui.horizontal(|ui| { - let texture_id = self.backbuffer.update_texture(img); - ui.image(ImageSource::Texture(SizedTexture::new(texture_id, size))); + if let Some(texture_id) = self.backbuffer.id() { + ui.image(ImageSource::Texture(SizedTexture::new(texture_id, size))); + } else { + ui.label("Rendering..."); + } ui.image(ImageSource::Texture(SizedTexture::new( self.tex_handle.id(), size, ))); }); - ui.label(format!("Splats: {}", msg.splats.num_splats())); - ui.label(format!("Step: {}", msg.iter)); + ui.label(format!("Splats: {}", step.num_splats)); + ui.label(format!("Step: {}", step.iter)); }); } } From 01fda5fd9f642e308946fcc0c02f74a06258efe3 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 20 Jan 2026 21:25:05 +0100 Subject: [PATCH 14/29] Move widget to backbuffer --- crates/brush-ui/src/scene.rs | 22 +++------------------- crates/brush-ui/src/splat_backbuffer.rs | 24 ++++++++++++++++++++++++ examples/train-2d/examples/train-2d.rs | 3 +++ 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 4892810a..5b2602d9 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -61,7 +61,6 @@ use crate::{ draw_checkerboard, panels::AppPane, ui_process::{BackgroundStyle, UiProcess}, - widget_3d::Widget3D, }; #[derive(Clone, PartialEq)] @@ -128,8 +127,6 @@ pub struct ScenePanel { #[serde(skip)] last_state: Option, #[serde(skip)] - widget_3d: Option, - #[serde(skip)] source_name: Option, #[serde(skip)] source_type: Option, @@ -334,22 +331,10 @@ impl ScenePanel { img_size: pixel_size, background: settings.background.unwrap_or(Vec3::ZERO), splat_scale: settings.splat_scale, + ctx: ui.ctx().clone(), + model_transform: process.model_local_to_world(), + grid_opacity, }); - - // Render widget_3d overlay to the same texture - if let Some(widget_3d) = &mut self.widget_3d - && let Some(texture) = backbuffer.texture() - { - widget_3d.render_to_texture( - &camera, - process.model_local_to_world(), - pixel_size, - &texture, - grid_opacity, - ); - } - - ui.ctx().request_repaint(); } ui.scope(|ui| { @@ -851,7 +836,6 @@ impl AppPane for ScenePanel { _burn_device: burn_wgpu::WgpuDevice, _adapter_info: wgpu::AdapterInfo, ) { - self.widget_3d = Some(Widget3D::new(device.clone(), queue.clone())); self.backbuffer = Some(SplatBackbuffer::new(renderer, device, queue)); // Create the settings popup now that we have the base_path diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index 2e10f3d6..6473cc8e 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -3,6 +3,7 @@ //! Background task renders splats directly to a wgpu texture. //! wgpu handles GPU synchronization - no locks needed on results. +use crate::widget_3d::Widget3D; use brush_process::slot::Slot; use brush_render::{ MainBackend, MainBackendBase, TextureMode, camera::Camera, gaussian_splats::Splats, @@ -27,6 +28,11 @@ pub struct RenderRequest { pub img_size: glam::UVec2, pub background: Vec3, pub splat_scale: Option, + pub ctx: egui::Context, + /// Model transform for the 3D overlay (grid, axes). + pub model_transform: glam::Affine3A, + /// Opacity of the grid overlay (0.0 = hidden, 1.0 = fully visible). + pub grid_opacity: f32, } /// Shared texture state that the render loop writes to directly. @@ -216,6 +222,9 @@ async fn render_loop( queue: wgpu::Queue, mut request_rx: mpsc::UnboundedReceiver, ) { + // Create Widget3D for rendering the grid overlay + let widget_3d = Widget3D::new(device.clone(), queue.clone()); + while let Some(mut req) = request_rx.recv().await { // Drain channel to get the latest request while let Ok(newer) = request_rx.try_recv() { @@ -245,6 +254,21 @@ async fn render_loop( if let Some(image) = render_result { copy_to_texture(image, &texture_state, &renderer, &device, &queue); + + // Render 3D overlay (grid, axes) on top of the splats + if req.grid_opacity > 0.0 { + if let Some(texture) = texture_state.lock().unwrap().texture.clone() { + widget_3d.render_to_texture( + &req.camera, + req.model_transform, + req.img_size, + &texture, + req.grid_opacity, + ); + } + } } + + req.ctx.request_repaint(); } } diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index 27452701..979dfabb 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -185,6 +185,9 @@ impl eframe::App for App { img_size: glam::uvec2(self.image.width(), self.image.height()), background: Vec3::ZERO, splat_scale: None, + ctx: ctx.clone(), + model_transform: glam::Affine3A::IDENTITY, + grid_opacity: 0.0, }); let size = egui::vec2(self.image.width() as f32, self.image.height() as f32); From b207dac2ba400fa01bef2df1898e8504d0a14a94 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 20 Jan 2026 21:58:28 +0100 Subject: [PATCH 15/29] Fixes for wasm --- crates/brush-render/src/gaussian_splats.rs | 3 --- crates/brush-rerun/src/visualize_tools.rs | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/crates/brush-render/src/gaussian_splats.rs b/crates/brush-render/src/gaussian_splats.rs index f938f7a6..e9fa41c8 100644 --- a/crates/brush-render/src/gaussian_splats.rs +++ b/crates/brush-render/src/gaussian_splats.rs @@ -14,9 +14,6 @@ use crate::{ sh::{sh_coeffs_for_degree, sh_degree_from_coeffs}, }; -#[cfg(target_family = "wasm")] -use crate::render::calc_tile_bounds; - #[derive( Module, Clone, Copy, Debug, Eq, PartialEq, ValueEnum, serde::Serialize, serde::Deserialize, )] diff --git a/crates/brush-rerun/src/visualize_tools.rs b/crates/brush-rerun/src/visualize_tools.rs index a8874f27..79557af6 100644 --- a/crates/brush-rerun/src/visualize_tools.rs +++ b/crates/brush-rerun/src/visualize_tools.rs @@ -360,7 +360,7 @@ mod visualize_tools_impl { #[allow(unused_variables)] #[allow(clippy::unnecessary_wraps, clippy::unused_self)] - pub fn log_splat_stats(&self, _iter: u32, _splats: &Splats) -> Result<()> { + pub fn log_splat_stats(&self, _iter: u32, _num_splats: u32) -> Result<()> { Ok(()) } From 869e62e362e39d62b7aac8ef182e9ac5412e066f Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 20 Jan 2026 22:08:15 +0100 Subject: [PATCH 16/29] fmt --- crates/brush-ui/src/scene.rs | 2 +- crates/brush-ui/src/splat_backbuffer.rs | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 5b2602d9..1f92c425 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -327,7 +327,7 @@ impl ScenePanel { backbuffer.submit(RenderRequest { slot: process.current_splats(), frame: self.frame as usize, - camera: camera.clone(), + camera, img_size: pixel_size, background: settings.background.unwrap_or(Vec3::ZERO), splat_scale: settings.splat_scale, diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index 6473cc8e..b64264f9 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -256,16 +256,16 @@ async fn render_loop( copy_to_texture(image, &texture_state, &renderer, &device, &queue); // Render 3D overlay (grid, axes) on top of the splats - if req.grid_opacity > 0.0 { - if let Some(texture) = texture_state.lock().unwrap().texture.clone() { - widget_3d.render_to_texture( - &req.camera, - req.model_transform, - req.img_size, - &texture, - req.grid_opacity, - ); - } + if req.grid_opacity > 0.0 + && let Some(texture) = texture_state.lock().unwrap().texture.clone() + { + widget_3d.render_to_texture( + &req.camera, + req.model_transform, + req.img_size, + &texture, + req.grid_opacity, + ); } } From cfdbac97e01c937178bdf6880f623cbb5cb72dd2 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 20 Jan 2026 22:23:38 +0100 Subject: [PATCH 17/29] Cleanup --- crates/brush-process/src/slot.rs | 18 ++++------------- crates/brush-ui/src/splat_backbuffer.rs | 27 ++----------------------- examples/train-2d/examples/train-2d.rs | 6 ------ 3 files changed, 6 insertions(+), 45 deletions(-) diff --git a/crates/brush-process/src/slot.rs b/crates/brush-process/src/slot.rs index 9d32fa86..9a45ff5f 100644 --- a/crates/brush-process/src/slot.rs +++ b/crates/brush-process/src/slot.rs @@ -1,17 +1,12 @@ use std::sync::Arc; use tokio::sync::Mutex; -/// A thread-safe async slot for sharing data between the process and UI. -/// Uses tokio's async Mutex so locks can be held across await points. +/// Async slot for sharing data between the process and UI. #[derive(Clone)] pub struct Slot(Arc>>); impl Slot { - /// Temporarily take ownership of a value at a specific index, do something with it, - /// and put the result back. The lock is held across the async operation. - /// - /// Uses swap + push to avoid needing a placeholder value. - /// The value is moved to the end during the operation, then swapped back. + /// Take ownership of value at index, apply async function, put result back. pub async fn act(&self, index: usize, f: F) -> Option where F: AsyncFnOnce(T) -> (T, R), @@ -21,12 +16,10 @@ impl Slot { if index >= len { return None; } - // Swap the target element to the end, then pop it guard.swap(index, len - 1); let value = guard.pop().unwrap(); let (new_value, result) = f(value).await; - // Push it back and swap to original position guard.push(new_value); let new_len = guard.len(); guard.swap(index, new_len - 1); @@ -44,20 +37,18 @@ impl Slot { .await } - /// Get a clone of the main (last) value. pub async fn clone_main(&self) -> Option { self.0.lock().await.last().cloned() } - /// Set the slot to contain a single value, clearing any previous contents. + /// Replace all contents with a single value. pub async fn set(&self, value: T) { let mut guard = self.0.lock().await; guard.clear(); guard.push(value); } - /// Set the value at the given index, or push if index equals current length. - /// Panics if index > len (gaps not allowed). + /// Set value at index, or push if index == len. Panics if index > len. pub async fn set_at(&self, index: usize, value: T) { let mut guard = self.0.lock().await; if index == guard.len() { @@ -67,7 +58,6 @@ impl Slot { } } - /// Clear the slot. pub async fn clear(&self) { self.0.lock().await.clear(); } diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index b64264f9..9898ce55 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -1,8 +1,3 @@ -//! Async splat rendering backbuffer. -//! -//! Background task renders splats directly to a wgpu texture. -//! wgpu handles GPU synchronization - no locks needed on results. - use crate::widget_3d::Widget3D; use brush_process::slot::Slot; use brush_render::{ @@ -35,16 +30,12 @@ pub struct RenderRequest { pub grid_opacity: f32, } -/// Shared texture state that the render loop writes to directly. struct TextureState { texture: Option, texture_id: Option, } -/// Async splat rendering backbuffer. -/// -/// Background task renders splats and writes directly to a wgpu texture. -/// No synchronization needed on results - wgpu handles that. +/// Renders splats asynchronously in a background task. pub struct SplatBackbuffer { texture_state: Arc>, request_tx: mpsc::UnboundedSender, @@ -62,7 +53,6 @@ impl SplatBackbuffer { })); let (request_tx, request_rx) = mpsc::unbounded_channel(); - // Spawn the background render loop task::spawn(render_loop( Arc::clone(&texture_state), renderer, @@ -77,20 +67,13 @@ impl SplatBackbuffer { } } - /// Submit a new render request. pub fn submit(&self, request: RenderRequest) { let _ = self.request_tx.send(request); } - /// Get the texture ID for display. pub fn id(&self) -> Option { self.texture_state.lock().unwrap().texture_id } - - /// Get the underlying texture for additional rendering. - pub fn texture(&self) -> Option { - self.texture_state.lock().unwrap().texture.clone() - } } fn create_texture(size: glam::UVec2, device: &wgpu::Device) -> wgpu::Texture { @@ -123,7 +106,6 @@ fn copy_to_texture( assert!(c == 1, "texture should be u8 packed RGBA"); let size = glam::uvec2(w as u32, h as u32); - // Check if we need to resize/create texture let needs_resize = { let state = texture_state.lock().unwrap(); state @@ -133,7 +115,6 @@ fn copy_to_texture( }; if needs_resize { - // Cleanup memory when resizing let client = WgpuRuntime::client(&img.device()); client.memory_cleanup(); @@ -141,7 +122,6 @@ fn copy_to_texture( let mut state = texture_state.lock().unwrap(); if let Some(id) = state.texture_id { - // Update existing registration renderer.write().update_egui_texture_from_wgpu_texture( device, &texture.create_view(&TextureViewDescriptor::default()), @@ -149,7 +129,6 @@ fn copy_to_texture( id, ); } else { - // New registration let id = renderer.write().register_native_texture( device, &texture.create_view(&TextureViewDescriptor::default()), @@ -222,11 +201,10 @@ async fn render_loop( queue: wgpu::Queue, mut request_rx: mpsc::UnboundedReceiver, ) { - // Create Widget3D for rendering the grid overlay let widget_3d = Widget3D::new(device.clone(), queue.clone()); while let Some(mut req) = request_rx.recv().await { - // Drain channel to get the latest request + // Drain to get latest request while let Ok(newer) = request_rx.try_recv() { req = newer; } @@ -255,7 +233,6 @@ async fn render_loop( if let Some(image) = render_result { copy_to_texture(image, &texture_state, &renderer, &device, &queue); - // Render 3D overlay (grid, axes) on top of the splats if req.grid_opacity > 0.0 && let Some(texture) = texture_state.lock().unwrap().texture.clone() { diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index 979dfabb..a8fc455f 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -35,7 +35,6 @@ fn spawn_train_loop( sender: Sender, slot: Slot>, ) { - // Spawn a task that iterates over the training stream. tokio::spawn(async move { let seed = 42; @@ -58,7 +57,6 @@ fn spawn_train_loop( BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE), ); - // One batch of training data, it's the same every step so can just construct it once. let batch = SceneBatch { img_tensor: sample_to_tensor_data(image), alpha_mode: AlphaMode::Transparent, @@ -71,8 +69,6 @@ fn spawn_train_loop( let (new_splats, _) = trainer.step(batch.clone(), splats).await; let (new_splats, _) = trainer.refine(iter, new_splats.valid()).await; let num_splats = new_splats.num_splats(); - - // Update the slot with latest splats slot.set(new_splats.clone()).await; splats = splats_into_autodiff(new_splats); @@ -177,7 +173,6 @@ impl eframe::App for App { return; }; - // Submit a render request self.backbuffer.submit(RenderRequest { slot: self.slot.clone(), frame: 0, @@ -213,7 +208,6 @@ impl eframe::App for App { #[tokio::main] async fn main() { let native_options = eframe::NativeOptions { - // Build app display. viewport: egui::ViewportBuilder::default() .with_inner_size(egui::Vec2::new(1100.0, 500.0)) .with_active(true), From ca471a6f0e04b57c7d7706c118d0d6902aa9b2bc Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Thu, 22 Jan 2026 15:19:27 +0100 Subject: [PATCH 18/29] Less locking --- crates/brush-ui/src/scene.rs | 4 +--- crates/brush-ui/src/splat_backbuffer.rs | 30 +++++++++++++++---------- crates/brush-ui/src/widget_3d.rs | 3 +-- examples/train-2d/examples/train-2d.rs | 1 - 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 1f92c425..9b7e8800 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -310,7 +310,6 @@ impl ScenePanel { if dirty { self.last_state = Some(state); - ui.ctx().request_repaint(); } let pixel_size = glam::uvec2( @@ -319,7 +318,7 @@ impl ScenePanel { ); // Submit new render request if dirty and we have splats - if let Some(backbuffer) = &mut self.backbuffer + if let Some(backbuffer) = &self.backbuffer && pixel_size.x > 8 && pixel_size.y > 8 && dirty @@ -331,7 +330,6 @@ impl ScenePanel { img_size: pixel_size, background: settings.background.unwrap_or(Vec3::ZERO), splat_scale: settings.splat_scale, - ctx: ui.ctx().clone(), model_transform: process.model_local_to_world(), grid_opacity, }); diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index 9898ce55..1bd06fbd 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -12,10 +12,11 @@ use egui::TextureId; use egui::epaint::mutex::RwLock as EguiRwLock; use glam::Vec3; use std::sync::{Arc, Mutex}; -use tokio::sync::mpsc; +use tokio::sync::watch; use tokio_with_wasm::alias::task; use wgpu::{CommandEncoderDescriptor, TexelCopyBufferLayout, TextureViewDescriptor}; +#[derive(Clone)] pub struct RenderRequest { pub slot: Slot>, pub frame: usize, @@ -23,7 +24,6 @@ pub struct RenderRequest { pub img_size: glam::UVec2, pub background: Vec3, pub splat_scale: Option, - pub ctx: egui::Context, /// Model transform for the 3D overlay (grid, axes). pub model_transform: glam::Affine3A, /// Opacity of the grid overlay (0.0 = hidden, 1.0 = fully visible). @@ -38,7 +38,7 @@ struct TextureState { /// Renders splats asynchronously in a background task. pub struct SplatBackbuffer { texture_state: Arc>, - request_tx: mpsc::UnboundedSender, + request_tx: watch::Sender>, } impl SplatBackbuffer { @@ -51,7 +51,7 @@ impl SplatBackbuffer { texture: None, texture_id: None, })); - let (request_tx, request_rx) = mpsc::unbounded_channel(); + let (request_tx, request_rx) = watch::channel(None); task::spawn(render_loop( Arc::clone(&texture_state), @@ -67,8 +67,10 @@ impl SplatBackbuffer { } } + /// Submit a render request. If a request is already pending, + /// it will be replaced with this one (latest wins). pub fn submit(&self, request: RenderRequest) { - let _ = self.request_tx.send(request); + let _ = self.request_tx.send(Some(request)); } pub fn id(&self) -> Option { @@ -199,16 +201,22 @@ async fn render_loop( renderer: Arc>, device: wgpu::Device, queue: wgpu::Queue, - mut request_rx: mpsc::UnboundedReceiver, + mut request_rx: watch::Receiver>, ) { let widget_3d = Widget3D::new(device.clone(), queue.clone()); - while let Some(mut req) = request_rx.recv().await { - // Drain to get latest request - while let Ok(newer) = request_rx.try_recv() { - req = newer; + loop { + // Wait for a new request (watch wakes on value change) + if request_rx.changed().await.is_err() { + // Sender dropped, exit loop + break; } + // Get the latest request + let Some(req) = request_rx.borrow_and_update().clone() else { + continue; + }; + let camera = req.camera.clone(); let img_size = req.img_size; let background = req.background; @@ -245,7 +253,5 @@ async fn render_loop( ); } } - - req.ctx.request_repaint(); } } diff --git a/crates/brush-ui/src/widget_3d.rs b/crates/brush-ui/src/widget_3d.rs index 2e6e3d76..cf5975ec 100644 --- a/crates/brush-ui/src/widget_3d.rs +++ b/crates/brush-ui/src/widget_3d.rs @@ -302,7 +302,6 @@ impl Widget3D { render_pass.set_vertex_buffer(0, self.up_axis_vertex_buffer.slice(..)); render_pass.draw(0..self.up_axis_vertex_count, 0..1); } - - self.queue.submit(std::iter::once(encoder.finish())); + self.queue.submit([encoder.finish()]); } } diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index a8fc455f..aaed91a7 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -180,7 +180,6 @@ impl eframe::App for App { img_size: glam::uvec2(self.image.width(), self.image.height()), background: Vec3::ZERO, splat_scale: None, - ctx: ctx.clone(), model_transform: glam::Affine3A::IDENTITY, grid_opacity: 0.0, }); From 484d038bd73341029c9d0a03a80ef0a73dd337ee Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Sun, 25 Jan 2026 23:44:34 +0100 Subject: [PATCH 19/29] WIP --- Cargo.lock | 1 + crates/brush-ui/Cargo.toml | 1 + crates/brush-ui/src/scene.rs | 1 + crates/brush-ui/src/splat_backbuffer.rs | 210 ++++++++++-------------- 4 files changed, 93 insertions(+), 120 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 419d93eb..39e8f2d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1301,6 +1301,7 @@ dependencies = [ "humantime", "image", "log", + "pollster", "rand 0.9.2", "rrfd", "serde", diff --git a/crates/brush-ui/Cargo.toml b/crates/brush-ui/Cargo.toml index 3b0757fb..a1365642 100644 --- a/crates/brush-ui/Cargo.toml +++ b/crates/brush-ui/Cargo.toml @@ -39,6 +39,7 @@ bytemuck = { version = "1.14", features = ["derive"] } brush-dataset = { path = "../brush-dataset", optional = true } image = { workspace = true, optional = true } +pollster = "0.4.0" [target.'cfg(not(target_family = "wasm"))'.dependencies] tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 9b7e8800..4e1e9872 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -330,6 +330,7 @@ impl ScenePanel { img_size: pixel_size, background: settings.background.unwrap_or(Vec3::ZERO), splat_scale: settings.splat_scale, + ctx: ui.ctx().clone(), model_transform: process.model_local_to_world(), grid_opacity, }); diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index 1bd06fbd..0b1c349b 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -11,9 +11,8 @@ use eframe::egui_wgpu::Renderer; use egui::TextureId; use egui::epaint::mutex::RwLock as EguiRwLock; use glam::Vec3; -use std::sync::{Arc, Mutex}; -use tokio::sync::watch; -use tokio_with_wasm::alias::task; +use pollster::block_on; +use std::sync::Arc; use wgpu::{CommandEncoderDescriptor, TexelCopyBufferLayout, TextureViewDescriptor}; #[derive(Clone)] @@ -24,21 +23,20 @@ pub struct RenderRequest { pub img_size: glam::UVec2, pub background: Vec3, pub splat_scale: Option, + pub ctx: egui::Context, /// Model transform for the 3D overlay (grid, axes). pub model_transform: glam::Affine3A, /// Opacity of the grid overlay (0.0 = hidden, 1.0 = fully visible). pub grid_opacity: f32, } -struct TextureState { - texture: Option, - texture_id: Option, -} - -/// Renders splats asynchronously in a background task. pub struct SplatBackbuffer { - texture_state: Arc>, - request_tx: watch::Sender>, + texture: wgpu::Texture, + texture_id: TextureId, + renderer: Arc>, + device: wgpu::Device, + queue: wgpu::Queue, + widget_3d: Widget3D, } impl SplatBackbuffer { @@ -47,34 +45,87 @@ impl SplatBackbuffer { device: wgpu::Device, queue: wgpu::Queue, ) -> Self { - let texture_state = Arc::new(Mutex::new(TextureState { - texture: None, - texture_id: None, - })); - let (request_tx, request_rx) = watch::channel(None); + // Start with a dummy texture + let texture = create_texture(glam::uvec2(64, 64), &device); + let id = renderer.write().register_native_texture( + &device, + &texture.create_view(&TextureViewDescriptor::default()), + wgpu::FilterMode::Linear, + ); + let widget_3d = Widget3D::new(device.clone(), queue.clone()); - task::spawn(render_loop( - Arc::clone(&texture_state), + Self { + texture, + texture_id: id, renderer, device, queue, - request_rx, - )); - - Self { - texture_state, - request_tx, + widget_3d, } } - /// Submit a render request. If a request is already pending, - /// it will be replaced with this one (latest wins). - pub fn submit(&self, request: RenderRequest) { - let _ = self.request_tx.send(Some(request)); + /// Submit a render request. Spawns an async task to do the rendering. + pub fn submit(&self, req: RenderRequest) { + let needs_resize = + self.texture.width() != req.img_size.x || self.texture.height() != req.img_size.y; + if needs_resize { + // TODO: Restore this. + // let client = WgpuRuntime::client(&req.); + // client.memory_cleanup(); + + self.texture = create_texture(req.img_size, &self.device); + self.renderer.write().update_egui_texture_from_wgpu_texture( + &self.device, + &self.texture.create_view(&TextureViewDescriptor::default()), + wgpu::FilterMode::Linear, + self.texture_id, + ); + } + + block_on(async move { + let camera = req.camera.clone(); + let img_size = req.img_size; + let background = req.background; + let splat_scale = req.splat_scale; + + let splats = req.slot.clone_main().await; + + if let Some(splats) = splats { + let (image, _) = render_splats( + splats.clone(), + &camera, + img_size, + background, + splat_scale, + TextureMode::Packed, + ) + .await; + + copy_to_texture( + image, + &self.texture, + self.texture_id, + &self.renderer, + &self.device, + &self.queue, + ); + + if req.grid_opacity > 0.0 { + self.widget_3d.render_to_texture( + &req.camera, + req.model_transform, + req.img_size, + &texture, + req.grid_opacity, + ); + } + } + req.ctx.request_repaint(); + }); } pub fn id(&self) -> Option { - self.texture_state.lock().unwrap().texture_id + self.texture_id } } @@ -99,7 +150,8 @@ fn create_texture(size: glam::UVec2, device: &wgpu::Device) -> wgpu::Texture { fn copy_to_texture( img: Tensor, - texture_state: &Arc>, + texture: &wgpu::Texture, + texture_id: TextureId, renderer: &Arc>, device: &wgpu::Device, queue: &wgpu::Queue, @@ -108,41 +160,19 @@ fn copy_to_texture( assert!(c == 1, "texture should be u8 packed RGBA"); let size = glam::uvec2(w as u32, h as u32); - let needs_resize = { - let state = texture_state.lock().unwrap(); - state - .texture - .as_ref() - .is_none_or(|t| t.width() != size.x || t.height() != size.y) - }; - + let needs_resize = texture.width() != size.x || texture.height() != size.y; if needs_resize { let client = WgpuRuntime::client(&img.device()); client.memory_cleanup(); - - let texture = create_texture(size, device); - - let mut state = texture_state.lock().unwrap(); - if let Some(id) = state.texture_id { - renderer.write().update_egui_texture_from_wgpu_texture( - device, - &texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - id, - ); - } else { - let id = renderer.write().register_native_texture( - device, - &texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - ); - state.texture_id = Some(id); - } - - state.texture = Some(texture); + texture = create_texture(size, device); + renderer.write().update_egui_texture_from_wgpu_texture( + &device, + &texture.create_view(&TextureViewDescriptor::default()), + wgpu::FilterMode::Linear, + texture_id, + ); } - let texture = texture_state.lock().unwrap().texture.clone().unwrap(); let [height, width, c] = img.dims(); let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor { @@ -195,63 +225,3 @@ fn copy_to_texture( queue.submit([encoder.finish()]); } - -async fn render_loop( - texture_state: Arc>, - renderer: Arc>, - device: wgpu::Device, - queue: wgpu::Queue, - mut request_rx: watch::Receiver>, -) { - let widget_3d = Widget3D::new(device.clone(), queue.clone()); - - loop { - // Wait for a new request (watch wakes on value change) - if request_rx.changed().await.is_err() { - // Sender dropped, exit loop - break; - } - - // Get the latest request - let Some(req) = request_rx.borrow_and_update().clone() else { - continue; - }; - - let camera = req.camera.clone(); - let img_size = req.img_size; - let background = req.background; - let splat_scale = req.splat_scale; - - let render_result = req - .slot - .act(req.frame, async move |splats| { - let (image, _) = render_splats( - splats.clone(), - &camera, - img_size, - background, - splat_scale, - TextureMode::Packed, - ) - .await; - (splats, image) - }) - .await; - - if let Some(image) = render_result { - copy_to_texture(image, &texture_state, &renderer, &device, &queue); - - if req.grid_opacity > 0.0 - && let Some(texture) = texture_state.lock().unwrap().texture.clone() - { - widget_3d.render_to_texture( - &req.camera, - req.model_transform, - req.img_size, - &texture, - req.grid_opacity, - ); - } - } - } -} From c6c37a3696de49195960a272cc5d6cc6c7b5a8a4 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 26 Jan 2026 00:26:12 +0100 Subject: [PATCH 20/29] Start on native backbuffer. --- Cargo.lock | 1 - crates/brush-ui/Cargo.toml | 1 - crates/brush-ui/src/scene.rs | 203 ++++---- crates/brush-ui/src/splat_backbuffer.rs | 589 ++++++++++++++++-------- examples/train-2d/examples/train-2d.rs | 10 +- 5 files changed, 494 insertions(+), 310 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 39e8f2d4..419d93eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1301,7 +1301,6 @@ dependencies = [ "humantime", "image", "log", - "pollster", "rand 0.9.2", "rrfd", "serde", diff --git a/crates/brush-ui/Cargo.toml b/crates/brush-ui/Cargo.toml index a1365642..3b0757fb 100644 --- a/crates/brush-ui/Cargo.toml +++ b/crates/brush-ui/Cargo.toml @@ -39,7 +39,6 @@ bytemuck = { version = "1.14", features = ["derive"] } brush-dataset = { path = "../brush-dataset", optional = true } image = { workspace = true, optional = true } -pollster = "0.4.0" [target.'cfg(not(target_family = "wasm"))'.dependencies] tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 4e1e9872..8a5d8753 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -4,6 +4,7 @@ use crate::settings_popup::SettingsPopup; use brush_process::message::TrainMessage; #[cfg(feature = "training")] use std::sync::Mutex; +use wgpu::naga::back; use brush_process::{create_process, message::ProcessMessage}; use brush_vfs::DataSource; @@ -258,114 +259,6 @@ impl ScenePanel { )); } - pub(crate) fn draw_splats( - &mut self, - ui: &mut egui::Ui, - process: &UiProcess, - interactive: bool, - ) -> egui::Rect { - let size = ui.available_size(); - let size = glam::uvec2(size.x.round() as u32, size.y.round() as u32); - let (rect, response) = ui.allocate_exact_size( - egui::Vec2::new(size.x as f32, size.y as f32), - egui::Sense::drag(), - ); - if interactive { - process.tick_controls(&response, ui); - } - - // Get camera after modifying the controls. - let mut camera = process.current_camera(); - - let view_eff = (camera.world_to_local() * process.model_local_to_world()).inverse(); - let (_, rotation, position) = view_eff.to_scale_rotation_translation(); - camera.position = position; - camera.rotation = rotation; - - let settings = process.get_cam_settings(); - - // Adjust FOV so that the scene view shows at least what's visible in the dataset view. - let camera_aspect = (camera.fov_x / 2.0).tan() / (camera.fov_y / 2.0).tan(); - let viewport_aspect = size.x as f64 / size.y as f64; - - if viewport_aspect > camera_aspect { - let focal_y = fov_to_focal(camera.fov_y, size.y); - camera.fov_x = focal_to_fov(focal_y, size.x); - } else { - let focal_x = fov_to_focal(camera.fov_x, size.x); - camera.fov_y = focal_to_fov(focal_x, size.y); - } - - let grid_opacity = process.get_grid_opacity(); - - let state = RenderState { - size, - cam: camera.clone(), - settings: settings.clone(), - grid_opacity, - frame: self.frame as u32, - }; - - let dirty = self.last_state != Some(state.clone()); - - if dirty { - self.last_state = Some(state); - } - - let pixel_size = glam::uvec2( - (size.x as f32 * ui.ctx().pixels_per_point().round()) as u32, - (size.y as f32 * ui.ctx().pixels_per_point().round()) as u32, - ); - - // Submit new render request if dirty and we have splats - if let Some(backbuffer) = &self.backbuffer - && pixel_size.x > 8 - && pixel_size.y > 8 - && dirty - { - backbuffer.submit(RenderRequest { - slot: process.current_splats(), - frame: self.frame as usize, - camera, - img_size: pixel_size, - background: settings.background.unwrap_or(Vec3::ZERO), - splat_scale: settings.splat_scale, - ctx: ui.ctx().clone(), - model_transform: process.model_local_to_world(), - grid_opacity, - }); - } - - ui.scope(|ui| { - // if training views have alpha, show a background checker. Masked images - // should still use a black background. - match process.background_style() { - BackgroundStyle::Checkerboard => { - draw_checkerboard(ui, rect, Color32::WHITE); - } - BackgroundStyle::Black => { - ui.painter().rect_filled(rect, 0.0, Color32::BLACK); - } - } - - if let Some(backbuffer) = &self.backbuffer - && let Some(id) = backbuffer.id() - { - ui.painter().image( - id, - rect, - Rect { - min: egui::pos2(0.0, 0.0), - max: egui::pos2(1.0, 1.0), - }, - Color32::WHITE, - ); - } - }); - - rect - } - fn draw_play_pause(&mut self, ui: &egui::Ui, rect: Rect) { // Only show play/pause if we have a multi-frame sequence that's fully loaded if self.frame_count > 1 { @@ -1024,7 +917,99 @@ impl AppPane for ScenePanel { let interactive = matches!(process.ui_mode(), UiMode::Default | UiMode::FullScreenSplat); - let rect = self.draw_splats(ui, process, interactive); + + let size = ui.available_size(); + let size = glam::uvec2(size.x.round() as u32, size.y.round() as u32); + let (rect, response) = ui.allocate_exact_size( + egui::Vec2::new(size.x as f32, size.y as f32), + egui::Sense::drag(), + ); + if interactive { + process.tick_controls(&response, ui); + } + + // Get camera after modifying the controls. + let mut camera = process.current_camera(); + + let view_eff = (camera.world_to_local() * process.model_local_to_world()).inverse(); + let (_, rotation, position) = view_eff.to_scale_rotation_translation(); + camera.position = position; + camera.rotation = rotation; + + let settings = process.get_cam_settings(); + + // Adjust FOV so that the scene view shows at least what's visible in the dataset view. + let camera_aspect = (camera.fov_x / 2.0).tan() / (camera.fov_y / 2.0).tan(); + let viewport_aspect = size.x as f64 / size.y as f64; + + if viewport_aspect > camera_aspect { + let focal_y = fov_to_focal(camera.fov_y, size.y); + camera.fov_x = focal_to_fov(focal_y, size.x); + } else { + let focal_x = fov_to_focal(camera.fov_x, size.x); + camera.fov_y = focal_to_fov(focal_x, size.y); + } + + let grid_opacity = process.get_grid_opacity(); + + let state = RenderState { + size, + cam: camera.clone(), + settings: settings.clone(), + grid_opacity, + frame: self.frame as u32, + }; + + let dirty = self.last_state != Some(state.clone()); + + if dirty { + self.last_state = Some(state); + } + + let pixel_size = glam::uvec2( + (size.x as f32 * ui.ctx().pixels_per_point().round()) as u32, + (size.y as f32 * ui.ctx().pixels_per_point().round()) as u32, + ); + + // Submit new render request if dirty and we have splats + ui.scope(|ui| { + // if training views have alpha, show a background checker. Masked images + // should still use a black background. + match process.background_style() { + BackgroundStyle::Checkerboard => { + draw_checkerboard(ui, rect, Color32::WHITE); + } + BackgroundStyle::Black => { + ui.painter().rect_filled(rect, 0.0, Color32::BLACK); + } + } + + if let Some(backbuffer) = &mut self.backbuffer { + if dirty { + backbuffer.submit(RenderRequest { + slot: process.current_splats(), + frame: self.frame as usize, + camera, + img_size: pixel_size, + background: settings.background.unwrap_or(Vec3::ZERO), + splat_scale: settings.splat_scale, + ctx: ui.ctx().clone(), + model_transform: process.model_local_to_world(), + grid_opacity, + }); + } + // ui.painter().image( + // backbuffer.id(), + // rect, + // Rect { + // min: egui::pos2(0.0, 0.0), + // max: egui::pos2(1.0, 1.0), + // }, + // Color32::WHITE, + // ); + backbuffer.draw(); + } + }); if interactive { self.draw_play_pause(ui, rect); diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index 0b1c349b..ad1c1d73 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -5,223 +5,424 @@ use brush_render::{ render_splats, }; use burn::tensor::{Tensor, TensorPrimitive}; -use burn_cubecl::cubecl::Runtime; -use burn_wgpu::WgpuRuntime; use eframe::egui_wgpu::Renderer; use egui::TextureId; use egui::epaint::mutex::RwLock as EguiRwLock; use glam::Vec3; -use pollster::block_on; +use std::num::NonZeroU64; use std::sync::Arc; +use tokio::task; use wgpu::{CommandEncoderDescriptor, TexelCopyBufferLayout, TextureViewDescriptor}; -#[derive(Clone)] -pub struct RenderRequest { - pub slot: Slot>, - pub frame: usize, - pub camera: Camera, - pub img_size: glam::UVec2, - pub background: Vec3, - pub splat_scale: Option, - pub ctx: egui::Context, - /// Model transform for the 3D overlay (grid, axes). - pub model_transform: glam::Affine3A, - /// Opacity of the grid overlay (0.0 = hidden, 1.0 = fully visible). - pub grid_opacity: f32, -} +use eframe::{ + egui_wgpu::wgpu::util::DeviceExt as _, + egui_wgpu::{self, wgpu}, +}; -pub struct SplatBackbuffer { - texture: wgpu::Texture, - texture_id: TextureId, - renderer: Arc>, - device: wgpu::Device, - queue: wgpu::Queue, - widget_3d: Widget3D, +pub struct Custom3d { + angle: f32, } -impl SplatBackbuffer { - pub fn new( - renderer: Arc>, - device: wgpu::Device, - queue: wgpu::Queue, - ) -> Self { - // Start with a dummy texture - let texture = create_texture(glam::uvec2(64, 64), &device); - let id = renderer.write().register_native_texture( - &device, - &texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - ); - let widget_3d = Widget3D::new(device.clone(), queue.clone()); - - Self { - texture, - texture_id: id, - renderer, - device, - queue, - widget_3d, - } +impl Custom3d { + pub fn new<'a>(cc: &'a eframe::CreationContext<'a>) -> Option { + // Get the WGPU render state from the eframe creation context. This can also be retrieved + // from `eframe::Frame` when you don't have a `CreationContext` available. + let wgpu_render_state = cc.wgpu_render_state.as_ref()?; + + let device = &wgpu_render_state.device; + + let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("custom3d"), + source: wgpu::ShaderSource::Wgsl(include_str!("./custom3d_wgpu_shader.wgsl").into()), + }); + + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("custom3d"), + entries: &[wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::VERTEX, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Uniform, + has_dynamic_offset: false, + min_binding_size: NonZeroU64::new(16), + }, + count: None, + }], + }); + + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("custom3d"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { + label: Some("custom3d"), + layout: Some(&pipeline_layout), + vertex: wgpu::VertexState { + module: &shader, + entry_point: None, + buffers: &[], + compilation_options: wgpu::PipelineCompilationOptions::default(), + }, + fragment: Some(wgpu::FragmentState { + module: &shader, + entry_point: Some("fs_main"), + targets: &[Some(wgpu_render_state.target_format.into())], + compilation_options: wgpu::PipelineCompilationOptions::default(), + }), + primitive: wgpu::PrimitiveState::default(), + depth_stencil: None, + multisample: wgpu::MultisampleState::default(), + multiview: None, + cache: None, + }); + + let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("custom3d"), + contents: bytemuck::cast_slice(&[0.0_f32; 4]), // 16 bytes aligned! + // Mapping at creation (as done by the create_buffer_init utility) doesn't require us to to add the MAP_WRITE usage + // (this *happens* to workaround this bug ) + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::UNIFORM, + }); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("custom3d"), + layout: &bind_group_layout, + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: uniform_buffer.as_entire_binding(), + }], + }); + + // Because the graphics pipeline must have the same lifetime as the egui render pass, + // instead of storing the pipeline in our `Custom3D` struct, we insert it into the + // `paint_callback_resources` type map, which is stored alongside the render pass. + wgpu_render_state + .renderer + .write() + .callback_resources + .insert(TriangleRenderResources { + pipeline, + bind_group, + uniform_buffer, + }); + + Some(Self { angle: 0.0 }) } +} - /// Submit a render request. Spawns an async task to do the rendering. - pub fn submit(&self, req: RenderRequest) { - let needs_resize = - self.texture.width() != req.img_size.x || self.texture.height() != req.img_size.y; - if needs_resize { - // TODO: Restore this. - // let client = WgpuRuntime::client(&req.); - // client.memory_cleanup(); - - self.texture = create_texture(req.img_size, &self.device); - self.renderer.write().update_egui_texture_from_wgpu_texture( - &self.device, - &self.texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - self.texture_id, - ); - } - - block_on(async move { - let camera = req.camera.clone(); - let img_size = req.img_size; - let background = req.background; - let splat_scale = req.splat_scale; - - let splats = req.slot.clone_main().await; - - if let Some(splats) = splats { - let (image, _) = render_splats( - splats.clone(), - &camera, - img_size, - background, - splat_scale, - TextureMode::Packed, - ) - .await; - - copy_to_texture( - image, - &self.texture, - self.texture_id, - &self.renderer, - &self.device, - &self.queue, +impl crate::DemoApp for Custom3d { + fn demo_ui(&mut self, ui: &mut egui::Ui, _frame: &mut eframe::Frame) { + // TODO(emilk): Use `ScrollArea::inner_margin` + egui::CentralPanel::default().show_inside(ui, |ui| { + egui::ScrollArea::both().auto_shrink(false).show(ui, |ui| { + ui.horizontal(|ui| { + ui.spacing_mut().item_spacing.x = 0.0; + ui.label("The triangle is being painted using "); + ui.hyperlink_to("WGPU", "https://wgpu.rs"); + ui.label(" (Portable Rust graphics API awesomeness)"); + }); + ui.label( + "It's not a very impressive demo, but it shows you can embed 3D inside of egui.", ); - if req.grid_opacity > 0.0 { - self.widget_3d.render_to_texture( - &req.camera, - req.model_transform, - req.img_size, - &texture, - req.grid_opacity, - ); - } - } - req.ctx.request_repaint(); + egui::Frame::canvas(ui.style()).show(ui, |ui| { + self.custom_painting(ui); + }); + ui.label("Drag to rotate!"); + ui.add(egui_demo_lib::egui_github_link_file!()); + }); }); } +} + +// Callbacks in egui_wgpu have 3 stages: +// * prepare (per callback impl) +// * finish_prepare (once) +// * paint (per callback impl) +// +// The prepare callback is called every frame before paint and is given access to the wgpu +// Device and Queue, which can be used, for instance, to update buffers and uniforms before +// rendering. +// If [`egui_wgpu::Renderer`] has [`egui_wgpu::FinishPrepareCallback`] registered, +// it will be called after all `prepare` callbacks have been called. +// You can use this to update any shared resources that need to be updated once per frame +// after all callbacks have been processed. +// +// On both prepare methods you can use the main `CommandEncoder` that is passed-in, +// return an arbitrary number of user-defined `CommandBuffer`s, or both. +// The main command buffer, as well as all user-defined ones, will be submitted together +// to the GPU in a single call. +// +// The paint callback is called after finish prepare and is given access to egui's main render pass, +// which can be used to issue draw commands. +struct CustomTriangleCallback { + angle: f32, +} + +impl egui_wgpu::CallbackTrait for CustomTriangleCallback { + fn prepare( + &self, + device: &wgpu::Device, + queue: &wgpu::Queue, + _screen_descriptor: &egui_wgpu::ScreenDescriptor, + _egui_encoder: &mut wgpu::CommandEncoder, + resources: &mut egui_wgpu::CallbackResources, + ) -> Vec { + let resources: &TriangleRenderResources = resources.get().unwrap(); + resources.prepare(device, queue, self.angle); + Vec::new() + } + + fn paint( + &self, + _info: egui::PaintCallbackInfo, + render_pass: &mut wgpu::RenderPass<'static>, + resources: &egui_wgpu::CallbackResources, + ) { + let resources: &TriangleRenderResources = resources.get().unwrap(); + resources.paint(render_pass); + } +} + +impl Custom3d { + fn custom_painting(&mut self, ui: &mut egui::Ui) { + let (rect, response) = + ui.allocate_exact_size(egui::Vec2::splat(300.0), egui::Sense::drag()); - pub fn id(&self) -> Option { - self.texture_id + self.angle += response.drag_motion().x * 0.01; + ui.painter().add(egui_wgpu::Callback::new_paint_callback( + rect, + CustomTriangleCallback { angle: self.angle }, + )); } } -fn create_texture(size: glam::UVec2, device: &wgpu::Device) -> wgpu::Texture { - device.create_texture(&wgpu::TextureDescriptor { - label: Some("Splat backbuffer"), - size: wgpu::Extent3d { - width: size.x, - height: size.y, - depth_or_array_layers: 1, - }, - mip_level_count: 1, - sample_count: 1, - dimension: wgpu::TextureDimension::D2, - format: wgpu::TextureFormat::Rgba8Unorm, - usage: wgpu::TextureUsages::TEXTURE_BINDING - | wgpu::TextureUsages::COPY_DST - | wgpu::TextureUsages::RENDER_ATTACHMENT, - view_formats: &[wgpu::TextureFormat::Rgba8Unorm], - }) +struct TriangleRenderResources { + pipeline: wgpu::RenderPipeline, + bind_group: wgpu::BindGroup, + uniform_buffer: wgpu::Buffer, } -fn copy_to_texture( - img: Tensor, - texture: &wgpu::Texture, - texture_id: TextureId, - renderer: &Arc>, - device: &wgpu::Device, - queue: &wgpu::Queue, -) { - let [h, w, c] = img.shape().dims(); - assert!(c == 1, "texture should be u8 packed RGBA"); - let size = glam::uvec2(w as u32, h as u32); - - let needs_resize = texture.width() != size.x || texture.height() != size.y; - if needs_resize { - let client = WgpuRuntime::client(&img.device()); - client.memory_cleanup(); - texture = create_texture(size, device); - renderer.write().update_egui_texture_from_wgpu_texture( - &device, - &texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - texture_id, +impl TriangleRenderResources { + fn prepare(&self, _device: &wgpu::Device, queue: &wgpu::Queue, angle: f32) { + // Update our uniform buffer with the angle from the UI + queue.write_buffer( + &self.uniform_buffer, + 0, + bytemuck::cast_slice(&[angle, 0.0, 0.0, 0.0]), ); } - let [height, width, c] = img.dims(); - - let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor { - label: Some("splat backbuffer encoder"), - }); - - let padded_shape = vec![height, width.div_ceil(64) * 64, c]; - - let img_prim = img.into_primitive().tensor(); - let fusion_client = img_prim.client.clone(); - let img = fusion_client.resolve_tensor_float::(img_prim); - let img: Tensor = Tensor::from_primitive(TensorPrimitive::Float(img)); - - // Pad if needed (WebGPU requires bytes_per_row divisible by 256) - let img = if width % 64 != 0 { - let padded: Tensor = Tensor::zeros(&padded_shape, &img.device()); - padded.slice_assign([0..height, 0..width], img) - } else { - img - }; - - let img = img.into_primitive().tensor(); - let client = &img.client; - let img_res_handle = client.get_resource(img.handle.clone().binding()); - client.flush(); - - let bytes_per_row = Some(4 * padded_shape[1] as u32); - - encoder.copy_buffer_to_texture( - wgpu::TexelCopyBufferInfo { - buffer: &img_res_handle.resource().buffer, - layout: TexelCopyBufferLayout { - offset: img_res_handle.resource().offset, - bytes_per_row, - rows_per_image: None, - }, - }, - wgpu::TexelCopyTextureInfo { - texture: &texture, - mip_level: 0, - origin: wgpu::Origin3d { x: 0, y: 0, z: 0 }, - aspect: wgpu::TextureAspect::All, - }, - wgpu::Extent3d { - width: width as u32, - height: height as u32, - depth_or_array_layers: 1, - }, - ); - - queue.submit([encoder.finish()]); + fn paint(&self, render_pass: &mut wgpu::RenderPass<'_>) { + // Draw our triangle! + render_pass.set_pipeline(&self.pipeline); + render_pass.set_bind_group(0, &self.bind_group, &[]); + render_pass.draw(0..3, 0..1); + } } + +// #[derive(Clone)] +// pub struct RenderRequest { +// pub slot: Slot>, +// pub frame: usize, +// pub camera: Camera, +// pub img_size: glam::UVec2, +// pub background: Vec3, +// pub splat_scale: Option, +// pub ctx: egui::Context, +// /// Model transform for the 3D overlay (grid, axes). +// pub model_transform: glam::Affine3A, +// /// Opacity of the grid overlay (0.0 = hidden, 1.0 = fully visible). +// pub grid_opacity: f32, +// } + +// pub struct SplatBackbuffer { +// texture: wgpu::Texture, +// texture_id: TextureId, +// device: wgpu::Device, +// queue: wgpu::Queue, +// widget_3d: Arc, +// } + +// impl SplatBackbuffer { +// pub fn new<'a>(cc: &'a eframe::CreationContext<'a>) -> Option { +// // Start with a dummy texture +// let texture = create_texture(glam::uvec2(64, 64), &device); +// let id = renderer.write().register_native_texture( +// &device, +// &texture.create_view(&TextureViewDescriptor::default()), +// wgpu::FilterMode::Linear, +// ); +// let widget_3d = Arc::new(Widget3D::new(device.clone(), queue.clone())); + +// Some(Self { +// texture, +// texture_id: id, +// device, +// queue, +// widget_3d, +// }) +// } + +// /// Submit a render request. Spawns an async task to do the rendering. +// pub fn submit(&mut self, req: RenderRequest) { +// if req.img_size.x <= 8 || req.img_size.y <= 8 { +// return; +// } + +// // Check resizing. This is done sync, as it requires a renderer lock... +// let needs_resize = +// self.texture.width() != req.img_size.x || self.texture.height() != req.img_size.y; +// if needs_resize { +// // TODO: Restore this. +// // let client = WgpuRuntime::client(&req.); +// // client.memory_cleanup(); +// self.texture = create_texture(req.img_size, &self.device); +// self.renderer.write().update_egui_texture_from_wgpu_texture( +// &self.device, +// &self.texture.create_view(&TextureViewDescriptor::default()), +// wgpu::FilterMode::Linear, +// self.texture_id, +// ); +// } + +// let texture = self.texture.clone(); +// let device = self.device.clone(); +// let queue = self.queue.clone(); + +// let camera = req.camera.clone(); +// let img_size = req.img_size; +// let background = req.background; +// let splat_scale = req.splat_scale; + +// let widget = self.widget_3d.clone(); + +// task::spawn(async move { +// let splats = req.slot.clone_main().await; + +// if let Some(splats) = splats { +// let (image, _) = render_splats( +// splats.clone(), +// &camera, +// img_size, +// background, +// splat_scale, +// TextureMode::Packed, +// ) +// .await; + +// copy_to_texture(image, &texture, &device, &queue); + +// if req.grid_opacity > 0.0 { +// widget.render_to_texture( +// &req.camera, +// req.model_transform, +// req.img_size, +// &texture, +// req.grid_opacity, +// ); +// } + +// req.ctx.request_repaint(); +// } +// }); +// } + +// pub fn id(&self) -> TextureId { +// self.texture_id +// } +// } + +// impl egui_wgpu::CallbackTrait for SplatBackbuffer {} + +// impl SplatBackbuffer { +// fn custom_painting(&mut self, ui: &mut egui::Ui) { +// let (rect, response) = +// ui.allocate_exact_size(egui::Vec2::splat(300.0), egui::Sense::drag()); + +// self.angle += response.drag_motion().x * 0.01; +// ui.painter().add(egui_wgpu::Callback::new_paint_callback( +// rect, +// CustomTriangleCallback { angle: self.angle }, +// )); +// } +// } + +// fn create_texture(size: glam::UVec2, device: &wgpu::Device) -> wgpu::Texture { +// device.create_texture(&wgpu::TextureDescriptor { +// label: Some("Splat backbuffer"), +// size: wgpu::Extent3d { +// width: size.x, +// height: size.y, +// depth_or_array_layers: 1, +// }, +// mip_level_count: 1, +// sample_count: 1, +// dimension: wgpu::TextureDimension::D2, +// format: wgpu::TextureFormat::Rgba8Unorm, +// usage: wgpu::TextureUsages::TEXTURE_BINDING +// | wgpu::TextureUsages::COPY_DST +// | wgpu::TextureUsages::RENDER_ATTACHMENT, +// view_formats: &[wgpu::TextureFormat::Rgba8Unorm], +// }) +// } + +// fn copy_to_texture( +// img: Tensor, +// texture: &wgpu::Texture, +// device: &wgpu::Device, +// queue: &wgpu::Queue, +// ) { +// let [height, width, c] = img.dims(); +// let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor { +// label: Some("splat backbuffer encoder"), +// }); + +// let padded_shape = vec![height, width.div_ceil(64) * 64, c]; + +// let img_prim = img.into_primitive().tensor(); +// let fusion_client = img_prim.client.clone(); +// let img = fusion_client.resolve_tensor_float::(img_prim); +// let img: Tensor = Tensor::from_primitive(TensorPrimitive::Float(img)); + +// // Pad if needed (WebGPU requires bytes_per_row divisible by 256) +// let img = if width % 64 != 0 { +// let padded: Tensor = Tensor::zeros(&padded_shape, &img.device()); +// padded.slice_assign([0..height, 0..width], img) +// } else { +// img +// }; + +// let img = img.into_primitive().tensor(); +// let client = &img.client; +// let img_res_handle = client.get_resource(img.handle.clone().binding()); +// client.flush(); + +// let bytes_per_row = Some(4 * padded_shape[1] as u32); + +// encoder.copy_buffer_to_texture( +// wgpu::TexelCopyBufferInfo { +// buffer: &img_res_handle.resource().buffer, +// layout: TexelCopyBufferLayout { +// offset: img_res_handle.resource().offset, +// bytes_per_row, +// rows_per_image: None, +// }, +// }, +// wgpu::TexelCopyTextureInfo { +// texture, +// mip_level: 0, +// origin: wgpu::Origin3d { x: 0, y: 0, z: 0 }, +// aspect: wgpu::TextureAspect::All, +// }, +// wgpu::Extent3d { +// width: width as u32, +// height: height as u32, +// depth_or_array_layers: 1, +// }, +// ); + +// queue.submit([encoder.finish()]); +// } diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index aaed91a7..2ed7ea27 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -182,16 +182,16 @@ impl eframe::App for App { splat_scale: None, model_transform: glam::Affine3A::IDENTITY, grid_opacity: 0.0, + ctx: ctx.clone(), }); let size = egui::vec2(self.image.width() as f32, self.image.height() as f32); ui.horizontal(|ui| { - if let Some(texture_id) = self.backbuffer.id() { - ui.image(ImageSource::Texture(SizedTexture::new(texture_id, size))); - } else { - ui.label("Rendering..."); - } + ui.image(ImageSource::Texture(SizedTexture::new( + self.backbuffer.id(), + size, + ))); ui.image(ImageSource::Texture(SizedTexture::new( self.tex_handle.id(), size, From 471afbfbb797aea86008f021250397b23f15b864 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 26 Jan 2026 14:55:48 +0100 Subject: [PATCH 21/29] WIP --- crates/brush-kernel/build.rs | 36 -- crates/brush-ui/src/scene.rs | 54 +- crates/brush-ui/src/splat_backbuffer.rs | 624 +++++++++--------------- crates/brush-ui/src/widget_3d.rs | 203 +++----- examples/train-2d/examples/train-2d.rs | 4 +- 5 files changed, 342 insertions(+), 579 deletions(-) delete mode 100644 crates/brush-kernel/build.rs diff --git a/crates/brush-kernel/build.rs b/crates/brush-kernel/build.rs deleted file mode 100644 index 1b8f58cc..00000000 --- a/crates/brush-kernel/build.rs +++ /dev/null @@ -1,36 +0,0 @@ -//! Build script to clean up old generated shader files. -//! -//! Previously, shaders were generated via build.rs into src/shaders/mod.rs files. -//! Now we use proc macros instead. This script removes any leftover generated files -//! to avoid confusion or compilation issues. - -use std::path::Path; - -fn main() { - // Only rerun if this build script changes - println!("cargo:rerun-if-changed=build.rs"); - - // Clean up old generated shader files in the workspace - let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap(); - let crates_dir = Path::new(&manifest_dir).parent().unwrap(); - - // List of crates that used to have generated shaders/mod.rs - let crates_with_old_shaders = [ - "brush-kernel", - "brush-prefix-sum", - "brush-sort", - "brush-render", - "brush-render-bwd", - ]; - - for crate_name in crates_with_old_shaders { - let old_file = crates_dir.join(crate_name).join("src/shaders/mod.rs"); - if old_file.exists() { - println!( - "cargo:warning=Removing old generated file: {}", - old_file.display() - ); - let _ = std::fs::remove_file(&old_file); - } - } -} diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 8a5d8753..ee2c69e3 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -2,10 +2,6 @@ use crate::settings_popup::SettingsPopup; #[cfg(feature = "training")] use brush_process::message::TrainMessage; -#[cfg(feature = "training")] -use std::sync::Mutex; -use wgpu::naga::back; - use brush_process::{create_process, message::ProcessMessage}; use brush_vfs::DataSource; use core::f32; @@ -13,6 +9,8 @@ use egui::{ Align2, Button, Frame, RichText, containers::Popup, epaint::mutex::RwLock as EguiRwLock, }; use std::sync::Arc; +#[cfg(feature = "training")] +use std::sync::Mutex; use brush_render::camera::{Camera, focal_to_fov, fov_to_focal}; use eframe::egui_wgpu::Renderer; @@ -21,17 +19,15 @@ use glam::{UVec2, Vec3}; use web_time::Instant; use crate::splat_backbuffer::{RenderRequest, SplatBackbuffer}; +use crate::widget_3d::Widget3DCallback; use serde::{Deserialize, Serialize}; /// Controls how often the viewport re-renders during training. #[derive(Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum RenderUpdateMode { - /// Don't re-render during training Off, - /// Re-render every 100 iterations Low, - /// Re-render every 5 iterations (default) #[default] Live, } @@ -728,7 +724,13 @@ impl AppPane for ScenePanel { _burn_device: burn_wgpu::WgpuDevice, _adapter_info: wgpu::AdapterInfo, ) { - self.backbuffer = Some(SplatBackbuffer::new(renderer, device, queue)); + // Initialize Widget3D resources for the grid overlay + // renderer + // .write() + // .callback_resources + // .insert(Widget3DResources::new(&device, target_format)); + + self.backbuffer = Some(SplatBackbuffer::new(renderer, device, &queue)); // Create the settings popup now that we have the base_path #[cfg(feature = "training")] @@ -989,25 +991,35 @@ impl AppPane for ScenePanel { backbuffer.submit(RenderRequest { slot: process.current_splats(), frame: self.frame as usize, - camera, + camera: camera.clone(), img_size: pixel_size, background: settings.background.unwrap_or(Vec3::ZERO), splat_scale: settings.splat_scale, ctx: ui.ctx().clone(), - model_transform: process.model_local_to_world(), - grid_opacity, }); } - // ui.painter().image( - // backbuffer.id(), - // rect, - // Rect { - // min: egui::pos2(0.0, 0.0), - // max: egui::pos2(1.0, 1.0), - // }, - // Color32::WHITE, - // ); - backbuffer.draw(); + ui.painter().image( + backbuffer.id(), + rect, + Rect { + min: egui::pos2(0.0, 0.0), + max: egui::pos2(1.0, 1.0), + }, + Color32::WHITE, + ); + + // Draw the 3D grid overlay using egui's wgpu callback + if grid_opacity > 0.0 { + ui.painter() + .add(eframe::egui_wgpu::Callback::new_paint_callback( + rect, + Widget3DCallback { + camera, + model_transform: process.model_local_to_world(), + grid_opacity, + }, + )); + } } }); diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index ad1c1d73..8a217144 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -1,428 +1,264 @@ -use crate::widget_3d::Widget3D; use brush_process::slot::Slot; use brush_render::{ MainBackend, MainBackendBase, TextureMode, camera::Camera, gaussian_splats::Splats, render_splats, }; use burn::tensor::{Tensor, TensorPrimitive}; -use eframe::egui_wgpu::Renderer; use egui::TextureId; -use egui::epaint::mutex::RwLock as EguiRwLock; use glam::Vec3; -use std::num::NonZeroU64; use std::sync::Arc; -use tokio::task; +use tokio::sync::mpsc; use wgpu::{CommandEncoderDescriptor, TexelCopyBufferLayout, TextureViewDescriptor}; -use eframe::{ - egui_wgpu::wgpu::util::DeviceExt as _, - egui_wgpu::{self, wgpu}, -}; - -pub struct Custom3d { - angle: f32, +use eframe::egui_wgpu::{self, wgpu}; + +/// Request sent to the async render worker. +#[derive(Clone)] +pub struct RenderRequest { + pub slot: Slot>, + pub frame: usize, + pub camera: Camera, + pub img_size: glam::UVec2, + pub background: Vec3, + pub splat_scale: Option, + pub ctx: egui::Context, } -impl Custom3d { - pub fn new<'a>(cc: &'a eframe::CreationContext<'a>) -> Option { - // Get the WGPU render state from the eframe creation context. This can also be retrieved - // from `eframe::Frame` when you don't have a `CreationContext` available. - let wgpu_render_state = cc.wgpu_render_state.as_ref()?; - - let device = &wgpu_render_state.device; - - let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { - label: Some("custom3d"), - source: wgpu::ShaderSource::Wgsl(include_str!("./custom3d_wgpu_shader.wgsl").into()), - }); - - let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { - label: Some("custom3d"), - entries: &[wgpu::BindGroupLayoutEntry { - binding: 0, - visibility: wgpu::ShaderStages::VERTEX, - ty: wgpu::BindingType::Buffer { - ty: wgpu::BufferBindingType::Uniform, - has_dynamic_offset: false, - min_binding_size: NonZeroU64::new(16), - }, - count: None, - }], - }); - - let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { - label: Some("custom3d"), - bind_group_layouts: &[&bind_group_layout], - push_constant_ranges: &[], - }); - - let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { - label: Some("custom3d"), - layout: Some(&pipeline_layout), - vertex: wgpu::VertexState { - module: &shader, - entry_point: None, - buffers: &[], - compilation_options: wgpu::PipelineCompilationOptions::default(), - }, - fragment: Some(wgpu::FragmentState { - module: &shader, - entry_point: Some("fs_main"), - targets: &[Some(wgpu_render_state.target_format.into())], - compilation_options: wgpu::PipelineCompilationOptions::default(), - }), - primitive: wgpu::PrimitiveState::default(), - depth_stencil: None, - multisample: wgpu::MultisampleState::default(), - multiview: None, - cache: None, - }); +pub struct SplatRenderResources { + pub texture: wgpu::Texture, + pub size: glam::UVec2, +} - let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { - label: Some("custom3d"), - contents: bytemuck::cast_slice(&[0.0_f32; 4]), // 16 bytes aligned! - // Mapping at creation (as done by the create_buffer_init utility) doesn't require us to to add the MAP_WRITE usage - // (this *happens* to workaround this bug ) - usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::UNIFORM, - }); +pub struct SplatBackbuffer { + /// Channel to send requests to the async worker. + request_sender: mpsc::UnboundedSender, + /// Egui texture ID for displaying results. + texture_id: TextureId, + /// Device reference for texture operations. + device: wgpu::Device, + /// Renderer reference for texture updates. + renderer: Arc>, +} - let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { - label: Some("custom3d"), - layout: &bind_group_layout, - entries: &[wgpu::BindGroupEntry { - binding: 0, - resource: uniform_buffer.as_entire_binding(), - }], - }); +impl SplatBackbuffer { + pub fn new( + renderer: Arc>, + device: wgpu::Device, + queue: &wgpu::Queue, + ) -> Self { + // Create initial texture + let initial_size = glam::uvec2(64, 64); + let texture = create_texture(initial_size, &device); + let texture_id = renderer.write().register_native_texture( + &device, + &texture.create_view(&TextureViewDescriptor::default()), + wgpu::FilterMode::Linear, + ); - // Because the graphics pipeline must have the same lifetime as the egui render pass, - // instead of storing the pipeline in our `Custom3D` struct, we insert it into the - // `paint_callback_resources` type map, which is stored alongside the render pass. - wgpu_render_state - .renderer + // Insert resources into callback_resources + renderer .write() .callback_resources - .insert(TriangleRenderResources { - pipeline, - bind_group, - uniform_buffer, + .insert(SplatRenderResources { + texture, + size: initial_size, }); - Some(Self { angle: 0.0 }) - } -} - -impl crate::DemoApp for Custom3d { - fn demo_ui(&mut self, ui: &mut egui::Ui, _frame: &mut eframe::Frame) { - // TODO(emilk): Use `ScrollArea::inner_margin` - egui::CentralPanel::default().show_inside(ui, |ui| { - egui::ScrollArea::both().auto_shrink(false).show(ui, |ui| { - ui.horizontal(|ui| { - ui.spacing_mut().item_spacing.x = 0.0; - ui.label("The triangle is being painted using "); - ui.hyperlink_to("WGPU", "https://wgpu.rs"); - ui.label(" (Portable Rust graphics API awesomeness)"); - }); - ui.label( - "It's not a very impressive demo, but it shows you can embed 3D inside of egui.", - ); + // Create channel for render requests + let (tx, rx) = mpsc::unbounded_channel(); + + // Spawn the async render worker + let worker_device = device.clone(); + let worker_queue = queue.clone(); + let worker_renderer = renderer.clone(); + tokio::task::spawn(render_worker( + rx, + worker_device, + worker_queue, + worker_renderer, + )); - egui::Frame::canvas(ui.style()).show(ui, |ui| { - self.custom_painting(ui); - }); - ui.label("Drag to rotate!"); - ui.add(egui_demo_lib::egui_github_link_file!()); - }); - }); + Self { + request_sender: tx, + texture_id, + device, + renderer, + } } -} -// Callbacks in egui_wgpu have 3 stages: -// * prepare (per callback impl) -// * finish_prepare (once) -// * paint (per callback impl) -// -// The prepare callback is called every frame before paint and is given access to the wgpu -// Device and Queue, which can be used, for instance, to update buffers and uniforms before -// rendering. -// If [`egui_wgpu::Renderer`] has [`egui_wgpu::FinishPrepareCallback`] registered, -// it will be called after all `prepare` callbacks have been called. -// You can use this to update any shared resources that need to be updated once per frame -// after all callbacks have been processed. -// -// On both prepare methods you can use the main `CommandEncoder` that is passed-in, -// return an arbitrary number of user-defined `CommandBuffer`s, or both. -// The main command buffer, as well as all user-defined ones, will be submitted together -// to the GPU in a single call. -// -// The paint callback is called after finish prepare and is given access to egui's main render pass, -// which can be used to issue draw commands. -struct CustomTriangleCallback { - angle: f32, -} - -impl egui_wgpu::CallbackTrait for CustomTriangleCallback { - fn prepare( - &self, - device: &wgpu::Device, - queue: &wgpu::Queue, - _screen_descriptor: &egui_wgpu::ScreenDescriptor, - _egui_encoder: &mut wgpu::CommandEncoder, - resources: &mut egui_wgpu::CallbackResources, - ) -> Vec { - let resources: &TriangleRenderResources = resources.get().unwrap(); - resources.prepare(device, queue, self.angle); - Vec::new() + /// Submit a render request. The async worker will process it. + pub fn submit(&mut self, req: RenderRequest) { + if req.img_size.x <= 8 || req.img_size.y <= 8 { + return; + } + + // Handle resize synchronously since it requires renderer lock + let needs_resize = { + let renderer = self.renderer.read(); + let resources: Option<&SplatRenderResources> = renderer.callback_resources.get(); + resources.is_none_or(|r| r.size != req.img_size) + }; + + if needs_resize { + let new_texture = create_texture(req.img_size, &self.device); + self.renderer.write().update_egui_texture_from_wgpu_texture( + &self.device, + &new_texture.create_view(&TextureViewDescriptor::default()), + wgpu::FilterMode::Linear, + self.texture_id, + ); + + // Update the texture in callback_resources + let mut renderer = self.renderer.write(); + if let Some(resources) = renderer + .callback_resources + .get_mut::() + { + resources.texture = new_texture; + resources.size = req.img_size; + } + } + // Send request to worker (ignore send errors if channel closed) + let _ = self.request_sender.send(req); } - fn paint( - &self, - _info: egui::PaintCallbackInfo, - render_pass: &mut wgpu::RenderPass<'static>, - resources: &egui_wgpu::CallbackResources, - ) { - let resources: &TriangleRenderResources = resources.get().unwrap(); - resources.paint(render_pass); + /// Get the texture ID for displaying the rendered result. + pub fn id(&self) -> TextureId { + self.texture_id } } -impl Custom3d { - fn custom_painting(&mut self, ui: &mut egui::Ui) { - let (rect, response) = - ui.allocate_exact_size(egui::Vec2::splat(300.0), egui::Sense::drag()); - - self.angle += response.drag_motion().x * 0.01; - ui.painter().add(egui_wgpu::Callback::new_paint_callback( - rect, - CustomTriangleCallback { angle: self.angle }, - )); +/// Async render worker that processes render requests. +async fn render_worker( + mut receiver: mpsc::UnboundedReceiver, + device: wgpu::Device, + queue: wgpu::Queue, + renderer: Arc>, +) { + loop { + // Wait for at least one request + let Some(mut request) = receiver.recv().await else { + break; // Channel closed + }; + + // Coalesce: drain channel, keep only the last request + while let Ok(newer) = receiver.try_recv() { + request = newer; + } + + // Skip tiny sizes + if request.img_size.x <= 8 || request.img_size.y <= 8 { + continue; + } + + // Clone splats (async) + let Some(splats) = request.slot.clone_main().await else { + continue; + }; + + // Render (async) + let (image, _) = render_splats( + splats, + &request.camera, + request.img_size, + request.background, + request.splat_scale, + TextureMode::Packed, + ) + .await; + + // Get the texture from callback_resources and copy to it + { + let renderer_guard = renderer.read(); + if let Some(resources) = renderer_guard + .callback_resources + .get::() + { + // Only copy if size matches (resize handled in submit) + if resources.size == request.img_size { + copy_to_texture(image, &resources.texture, &device, &queue); + } + } + } + + // Trigger egui repaint + request.ctx.request_repaint(); } } -struct TriangleRenderResources { - pipeline: wgpu::RenderPipeline, - bind_group: wgpu::BindGroup, - uniform_buffer: wgpu::Buffer, +fn create_texture(size: glam::UVec2, device: &wgpu::Device) -> wgpu::Texture { + device.create_texture(&wgpu::TextureDescriptor { + label: Some("Splat backbuffer"), + size: wgpu::Extent3d { + width: size.x, + height: size.y, + depth_or_array_layers: 1, + }, + mip_level_count: 1, + sample_count: 1, + dimension: wgpu::TextureDimension::D2, + format: wgpu::TextureFormat::Rgba8Unorm, + usage: wgpu::TextureUsages::TEXTURE_BINDING + | wgpu::TextureUsages::COPY_DST + | wgpu::TextureUsages::RENDER_ATTACHMENT, + view_formats: &[wgpu::TextureFormat::Rgba8Unorm], + }) } -impl TriangleRenderResources { - fn prepare(&self, _device: &wgpu::Device, queue: &wgpu::Queue, angle: f32) { - // Update our uniform buffer with the angle from the UI - queue.write_buffer( - &self.uniform_buffer, - 0, - bytemuck::cast_slice(&[angle, 0.0, 0.0, 0.0]), - ); - } - - fn paint(&self, render_pass: &mut wgpu::RenderPass<'_>) { - // Draw our triangle! - render_pass.set_pipeline(&self.pipeline); - render_pass.set_bind_group(0, &self.bind_group, &[]); - render_pass.draw(0..3, 0..1); - } +fn copy_to_texture( + img: Tensor, + texture: &wgpu::Texture, + device: &wgpu::Device, + queue: &wgpu::Queue, +) { + let [height, width, c] = img.dims(); + let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor { + label: Some("splat backbuffer encoder"), + }); + + let padded_shape = vec![height, width.div_ceil(64) * 64, c]; + + let img_prim = img.into_primitive().tensor(); + let fusion_client = img_prim.client.clone(); + let img = fusion_client.resolve_tensor_float::(img_prim); + let img: Tensor = Tensor::from_primitive(TensorPrimitive::Float(img)); + + // Pad if needed (WebGPU requires bytes_per_row divisible by 256) + let img = if width % 64 != 0 { + let padded: Tensor = Tensor::zeros(&padded_shape, &img.device()); + padded.slice_assign([0..height, 0..width], img) + } else { + img + }; + + let img = img.into_primitive().tensor(); + let client = &img.client; + let img_res_handle = client.get_resource(img.handle.clone().binding()); + client.flush(); + + let bytes_per_row = Some(4 * padded_shape[1] as u32); + + encoder.copy_buffer_to_texture( + wgpu::TexelCopyBufferInfo { + buffer: &img_res_handle.resource().buffer, + layout: TexelCopyBufferLayout { + offset: img_res_handle.resource().offset, + bytes_per_row, + rows_per_image: None, + }, + }, + wgpu::TexelCopyTextureInfo { + texture, + mip_level: 0, + origin: wgpu::Origin3d { x: 0, y: 0, z: 0 }, + aspect: wgpu::TextureAspect::All, + }, + wgpu::Extent3d { + width: width as u32, + height: height as u32, + depth_or_array_layers: 1, + }, + ); + + queue.submit([encoder.finish()]); } - -// #[derive(Clone)] -// pub struct RenderRequest { -// pub slot: Slot>, -// pub frame: usize, -// pub camera: Camera, -// pub img_size: glam::UVec2, -// pub background: Vec3, -// pub splat_scale: Option, -// pub ctx: egui::Context, -// /// Model transform for the 3D overlay (grid, axes). -// pub model_transform: glam::Affine3A, -// /// Opacity of the grid overlay (0.0 = hidden, 1.0 = fully visible). -// pub grid_opacity: f32, -// } - -// pub struct SplatBackbuffer { -// texture: wgpu::Texture, -// texture_id: TextureId, -// device: wgpu::Device, -// queue: wgpu::Queue, -// widget_3d: Arc, -// } - -// impl SplatBackbuffer { -// pub fn new<'a>(cc: &'a eframe::CreationContext<'a>) -> Option { -// // Start with a dummy texture -// let texture = create_texture(glam::uvec2(64, 64), &device); -// let id = renderer.write().register_native_texture( -// &device, -// &texture.create_view(&TextureViewDescriptor::default()), -// wgpu::FilterMode::Linear, -// ); -// let widget_3d = Arc::new(Widget3D::new(device.clone(), queue.clone())); - -// Some(Self { -// texture, -// texture_id: id, -// device, -// queue, -// widget_3d, -// }) -// } - -// /// Submit a render request. Spawns an async task to do the rendering. -// pub fn submit(&mut self, req: RenderRequest) { -// if req.img_size.x <= 8 || req.img_size.y <= 8 { -// return; -// } - -// // Check resizing. This is done sync, as it requires a renderer lock... -// let needs_resize = -// self.texture.width() != req.img_size.x || self.texture.height() != req.img_size.y; -// if needs_resize { -// // TODO: Restore this. -// // let client = WgpuRuntime::client(&req.); -// // client.memory_cleanup(); -// self.texture = create_texture(req.img_size, &self.device); -// self.renderer.write().update_egui_texture_from_wgpu_texture( -// &self.device, -// &self.texture.create_view(&TextureViewDescriptor::default()), -// wgpu::FilterMode::Linear, -// self.texture_id, -// ); -// } - -// let texture = self.texture.clone(); -// let device = self.device.clone(); -// let queue = self.queue.clone(); - -// let camera = req.camera.clone(); -// let img_size = req.img_size; -// let background = req.background; -// let splat_scale = req.splat_scale; - -// let widget = self.widget_3d.clone(); - -// task::spawn(async move { -// let splats = req.slot.clone_main().await; - -// if let Some(splats) = splats { -// let (image, _) = render_splats( -// splats.clone(), -// &camera, -// img_size, -// background, -// splat_scale, -// TextureMode::Packed, -// ) -// .await; - -// copy_to_texture(image, &texture, &device, &queue); - -// if req.grid_opacity > 0.0 { -// widget.render_to_texture( -// &req.camera, -// req.model_transform, -// req.img_size, -// &texture, -// req.grid_opacity, -// ); -// } - -// req.ctx.request_repaint(); -// } -// }); -// } - -// pub fn id(&self) -> TextureId { -// self.texture_id -// } -// } - -// impl egui_wgpu::CallbackTrait for SplatBackbuffer {} - -// impl SplatBackbuffer { -// fn custom_painting(&mut self, ui: &mut egui::Ui) { -// let (rect, response) = -// ui.allocate_exact_size(egui::Vec2::splat(300.0), egui::Sense::drag()); - -// self.angle += response.drag_motion().x * 0.01; -// ui.painter().add(egui_wgpu::Callback::new_paint_callback( -// rect, -// CustomTriangleCallback { angle: self.angle }, -// )); -// } -// } - -// fn create_texture(size: glam::UVec2, device: &wgpu::Device) -> wgpu::Texture { -// device.create_texture(&wgpu::TextureDescriptor { -// label: Some("Splat backbuffer"), -// size: wgpu::Extent3d { -// width: size.x, -// height: size.y, -// depth_or_array_layers: 1, -// }, -// mip_level_count: 1, -// sample_count: 1, -// dimension: wgpu::TextureDimension::D2, -// format: wgpu::TextureFormat::Rgba8Unorm, -// usage: wgpu::TextureUsages::TEXTURE_BINDING -// | wgpu::TextureUsages::COPY_DST -// | wgpu::TextureUsages::RENDER_ATTACHMENT, -// view_formats: &[wgpu::TextureFormat::Rgba8Unorm], -// }) -// } - -// fn copy_to_texture( -// img: Tensor, -// texture: &wgpu::Texture, -// device: &wgpu::Device, -// queue: &wgpu::Queue, -// ) { -// let [height, width, c] = img.dims(); -// let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor { -// label: Some("splat backbuffer encoder"), -// }); - -// let padded_shape = vec![height, width.div_ceil(64) * 64, c]; - -// let img_prim = img.into_primitive().tensor(); -// let fusion_client = img_prim.client.clone(); -// let img = fusion_client.resolve_tensor_float::(img_prim); -// let img: Tensor = Tensor::from_primitive(TensorPrimitive::Float(img)); - -// // Pad if needed (WebGPU requires bytes_per_row divisible by 256) -// let img = if width % 64 != 0 { -// let padded: Tensor = Tensor::zeros(&padded_shape, &img.device()); -// padded.slice_assign([0..height, 0..width], img) -// } else { -// img -// }; - -// let img = img.into_primitive().tensor(); -// let client = &img.client; -// let img_res_handle = client.get_resource(img.handle.clone().binding()); -// client.flush(); - -// let bytes_per_row = Some(4 * padded_shape[1] as u32); - -// encoder.copy_buffer_to_texture( -// wgpu::TexelCopyBufferInfo { -// buffer: &img_res_handle.resource().buffer, -// layout: TexelCopyBufferLayout { -// offset: img_res_handle.resource().offset, -// bytes_per_row, -// rows_per_image: None, -// }, -// }, -// wgpu::TexelCopyTextureInfo { -// texture, -// mip_level: 0, -// origin: wgpu::Origin3d { x: 0, y: 0, z: 0 }, -// aspect: wgpu::TextureAspect::All, -// }, -// wgpu::Extent3d { -// width: width as u32, -// height: height as u32, -// depth_or_array_layers: 1, -// }, -// ); - -// queue.submit([encoder.finish()]); -// } diff --git a/crates/brush-ui/src/widget_3d.rs b/crates/brush-ui/src/widget_3d.rs index cf5975ec..03775ab7 100644 --- a/crates/brush-ui/src/widget_3d.rs +++ b/crates/brush-ui/src/widget_3d.rs @@ -1,3 +1,5 @@ +use brush_render::camera::Camera; +use eframe::egui_wgpu::{self, wgpu}; use glam::{Mat4, Vec3}; use wgpu::util::DeviceExt; @@ -13,7 +15,7 @@ struct Vertex { struct Uniforms { view_proj: [[f32; 4]; 4], grid_opacity: f32, - _padding: [f32; 3], // Padding for alignment + _padding: [f32; 3], } impl Vertex { @@ -29,9 +31,8 @@ impl Vertex { } } -pub struct Widget3D { - device: wgpu::Device, - queue: wgpu::Queue, +/// Resources for Widget3D stored in callback_resources. +pub struct Widget3DResources { pipeline: wgpu::RenderPipeline, uniform_buffer: wgpu::Buffer, uniform_bind_group: wgpu::BindGroup, @@ -41,14 +42,14 @@ pub struct Widget3D { up_axis_vertex_count: u32, } -impl Widget3D { - pub fn new(device: wgpu::Device, queue: wgpu::Queue) -> Self { +impl Widget3DResources { + /// Create Widget3D resources. Call this once during initialization. + pub fn new(device: &wgpu::Device, target_format: wgpu::TextureFormat) -> Self { let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { label: Some("Widget 3D Shader"), source: wgpu::ShaderSource::Wgsl(include_str!("shaders/widget_3d.wgsl").into()), }); - // Create uniform buffer let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: Some("Widget 3D Uniform Buffer"), size: std::mem::size_of::() as u64, @@ -56,12 +57,11 @@ impl Widget3D { mapped_at_creation: false, }); - // Create bind group layout and bind group let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { label: Some("Widget 3D Bind Group Layout"), entries: &[wgpu::BindGroupLayoutEntry { binding: 0, - visibility: wgpu::ShaderStages::VERTEX | wgpu::ShaderStages::FRAGMENT, // Fragment needs access for grid_opacity + visibility: wgpu::ShaderStages::VERTEX | wgpu::ShaderStages::FRAGMENT, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, @@ -80,13 +80,13 @@ impl Widget3D { }], }); - // Create render pipeline let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: Some("Widget 3D Pipeline Layout"), bind_group_layouts: &[&bind_group_layout], push_constant_ranges: &[], }); + // Pipeline without depth stencil - draws on top of egui content let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { label: Some("Widget 3D Pipeline"), layout: Some(&pipeline_layout), @@ -100,7 +100,7 @@ impl Widget3D { module: &shader, entry_point: Some("fs_main"), targets: &[Some(wgpu::ColorTargetState { - format: wgpu::TextureFormat::Rgba8Unorm, + format: target_format, blend: Some(wgpu::BlendState::ALPHA_BLENDING), write_mask: wgpu::ColorWrites::ALL, })], @@ -115,19 +115,12 @@ impl Widget3D { polygon_mode: wgpu::PolygonMode::Fill, conservative: false, }, - depth_stencil: Some(wgpu::DepthStencilState { - format: wgpu::TextureFormat::Depth32Float, - depth_write_enabled: true, - depth_compare: wgpu::CompareFunction::Less, - stencil: wgpu::StencilState::default(), - bias: wgpu::DepthBiasState::default(), - }), + depth_stencil: None, // No depth buffer - draw on top multisample: wgpu::MultisampleState::default(), multiview: None, cache: None, }); - // Create geometry let (grid_vertices, grid_vertex_count) = Self::create_grid_geometry(); let grid_vertex_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { label: Some("Grid Vertex Buffer"), @@ -143,8 +136,6 @@ impl Widget3D { }); Self { - device, - queue, pipeline, uniform_buffer, uniform_bind_group, @@ -159,13 +150,10 @@ impl Widget3D { let mut vertices = Vec::new(); let size = 10.0; let step = 1.0; - let color = [0.3, 0.3, 0.3, 0.8]; // Semi-transparent gray + let color = [0.3, 0.3, 0.3, 0.8]; - // Create grid lines in XZ plane (Y=0) for OpenCV coordinates - // This creates a ground plane since Y is down in OpenCV let mut i = -size; while i <= size { - // Lines parallel to X axis vertices.push(Vertex { position: [-size, 0.0, i], color, @@ -174,8 +162,6 @@ impl Widget3D { position: [size, 0.0, i], color, }); - - // Lines parallel to Z axis vertices.push(Vertex { position: [i, 0.0, -size], color, @@ -184,7 +170,6 @@ impl Widget3D { position: [i, 0.0, size], color, }); - i += step; } @@ -192,116 +177,84 @@ impl Widget3D { } fn create_up_axis_geometry() -> (Vec, u32) { - let mut vertices = Vec::new(); - let length = 1.5; - - // Single blue line pointing up (negative Y in OpenCV coordinates) - vertices.push(Vertex { - position: [0.0, 0.0, 0.0], - color: [0.0, 0.5, 1.0, 1.0], // Light blue - }); - vertices.push(Vertex { - position: [0.0, -length, 0.0], // Negative Y is up - color: [0.0, 0.5, 1.0, 1.0], // Light blue - }); - + let vertices = vec![ + Vertex { + position: [0.0, 0.0, 0.0], + color: [0.0, 0.5, 1.0, 1.0], + }, + Vertex { + position: [0.0, -1.5, 0.0], + color: [0.0, 0.5, 1.0, 1.0], + }, + ]; (vertices, 2) } - pub fn render_to_texture( - &self, - camera: &brush_render::camera::Camera, - model_transform: glam::Affine3A, - size: glam::UVec2, - target_texture: &wgpu::Texture, - grid_opacity: f32, - ) { - let output_view = target_texture.create_view(&wgpu::TextureViewDescriptor::default()); - - // Create depth texture - let depth_texture = self.device.create_texture(&wgpu::TextureDescriptor { - label: Some("Widget 3D Depth Texture"), - size: wgpu::Extent3d { - width: size.x, - height: size.y, - depth_or_array_layers: 1, - }, - mip_level_count: 1, - sample_count: 1, - dimension: wgpu::TextureDimension::D2, - format: wgpu::TextureFormat::Depth32Float, - usage: wgpu::TextureUsages::RENDER_ATTACHMENT, - view_formats: &[], - }); - let depth_view = depth_texture.create_view(&wgpu::TextureViewDescriptor::default()); - - // Use perspective_lh since camera uses +Z as forward - // But flip Y since camera uses Y-down while perspective_lh uses Y-up - let aspect = size.x as f32 / size.y as f32; - let proj_matrix = Mat4::perspective_lh(camera.fov_y as f32, aspect, 0.1, 1000.0); - - // Y-flip to convert from Y-up to Y-down - let y_flip = Mat4::from_scale(Vec3::new(1.0, -1.0, 1.0)); - - // The camera already has model transform baked in - // To get world-space view, we need to undo the model transform by applying its inverse - let view_matrix = camera.world_to_local(); - let world_view = Mat4::from(view_matrix) * Mat4::from(model_transform.inverse()); - - // Apply Y flip and combine with projection - let view_proj = proj_matrix * y_flip * world_view; - + /// Update uniforms for rendering. + pub fn prepare(&self, queue: &wgpu::Queue, view_proj: Mat4, grid_opacity: f32) { let uniforms = Uniforms { view_proj: view_proj.to_cols_array_2d(), grid_opacity, _padding: [0.0; 3], }; + queue.write_buffer(&self.uniform_buffer, 0, bytemuck::cast_slice(&[uniforms])); + } - self.queue - .write_buffer(&self.uniform_buffer, 0, bytemuck::cast_slice(&[uniforms])); + /// Issue draw commands to the provided render pass. + pub fn paint(&self, render_pass: &mut wgpu::RenderPass<'_>) { + render_pass.set_pipeline(&self.pipeline); + render_pass.set_bind_group(0, &self.uniform_bind_group, &[]); - // Render - let mut encoder = self - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("Widget 3D Render Encoder"), - }); + render_pass.set_vertex_buffer(0, self.grid_vertex_buffer.slice(..)); + render_pass.draw(0..self.grid_vertex_count, 0..1); - { - let mut render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { - label: Some("Widget 3D Render Pass"), - color_attachments: &[Some(wgpu::RenderPassColorAttachment { - view: &output_view, - resolve_target: None, - ops: wgpu::Operations { - load: wgpu::LoadOp::Load, // Load existing content instead of clearing - store: wgpu::StoreOp::Store, - }, - depth_slice: None, - })], - depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment { - view: &depth_view, - depth_ops: Some(wgpu::Operations { - load: wgpu::LoadOp::Clear(1.0), - store: wgpu::StoreOp::Store, - }), - stencil_ops: None, - }), - timestamp_writes: None, - occlusion_query_set: None, - }); + render_pass.set_vertex_buffer(0, self.up_axis_vertex_buffer.slice(..)); + render_pass.draw(0..self.up_axis_vertex_count, 0..1); + } +} + +/// Callback for rendering the 3D widget overlay via egui's paint system. +pub struct Widget3DCallback { + pub camera: Camera, + pub model_transform: glam::Affine3A, + pub grid_opacity: f32, +} + +impl egui_wgpu::CallbackTrait for Widget3DCallback { + fn prepare( + &self, + _device: &wgpu::Device, + queue: &wgpu::Queue, + screen_descriptor: &egui_wgpu::ScreenDescriptor, + _egui_encoder: &mut wgpu::CommandEncoder, + resources: &mut egui_wgpu::CallbackResources, + ) -> Vec { + let Some(widget_3d) = resources.get::() else { + return Vec::new(); + }; - render_pass.set_pipeline(&self.pipeline); - render_pass.set_bind_group(0, &self.uniform_bind_group, &[]); + let aspect = screen_descriptor.size_in_pixels[0] as f32 + / screen_descriptor.size_in_pixels[1] as f32; + let proj_matrix = Mat4::perspective_lh(self.camera.fov_y as f32, aspect, 0.1, 1000.0); + let y_flip = Mat4::from_scale(Vec3::new(1.0, -1.0, 1.0)); + let view_matrix = self.camera.world_to_local(); + let world_view = Mat4::from(view_matrix) * Mat4::from(self.model_transform.inverse()); + let view_proj = proj_matrix * y_flip * world_view; - // Draw grid - render_pass.set_vertex_buffer(0, self.grid_vertex_buffer.slice(..)); - render_pass.draw(0..self.grid_vertex_count, 0..1); + widget_3d.prepare(queue, view_proj, self.grid_opacity); - // Draw up axis - render_pass.set_vertex_buffer(0, self.up_axis_vertex_buffer.slice(..)); - render_pass.draw(0..self.up_axis_vertex_count, 0..1); - } - self.queue.submit([encoder.finish()]); + Vec::new() + } + + fn paint( + &self, + _info: egui::PaintCallbackInfo, + render_pass: &mut wgpu::RenderPass<'static>, + resources: &egui_wgpu::CallbackResources, + ) { + let Some(widget_3d) = resources.get::() else { + return; + }; + widget_3d.paint(render_pass); } } diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index 2ed7ea27..55d38723 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -153,7 +153,7 @@ impl App { image, camera, tex_handle: handle, - backbuffer: SplatBackbuffer::new(renderer, state.device.clone(), state.queue.clone()), + backbuffer: SplatBackbuffer::new(renderer, state.device.clone(), &state.queue), slot, receiver, last_step: None, @@ -180,8 +180,6 @@ impl eframe::App for App { img_size: glam::uvec2(self.image.width(), self.image.height()), background: Vec3::ZERO, splat_scale: None, - model_transform: glam::Affine3A::IDENTITY, - grid_opacity: 0.0, ctx: ctx.clone(), }); From 959dc800421b25f23971e0639e8e898ede75da63 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 26 Jan 2026 16:59:02 +0100 Subject: [PATCH 22/29] WIP --- crates/brush-ui/src/app.rs | 10 +- crates/brush-ui/src/panels.rs | 16 +- crates/brush-ui/src/scene.rs | 69 ++----- crates/brush-ui/src/splat_backbuffer.rs | 257 +++++++----------------- crates/brush-ui/src/stats.rs | 16 +- crates/brush-ui/src/widget_3d.rs | 97 +++++---- examples/train-2d/examples/train-2d.rs | 19 +- 7 files changed, 166 insertions(+), 318 deletions(-) diff --git a/crates/brush-ui/src/app.rs b/crates/brush-ui/src/app.rs index 166fa4c4..60f169da 100644 --- a/crates/brush-ui/src/app.rs +++ b/crates/brush-ui/src/app.rs @@ -249,13 +249,9 @@ impl App { // Initialize all panels with runtime state for (_, tile) in tree.tiles.iter_mut() { if let egui_tiles::Tile::Pane(pane) = tile { - pane.get_mut().as_pane_mut().init( - state.device.clone(), - state.queue.clone(), - state.renderer.clone(), - burn_device.clone(), - state.adapter.get_info(), - ); + pane.get_mut() + .as_pane_mut() + .init(state, burn_device.clone()); } } diff --git a/crates/brush-ui/src/panels.rs b/crates/brush-ui/src/panels.rs index 6b83e732..ed0345ec 100644 --- a/crates/brush-ui/src/panels.rs +++ b/crates/brush-ui/src/panels.rs @@ -1,8 +1,6 @@ -use std::sync::Arc; - use brush_process::message::ProcessMessage; -use eframe::egui_wgpu::Renderer; -use egui::mutex::RwLock; +use burn_wgpu::WgpuDevice; +use eframe::egui_wgpu::RenderState; use crate::ui_process::UiProcess; @@ -11,15 +9,7 @@ pub(crate) trait AppPane { /// Initialize runtime state after creation or deserialization. #[allow(unused_variables)] - fn init( - &mut self, - device: wgpu::Device, - queue: wgpu::Queue, - renderer: Arc>, - burn_device: burn_wgpu::WgpuDevice, - adapter_info: wgpu::AdapterInfo, - ) { - } + fn init(&mut self, state: &RenderState, burn_device: WgpuDevice) {} /// Draw the pane's UI's content. fn ui(&mut self, ui: &mut egui::Ui, process: &UiProcess); diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index ee2c69e3..37e67007 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -5,21 +5,20 @@ use brush_process::message::TrainMessage; use brush_process::{create_process, message::ProcessMessage}; use brush_vfs::DataSource; use core::f32; -use egui::{ - Align2, Button, Frame, RichText, containers::Popup, epaint::mutex::RwLock as EguiRwLock, -}; -use std::sync::Arc; +use eframe::egui_wgpu::RenderState; +use egui::{Align2, Button, Frame, RichText, containers::Popup}; #[cfg(feature = "training")] -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use brush_render::camera::{Camera, focal_to_fov, fov_to_focal}; -use eframe::egui_wgpu::Renderer; use egui::{Color32, Rect, Slider}; use glam::{UVec2, Vec3}; use web_time::Instant; -use crate::splat_backbuffer::{RenderRequest, SplatBackbuffer}; -use crate::widget_3d::Widget3DCallback; +use crate::{ + splat_backbuffer::{RenderRequest, SplatBackbuffer}, + widget_3d::GridWidget, +}; use serde::{Deserialize, Serialize}; @@ -61,7 +60,7 @@ use crate::{ }; #[derive(Clone, PartialEq)] -struct RenderState { +struct SplatRenderState { size: UVec2, cam: Camera, settings: CameraSettings, @@ -98,7 +97,8 @@ impl ErrorDisplay { #[derive(Default, Serialize, Deserialize)] pub struct ScenePanel { - /// Async splat renderer and texture backbuffer. + #[serde(skip)] + grid: Option, #[serde(skip)] backbuffer: Option, #[serde(skip)] @@ -122,7 +122,7 @@ pub struct ScenePanel { #[serde(skip)] seen_warning_count: usize, #[serde(skip)] - last_state: Option, + last_state: Option, #[serde(skip)] source_name: Option, #[serde(skip)] @@ -716,21 +716,9 @@ impl AppPane for ScenePanel { } } - fn init( - &mut self, - device: wgpu::Device, - queue: wgpu::Queue, - renderer: Arc>, - _burn_device: burn_wgpu::WgpuDevice, - _adapter_info: wgpu::AdapterInfo, - ) { - // Initialize Widget3D resources for the grid overlay - // renderer - // .write() - // .callback_resources - // .insert(Widget3DResources::new(&device, target_format)); - - self.backbuffer = Some(SplatBackbuffer::new(renderer, device, &queue)); + fn init(&mut self, state: &RenderState, _burn_device: burn_wgpu::WgpuDevice) { + GridWidget::new(state); + self.backbuffer = Some(SplatBackbuffer::new()); // Create the settings popup now that we have the base_path #[cfg(feature = "training")] @@ -954,7 +942,7 @@ impl AppPane for ScenePanel { let grid_opacity = process.get_grid_opacity(); - let state = RenderState { + let state = SplatRenderState { size, cam: camera.clone(), settings: settings.clone(), @@ -998,28 +986,13 @@ impl AppPane for ScenePanel { ctx: ui.ctx().clone(), }); } - ui.painter().image( - backbuffer.id(), - rect, - Rect { - min: egui::pos2(0.0, 0.0), - max: egui::pos2(1.0, 1.0), - }, - Color32::WHITE, - ); - // Draw the 3D grid overlay using egui's wgpu callback - if grid_opacity > 0.0 { - ui.painter() - .add(eframe::egui_wgpu::Callback::new_paint_callback( - rect, - Widget3DCallback { - camera, - model_transform: process.model_local_to_world(), - grid_opacity, - }, - )); - } + backbuffer.paint(rect, ui); + } + + if let Some(grid) = &mut self.grid { + let model_ltw = process.model_local_to_world(); + grid.paint(rect, camera, model_ltw, grid_opacity, ui); } }); diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index 8a217144..4d63c26f 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -3,137 +3,106 @@ use brush_render::{ MainBackend, MainBackendBase, TextureMode, camera::Camera, gaussian_splats::Splats, render_splats, }; -use burn::tensor::{Tensor, TensorPrimitive}; -use egui::TextureId; -use glam::Vec3; -use std::sync::Arc; +use burn::tensor::Tensor; +use egui::Rect; +use glam::{UVec2, Vec3}; use tokio::sync::mpsc; -use wgpu::{CommandEncoderDescriptor, TexelCopyBufferLayout, TextureViewDescriptor}; -use eframe::egui_wgpu::{self, wgpu}; +use eframe::egui_wgpu::{self, CallbackTrait, wgpu}; /// Request sent to the async render worker. #[derive(Clone)] pub struct RenderRequest { pub slot: Slot>, pub frame: usize, + pub camera: Camera, - pub img_size: glam::UVec2, pub background: Vec3, pub splat_scale: Option, - pub ctx: egui::Context, -} -pub struct SplatRenderResources { - pub texture: wgpu::Texture, - pub size: glam::UVec2, + pub img_size: UVec2, + pub ctx: egui::Context, } pub struct SplatBackbuffer { - /// Channel to send requests to the async worker. - request_sender: mpsc::UnboundedSender, - /// Egui texture ID for displaying results. - texture_id: TextureId, - /// Device reference for texture operations. - device: wgpu::Device, - /// Renderer reference for texture updates. - renderer: Arc>, + req_send: mpsc::UnboundedSender, + img_rec: mpsc::Receiver>, + last_image: Option>, } impl SplatBackbuffer { - pub fn new( - renderer: Arc>, - device: wgpu::Device, - queue: &wgpu::Queue, - ) -> Self { - // Create initial texture - let initial_size = glam::uvec2(64, 64); - let texture = create_texture(initial_size, &device); - let texture_id = renderer.write().register_native_texture( - &device, - &texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - ); - - // Insert resources into callback_resources - renderer - .write() - .callback_resources - .insert(SplatRenderResources { - texture, - size: initial_size, - }); - + pub fn new() -> Self { // Create channel for render requests - let (tx, rx) = mpsc::unbounded_channel(); - - // Spawn the async render worker - let worker_device = device.clone(); - let worker_queue = queue.clone(); - let worker_renderer = renderer.clone(); - tokio::task::spawn(render_worker( - rx, - worker_device, - worker_queue, - worker_renderer, - )); + let (req_send, req_rec) = mpsc::unbounded_channel(); + let (img_send, img_rec) = mpsc::channel(1); + tokio::task::spawn(render_worker(req_rec, img_send)); Self { - request_sender: tx, - texture_id, - device, - renderer, + req_send, + img_rec, + last_image: None, } } /// Submit a render request. The async worker will process it. pub fn submit(&mut self, req: RenderRequest) { - if req.img_size.x <= 8 || req.img_size.y <= 8 { - return; - } - - // Handle resize synchronously since it requires renderer lock - let needs_resize = { - let renderer = self.renderer.read(); - let resources: Option<&SplatRenderResources> = renderer.callback_resources.get(); - resources.is_none_or(|r| r.size != req.img_size) - }; + // Send request to worker (ignore send errors if channel closed) + let _ = self.req_send.send(req); + } - if needs_resize { - let new_texture = create_texture(req.img_size, &self.device); - self.renderer.write().update_egui_texture_from_wgpu_texture( - &self.device, - &new_texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - self.texture_id, - ); + pub fn paint( + &mut self, // Not used atm, but, in the future the widget might have some state. + rect: Rect, + ui: &egui::Ui, + ) { + while let Ok(img) = self.img_rec.try_recv() { + self.last_image = Some(img); + } - // Update the texture in callback_resources - let mut renderer = self.renderer.write(); - if let Some(resources) = renderer - .callback_resources - .get_mut::() - { - resources.texture = new_texture; - resources.size = req.img_size; - } + if let Some(image) = &self.last_image { + ui.painter() + .add(eframe::egui_wgpu::Callback::new_paint_callback( + rect, + SplatBackbufferPainter { + rect, + last_img: image.clone(), + }, + )); } - // Send request to worker (ignore send errors if channel closed) - let _ = self.request_sender.send(req); } +} + +struct SplatBackbufferPainter { + rect: Rect, + last_img: Tensor, +} - /// Get the texture ID for displaying the rendered result. - pub fn id(&self) -> TextureId { - self.texture_id +impl CallbackTrait for SplatBackbufferPainter { + fn paint( + &self, + _info: egui::PaintCallbackInfo, + _render_pass: &mut wgpu::RenderPass<'static>, + _callback_resources: &egui_wgpu::CallbackResources, + ) { + let last_img = self.last_img.clone().into_primitive().tensor(); + + let client = last_img.client.clone(); + let prim_tensor = client.resolve_tensor_int::(last_img); + let prim_client = prim_tensor.client; + let img_res_handle = prim_client.get_resource(prim_tensor.handle.binding()); + let buffer = img_res_handle.resource(); + // TODO: Draw our tensor to the screen here from straight wgpu! + // The shader can read from the wgpu binding. + // Nearest neighbour sampling is fine, we _usually_ have things rendered to the correct nr. of pixels, + // only when we're midway a resize do we not, so not much point in filtering. + // NB: The format of this tensor is NOT a float, but PACKED RGBA8, secretely. } } /// Async render worker that processes render requests. async fn render_worker( mut receiver: mpsc::UnboundedReceiver, - device: wgpu::Device, - queue: wgpu::Queue, - renderer: Arc>, + img_sender: mpsc::Sender>, ) { loop { // Wait for at least one request @@ -146,11 +115,6 @@ async fn render_worker( request = newer; } - // Skip tiny sizes - if request.img_size.x <= 8 || request.img_size.y <= 8 { - continue; - } - // Clone splats (async) let Some(splats) = request.slot.clone_main().await else { continue; @@ -167,98 +131,13 @@ async fn render_worker( ) .await; - // Get the texture from callback_resources and copy to it - { - let renderer_guard = renderer.read(); - if let Some(resources) = renderer_guard - .callback_resources - .get::() - { - // Only copy if size matches (resize handled in submit) - if resources.size == request.img_size { - copy_to_texture(image, &resources.texture, &device, &queue); - } - } - } + // Don't care about errors if channel is closed. + let _ = img_sender.send(image).await; + + // TODO: Store the latest img tensor. Only continue + // once it has been removed from the channel. // Trigger egui repaint request.ctx.request_repaint(); } } - -fn create_texture(size: glam::UVec2, device: &wgpu::Device) -> wgpu::Texture { - device.create_texture(&wgpu::TextureDescriptor { - label: Some("Splat backbuffer"), - size: wgpu::Extent3d { - width: size.x, - height: size.y, - depth_or_array_layers: 1, - }, - mip_level_count: 1, - sample_count: 1, - dimension: wgpu::TextureDimension::D2, - format: wgpu::TextureFormat::Rgba8Unorm, - usage: wgpu::TextureUsages::TEXTURE_BINDING - | wgpu::TextureUsages::COPY_DST - | wgpu::TextureUsages::RENDER_ATTACHMENT, - view_formats: &[wgpu::TextureFormat::Rgba8Unorm], - }) -} - -fn copy_to_texture( - img: Tensor, - texture: &wgpu::Texture, - device: &wgpu::Device, - queue: &wgpu::Queue, -) { - let [height, width, c] = img.dims(); - let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor { - label: Some("splat backbuffer encoder"), - }); - - let padded_shape = vec![height, width.div_ceil(64) * 64, c]; - - let img_prim = img.into_primitive().tensor(); - let fusion_client = img_prim.client.clone(); - let img = fusion_client.resolve_tensor_float::(img_prim); - let img: Tensor = Tensor::from_primitive(TensorPrimitive::Float(img)); - - // Pad if needed (WebGPU requires bytes_per_row divisible by 256) - let img = if width % 64 != 0 { - let padded: Tensor = Tensor::zeros(&padded_shape, &img.device()); - padded.slice_assign([0..height, 0..width], img) - } else { - img - }; - - let img = img.into_primitive().tensor(); - let client = &img.client; - let img_res_handle = client.get_resource(img.handle.clone().binding()); - client.flush(); - - let bytes_per_row = Some(4 * padded_shape[1] as u32); - - encoder.copy_buffer_to_texture( - wgpu::TexelCopyBufferInfo { - buffer: &img_res_handle.resource().buffer, - layout: TexelCopyBufferLayout { - offset: img_res_handle.resource().offset, - bytes_per_row, - rows_per_image: None, - }, - }, - wgpu::TexelCopyTextureInfo { - texture, - mip_level: 0, - origin: wgpu::Origin3d { x: 0, y: 0, z: 0 }, - aspect: wgpu::TextureAspect::All, - }, - wgpu::Extent3d { - width: width as u32, - height: height as u32, - depth_or_array_layers: 1, - }, - ); - - queue.submit([encoder.finish()]); -} diff --git a/crates/brush-ui/src/stats.rs b/crates/brush-ui/src/stats.rs index cc6bbda8..ffcb9899 100644 --- a/crates/brush-ui/src/stats.rs +++ b/crates/brush-ui/src/stats.rs @@ -1,12 +1,9 @@ -use std::sync::Arc; - use crate::{UiMode, panels::AppPane, ui_process::UiProcess}; use brush_process::message::ProcessMessage; use brush_process::message::TrainMessage; use burn_cubecl::cubecl::Runtime; use burn_wgpu::{WgpuDevice, WgpuRuntime}; -use eframe::egui_wgpu::Renderer; -use egui::mutex::RwLock; +use eframe::egui_wgpu::RenderState; use web_time::Duration; use wgpu::AdapterInfo; @@ -81,16 +78,9 @@ impl AppPane for StatsPanel { "Stats".into() } - fn init( - &mut self, - _device: wgpu::Device, - _queue: wgpu::Queue, - _renderer: Arc>, - burn_device: burn_wgpu::WgpuDevice, - adapter_info: wgpu::AdapterInfo, - ) { + fn init(&mut self, state: &RenderState, burn_device: burn_wgpu::WgpuDevice) { self.device = Some(burn_device); - self.adapter_info = Some(adapter_info); + self.adapter_info = Some(state.adapter.get_info()); } fn is_visible(&self, process: &UiProcess) -> bool { diff --git a/crates/brush-ui/src/widget_3d.rs b/crates/brush-ui/src/widget_3d.rs index 03775ab7..7dfb22af 100644 --- a/crates/brush-ui/src/widget_3d.rs +++ b/crates/brush-ui/src/widget_3d.rs @@ -1,5 +1,6 @@ use brush_render::camera::Camera; -use eframe::egui_wgpu::{self, wgpu}; +use eframe::egui_wgpu::{self, RenderState, wgpu}; +use egui::Rect; use glam::{Mat4, Vec3}; use wgpu::util::DeviceExt; @@ -31,8 +32,42 @@ impl Vertex { } } -/// Resources for Widget3D stored in callback_resources. -pub struct Widget3DResources { +pub struct GridWidget {} + +impl GridWidget { + pub fn new(state: &RenderState) -> Self { + state + .renderer + .write() + .callback_resources + .insert(GridWidgetResources::new(&state.device, state.target_format)); + Self {} + } + + #[expect(clippy::unused_self)] + pub fn paint( + &self, // Not used atm,but, in the future the widget might have some state. + rect: Rect, + camera: Camera, + model_transform: glam::Affine3A, + grid_opacity: f32, + ui: &egui::Ui, + ) { + if grid_opacity > 0.0 { + ui.painter() + .add(eframe::egui_wgpu::Callback::new_paint_callback( + rect, + GridWidgetPainter { + camera, + model_transform, + grid_opacity, + }, + )); + } + } +} + +struct GridWidgetResources { pipeline: wgpu::RenderPipeline, uniform_buffer: wgpu::Buffer, uniform_bind_group: wgpu::BindGroup, @@ -42,8 +77,7 @@ pub struct Widget3DResources { up_axis_vertex_count: u32, } -impl Widget3DResources { - /// Create Widget3D resources. Call this once during initialization. +impl GridWidgetResources { pub fn new(device: &wgpu::Device, target_format: wgpu::TextureFormat) -> Self { let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { label: Some("Widget 3D Shader"), @@ -189,38 +223,16 @@ impl Widget3DResources { ]; (vertices, 2) } - - /// Update uniforms for rendering. - pub fn prepare(&self, queue: &wgpu::Queue, view_proj: Mat4, grid_opacity: f32) { - let uniforms = Uniforms { - view_proj: view_proj.to_cols_array_2d(), - grid_opacity, - _padding: [0.0; 3], - }; - queue.write_buffer(&self.uniform_buffer, 0, bytemuck::cast_slice(&[uniforms])); - } - - /// Issue draw commands to the provided render pass. - pub fn paint(&self, render_pass: &mut wgpu::RenderPass<'_>) { - render_pass.set_pipeline(&self.pipeline); - render_pass.set_bind_group(0, &self.uniform_bind_group, &[]); - - render_pass.set_vertex_buffer(0, self.grid_vertex_buffer.slice(..)); - render_pass.draw(0..self.grid_vertex_count, 0..1); - - render_pass.set_vertex_buffer(0, self.up_axis_vertex_buffer.slice(..)); - render_pass.draw(0..self.up_axis_vertex_count, 0..1); - } } /// Callback for rendering the 3D widget overlay via egui's paint system. -pub struct Widget3DCallback { +struct GridWidgetPainter { pub camera: Camera, pub model_transform: glam::Affine3A, pub grid_opacity: f32, } -impl egui_wgpu::CallbackTrait for Widget3DCallback { +impl egui_wgpu::CallbackTrait for GridWidgetPainter { fn prepare( &self, _device: &wgpu::Device, @@ -229,20 +241,28 @@ impl egui_wgpu::CallbackTrait for Widget3DCallback { _egui_encoder: &mut wgpu::CommandEncoder, resources: &mut egui_wgpu::CallbackResources, ) -> Vec { - let Some(widget_3d) = resources.get::() else { + let Some(resources) = resources.get::() else { return Vec::new(); }; - let aspect = screen_descriptor.size_in_pixels[0] as f32 - / screen_descriptor.size_in_pixels[1] as f32; + let aspect = + screen_descriptor.size_in_pixels[0] as f32 / screen_descriptor.size_in_pixels[1] as f32; let proj_matrix = Mat4::perspective_lh(self.camera.fov_y as f32, aspect, 0.1, 1000.0); let y_flip = Mat4::from_scale(Vec3::new(1.0, -1.0, 1.0)); let view_matrix = self.camera.world_to_local(); let world_view = Mat4::from(view_matrix) * Mat4::from(self.model_transform.inverse()); let view_proj = proj_matrix * y_flip * world_view; - widget_3d.prepare(queue, view_proj, self.grid_opacity); - + let uniforms = Uniforms { + view_proj: view_proj.to_cols_array_2d(), + grid_opacity: self.grid_opacity, + _padding: [0.0; 3], + }; + queue.write_buffer( + &resources.uniform_buffer, + 0, + bytemuck::cast_slice(&[uniforms]), + ); Vec::new() } @@ -252,9 +272,14 @@ impl egui_wgpu::CallbackTrait for Widget3DCallback { render_pass: &mut wgpu::RenderPass<'static>, resources: &egui_wgpu::CallbackResources, ) { - let Some(widget_3d) = resources.get::() else { + let Some(resources) = resources.get::() else { return; }; - widget_3d.paint(render_pass); + render_pass.set_pipeline(&resources.pipeline); + render_pass.set_bind_group(0, &resources.uniform_bind_group, &[]); + render_pass.set_vertex_buffer(0, resources.grid_vertex_buffer.slice(..)); + render_pass.draw(0..resources.grid_vertex_count, 0..1); + render_pass.set_vertex_buffer(0, resources.up_axis_vertex_buffer.slice(..)); + render_pass.draw(0..resources.up_axis_vertex_count, 0..1); } } diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index 55d38723..a62f810b 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -142,18 +142,11 @@ impl App { slot.clone(), ); - let renderer = cc - .wgpu_render_state - .as_ref() - .expect("No wgpu renderer enabled in egui") - .renderer - .clone(); - Self { image, camera, tex_handle: handle, - backbuffer: SplatBackbuffer::new(renderer, state.device.clone(), &state.queue), + backbuffer: SplatBackbuffer::new(), slot, receiver, last_step: None, @@ -186,10 +179,12 @@ impl eframe::App for App { let size = egui::vec2(self.image.width() as f32, self.image.height() as f32); ui.horizontal(|ui| { - ui.image(ImageSource::Texture(SizedTexture::new( - self.backbuffer.id(), - size, - ))); + // TODO: Fixup to new methods. + + // ui.image(ImageSource::Texture(SizedTexture::new( + // self.backbuffer.id(), + // size, + // ))); ui.image(ImageSource::Texture(SizedTexture::new( self.tex_handle.id(), size, From 6c8adf426e8afee0edd601d991b69d0069f33637 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 26 Jan 2026 19:17:35 +0100 Subject: [PATCH 23/29] Render splats straight to UI --- crates/brush-ui/src/scene.rs | 94 ++----- crates/brush-ui/src/splat_backbuffer.rs | 316 +++++++++++++++++++----- examples/train-2d/examples/train-2d.rs | 36 +-- 3 files changed, 303 insertions(+), 143 deletions(-) diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 37e67007..4695d8aa 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -1,27 +1,21 @@ #[cfg(feature = "training")] use crate::settings_popup::SettingsPopup; +use crate::{splat_backbuffer::SplatBackbuffer, widget_3d::GridWidget}; #[cfg(feature = "training")] use brush_process::message::TrainMessage; use brush_process::{create_process, message::ProcessMessage}; +use brush_render::camera::{focal_to_fov, fov_to_focal}; use brush_vfs::DataSource; use core::f32; use eframe::egui_wgpu::RenderState; use egui::{Align2, Button, Frame, RichText, containers::Popup}; +use egui::{Color32, Rect, Slider}; +use glam::Vec3; +use serde::{Deserialize, Serialize}; #[cfg(feature = "training")] use std::sync::{Arc, Mutex}; - -use brush_render::camera::{Camera, focal_to_fov, fov_to_focal}; -use egui::{Color32, Rect, Slider}; -use glam::{UVec2, Vec3}; use web_time::Instant; -use crate::{ - splat_backbuffer::{RenderRequest, SplatBackbuffer}, - widget_3d::GridWidget, -}; - -use serde::{Deserialize, Serialize}; - /// Controls how often the viewport re-renders during training. #[derive(Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum RenderUpdateMode { @@ -52,22 +46,11 @@ impl RenderUpdateMode { } use crate::{ - UiMode, - app::CameraSettings, - draw_checkerboard, + UiMode, draw_checkerboard, panels::AppPane, ui_process::{BackgroundStyle, UiProcess}, }; -#[derive(Clone, PartialEq)] -struct SplatRenderState { - size: UVec2, - cam: Camera, - settings: CameraSettings, - grid_opacity: f32, - frame: u32, -} - struct ErrorDisplay { headline: String, context: Vec, @@ -105,6 +88,8 @@ pub struct ScenePanel { pub(crate) last_draw: Option, #[serde(skip)] has_splats: bool, + #[serde(skip)] + splats_dirty: bool, /// Current frame for animated sequences (as float for smooth interpolation). #[serde(skip)] frame: f32, @@ -122,8 +107,6 @@ pub struct ScenePanel { #[serde(skip)] seen_warning_count: usize, #[serde(skip)] - last_state: Option, - #[serde(skip)] source_name: Option, #[serde(skip)] source_type: Option, @@ -364,8 +347,8 @@ impl ScenePanel { impl ScenePanel { fn reset(&mut self) { self.last_draw = None; - self.last_state = None; self.has_splats = false; + self.splats_dirty = false; self.frame = 0.0; self.frame_count = 0; self.paused = false; @@ -710,15 +693,15 @@ impl AppPane for ScenePanel { }; // If enabling rendering from Off, force a redraw if old_mode == RenderUpdateMode::Off { - self.last_state = None; + self.splats_dirty = true; } } } } fn init(&mut self, state: &RenderState, _burn_device: burn_wgpu::WgpuDevice) { - GridWidget::new(state); - self.backbuffer = Some(SplatBackbuffer::new()); + self.grid = Some(GridWidget::new(state)); + self.backbuffer = Some(SplatBackbuffer::new(state)); // Create the settings popup now that we have the base_path #[cfg(feature = "training")] @@ -767,12 +750,11 @@ impl AppPane for ScenePanel { .. } => { self.has_splats = true; + self.splats_dirty = true; self.frame_count = *total_frames; // For non-training updates (e.g., loading), always redraw if !process.is_training() { - self.last_state = None; - // When training, datasets handle this. if let Some(up_axis) = up_axis { process.set_model_up(*up_axis); @@ -791,7 +773,7 @@ impl AppPane for ScenePanel { // Check if enough iterations have passed since last render if *iter >= self.last_rendered_iter + interval || self.last_rendered_iter == 0 { self.last_rendered_iter = *iter; - self.last_state = None; + self.splats_dirty = true; } } } @@ -940,28 +922,7 @@ impl AppPane for ScenePanel { camera.fov_y = focal_to_fov(focal_x, size.y); } - let grid_opacity = process.get_grid_opacity(); - - let state = SplatRenderState { - size, - cam: camera.clone(), - settings: settings.clone(), - grid_opacity, - frame: self.frame as u32, - }; - - let dirty = self.last_state != Some(state.clone()); - - if dirty { - self.last_state = Some(state); - } - - let pixel_size = glam::uvec2( - (size.x as f32 * ui.ctx().pixels_per_point().round()) as u32, - (size.y as f32 * ui.ctx().pixels_per_point().round()) as u32, - ); - - // Submit new render request if dirty and we have splats + // Render the splats and grid ui.scope(|ui| { // if training views have alpha, show a background checker. Masked images // should still use a black background. @@ -975,23 +936,22 @@ impl AppPane for ScenePanel { } if let Some(backbuffer) = &mut self.backbuffer { - if dirty { - backbuffer.submit(RenderRequest { - slot: process.current_splats(), - frame: self.frame as usize, - camera: camera.clone(), - img_size: pixel_size, - background: settings.background.unwrap_or(Vec3::ZERO), - splat_scale: settings.splat_scale, - ctx: ui.ctx().clone(), - }); - } - - backbuffer.paint(rect, ui); + backbuffer.paint( + rect, + ui, + &process.current_splats(), + &camera, + self.frame as usize, + settings.background.unwrap_or(Vec3::ZERO), + settings.splat_scale, + self.splats_dirty, + ); + self.splats_dirty = false; } if let Some(grid) = &mut self.grid { let model_ltw = process.model_local_to_world(); + let grid_opacity = process.get_grid_opacity(); grid.paint(rect, camera, model_ltw, grid_opacity, ui); } }); diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index 4d63c26f..273bc41d 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -10,92 +10,296 @@ use tokio::sync::mpsc; use eframe::egui_wgpu::{self, CallbackTrait, wgpu}; -/// Request sent to the async render worker. +/// Internal request sent to the async render worker. #[derive(Clone)] -pub struct RenderRequest { - pub slot: Slot>, - pub frame: usize, - - pub camera: Camera, - pub background: Vec3, - pub splat_scale: Option, +struct RenderRequest { + slot: Slot>, + ctx: egui::Context, + state: LastRenderState, +} - pub img_size: UVec2, - pub ctx: egui::Context, +/// State used to track if we need to re-render. +#[derive(Clone, PartialEq)] +struct LastRenderState { + frame: usize, + camera: Camera, + background: Vec3, + splat_scale: Option, + img_size: UVec2, } pub struct SplatBackbuffer { req_send: mpsc::UnboundedSender, img_rec: mpsc::Receiver>, last_image: Option>, + last_state: Option, } impl SplatBackbuffer { - pub fn new() -> Self { + pub fn new(state: &eframe::egui_wgpu::RenderState) -> Self { // Create channel for render requests let (req_send, req_rec) = mpsc::unbounded_channel(); let (img_send, img_rec) = mpsc::channel(1); + // Register splat backbuffer resources + state + .renderer + .write() + .callback_resources + .insert(SplatBackbufferResources::new( + &state.device, + state.target_format, + )); + tokio::task::spawn(render_worker(req_rec, img_send)); Self { req_send, img_rec, last_image: None, + last_state: None, } } - /// Submit a render request. The async worker will process it. - pub fn submit(&mut self, req: RenderRequest) { - // Send request to worker (ignore send errors if channel closed) - let _ = self.req_send.send(req); - } - + #[allow(clippy::too_many_arguments)] pub fn paint( - &mut self, // Not used atm, but, in the future the widget might have some state. + &mut self, rect: Rect, ui: &egui::Ui, + slot: &Slot>, + camera: &Camera, + frame: usize, + background: Vec3, + splat_scale: Option, + splats_dirty: bool, ) { + // Calculate pixel size for rendering + let ppp = ui.ctx().pixels_per_point(); + let img_size = UVec2::new( + (rect.width() * ppp).round() as u32, + (rect.height() * ppp).round() as u32, + ); + + // Check if we need to re-render + let current_state = LastRenderState { + frame, + camera: camera.clone(), + background, + splat_scale, + img_size, + }; + + let dirty = splats_dirty || self.last_state.as_ref() != Some(¤t_state); + + if dirty { + self.last_state = Some(current_state.clone()); + // Send request to worker (ignore send errors if channel closed) + let _ = self.req_send.send(RenderRequest { + slot: slot.clone(), + state: current_state, + ctx: ui.ctx().clone(), + }); + } + while let Ok(img) = self.img_rec.try_recv() { self.last_image = Some(img); } if let Some(image) = &self.last_image { + let shape = image.shape(); + let img_height = shape.dims[0] as u32; + let img_width = shape.dims[1] as u32; + ui.painter() .add(eframe::egui_wgpu::Callback::new_paint_callback( rect, SplatBackbufferPainter { - rect, last_img: image.clone(), + img_width, + img_height, }, )); } } } +#[repr(C)] +#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] +struct Uniforms { + img_width: u32, + img_height: u32, +} + +pub struct SplatBackbufferResources { + pipeline: wgpu::RenderPipeline, + uniform_buffer: wgpu::Buffer, + bind_group_layout: wgpu::BindGroupLayout, + // Per-frame bind group - created in prepare() with the current tensor buffer + bind_group: Option, +} + +impl SplatBackbufferResources { + pub fn new(device: &wgpu::Device, target_format: wgpu::TextureFormat) -> Self { + let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Splat Backbuffer Shader"), + source: wgpu::ShaderSource::Wgsl(include_str!("shaders/splat_backbuffer.wgsl").into()), + }); + let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Splat Backbuffer Uniform Buffer"), + size: std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("Splat Backbuffer Bind Group Layout"), + entries: &[ + // Uniform buffer for image dimensions + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::VERTEX | wgpu::ShaderStages::FRAGMENT, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Uniform, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // Storage buffer for image data (read-only) + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::FRAGMENT, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + }); + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Splat Backbuffer Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { + label: Some("Splat Backbuffer Pipeline"), + layout: Some(&pipeline_layout), + vertex: wgpu::VertexState { + module: &shader, + entry_point: Some("vs_main"), + buffers: &[], // No vertex buffers - using fullscreen triangle trick + compilation_options: wgpu::PipelineCompilationOptions::default(), + }, + fragment: Some(wgpu::FragmentState { + module: &shader, + entry_point: Some("fs_main"), + targets: &[Some(wgpu::ColorTargetState { + format: target_format, + blend: Some(wgpu::BlendState::ALPHA_BLENDING), + write_mask: wgpu::ColorWrites::ALL, + })], + compilation_options: wgpu::PipelineCompilationOptions::default(), + }), + primitive: wgpu::PrimitiveState { + topology: wgpu::PrimitiveTopology::TriangleList, + strip_index_format: None, + front_face: wgpu::FrontFace::Ccw, + cull_mode: None, + unclipped_depth: false, + polygon_mode: wgpu::PolygonMode::Fill, + conservative: false, + }, + depth_stencil: None, + multisample: wgpu::MultisampleState::default(), + multiview: None, + cache: None, + }); + + Self { + pipeline, + uniform_buffer, + bind_group_layout, + bind_group: None, + } + } +} + struct SplatBackbufferPainter { - rect: Rect, last_img: Tensor, + img_width: u32, + img_height: u32, } impl CallbackTrait for SplatBackbufferPainter { + fn prepare( + &self, + device: &wgpu::Device, + queue: &wgpu::Queue, + _screen_descriptor: &egui_wgpu::ScreenDescriptor, + _egui_encoder: &mut wgpu::CommandEncoder, + resources: &mut egui_wgpu::CallbackResources, + ) -> Vec { + let Some(res) = resources.get_mut::() else { + return Vec::new(); + }; + + // Update uniform buffer with image dimensions + queue.write_buffer( + &res.uniform_buffer, + 0, + bytemuck::cast_slice(&[Uniforms { + img_width: self.img_width, + img_height: self.img_height, + }]), + ); + + // Extract the wgpu buffer from the Burn tensor + let last_img = self.last_img.clone().into_primitive().tensor(); + let prim_tensor = last_img + .client + .clone() + .resolve_tensor_int::(last_img); + let img_res_handle = prim_tensor + .client + .get_resource(prim_tensor.handle.binding()); + + // Create a new bind group with the current tensor buffer + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("Splat Backbuffer Bind Group"), + layout: &res.bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: res.uniform_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: img_res_handle.resource().buffer.as_entire_binding(), + }, + ], + }); + + res.bind_group = Some(bind_group); + Vec::new() + } + fn paint( &self, _info: egui::PaintCallbackInfo, - _render_pass: &mut wgpu::RenderPass<'static>, - _callback_resources: &egui_wgpu::CallbackResources, + render_pass: &mut wgpu::RenderPass<'static>, + callback_resources: &egui_wgpu::CallbackResources, ) { - let last_img = self.last_img.clone().into_primitive().tensor(); + let Some(res) = callback_resources.get::() else { + return; + }; + + let Some(bind_group) = res.bind_group.as_ref() else { + return; + }; - let client = last_img.client.clone(); - let prim_tensor = client.resolve_tensor_int::(last_img); - let prim_client = prim_tensor.client; - let img_res_handle = prim_client.get_resource(prim_tensor.handle.binding()); - let buffer = img_res_handle.resource(); - // TODO: Draw our tensor to the screen here from straight wgpu! - // The shader can read from the wgpu binding. - // Nearest neighbour sampling is fine, we _usually_ have things rendered to the correct nr. of pixels, - // only when we're midway a resize do we not, so not much point in filtering. - // NB: The format of this tensor is NOT a float, but PACKED RGBA8, secretely. + render_pass.set_pipeline(&res.pipeline); + render_pass.set_bind_group(0, bind_group, &[]); + render_pass.draw(0..3, 0..1); // 3 vertices for fullscreen triangle } } @@ -105,39 +309,35 @@ async fn render_worker( img_sender: mpsc::Sender>, ) { loop { - // Wait for at least one request + // Wait for at least one request and get latest. let Some(mut request) = receiver.recv().await else { - break; // Channel closed + break; }; - - // Coalesce: drain channel, keep only the last request while let Ok(newer) = receiver.try_recv() { request = newer; } - // Clone splats (async) - let Some(splats) = request.slot.clone_main().await else { - continue; - }; - - // Render (async) - let (image, _) = render_splats( - splats, - &request.camera, - request.img_size, - request.background, - request.splat_scale, - TextureMode::Packed, - ) - .await; + let image = request + .slot + .act(request.state.frame, async |splats| { + let (image, _) = render_splats( + splats.clone(), + &request.state.camera, + request.state.img_size, + request.state.background, + request.state.splat_scale, + TextureMode::Packed, + ) + .await; + (splats, image) + }) + .await; - // Don't care about errors if channel is closed. - let _ = img_sender.send(image).await; - - // TODO: Store the latest img tensor. Only continue - // once it has been removed from the channel. + if let Some(image) = image { + let _ = img_sender.send(image).await; + } - // Trigger egui repaint + // Trigger egui repaint so the new texture gets picked up. request.ctx.request_repaint(); } } diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index a62f810b..b465d4f4 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -13,7 +13,7 @@ use brush_train::{ RandomSplatsConfig, config::TrainConfig, create_random_splats, splats_into_autodiff, train::SplatTrainer, }; -use brush_ui::splat_backbuffer::{RenderRequest, SplatBackbuffer}; +use brush_ui::splat_backbuffer::SplatBackbuffer; use burn::{backend::wgpu::WgpuDevice, module::AutodiffModule, prelude::Backend}; use egui::{ImageSource, TextureHandle, TextureOptions, load::SizedTexture}; use glam::{Quat, Vec2, Vec3}; @@ -90,6 +90,7 @@ struct App { slot: Slot>, receiver: Receiver, last_step: Option, + splats_dirty: bool, } impl App { @@ -98,6 +99,7 @@ impl App { .wgpu_render_state .as_ref() .expect("No wgpu renderer enabled in egui"); + let device = brush_process::burn_init_device( state.adapter.clone(), state.device.clone(), @@ -128,7 +130,6 @@ impl App { let handle = cc.egui_ctx .load_texture("nearest_view_tex", color_img, TextureOptions::default()); - let slot = Slot::default(); let config = TrainConfig::default(); @@ -146,10 +147,11 @@ impl App { image, camera, tex_handle: handle, - backbuffer: SplatBackbuffer::new(), + backbuffer: SplatBackbuffer::new(state), slot, receiver, last_step: None, + splats_dirty: false, } } } @@ -158,6 +160,7 @@ impl eframe::App for App { fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { while let Ok(step) = self.receiver.try_recv() { self.last_step = Some(step); + self.splats_dirty = true; } egui::CentralPanel::default().show(ctx, |ui| { @@ -166,25 +169,22 @@ impl eframe::App for App { return; }; - self.backbuffer.submit(RenderRequest { - slot: self.slot.clone(), - frame: 0, - camera: self.camera.clone(), - img_size: glam::uvec2(self.image.width(), self.image.height()), - background: Vec3::ZERO, - splat_scale: None, - ctx: ctx.clone(), - }); - let size = egui::vec2(self.image.width() as f32, self.image.height() as f32); ui.horizontal(|ui| { - // TODO: Fixup to new methods. + let (rect, _response) = ui.allocate_exact_size(size, egui::Sense::hover()); + self.backbuffer.paint( + rect, + ui, + &self.slot, + &self.camera, + 0, + Vec3::ZERO, + None, + self.splats_dirty, + ); + self.splats_dirty = false; - // ui.image(ImageSource::Texture(SizedTexture::new( - // self.backbuffer.id(), - // size, - // ))); ui.image(ImageSource::Texture(SizedTexture::new( self.tex_handle.id(), size, From 51b06ac80e4a34ee6e8eef0fd19a0640bbe94fb8 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 26 Jan 2026 19:17:44 +0100 Subject: [PATCH 24/29] Missing shader --- .../src/shaders/splat_backbuffer.wgsl | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 crates/brush-ui/src/shaders/splat_backbuffer.wgsl diff --git a/crates/brush-ui/src/shaders/splat_backbuffer.wgsl b/crates/brush-ui/src/shaders/splat_backbuffer.wgsl new file mode 100644 index 00000000..1f3d47bd --- /dev/null +++ b/crates/brush-ui/src/shaders/splat_backbuffer.wgsl @@ -0,0 +1,50 @@ +struct Uniforms { + img_width: u32, + img_height: u32, +} + +@group(0) @binding(0) var uniforms: Uniforms; +@group(0) @binding(1) var image_data: array; + +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) uv: vec2, +} + +@vertex +fn vs_main(@builtin(vertex_index) vertex_index: u32) -> VertexOutput { + // Fullscreen triangle using oversized triangle technique + // Creates a triangle twice as large as the viewport - GPU clips to screen + var out: VertexOutput; + let x = f32((vertex_index << 1u) & 2u); // 0, 2, 0 for indices 0, 1, 2 + let y = f32(vertex_index & 2u); // 0, 0, 2 for indices 0, 1, 2 + + // Position in clip space: (-1,-1), (3,-1), (-1,3) - oversized triangle + out.position = vec4(x * 2.0 - 1.0, y * 2.0 - 1.0, 0.0, 1.0); + + // UV coordinates with Y flipped (image top = uv.y 0, clip space top = y 1) + out.uv = vec2(x, 1.0 - y); + + return out; +} + +@fragment +fn fs_main(in: VertexOutput) -> @location(0) vec4 { + let pixel_x = u32(in.uv.x * f32(uniforms.img_width)); + let pixel_y = u32(in.uv.y * f32(uniforms.img_height)); + + if (pixel_x >= uniforms.img_width || pixel_y >= uniforms.img_height) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + + let idx = pixel_y * uniforms.img_width + pixel_x; + let packed = image_data[idx]; + + // Unpack RGBA8: R|(G<<8)|(B<<16)|(A<<24) + let r = f32(packed & 0xFFu) / 255.0; + let g = f32((packed >> 8u) & 0xFFu) / 255.0; + let b = f32((packed >> 16u) & 0xFFu) / 255.0; + let a = f32((packed >> 24u) & 0xFFu) / 255.0; + + return vec4(r, g, b, a); +} From 856cd6e00b443cdac575d06ad29dacf33d837589 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 26 Jan 2026 20:19:10 +0100 Subject: [PATCH 25/29] Fix live update --- crates/brush-ui/src/scene.rs | 32 +++++++++++-------------- crates/brush-ui/src/splat_backbuffer.rs | 2 +- crates/brush-ui/src/ui_process.rs | 16 ++++++++++++- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 4695d8aa..6a21494e 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -1,8 +1,6 @@ #[cfg(feature = "training")] use crate::settings_popup::SettingsPopup; use crate::{splat_backbuffer::SplatBackbuffer, widget_3d::GridWidget}; -#[cfg(feature = "training")] -use brush_process::message::TrainMessage; use brush_process::{create_process, message::ProcessMessage}; use brush_render::camera::{focal_to_fov, fov_to_focal}; use brush_vfs::DataSource; @@ -685,16 +683,11 @@ impl AppPane for ScenePanel { let new_idx = idx.round() as usize; if new_idx != old_idx { - let old_mode = self.render_update_mode; self.render_update_mode = match new_idx { 0 => RenderUpdateMode::Off, 1 => RenderUpdateMode::Low, _ => RenderUpdateMode::Live, }; - // If enabling rendering from Off, force a redraw - if old_mode == RenderUpdateMode::Off { - self.splats_dirty = true; - } } } } @@ -750,11 +743,12 @@ impl AppPane for ScenePanel { .. } => { self.has_splats = true; - self.splats_dirty = true; self.frame_count = *total_frames; // For non-training updates (e.g., loading), always redraw if !process.is_training() { + self.splats_dirty = true; + // When training, datasets handle this. if let Some(up_axis) = up_axis { process.set_model_up(*up_axis); @@ -764,16 +758,18 @@ impl AppPane for ScenePanel { if *total_frames <= 1 || *frame < *total_frames - 1 { self.frame = *frame as f32; } - } - } - #[cfg(feature = "training")] - ProcessMessage::TrainMessage(TrainMessage::TrainStep { iter, .. }) => { - // Check if we should redraw based on render update mode - if let Some(interval) = self.render_update_mode.update_interval() { - // Check if enough iterations have passed since last render - if *iter >= self.last_rendered_iter + interval || self.last_rendered_iter == 0 { - self.last_rendered_iter = *iter; - self.splats_dirty = true; + } else { + // Check if we should redraw based on render update mode + if let Some(interval) = self.render_update_mode.update_interval() { + let iter = process.train_iter(); + + // Check if enough iterations have passed since last render + if iter >= self.last_rendered_iter + interval + || self.last_rendered_iter == 0 + { + self.last_rendered_iter = iter; + self.splats_dirty = true; + } } } } diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index 273bc41d..c53d0478 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -95,8 +95,8 @@ impl SplatBackbuffer { // Send request to worker (ignore send errors if channel closed) let _ = self.req_send.send(RenderRequest { slot: slot.clone(), - state: current_state, ctx: ui.ctx().clone(), + state: current_state, }); } diff --git a/crates/brush-ui/src/ui_process.rs b/crates/brush-ui/src/ui_process.rs index 26cfe7d7..76df1bf3 100644 --- a/crates/brush-ui/src/ui_process.rs +++ b/crates/brush-ui/src/ui_process.rs @@ -1,6 +1,9 @@ use crate::{UiMode, app::CameraSettings, camera_controls::CameraController}; use anyhow::Result; -use brush_process::{message::ProcessMessage, slot::Slot}; +use brush_process::{ + message::{ProcessMessage, TrainMessage}, + slot::Slot, +}; use brush_render::{MainBackend, camera::Camera, gaussian_splats::Splats}; use burn_wgpu::WgpuDevice; use egui::{Response, TextureHandle}; @@ -108,6 +111,10 @@ impl UiProcess { self.read().train_paused } + pub(crate) fn train_iter(&self) -> u32 { + self.read().train_iter + } + pub fn get_cam_settings(&self) -> CameraSettings { self.read().controls.settings.clone() } @@ -235,10 +242,15 @@ impl UiProcess { Ok(ProcessMessage::StartLoading { training, .. }) => { inner.is_training = *training; inner.is_loading = true; + inner.train_iter = 0; } Ok(ProcessMessage::DoneLoading) => { inner.is_loading = false; } + #[cfg(feature = "training")] + Ok(ProcessMessage::TrainMessage(TrainMessage::TrainStep { iter, .. })) => { + inner.train_iter = *iter; + } Err(_) => { inner.is_loading = false; inner.is_training = false; @@ -293,6 +305,7 @@ struct UiProcessInner { ui_mode: UiMode, background_style: BackgroundStyle, train_paused: bool, + train_iter: u32, reset_layout_requested: bool, session_reset_requested: bool, ui_ctx: egui::Context, @@ -313,6 +326,7 @@ impl UiProcessInner { splat_scale: None, is_loading: false, is_training: false, + train_iter: 0, process_handle: None, ui_mode: UiMode::Default, background_style: BackgroundStyle::Black, From 6d4468a058ec483d01cd6db8fe32ce18975a4371 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 26 Jan 2026 20:33:32 +0100 Subject: [PATCH 26/29] Fix no default features --- Cargo.toml | 2 +- crates/brush-ui/src/scene.rs | 1 - crates/brush-ui/src/ui_process.rs | 5 +---- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 37e830da..bbfdb251 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ tracing = "0.1.41" tracing-tracy = "0.11.3" tracing-subscriber = "0.3.19" -winapi = "0.3" +winapi = { version = "0.3", features = ["wincon"] } tokio = { version = "1.42.0", default-features = false } tokio_with_wasm = "0.8.2" diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 6a21494e..8a5b6296 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -25,7 +25,6 @@ pub enum RenderUpdateMode { impl RenderUpdateMode { /// Returns the iteration interval for this mode, or None if rendering is disabled. - #[cfg(feature = "training")] fn update_interval(&self) -> Option { match self { Self::Off => None, diff --git a/crates/brush-ui/src/ui_process.rs b/crates/brush-ui/src/ui_process.rs index 76df1bf3..cdc2b824 100644 --- a/crates/brush-ui/src/ui_process.rs +++ b/crates/brush-ui/src/ui_process.rs @@ -1,9 +1,6 @@ use crate::{UiMode, app::CameraSettings, camera_controls::CameraController}; use anyhow::Result; -use brush_process::{ - message::{ProcessMessage, TrainMessage}, - slot::Slot, -}; +use brush_process::{message::ProcessMessage, slot::Slot}; use brush_render::{MainBackend, camera::Camera, gaussian_splats::Splats}; use burn_wgpu::WgpuDevice; use egui::{Response, TextureHandle}; From 28ab1b0984edbde57eb1d3f7a71c58ef73f87d18 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 26 Jan 2026 20:38:05 +0100 Subject: [PATCH 27/29] Fix --- crates/brush-ui/src/ui_process.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/brush-ui/src/ui_process.rs b/crates/brush-ui/src/ui_process.rs index cdc2b824..81508100 100644 --- a/crates/brush-ui/src/ui_process.rs +++ b/crates/brush-ui/src/ui_process.rs @@ -245,7 +245,9 @@ impl UiProcess { inner.is_loading = false; } #[cfg(feature = "training")] - Ok(ProcessMessage::TrainMessage(TrainMessage::TrainStep { iter, .. })) => { + Ok(ProcessMessage::TrainMessage( + brush_process::message::TrainMessage::TrainStep { iter, .. }, + )) => { inner.train_iter = *iter; } Err(_) => { From b605e7f4a753214625cb0b846d0dd79f01b6caf2 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 26 Jan 2026 21:02:08 +0100 Subject: [PATCH 28/29] Misc clean --- crates/brush-ui/src/shaders/splat_backbuffer.wgsl | 7 ------- crates/brush-ui/src/splat_backbuffer.rs | 4 +--- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/crates/brush-ui/src/shaders/splat_backbuffer.wgsl b/crates/brush-ui/src/shaders/splat_backbuffer.wgsl index 1f3d47bd..c5fea9b2 100644 --- a/crates/brush-ui/src/shaders/splat_backbuffer.wgsl +++ b/crates/brush-ui/src/shaders/splat_backbuffer.wgsl @@ -14,17 +14,11 @@ struct VertexOutput { @vertex fn vs_main(@builtin(vertex_index) vertex_index: u32) -> VertexOutput { // Fullscreen triangle using oversized triangle technique - // Creates a triangle twice as large as the viewport - GPU clips to screen var out: VertexOutput; let x = f32((vertex_index << 1u) & 2u); // 0, 2, 0 for indices 0, 1, 2 let y = f32(vertex_index & 2u); // 0, 0, 2 for indices 0, 1, 2 - - // Position in clip space: (-1,-1), (3,-1), (-1,3) - oversized triangle out.position = vec4(x * 2.0 - 1.0, y * 2.0 - 1.0, 0.0, 1.0); - - // UV coordinates with Y flipped (image top = uv.y 0, clip space top = y 1) out.uv = vec2(x, 1.0 - y); - return out; } @@ -45,6 +39,5 @@ fn fs_main(in: VertexOutput) -> @location(0) vec4 { let g = f32((packed >> 8u) & 0xFFu) / 255.0; let b = f32((packed >> 16u) & 0xFFu) / 255.0; let a = f32((packed >> 24u) & 0xFFu) / 255.0; - return vec4(r, g, b, a); } diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index c53d0478..dbebffa0 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -10,7 +10,6 @@ use tokio::sync::mpsc; use eframe::egui_wgpu::{self, CallbackTrait, wgpu}; -/// Internal request sent to the async render worker. #[derive(Clone)] struct RenderRequest { slot: Slot>, @@ -18,7 +17,6 @@ struct RenderRequest { state: LastRenderState, } -/// State used to track if we need to re-render. #[derive(Clone, PartialEq)] struct LastRenderState { frame: usize, @@ -299,7 +297,7 @@ impl CallbackTrait for SplatBackbufferPainter { render_pass.set_pipeline(&res.pipeline); render_pass.set_bind_group(0, bind_group, &[]); - render_pass.draw(0..3, 0..1); // 3 vertices for fullscreen triangle + render_pass.draw(0..3, 0..1); } } From b6cc3a99cb865a622b8d0986efa6268a351eee9f Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 27 Jan 2026 00:27:40 +0100 Subject: [PATCH 29/29] Simplify some code --- crates/brush-ui/src/app.rs | 6 +-- crates/brush-ui/src/panels.rs | 3 +- crates/brush-ui/src/scene.rs | 2 +- crates/brush-ui/src/splat_backbuffer.rs | 3 +- crates/brush-ui/src/stats.rs | 59 ++++++++++++------------- crates/brush-ui/src/ui_process.rs | 4 ++ 6 files changed, 38 insertions(+), 39 deletions(-) diff --git a/crates/brush-ui/src/app.rs b/crates/brush-ui/src/app.rs index 60f169da..eb0da69f 100644 --- a/crates/brush-ui/src/app.rs +++ b/crates/brush-ui/src/app.rs @@ -230,7 +230,7 @@ impl App { ); log::info!("Connecting context to Burn device & GUI context."); - let context = std::sync::Arc::new(UiProcess::new(burn_device.clone(), cc.egui_ctx.clone())); + let context = std::sync::Arc::new(UiProcess::new(burn_device, cc.egui_ctx.clone())); if let Some(process) = init_process { context.connect_to_process(process); @@ -249,9 +249,7 @@ impl App { // Initialize all panels with runtime state for (_, tile) in tree.tiles.iter_mut() { if let egui_tiles::Tile::Pane(pane) = tile { - pane.get_mut() - .as_pane_mut() - .init(state, burn_device.clone()); + pane.get_mut().as_pane_mut().init(state); } } diff --git a/crates/brush-ui/src/panels.rs b/crates/brush-ui/src/panels.rs index ed0345ec..9d1285f0 100644 --- a/crates/brush-ui/src/panels.rs +++ b/crates/brush-ui/src/panels.rs @@ -1,5 +1,4 @@ use brush_process::message::ProcessMessage; -use burn_wgpu::WgpuDevice; use eframe::egui_wgpu::RenderState; use crate::ui_process::UiProcess; @@ -9,7 +8,7 @@ pub(crate) trait AppPane { /// Initialize runtime state after creation or deserialization. #[allow(unused_variables)] - fn init(&mut self, state: &RenderState, burn_device: WgpuDevice) {} + fn init(&mut self, state: &RenderState) {} /// Draw the pane's UI's content. fn ui(&mut self, ui: &mut egui::Ui, process: &UiProcess); diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index 8a5b6296..a28fe541 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -691,7 +691,7 @@ impl AppPane for ScenePanel { } } - fn init(&mut self, state: &RenderState, _burn_device: burn_wgpu::WgpuDevice) { + fn init(&mut self, state: &RenderState) { self.grid = Some(GridWidget::new(state)); self.backbuffer = Some(SplatBackbuffer::new(state)); diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs index dbebffa0..580f81fd 100644 --- a/crates/brush-ui/src/splat_backbuffer.rs +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -7,6 +7,7 @@ use burn::tensor::Tensor; use egui::Rect; use glam::{UVec2, Vec3}; use tokio::sync::mpsc; +use tokio_with_wasm::alias::task; use eframe::egui_wgpu::{self, CallbackTrait, wgpu}; @@ -49,7 +50,7 @@ impl SplatBackbuffer { state.target_format, )); - tokio::task::spawn(render_worker(req_rec, img_send)); + task::spawn(render_worker(req_rec, img_send)); Self { req_send, img_rec, diff --git a/crates/brush-ui/src/stats.rs b/crates/brush-ui/src/stats.rs index ffcb9899..5b9eb3b3 100644 --- a/crates/brush-ui/src/stats.rs +++ b/crates/brush-ui/src/stats.rs @@ -2,14 +2,13 @@ use crate::{UiMode, panels::AppPane, ui_process::UiProcess}; use brush_process::message::ProcessMessage; use brush_process::message::TrainMessage; use burn_cubecl::cubecl::Runtime; -use burn_wgpu::{WgpuDevice, WgpuRuntime}; +use burn_wgpu::WgpuRuntime; use eframe::egui_wgpu::RenderState; use web_time::Duration; use wgpu::AdapterInfo; #[derive(Default)] pub struct StatsPanel { - device: Option, last_eval: Option, frames: u32, adapter_info: Option, @@ -78,8 +77,7 @@ impl AppPane for StatsPanel { "Stats".into() } - fn init(&mut self, state: &RenderState, burn_device: burn_wgpu::WgpuDevice) { - self.device = Some(burn_device); + fn init(&mut self, state: &RenderState) { self.adapter_info = Some(state.adapter.get_info()); } @@ -187,40 +185,39 @@ impl AppPane for StatsPanel { }); } - if let Some(device) = &self.device { - ui.add_space(10.0); - ui.heading("GPU"); - ui.separator(); + let device = process.burn_device(); + let client = WgpuRuntime::client(&device); + let memory = client.memory_usage(); + + ui.add_space(10.0); + ui.heading("GPU"); + ui.separator(); - let client = WgpuRuntime::client(device); - let memory = client.memory_usage(); + stats_grid(ui, "memory_stats_grid", |ui, v| { + stat_row(ui, "Bytes in use", bytes_format(memory.bytes_in_use), v); + stat_row(ui, "Bytes reserved", bytes_format(memory.bytes_reserved), v); + stat_row( + ui, + "Active allocations", + format!("{}", memory.number_allocs), + v, + ); + }); - stats_grid(ui, "memory_stats_grid", |ui, v| { - stat_row(ui, "Bytes in use", bytes_format(memory.bytes_in_use), v); - stat_row(ui, "Bytes reserved", bytes_format(memory.bytes_reserved), v); + // On WASM, adapter info is mostly private, not worth showing. + if !cfg!(target_family = "wasm") + && let Some(adapter_info) = &self.adapter_info + { + stats_grid(ui, "gpu_info_grid", |ui, v| { + stat_row(ui, "Name", &adapter_info.name, v); + stat_row(ui, "Type", format!("{:?}", adapter_info.device_type), v); stat_row( ui, - "Active allocations", - format!("{}", memory.number_allocs), + "Driver", + format!("{}, {}", adapter_info.driver, adapter_info.driver_info), v, ); }); - - // On WASM, adapter info is mostly private, not worth showing. - if !cfg!(target_family = "wasm") - && let Some(adapter_info) = &self.adapter_info - { - stats_grid(ui, "gpu_info_grid", |ui, v| { - stat_row(ui, "Name", &adapter_info.name, v); - stat_row(ui, "Type", format!("{:?}", adapter_info.device_type), v); - stat_row( - ui, - "Driver", - format!("{}, {}", adapter_info.driver, adapter_info.driver_info), - v, - ); - }); - } } }); } diff --git a/crates/brush-ui/src/ui_process.rs b/crates/brush-ui/src/ui_process.rs index 81508100..5cf49e6b 100644 --- a/crates/brush-ui/src/ui_process.rs +++ b/crates/brush-ui/src/ui_process.rs @@ -292,6 +292,10 @@ impl UiProcess { inner.session_reset_requested = false; requested } + + pub fn burn_device(&self) -> WgpuDevice { + self.read().burn_device.clone() + } } struct UiProcessInner {