diff --git a/Cargo.lock b/Cargo.lock index 1ed80a95..419d93eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1175,6 +1175,7 @@ dependencies = [ "glam", "log", "serde", + "tokio", "tracing", ] @@ -3865,6 +3866,7 @@ dependencies = [ "brush-train", "brush-ui", "burn", + "burn-cubecl", "eframe", "egui", "glam", diff --git a/Cargo.toml b/Cargo.toml index 37e830da..bbfdb251 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ tracing = "0.1.41" tracing-tracy = "0.11.3" tracing-subscriber = "0.3.19" -winapi = "0.3" +winapi = { version = "0.3", features = ["wincon"] } tokio = { version = "1.42.0", default-features = false } tokio_with_wasm = "0.8.2" diff --git a/crates/brush-bench-test/src/benches.rs b/crates/brush-bench-test/src/benches.rs index 6a421f09..079a6c85 100644 --- a/crates/brush-bench-test/src/benches.rs +++ b/crates/brush-bench-test/src/benches.rs @@ -1,6 +1,6 @@ use brush_dataset::scene::SceneBatch; use brush_render::{ - AlphaMode, MainBackend, + AlphaMode, MainBackend, TextureMode, camera::Camera, gaussian_splats::{SplatRenderMode, Splats}, render_splats, @@ -144,8 +144,9 @@ fn generate_training_batch(resolution: (u32, u32), camera_pos: Vec3) -> SceneBat mod forward_rendering { use super::{ AutodiffModule, Backend, Camera, ITERS_PER_SYNC, MainBackend, Quat, RESOLUTIONS, - SPLAT_COUNTS, Vec3, WgpuDevice, gen_splats, render_splats, + SPLAT_COUNTS, TextureMode, Vec3, WgpuDevice, gen_splats, render_splats, }; + use burn_cubecl::cubecl::future::block_on; #[divan::bench(args = SPLAT_COUNTS)] fn render_1080p(bencher: divan::Bencher, splat_count: usize) { @@ -160,10 +161,20 @@ mod forward_rendering { ); bencher.bench_local(move || { - for _ in 0..ITERS_PER_SYNC { - let _ = render_splats(&splats, &camera, glam::uvec2(1920, 1080), Vec3::ZERO, None); - } - MainBackend::sync(&device).expect("Failed to sync"); + block_on(async { + for _ in 0..ITERS_PER_SYNC { + let _ = render_splats( + splats.clone(), + &camera, + glam::uvec2(1920, 1080), + Vec3::ZERO, + None, + TextureMode::Float, + ) + .await; + } + MainBackend::sync(&device).expect("Failed to sync"); + }); }); } @@ -180,16 +191,20 @@ mod forward_rendering { ); bencher.bench_local(move || { - for _ in 0..ITERS_PER_SYNC { - let _ = render_splats( - &splats, - &camera, - glam::uvec2(width, height), - Vec3::ZERO, - None, - ); - } - MainBackend::sync(&device).expect("Failed to sync"); + block_on(async { + for _ in 0..ITERS_PER_SYNC { + let _ = render_splats( + splats.clone(), + &camera, + glam::uvec2(width, height), + Vec3::ZERO, + None, + TextureMode::Float, + ) + .await; + } + MainBackend::sync(&device).expect("Failed to sync"); + }); }); } } @@ -200,6 +215,7 @@ mod backward_rendering { Backend, Camera, DiffBackend, ITERS_PER_SYNC, MainBackend, Quat, RESOLUTIONS, Tensor, TensorPrimitive, Vec3, WgpuDevice, gen_splats, render_splats_diff, }; + use burn_cubecl::cubecl::future::block_on; #[divan::bench(args = [1_000_000, 2_000_000, 5_000_000])] fn render_grad_1080p(bencher: divan::Bencher, splat_count: usize) { @@ -214,14 +230,21 @@ mod backward_rendering { ); bencher.bench_local(move || { - for _ in 0..ITERS_PER_SYNC { - let diff_out = - render_splats_diff(&splats, &camera, glam::uvec2(1920, 1080), Vec3::ZERO); - let img: Tensor = - Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); - let _ = img.mean().backward(); - } - MainBackend::sync(&device).expect("Failed to sync"); + block_on(async { + for _ in 0..ITERS_PER_SYNC { + let diff_out = render_splats_diff( + splats.clone(), + &camera, + glam::uvec2(1920, 1080), + Vec3::ZERO, + ) + .await; + let img: Tensor = + Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); + let _ = img.mean().backward(); + } + MainBackend::sync(&device).expect("Failed to sync"); + }); }); } @@ -237,14 +260,21 @@ mod backward_rendering { glam::vec2(0.5, 0.5), ); bencher.bench_local(move || { - for _ in 0..ITERS_PER_SYNC { - let diff_out = - render_splats_diff(&splats, &camera, glam::uvec2(width, height), Vec3::ZERO); - let img: Tensor = - Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); - let _ = img.mean().backward(); - } - MainBackend::sync(&device).expect("Failed to sync"); + block_on(async { + for _ in 0..ITERS_PER_SYNC { + let diff_out = render_splats_diff( + splats.clone(), + &camera, + glam::uvec2(width, height), + Vec3::ZERO, + ) + .await; + let img: Tensor = + Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); + let _ = img.mean().backward(); + } + MainBackend::sync(&device).expect("Failed to sync"); + }); }); } } @@ -252,6 +282,7 @@ mod backward_rendering { #[divan::bench_group(max_time = 4)] mod training { use brush_render::bounding_box::BoundingBox; + use burn_cubecl::cubecl::future::block_on; use crate::benches::ITERS_PER_SYNC; @@ -262,25 +293,23 @@ mod training { #[divan::bench(args = SPLAT_COUNTS)] fn train_steps(splat_count: usize) { - burn_cubecl::cubecl::future::block_on(async { - let device = WgpuDevice::default(); - let batch1 = generate_training_batch((1920, 1080), Vec3::new(0.0, 0.0, 5.0)); - let batch2 = generate_training_batch((1920, 1080), Vec3::new(2.0, 0.0, 5.0)); - let batches = [batch1, batch2]; - let config = TrainConfig::default(); - let mut splats = gen_splats(&device, splat_count); - let mut trainer = SplatTrainer::new( - &config, - &device, - BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE), - ); - for step in 0..ITERS_PER_SYNC { - let batch = batches[step as usize % batches.len()].clone(); - let (new_splats, _) = trainer.step(batch, splats); - splats = new_splats; - } - MainBackend::sync(&device).expect("Failed to sync"); - }); + let device = WgpuDevice::default(); + let batch1 = generate_training_batch((1920, 1080), Vec3::new(0.0, 0.0, 5.0)); + let batch2 = generate_training_batch((1920, 1080), Vec3::new(2.0, 0.0, 5.0)); + let batches = [batch1, batch2]; + let config = TrainConfig::default(); + let mut splats = gen_splats(&device, splat_count); + let mut trainer = SplatTrainer::new( + &config, + &device, + BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE), + ); + for step in 0..ITERS_PER_SYNC { + let batch = batches[step as usize % batches.len()].clone(); + let (new_splats, _) = block_on(trainer.step(batch, splats)); + splats = new_splats; + } + MainBackend::sync(&device).expect("Failed to sync"); } } diff --git a/crates/brush-bench-test/src/reference.rs b/crates/brush-bench-test/src/reference.rs index d9b907c4..687d3840 100644 --- a/crates/brush-bench-test/src/reference.rs +++ b/crates/brush-bench-test/src/reference.rs @@ -10,7 +10,7 @@ use burn::{ Tensor, backend::{Autodiff, wgpu::WgpuDevice}, prelude::Backend, - tensor::{Float, Int, TensorPrimitive}, + tensor::TensorPrimitive, }; use anyhow::{Context, Result}; @@ -127,16 +127,15 @@ async fn test_reference() -> Result<()> { ); let diff_out = brush_render_bwd::render_splats( - &splats, + splats.clone(), &cam, glam::uvec2(w as u32, h as u32), Vec3::ZERO, - ); + ) + .await; - let (out, aux) = ( - Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)), - diff_out.aux, - ); + let out: Tensor = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); + let render_aux = diff_out.render_aux; if let Some(rec) = rec.as_ref() { rec.set_time_sequence("test case", i as i64); @@ -148,28 +147,10 @@ async fn test_reference() -> Result<()> { )?; rec.log( "images/tile_depth", - &aux.calc_tile_depth().into_rerun().await, + &render_aux.calc_tile_depth().into_rerun().await, )?; } - let num_visible: Tensor = aux.num_visible(); - let num_visible = num_visible.into_scalar_async().await.unwrap() as usize; - let global_from_compact_gid: Tensor = - Tensor::from_primitive(aux.global_from_compact_gid.clone()); - let gs_ids = global_from_compact_gid.clone().slice([0..num_visible]); - let projected_splats = - Tensor::from_primitive(TensorPrimitive::Float(aux.projected_splats.clone())); - let xys: Tensor = - projected_splats.clone().slice([0..num_visible, 0..2]); - let xys_ref = safetensor_to_burn::(&tensors.tensor("xys")?, &device); - let xys_ref = xys_ref.select(0, gs_ids.clone()); - compare("xy", xys, xys_ref, 1e-5, 2e-5); - let conics: Tensor = - projected_splats.clone().slice([0..num_visible, 2..5]); - let conics_ref = safetensor_to_burn::(&tensors.tensor("conics")?, &device); - let conics_ref = conics_ref.select(0, gs_ids.clone()); - compare("conics", conics, conics_ref, 1e-6, 2e-5); - // Check if images match. compare("img", out.clone(), img_ref, 1e-5, 1e-5); diff --git a/crates/brush-bench-test/tests/integration.rs b/crates/brush-bench-test/tests/integration.rs index e5bcb4a8..259f5980 100644 --- a/crates/brush-bench-test/tests/integration.rs +++ b/crates/brush-bench-test/tests/integration.rs @@ -161,8 +161,8 @@ fn test_forward_rendering() { assert!(means_data.iter().all(|&x| x.is_finite())); } -#[test] -fn test_training_step() { +#[tokio::test] +async fn test_training_step() { let device = WgpuDevice::default(); let batch = generate_test_batch((64, 64)); let splats = generate_test_splats(&device, 500); @@ -172,7 +172,7 @@ fn test_training_step() { &device, BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE), ); - let (final_splats, stats) = trainer.step(batch, splats); + let (final_splats, stats) = trainer.step(batch, splats).await; assert!(final_splats.num_splats() > 0); let loss = stats.loss.into_scalar(); @@ -190,8 +190,8 @@ fn test_batch_generation() { assert!(img_data.iter().all(|&x| (0.0..=1.1).contains(&x))); } -#[test] -fn test_multi_step_training() { +#[tokio::test] +async fn test_multi_step_training() { let device = WgpuDevice::default(); let batch = generate_test_batch((64, 64)); let config = TrainConfig::default(); @@ -205,7 +205,7 @@ fn test_multi_step_training() { // Run a few training steps for _ in 0..3 { - let (new_splats, stats) = trainer.step(batch.clone(), splats); + let (new_splats, stats) = trainer.step(batch.clone(), splats).await; splats = new_splats; let loss = stats.loss.into_scalar(); @@ -216,8 +216,8 @@ fn test_multi_step_training() { assert!(splats.num_splats() > 0); } -#[test] -fn test_gradient_validation() { +#[tokio::test] +async fn test_gradient_validation() { let device = WgpuDevice::default(); let splats = generate_test_splats(&device, 100); @@ -231,7 +231,8 @@ fn test_gradient_validation() { ); let img_size = glam::uvec2(64, 64); - let result = render_splats(&splats, &camera, img_size, Vec3::ZERO); + // Clone splats since render_splats takes ownership and we need splats for gradient validation + let result = render_splats(splats.clone(), &camera, img_size, Vec3::ZERO).await; let rendered: Tensor = Tensor::from_primitive(TensorPrimitive::Float(result.img)); diff --git a/crates/brush-kernel/build.rs b/crates/brush-kernel/build.rs deleted file mode 100644 index 1b8f58cc..00000000 --- a/crates/brush-kernel/build.rs +++ /dev/null @@ -1,36 +0,0 @@ -//! Build script to clean up old generated shader files. -//! -//! Previously, shaders were generated via build.rs into src/shaders/mod.rs files. -//! Now we use proc macros instead. This script removes any leftover generated files -//! to avoid confusion or compilation issues. - -use std::path::Path; - -fn main() { - // Only rerun if this build script changes - println!("cargo:rerun-if-changed=build.rs"); - - // Clean up old generated shader files in the workspace - let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap(); - let crates_dir = Path::new(&manifest_dir).parent().unwrap(); - - // List of crates that used to have generated shaders/mod.rs - let crates_with_old_shaders = [ - "brush-kernel", - "brush-prefix-sum", - "brush-sort", - "brush-render", - "brush-render-bwd", - ]; - - for crate_name in crates_with_old_shaders { - let old_file = crates_dir.join(crate_name).join("src/shaders/mod.rs"); - if old_file.exists() { - println!( - "cargo:warning=Removing old generated file: {}", - old_file.display() - ); - let _ = std::fs::remove_file(&old_file); - } - } -} diff --git a/crates/brush-process/src/lib.rs b/crates/brush-process/src/lib.rs index 775073d4..03d123a1 100644 --- a/crates/brush-process/src/lib.rs +++ b/crates/brush-process/src/lib.rs @@ -172,23 +172,21 @@ pub fn create_process< // For the first frame of a new file, clear existing frames if frame == 0 { - splat_view.clear(); + splat_view.clear().await; } - // Ensure we have space up to this frame index and set it - { - let mut guard = splat_view.write(); - if guard.len() <= frame { - guard.resize(frame + 1, splats.clone()); - } - guard[frame] = splats; - } + // Capture stats before moving splats + let num_splats = splats.num_splats(); + let sh_degree = splats.sh_degree(); + splat_view.set_at(frame, splats).await; emitter .emit(ProcessMessage::SplatsUpdated { up_axis: message.meta.up_axis, frame: frame as u32, total_frames, + num_splats, + sh_degree, }) .await; } diff --git a/crates/brush-process/src/message.rs b/crates/brush-process/src/message.rs index f4ab85d0..f26b9bda 100644 --- a/crates/brush-process/src/message.rs +++ b/crates/brush-process/src/message.rs @@ -55,6 +55,8 @@ pub enum ProcessMessage { up_axis: Option, frame: u32, total_frames: u32, + num_splats: u32, + sh_degree: u32, }, #[cfg(feature = "training")] TrainMessage(TrainMessage), diff --git a/crates/brush-process/src/slot.rs b/crates/brush-process/src/slot.rs index 1bf03674..9a45ff5f 100644 --- a/crates/brush-process/src/slot.rs +++ b/crates/brush-process/src/slot.rs @@ -1,37 +1,65 @@ -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::Arc; +use tokio::sync::Mutex; -/// A thread-safe slot for sharing data between the process and UI. -/// Uses Mutex because the inner type (Splats) is not Sync. +/// Async slot for sharing data between the process and UI. #[derive(Clone)] pub struct Slot(Arc>>); impl Slot { - pub fn write(&self) -> MutexGuard<'_, Vec> { - self.0.lock().unwrap() - } + /// Take ownership of value at index, apply async function, put result back. + pub async fn act(&self, index: usize, f: F) -> Option + where + F: AsyncFnOnce(T) -> (T, R), + { + let mut guard = self.0.lock().await; + let len = guard.len(); + if index >= len { + return None; + } + guard.swap(index, len - 1); + let value = guard.pop().unwrap(); + let (new_value, result) = f(value).await; - pub fn push(&self, value: T) { - self.0.lock().unwrap().push(value); + guard.push(new_value); + let new_len = guard.len(); + guard.swap(index, new_len - 1); + Some(result) } - pub fn clear(&self) { - self.0.lock().unwrap().clear(); + pub async fn map(&self, index: usize, f: F) -> Option + where + F: FnOnce(&T) -> R, + { + self.act(index, async move |value| { + let ret = f(&value); + (value, ret) + }) + .await } - pub fn len(&self) -> usize { - self.0.lock().unwrap().len() + pub async fn clone_main(&self) -> Option { + self.0.lock().await.last().cloned() } - pub fn is_empty(&self) -> bool { - self.0.lock().unwrap().is_empty() + /// Replace all contents with a single value. + pub async fn set(&self, value: T) { + let mut guard = self.0.lock().await; + guard.clear(); + guard.push(value); } - pub fn get(&self, index: usize) -> Option { - self.0.lock().unwrap().get(index).cloned() + /// Set value at index, or push if index == len. Panics if index > len. + pub async fn set_at(&self, index: usize, value: T) { + let mut guard = self.0.lock().await; + if index == guard.len() { + guard.push(value); + } else { + guard[index] = value; + } } - pub fn get_main(&self) -> Option { - self.0.lock().unwrap().last().cloned() + pub async fn clear(&self) { + self.0.lock().await.clear(); } } diff --git a/crates/brush-process/src/train_stream.rs b/crates/brush-process/src/train_stream.rs index c05dc23a..01d1ad09 100644 --- a/crates/brush-process/src/train_stream.rs +++ b/crates/brush-process/src/train_stream.rs @@ -10,11 +10,11 @@ use brush_render::{ MainBackend, gaussian_splats::{SplatRenderMode, Splats}, }; -use brush_rerun::{RerunConfig, visualize_tools::VisualizeTools}; +use brush_rerun::visualize_tools::VisualizeTools; use brush_train::{ RandomSplatsConfig, create_random_splats, eval::eval_stats, - msg::{RefineStats, TrainStepStats}, + msg::RefineStats, splats_into_autodiff, to_init_splats, train::{BOUND_PERCENTILE, SplatTrainer, get_splat_bounds}, }; @@ -127,12 +127,14 @@ pub(crate) async fn train_stream( // If the metadata has an up axis prefer that, otherwise estimate the up direction. let up_axis = up_axis.or(Some(estimated_up)); - *splat_slot.write() = vec![init_splats.clone()]; + splat_slot.set(init_splats.clone()).await; emitter .emit(ProcessMessage::SplatsUpdated { up_axis, frame: 0, total_frames: 1, + num_splats: init_splats.num_splats(), + sh_degree: init_splats.sh_degree(), }) .await; @@ -170,6 +172,7 @@ pub(crate) async fn train_stream( let export_path = base_path.join(&export_path_str); // Normalize path components let export_path: PathBuf = export_path.components().collect(); + let sh_degree = init_splats.sh_degree(); log::info!("Start training loop."); for iter in @@ -177,19 +180,19 @@ pub(crate) async fn train_stream( { let step_time = Instant::now(); - // Scope so we're sure we're sharing memory for the splat. - let stats = { - let batch = dataloader - .next_batch() - .instrument(trace_span!("Wait for next data batch")) - .await; + // Wait for next batch. + let batch = dataloader + .next_batch() + .instrument(trace_span!("Wait for next data batch")) + .await; - let mut splat_handle = splat_slot.write(); - let splats = splats_into_autodiff(splat_handle.remove(0)); - let (new_splats, stats) = trainer.step(batch, splats); - *splat_handle = vec![new_splats.valid()]; - stats - }; + let stats = splat_slot + .act(0, |splats| async { + let (new_splats, stats) = trainer.step(batch, splats_into_autodiff(splats)).await; + (new_splats.valid(), stats) + }) + .await + .unwrap(); let train_t = (iter as f32 / train_stream_config.train_config.total_steps as f32).clamp(0.0, 1.0); @@ -198,19 +201,19 @@ pub(crate) async fn train_stream( && iter.is_multiple_of(train_stream_config.train_config.refine_every) && train_t <= 0.95 { - let splats = splat_slot.get_main().unwrap(); - let (new_splats, refine) = trainer - .refine(iter, splats) - .instrument(trace_span!("Refine splats")) - .await; - *splat_slot.write() = vec![new_splats]; - Some(refine) + splat_slot + .act(0, async |splats| trainer.refine(iter, splats).await) + .await + .unwrap() } else { - None + let new_total: u32 = splat_slot.map(0, |s| s.num_splats()).await.unwrap(); + RefineStats { + num_added: 0, + num_pruned: 0, + total_splats: new_total, + } }; - let splats = splat_slot.get_main().unwrap(); - // We just finished iter 'iter', now starting iter + 1. let iter = iter + 1; let is_last_step = iter == train_stream_config.train_config.total_steps; @@ -232,7 +235,7 @@ pub(crate) async fn train_stream( &device, &emitter, &visualize, - splats.clone(), + splat_slot.clone_main().await.unwrap(), iter, eval_scene, save_path, @@ -248,7 +251,7 @@ pub(crate) async fn train_stream( #[cfg(not(target_family = "wasm"))] if iter % process_config.export_every == 0 || is_last_step { let res = export_checkpoint( - splats.clone(), + splat_slot.clone_main().await.unwrap(), &export_path, &process_config.export_name, iter, @@ -262,27 +265,37 @@ pub(crate) async fn train_stream( } } - let res = rerun_log( - &train_stream_config.rerun_config, - &visualize, - splats.clone(), - &stats, - iter, - is_last_step, - &device, - refine.as_ref(), - ) - .await - .context("Rerun visualization failed"); + { + let rerun_config = &train_stream_config.rerun_config; + visualize + .log_splat_stats(iter, refine.total_splats) + .unwrap(); + + if let Some(every) = rerun_config.rerun_log_splats_every + && (iter.is_multiple_of(every) || is_last_step) + { + let splats = splat_slot.clone_main().await.unwrap(); + visualize.log_splats(iter, splats).await.unwrap(); + } + + // Log out train stats. + if iter.is_multiple_of(rerun_config.rerun_log_train_stats_every) || is_last_step { + visualize + .log_train_stats(iter, stats.clone()) + .await + .unwrap(); + } - if let Err(error) = res { - emitter.emit(ProcessMessage::Warning { error }).await; + visualize.log_memory(iter, &WgpuRuntime::client(&device).memory_usage())?; + if refine.num_added > 0 { + visualize.log_refine_stats(iter, &refine).unwrap(); + } } - if refine.is_some() { + if refine.num_added > 0 { emitter .emit(ProcessMessage::TrainMessage(TrainMessage::RefineStep { - cur_splat_count: splats.num_splats(), + cur_splat_count: refine.total_splats, iter, })) .await; @@ -296,6 +309,8 @@ pub(crate) async fn train_stream( up_axis: None, frame: 0, total_frames: 1, + num_splats: refine.total_splats, + sh_degree, }) .await; @@ -335,12 +350,13 @@ async fn run_eval( let eval_img = view.image.load().await?; let sample = eval_stats( - &splats, + splats.clone(), &view.camera, eval_img, view.image.alpha_mode(), device, ) + .await .context("Failed to run eval for sample.")?; count += 1; @@ -398,35 +414,3 @@ async fn export_checkpoint( .context(format!("Failed to export ply {export_path:?}"))?; Ok(()) } - -async fn rerun_log( - rerun_config: &RerunConfig, - visualize: &VisualizeTools, - splats: Splats, - stats: &TrainStepStats, - iter: u32, - is_last_step: bool, - device: &WgpuDevice, - refine: Option<&RefineStats>, -) -> Result<(), anyhow::Error> { - visualize.log_splat_stats(iter, &splats)?; - - if let Some(every) = rerun_config.rerun_log_splats_every - && (iter.is_multiple_of(every) || is_last_step) - { - visualize.log_splats(iter, splats.clone()).await?; - } - - // Log out train stats. - if iter.is_multiple_of(rerun_config.rerun_log_train_stats_every) || is_last_step { - visualize.log_train_stats(iter, stats.clone()).await?; - } - - let client = WgpuRuntime::client(device); - visualize.log_memory(iter, &client.memory_usage())?; - // Emit some messages. Important to not count these in the training time (as this might pause). - if let Some(stats) = refine { - visualize.log_refine_stats(iter, stats)?; - } - Ok(()) -} diff --git a/crates/brush-render-bwd/src/burn_glue.rs b/crates/brush-render-bwd/src/burn_glue.rs index 72b88723..81adbd2f 100644 --- a/crates/brush-render-bwd/src/burn_glue.rs +++ b/crates/brush-render-bwd/src/burn_glue.rs @@ -1,9 +1,9 @@ use brush_render::{ - MainBackendBase, SplatForward, + MainBackendBase, RenderAux, SplatOps, camera::Camera, gaussian_splats::{SplatRenderMode, Splats}, - render_aux::RenderAux, sh::{sh_coeffs_for_degree, sh_degree_from_coeffs}, + shaders::helpers::ProjectUniforms, }; use burn::{ backend::{ @@ -30,55 +30,80 @@ use burn_fusion::{ use burn_ir::{CustomOpIr, HandleContainer, OperationIr, OperationOutput, TensorIr}; use glam::Vec3; -use crate::render_bwd::SplatGrads; +/// Intermediate gradients from the rasterize backward pass. +#[derive(Debug, Clone)] +pub struct RasterizeGrads { + /// Gradients w.r.t. projected splat data [`num_points`, 8]. + pub v_projected_splats: FloatTensor, + /// Gradients w.r.t. raw opacity from rasterization [`num_points`]. + pub v_raw_opac: FloatTensor, + /// Refinement weights for densification [`num_points`]. + pub v_refine_weight: FloatTensor, +} -/// Like [`SplatForward`], but for backends that support differentiation. -/// -/// This shouldn't be a separate trait, but atm is needed because of orphan trait rules. -pub trait SplatForwardDiff { - /// Render splats to a buffer. - /// - /// This projects the gaussians, sorts them, and rasterizes them to a buffer, in a - /// differentiable way. +/// Final gradients w.r.t. splat inputs from the project backward pass. +#[derive(Debug, Clone)] +pub struct SplatGrads { + pub v_means: FloatTensor, + pub v_quats: FloatTensor, + pub v_scales: FloatTensor, + pub v_coeffs: FloatTensor, + pub v_raw_opac: FloatTensor, + pub v_refine_weight: FloatTensor, +} + +/// Backward pass trait mirroring [`SplatOps`]. +pub trait SplatBwdOps: SplatOps { + /// Backward pass for rasterization. #[allow(clippy::too_many_arguments)] - fn render_splats( - camera: &Camera, + fn rasterize_bwd( + out_img: FloatTensor, + projected_splats: FloatTensor, + global_from_compact_gid: IntTensor, + compact_gid_from_isect: IntTensor, + tile_offsets: IntTensor, + background: Vec3, img_size: glam::UVec2, + v_output: FloatTensor, + ) -> RasterizeGrads; + + /// Backward pass for projection. + #[allow(clippy::too_many_arguments)] + fn project_bwd( means: FloatTensor, log_scales: FloatTensor, quats: FloatTensor, - sh_coeffs: FloatTensor, - raw_opacity: FloatTensor, + raw_opac: FloatTensor, + num_visible: IntTensor, + global_from_compact_gid: IntTensor, + project_uniforms: ProjectUniforms, + sh_degree: u32, render_mode: SplatRenderMode, - background: Vec3, - ) -> SplatOutputDiff; -} - -pub trait SplatBackwardOps { - /// Backward pass for `render_splats`. - /// - /// Do not use directly, `render_splats` will use this to calculate gradients. - #[allow(unused_variables)] - fn render_splats_bwd( - state: GaussianBackwardState, - v_output: FloatTensor, + rasterize_grads: RasterizeGrads, ) -> SplatGrads; } +/// State saved during forward pass for backward computation. #[derive(Debug, Clone)] -pub struct GaussianBackwardState { - pub(crate) means: FloatTensor, - pub(crate) quats: FloatTensor, - pub(crate) log_scales: FloatTensor, - pub(crate) raw_opac: FloatTensor, - pub(crate) out_img: FloatTensor, - pub(crate) projected_splats: FloatTensor, - pub(crate) uniforms_buffer: IntTensor, - pub(crate) compact_gid_from_isect: IntTensor, - pub(crate) global_from_compact_gid: IntTensor, - pub(crate) tile_offsets: IntTensor, - pub(crate) render_mode: SplatRenderMode, - pub(crate) sh_degree: u32, +struct GaussianBackwardState { + means: FloatTensor, + quats: FloatTensor, + log_scales: FloatTensor, + raw_opac: FloatTensor, + + projected_splats: FloatTensor, + project_uniforms: ProjectUniforms, + num_visible: IntTensor, + global_from_compact_gid: IntTensor, + + out_img: FloatTensor, + compact_gid_from_isect: IntTensor, + tile_offsets: IntTensor, + + render_mode: SplatRenderMode, + sh_degree: u32, + background: Vec3, + img_size: glam::UVec2, } #[derive(Debug)] @@ -87,7 +112,7 @@ struct RenderBackwards; const NUM_BWD_ARGS: usize = 6; // Implement gradient registration when rendering backwards. -impl> Backward for RenderBackwards { +impl> Backward for RenderBackwards { type State = GaussianBackwardState; fn backward( @@ -99,7 +124,6 @@ impl> Backward for RenderBackw let _span = tracing::trace_span!("render_gaussians backwards").entered(); let state = ops.state; - let v_output = grads.consume::(&ops.node); // Register gradients for parent nodes (This code is already skipped entirely @@ -113,152 +137,335 @@ impl> Backward for RenderBackw raw_opacity_parent, ] = ops.parents; - let v_tens = B::render_splats_bwd(state, v_output); + // Step 1: Rasterize backward + let rasterize_grads = B::rasterize_bwd( + state.out_img, + state.projected_splats, + state.global_from_compact_gid.clone(), + state.compact_gid_from_isect, + state.tile_offsets, + state.background, + state.img_size, + v_output, + ); + + // Step 2: Project backward + let splat_grads = B::project_bwd( + state.means, + state.log_scales, + state.quats, + state.raw_opac, + state.num_visible, + state.global_from_compact_gid, + state.project_uniforms, + state.sh_degree, + state.render_mode, + rasterize_grads, + ); if let Some(node) = mean_parent { - grads.register::(node.id, v_tens.v_means); + grads.register::(node.id, splat_grads.v_means); } // Register the gradients for the dummy xy input. if let Some(node) = refine_weight { - grads.register::(node.id, v_tens.v_refine_weight); + grads.register::(node.id, splat_grads.v_refine_weight); } if let Some(node) = log_scales_parent { - grads.register::(node.id, v_tens.v_scales); + grads.register::(node.id, splat_grads.v_scales); } if let Some(node) = quats_parent { - grads.register::(node.id, v_tens.v_quats); + grads.register::(node.id, splat_grads.v_quats); } if let Some(node) = coeffs_parent { - grads.register::(node.id, v_tens.v_coeffs); + grads.register::(node.id, splat_grads.v_coeffs); } if let Some(node) = raw_opacity_parent { - grads.register::(node.id, v_tens.v_raw_opac); + grads.register::(node.id, splat_grads.v_raw_opac); } } } pub struct SplatOutputDiff { pub img: FloatTensor, - pub aux: RenderAux, + pub render_aux: RenderAux, pub refine_weight_holder: Tensor, } -// Implement -impl + SplatForward, C: CheckpointStrategy> - SplatForwardDiff for Autodiff +/// Render splats on a differentiable backend. +/// +/// This is the main entry point for differentiable rendering, wrapping +/// the forward pass with autodiff support. +/// +/// Takes ownership of the splats. Clone before calling if you need to reuse them. +pub async fn render_splats( + splats: Splats>, + camera: &Camera, + img_size: glam::UVec2, + background: Vec3, +) -> SplatOutputDiff> +where + B: Backend + SplatBwdOps, + C: CheckpointStrategy, { - fn render_splats( - camera: &Camera, - img_size: glam::UVec2, - means: FloatTensor, - log_scales: FloatTensor, - quats: FloatTensor, - sh_coeffs: FloatTensor, - raw_opacity: FloatTensor, - render_mode: SplatRenderMode, + splats.validate_values(); + + let device = Tensor::, 2>::from_primitive(TensorPrimitive::Float( + splats.means.val().into_primitive().tensor(), + )) + .device(); + let refine_weight_holder = Tensor::, 1>::zeros([1], &device).require_grad(); + + let prep_nodes = RenderBackwards + .prepare::([ + splats.means.val().into_primitive().tensor().node, + refine_weight_holder.clone().into_primitive().tensor().node, + splats.log_scales.val().into_primitive().tensor().node, + splats.rotations.val().into_primitive().tensor().node, + splats.sh_coeffs.val().into_primitive().tensor().node, + splats.raw_opacities.val().into_primitive().tensor().node, + ]) + .compute_bound() + .stateful(); + + let means = splats + .means + .val() + .into_primitive() + .tensor() + .into_primitive(); + let log_scales = splats + .log_scales + .val() + .into_primitive() + .tensor() + .into_primitive(); + let quats = splats + .rotations + .val() + .into_primitive() + .tensor() + .into_primitive(); + let sh_coeffs_dims = splats.sh_coeffs.dims(); + let sh_coeffs = splats + .sh_coeffs + .val() + .into_primitive() + .tensor() + .into_primitive(); + let raw_opacity = splats + .raw_opacities + .val() + .into_primitive() + .tensor() + .into_primitive(); + let render_mode = splats.render_mode; + + let project_output = >::project( + camera, + img_size, + means.clone(), + log_scales.clone(), + quats.clone(), + sh_coeffs, + raw_opacity.clone(), + render_mode, + ); + + // Async readback + let num_intersections = project_output.read_num_intersections().await; + + let (out_img, render_aux, compact_gid_from_isect) = + >::rasterize(&project_output, num_intersections, background, true); + + let wrapped_render_aux = RenderAux::> { + num_visible: render_aux.num_visible.clone(), + num_intersections: render_aux.num_intersections, + visible: as AutodiffBackend>::from_inner(render_aux.visible.clone()), + tile_offsets: render_aux.tile_offsets.clone(), + img_size: render_aux.img_size, + }; + + let sh_degree = sh_degree_from_coeffs(sh_coeffs_dims[1] as u32); + + match prep_nodes { + OpsKind::Tracked(prep) => { + let state = GaussianBackwardState { + means, + log_scales, + quats, + raw_opac: raw_opacity, + sh_degree, + out_img: out_img.clone(), + projected_splats: project_output.projected_splats, + project_uniforms: project_output.project_uniforms, + num_visible: project_output.num_visible, + tile_offsets: render_aux.tile_offsets, + compact_gid_from_isect, + render_mode, + global_from_compact_gid: project_output.global_from_compact_gid, + background, + img_size, + }; + + let out_img = prep.finish(state, out_img); + + let result = SplatOutputDiff { + img: out_img, + render_aux: wrapped_render_aux, + refine_weight_holder, + }; + result.render_aux.validate(); + result + } + OpsKind::UnTracked(prep) => { + let result = SplatOutputDiff { + img: prep.finish(out_img), + render_aux: wrapped_render_aux, + refine_weight_holder, + }; + result.render_aux.validate(); + result + } + } +} + +impl SplatBwdOps for Fusion { + #[allow(clippy::too_many_arguments)] + fn rasterize_bwd( + out_img: FloatTensor, + projected_splats: FloatTensor, + global_from_compact_gid: IntTensor, + compact_gid_from_isect: IntTensor, + tile_offsets: IntTensor, background: Vec3, - ) -> SplatOutputDiff { - // Get backend tensors & dequantize if needed. Could try and support quantized inputs - // in the future. - let device = - Tensor::::from_primitive(TensorPrimitive::Float(means.clone())).device(); - let refine_weight_holder = Tensor::::zeros([1], &device).require_grad(); - - // Prepare backward pass, and check if we even need to do it. Store nodes that need gradients. - let prep_nodes = RenderBackwards - .prepare::([ - means.node.clone(), - refine_weight_holder.clone().into_primitive().tensor().node, - log_scales.node.clone(), - quats.node.clone(), - sh_coeffs.node.clone(), - raw_opacity.node.clone(), - ]) - .compute_bound() - .stateful(); - - // Render complete forward pass. - let (out_img, aux) = >::render_splats( - camera, - img_size, - means.clone().into_primitive(), - log_scales.clone().into_primitive(), - quats.clone().into_primitive(), - sh_coeffs.clone().into_primitive(), - raw_opacity.clone().into_primitive(), - render_mode, - background, - true, + img_size: glam::UVec2, + v_output: FloatTensor, + ) -> RasterizeGrads { + #[derive(Debug)] + struct CustomOp { + desc: CustomOpIr, + background: Vec3, + img_size: glam::UVec2, + } + + impl Operation> for CustomOp { + fn execute( + &self, + h: &mut HandleContainer>>, + ) { + let (inputs, outputs) = self.desc.as_fixed(); + + let [ + v_output, + out_img, + projected_splats, + global_from_compact_gid, + compact_gid_from_isect, + tile_offsets, + ] = inputs; + + let [v_projected_splats, v_raw_opac, v_refine_weight] = outputs; + + let grads = >::rasterize_bwd( + h.get_float_tensor::(out_img), + h.get_float_tensor::(projected_splats), + h.get_int_tensor::(global_from_compact_gid), + h.get_int_tensor::(compact_gid_from_isect), + h.get_int_tensor::(tile_offsets), + self.background, + self.img_size, + h.get_float_tensor::(v_output), + ); + + h.register_float_tensor::( + &v_projected_splats.id, + grads.v_projected_splats, + ); + h.register_float_tensor::(&v_raw_opac.id, grads.v_raw_opac); + h.register_float_tensor::( + &v_refine_weight.id, + grads.v_refine_weight, + ); + } + } + + let client = v_output.client.clone(); + let num_points = projected_splats.shape[0]; + + let v_projected_splats_out = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_points, 8]), + DType::F32, + ); + let v_raw_opac = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_points]), + DType::F32, + ); + let v_refine_weight = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_points]), + DType::F32, ); - let wrapped_aux = RenderAux:: { - projected_splats: ::from_inner(aux.projected_splats.clone()), - num_intersections: aux.num_intersections, - tile_offsets: aux.tile_offsets.clone(), - compact_gid_from_isect: aux.compact_gid_from_isect.clone(), - global_from_compact_gid: aux.global_from_compact_gid.clone(), - uniforms_buffer: aux.uniforms_buffer.clone(), - visible: ::from_inner(aux.visible), - img_size: aux.img_size, + let input_tensors = [ + v_output, + out_img, + projected_splats, + global_from_compact_gid, + compact_gid_from_isect, + tile_offsets, + ]; + + let stream = OperationStreams::with_inputs(&input_tensors); + let desc = CustomOpIr::new( + "rasterize_bwd", + &input_tensors.map(|t| t.into_ir()), + &[v_projected_splats_out, v_raw_opac, v_refine_weight], + ); + let op = CustomOp { + desc: desc.clone(), + background, + img_size, }; - match prep_nodes { - OpsKind::Tracked(prep) => { - // Save state needed for backward pass. - let state = GaussianBackwardState { - means: means.into_primitive(), - log_scales: log_scales.into_primitive(), - quats: quats.into_primitive(), - raw_opac: raw_opacity.into_primitive(), - sh_degree: sh_degree_from_coeffs( - Tensor::::from_primitive(TensorPrimitive::Float(sh_coeffs)).dims() - [1] as u32, - ), - out_img: out_img.clone(), - projected_splats: aux.projected_splats, - uniforms_buffer: aux.uniforms_buffer, - tile_offsets: aux.tile_offsets, - compact_gid_from_isect: aux.compact_gid_from_isect, - render_mode, - global_from_compact_gid: aux.global_from_compact_gid, - }; + let outputs = client + .register(stream, OperationIr::Custom(desc), op) + .outputs(); - let out_img = prep.finish(state, out_img); + let [v_projected_splats, v_raw_opac, v_refine_weight] = outputs; - SplatOutputDiff { - img: out_img, - aux: wrapped_aux, - refine_weight_holder, - } - } - OpsKind::UnTracked(prep) => { - // When no node is tracked, we can just use the original operation without - // keeping any state. - SplatOutputDiff { - img: prep.finish(out_img), - aux: wrapped_aux, - refine_weight_holder, - } - } + RasterizeGrads { + v_projected_splats, + v_raw_opac, + v_refine_weight, } } -} -impl SplatBackwardOps for Fusion { - fn render_splats_bwd( - state: GaussianBackwardState, - v_output: FloatTensor, + #[allow(clippy::too_many_arguments)] + fn project_bwd( + means: FloatTensor, + log_scales: FloatTensor, + quats: FloatTensor, + raw_opac: FloatTensor, + num_visible: IntTensor, + global_from_compact_gid: IntTensor, + project_uniforms: ProjectUniforms, + sh_degree: u32, + render_mode: SplatRenderMode, + rasterize_grads: RasterizeGrads, ) -> SplatGrads { #[derive(Debug)] struct CustomOp { desc: CustomOpIr, render_mode: SplatRenderMode, sh_degree: u32, + project_uniforms: ProjectUniforms, } impl Operation> for CustomOp { @@ -269,114 +476,115 @@ impl SplatBackwardOps for Fusion { let (inputs, outputs) = self.desc.as_fixed(); let [ - v_output, means, - quats, log_scales, + quats, raw_opac, - out_img, - projected_splats, - uniforms_buffer, - tile_offsets, - compact_gid_from_isect, + num_visible, global_from_compact_gid, + v_projected_splats, + v_raw_opac_in, + v_refine_weight_in, ] = inputs; - let [v_means, v_quats, v_scales, v_coeffs, v_raw_opac, v_refine] = outputs; - - let inner_state = GaussianBackwardState { - means: h.get_float_tensor::(means), - log_scales: h.get_float_tensor::(log_scales), - quats: h.get_float_tensor::(quats), - raw_opac: h.get_float_tensor::(raw_opac), - out_img: h.get_float_tensor::(out_img), - projected_splats: h.get_float_tensor::(projected_splats), - uniforms_buffer: h.get_int_tensor::(uniforms_buffer), - tile_offsets: h.get_int_tensor::(tile_offsets), - compact_gid_from_isect: h - .get_int_tensor::(compact_gid_from_isect), - global_from_compact_gid: h - .get_int_tensor::(global_from_compact_gid), - sh_degree: self.sh_degree, - render_mode: self.render_mode, + let [ + v_means, + v_quats, + v_scales, + v_coeffs, + v_raw_opac, + v_refine_weight, + ] = outputs; + + let inner_rasterize_grads = RasterizeGrads { + v_projected_splats: h.get_float_tensor::(v_projected_splats), + v_raw_opac: h.get_float_tensor::(v_raw_opac_in), + v_refine_weight: h.get_float_tensor::(v_refine_weight_in), }; - let grads = - >::render_splats_bwd( - inner_state, - h.get_float_tensor::(v_output), - ); + let grads = >::project_bwd( + h.get_float_tensor::(means), + h.get_float_tensor::(log_scales), + h.get_float_tensor::(quats), + h.get_float_tensor::(raw_opac), + h.get_int_tensor::(num_visible), + h.get_int_tensor::(global_from_compact_gid), + self.project_uniforms, + self.sh_degree, + self.render_mode, + inner_rasterize_grads, + ); - // // Register output. h.register_float_tensor::(&v_means.id, grads.v_means); h.register_float_tensor::(&v_quats.id, grads.v_quats); h.register_float_tensor::(&v_scales.id, grads.v_scales); h.register_float_tensor::(&v_coeffs.id, grads.v_coeffs); h.register_float_tensor::(&v_raw_opac.id, grads.v_raw_opac); - h.register_float_tensor::(&v_refine.id, grads.v_refine_weight); + h.register_float_tensor::( + &v_refine_weight.id, + grads.v_refine_weight, + ); } } - let client = v_output.client.clone(); - let num_points = state.means.shape[0]; - let coeffs = sh_coeffs_for_degree(state.sh_degree) as usize; + let client = means.client.clone(); + let num_points = means.shape[0]; + let coeffs = sh_coeffs_for_degree(sh_degree) as usize; - let v_means = TensorIr::uninit( + let v_means_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, 3]), DType::F32, ); - let v_scales = TensorIr::uninit( + let v_scales_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, 3]), DType::F32, ); - let v_quats = TensorIr::uninit( + let v_quats_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, 4]), DType::F32, ); - let v_coeffs = TensorIr::uninit( + let v_coeffs_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, coeffs, 3]), DType::F32, ); - let v_raw_opac = TensorIr::uninit( + let v_raw_opac_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points]), DType::F32, ); - let v_refine_weight = TensorIr::uninit( + let v_refine_weight_out = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points]), DType::F32, ); let input_tensors = [ - v_output, - state.means, - state.quats, - state.log_scales, - state.raw_opac, - state.out_img, - state.projected_splats, - state.uniforms_buffer, - state.tile_offsets, - state.compact_gid_from_isect, - state.global_from_compact_gid, + means, + log_scales, + quats, + raw_opac, + num_visible, + global_from_compact_gid, + rasterize_grads.v_projected_splats, + rasterize_grads.v_raw_opac, + rasterize_grads.v_refine_weight, ]; let stream = OperationStreams::with_inputs(&input_tensors); let desc = CustomOpIr::new( - "render_splat_bwd", + "project_bwd", &input_tensors.map(|t| t.into_ir()), &[ - v_means, - v_quats, - v_scales, - v_coeffs, - v_raw_opac, - v_refine_weight, + v_means_out, + v_quats_out, + v_scales_out, + v_coeffs_out, + v_raw_opac_out, + v_refine_weight_out, ], ); @@ -385,10 +593,10 @@ impl SplatBackwardOps for Fusion { stream, OperationIr::Custom(desc.clone()), CustomOp { - // state, desc, - sh_degree: state.sh_degree, - render_mode: state.render_mode, + sh_degree, + render_mode, + project_uniforms, }, ) .outputs(); @@ -412,29 +620,3 @@ impl SplatBackwardOps for Fusion { } } } - -/// Render splats on a differentiable backend. -pub fn render_splats( - splats: &Splats, - camera: &Camera, - img_size: glam::UVec2, - background: Vec3, -) -> SplatOutputDiff -where - B: Backend + SplatForwardDiff, -{ - splats.validate_values(); - let result = B::render_splats( - camera, - img_size, - splats.means.val().into_primitive().tensor(), - splats.log_scales.val().into_primitive().tensor(), - splats.rotations.val().into_primitive().tensor(), - splats.sh_coeffs.val().into_primitive().tensor(), - splats.raw_opacities.val().into_primitive().tensor(), - splats.render_mode, - background, - ); - result.aux.validate_values(); - result -} diff --git a/crates/brush-render-bwd/src/lib.rs b/crates/brush-render-bwd/src/lib.rs index 39698ab2..4bbae9e4 100644 --- a/crates/brush-render-bwd/src/lib.rs +++ b/crates/brush-render-bwd/src/lib.rs @@ -1,7 +1,4 @@ pub mod burn_glue; mod render_bwd; -pub use burn_glue::render_splats; - -#[cfg(test)] -mod tests; +pub use burn_glue::{RasterizeGrads, SplatBwdOps, SplatGrads, SplatOutputDiff, render_splats}; diff --git a/crates/brush-render-bwd/src/render_bwd.rs b/crates/brush-render-bwd/src/render_bwd.rs index c63574be..808f9464 100644 --- a/crates/brush-render-bwd/src/render_bwd.rs +++ b/crates/brush-render-bwd/src/render_bwd.rs @@ -1,19 +1,21 @@ -use brush_kernel::{CubeCount, calc_cube_count_1d}; +use brush_kernel::{CubeCount, calc_cube_count_1d, create_meta_binding}; +use brush_render::MainBackendBase; use brush_render::gaussian_splats::SplatRenderMode; +use brush_render::shaders::helpers::RasterizeUniforms; use brush_wgsl::wgsl_kernel; -use brush_render::MainBackendBase; use brush_render::sh::sh_coeffs_for_degree; use burn::tensor::FloatDType; -use burn::tensor::ops::FloatTensorOps; -use burn::{prelude::Backend, tensor::ops::FloatTensor}; +use burn::tensor::ops::IntTensor; +use burn::tensor::ops::{FloatTensor, FloatTensorOps}; use burn_cubecl::cubecl::features::TypeUsage; use burn_cubecl::cubecl::ir::{ElemType, FloatKind, StorageType}; use burn_cubecl::cubecl::server::Bindings; use burn_cubecl::kernel::into_contiguous; -use glam::uvec2; +use glam::{Vec3, uvec2}; -use crate::burn_glue::{GaussianBackwardState, SplatBackwardOps}; +use crate::burn_glue::{RasterizeGrads, SplatBwdOps, SplatGrads}; +use brush_render::shaders::helpers::ProjectUniforms; // Kernel definitions using proc macro #[wgsl_kernel( @@ -33,63 +35,31 @@ pub struct RasterizeBackwards { pub webgpu: bool, } -#[derive(Debug, Clone)] -pub struct SplatGrads { - pub v_means: FloatTensor, - pub v_quats: FloatTensor, - pub v_scales: FloatTensor, - pub v_coeffs: FloatTensor, - pub v_raw_opac: FloatTensor, - pub v_refine_weight: FloatTensor, -} - -impl SplatBackwardOps for MainBackendBase { - fn render_splats_bwd( - state: GaussianBackwardState, +impl SplatBwdOps for MainBackendBase { + #[allow(clippy::too_many_arguments)] + fn rasterize_bwd( + out_img: FloatTensor, + projected_splats: FloatTensor, + global_from_compact_gid: IntTensor, + compact_gid_from_isect: IntTensor, + tile_offsets: IntTensor, + background: Vec3, + img_size: glam::UVec2, v_output: FloatTensor, - ) -> SplatGrads { + ) -> RasterizeGrads { + let _span = tracing::trace_span!("rasterize_bwd").entered(); + // Comes from loss, might not be contiguous. let v_output = into_contiguous(v_output); - // Comes from params, might not be contiguous. - let means = into_contiguous(state.means); - let log_scales = into_contiguous(state.log_scales); - let quats = into_contiguous(state.quats); - let raw_opac = into_contiguous(state.raw_opac); - - // We're in charge of these, SHOULD be contiguous but might as well. - let projected_splats = into_contiguous(state.projected_splats); - let uniforms_buffer = into_contiguous(state.uniforms_buffer); - let compact_gid_from_isect = into_contiguous(state.compact_gid_from_isect); - let global_from_compact_gid = into_contiguous(state.global_from_compact_gid); - let tile_offsets = into_contiguous(state.tile_offsets); - - let device = &state.out_img.device; - let img_dimgs = state.out_img.shape.dims; - let img_size = glam::uvec2(img_dimgs[1] as u32, img_dimgs[0] as u32); + let device = &out_img.device; + let num_points = projected_splats.shape.dims[0]; - let num_points = means.shape.dims[0]; - - let client = &means.client; + let client = &projected_splats.client; - // Setup tensors. - // Nb: these are packed vec3 values, special care is taken in the kernel to respect alignment. - let v_means = Self::float_zeros([num_points, 3].into(), device, FloatDType::F32); - - let v_scales = Self::float_zeros([num_points, 3].into(), device, FloatDType::F32); - let v_quats = Self::float_zeros([num_points, 4].into(), device, FloatDType::F32); - let v_coeffs = Self::float_zeros( - [ - num_points, - sh_coeffs_for_degree(state.sh_degree) as usize, - 3, - ] - .into(), - device, - FloatDType::F32, - ); + // Setup output tensors. + let v_projected_splats = Self::float_zeros([num_points, 8].into(), device, FloatDType::F32); let v_raw_opac = Self::float_zeros([num_points].into(), device, FloatDType::F32); - let v_grads = Self::float_zeros([num_points, 8].into(), device, FloatDType::F32); let v_refine_weight = Self::float_zeros([num_points].into(), device, FloatDType::F32); let tile_bounds = uvec2( @@ -101,15 +71,20 @@ impl SplatBackwardOps for MainBackendBase { .div_ceil(brush_render::shaders::helpers::TILE_WIDTH), ); + // Create RasterizeUniforms for the backward rasterize pass + let rasterize_uniforms = RasterizeUniforms { + tile_bounds: tile_bounds.into(), + img_size: img_size.into(), + background: [background.x, background.y, background.z, 1.0], + }; + let hard_floats = client .properties() .type_usage(StorageType::Atomic(ElemType::Float(FloatKind::F32))) .contains(TypeUsage::AtomicAdd); let webgpu = cfg!(target_family = "wasm"); - let mip_splat = matches!(state.render_mode, SplatRenderMode::Mip); - // Use checked execution, as the atomic loops are potentially unbounded. tracing::trace_span!("RasterizeBackwards").in_scope(|| { // SAFETY: Kernel checked to have no OOB, bounded loops. unsafe { @@ -117,61 +92,103 @@ impl SplatBackwardOps for MainBackendBase { .launch_unchecked( RasterizeBackwards::task(hard_floats, webgpu), CubeCount::Static(tile_bounds.x * tile_bounds.y, 1, 1), - Bindings::new().with_buffers(vec![ - uniforms_buffer.handle.clone().binding(), - compact_gid_from_isect.handle.binding(), - global_from_compact_gid.handle.clone().binding(), - tile_offsets.handle.binding(), - projected_splats.handle.binding(), - state.out_img.handle.binding(), - v_output.handle.binding(), - v_grads.handle.clone().binding(), - v_raw_opac.handle.clone().binding(), - v_refine_weight.handle.clone().binding(), - ]), + Bindings::new() + .with_buffers(vec![ + compact_gid_from_isect.handle.binding(), + global_from_compact_gid.handle.binding(), + tile_offsets.handle.binding(), + projected_splats.handle.binding(), + out_img.handle.binding(), + v_output.handle.binding(), + v_projected_splats.handle.clone().binding(), + v_raw_opac.handle.clone().binding(), + v_refine_weight.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(rasterize_uniforms)), ) .expect("Failed to bwd-diff splats"); } }); - tracing::trace_span!("ProjectBackwards").in_scope(|| - // SAFETY: Kernel has to contain no OOB indexing, bounded loops. - unsafe { - client.launch_unchecked( - ProjectBackwards::task(mip_splat), - calc_cube_count_1d(num_points as u32, ProjectBackwards::WORKGROUP_SIZE[0]), - Bindings::new().with_buffers( - vec![ - uniforms_buffer.handle.binding(), - means.handle.binding(), - log_scales.handle.binding(), - quats.handle.binding(), - raw_opac.handle.binding(), - global_from_compact_gid.handle.binding(), - v_grads.handle.binding(), - v_means.handle.clone().binding(), - v_scales.handle.clone().binding(), - v_quats.handle.clone().binding(), - v_coeffs.handle.clone().binding(), - v_raw_opac.handle.clone().binding(), - ]), - ).expect("Failed to bwd-diff splats"); - }); - - assert!(v_means.is_contiguous(), "Grads must be contiguous"); - assert!(v_quats.is_contiguous(), "Grads must be contiguous"); - assert!(v_scales.is_contiguous(), "Grads must be contiguous"); - assert!(v_coeffs.is_contiguous(), "Grads must be contiguous"); - assert!(v_raw_opac.is_contiguous(), "Grads must be contiguous"); - assert!(v_refine_weight.is_contiguous(), "Grads must be contiguous"); + RasterizeGrads { + v_projected_splats, + v_raw_opac, + v_refine_weight, + } + } + + #[allow(clippy::too_many_arguments)] + fn project_bwd( + means: FloatTensor, + log_scales: FloatTensor, + quats: FloatTensor, + raw_opac: FloatTensor, + num_visible: IntTensor, + global_from_compact_gid: IntTensor, + project_uniforms: ProjectUniforms, + sh_degree: u32, + render_mode: SplatRenderMode, + rasterize_grads: RasterizeGrads, + ) -> SplatGrads { + let _span = tracing::trace_span!("project_bwd").entered(); + + // Comes from params, might not be contiguous. + let means = into_contiguous(means); + let log_scales = into_contiguous(log_scales); + let quats = into_contiguous(quats); + let raw_opac = into_contiguous(raw_opac); + + let device = &means.device; + let num_points = means.shape.dims[0]; + let client = &means.client; + + // Setup output tensors. + let v_means = Self::float_zeros([num_points, 3].into(), device, FloatDType::F32); + let v_scales = Self::float_zeros([num_points, 3].into(), device, FloatDType::F32); + let v_quats = Self::float_zeros([num_points, 4].into(), device, FloatDType::F32); + let v_coeffs = Self::float_zeros( + [num_points, sh_coeffs_for_degree(sh_degree) as usize, 3].into(), + device, + FloatDType::F32, + ); + + let mip_splat = matches!(render_mode, SplatRenderMode::Mip); + + tracing::trace_span!("ProjectBackwards").in_scope(|| { + // SAFETY: Kernel has to contain no OOB indexing, bounded loops. + unsafe { + client + .launch_unchecked( + ProjectBackwards::task(mip_splat), + calc_cube_count_1d(num_points as u32, ProjectBackwards::WORKGROUP_SIZE[0]), + Bindings::new() + .with_buffers(vec![ + num_visible.handle.binding(), + means.handle.binding(), + log_scales.handle.binding(), + quats.handle.binding(), + raw_opac.handle.binding(), + global_from_compact_gid.handle.binding(), + rasterize_grads.v_projected_splats.handle.binding(), + v_means.handle.clone().binding(), + v_scales.handle.clone().binding(), + v_quats.handle.clone().binding(), + v_coeffs.handle.clone().binding(), + rasterize_grads.v_raw_opac.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(project_uniforms)), + ) + .expect("Failed to bwd-diff splats"); + } + }); SplatGrads { v_means, v_quats, v_scales, v_coeffs, - v_raw_opac, - v_refine_weight, + v_raw_opac: rasterize_grads.v_raw_opac, + v_refine_weight: rasterize_grads.v_refine_weight, } } } diff --git a/crates/brush-render-bwd/src/shaders/project_backwards.wgsl b/crates/brush-render-bwd/src/shaders/project_backwards.wgsl index 1c569646..6aead3cb 100644 --- a/crates/brush-render-bwd/src/shaders/project_backwards.wgsl +++ b/crates/brush-render-bwd/src/shaders/project_backwards.wgsl @@ -1,6 +1,6 @@ #import helpers; -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; +@group(0) @binding(0) var num_visible: u32; @group(0) @binding(1) var means: array; @group(0) @binding(2) var log_scales: array; @@ -16,6 +16,7 @@ @group(0) @binding(9) var v_quats: array; @group(0) @binding(10) var v_coeffs: array; @group(0) @binding(11) var v_opacs: array; +@group(0) @binding(12) var uniforms: helpers::ProjectUniforms; const SH_C0: f32 = 0.2820947917738781f; @@ -299,7 +300,7 @@ fn main( @builtin(local_invocation_index) lid: u32, ) { let compact_gid = helpers::get_global_id(wid, num_wgs, lid, WG_SIZE); - if compact_gid >= uniforms.num_visible { + if compact_gid >= num_visible { return; } diff --git a/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl b/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl index 9e28b003..1741319c 100644 --- a/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl +++ b/crates/brush-render-bwd/src/shaders/rasterize_backwards.wgsl @@ -1,17 +1,17 @@ #import helpers; -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; -@group(0) @binding(1) var compact_gid_from_isect: array; -@group(0) @binding(2) var global_from_compact_gid: array; -@group(0) @binding(3) var tile_offsets: array; -@group(0) @binding(4) var projected: array; -@group(0) @binding(5) var output: array; -@group(0) @binding(6) var v_output: array; +@group(0) @binding(0) var compact_gid_from_isect: array; +@group(0) @binding(1) var global_from_compact_gid: array; +@group(0) @binding(2) var tile_offsets: array; +@group(0) @binding(3) var projected: array; +@group(0) @binding(4) var output: array; +@group(0) @binding(5) var v_output: array; #ifdef HARD_FLOAT - @group(0) @binding(7) var v_splats: array>; - @group(0) @binding(8) var v_opacs: array>; - @group(0) @binding(9) var v_refines: array>; + @group(0) @binding(6) var v_splats: array>; + @group(0) @binding(7) var v_opacs: array>; + @group(0) @binding(8) var v_refines: array>; + @group(0) @binding(9) var uniforms: helpers::RasterizeUniforms; fn write_grads_atomic(id: u32, grads: f32) { atomicAdd(&v_splats[id], grads); @@ -23,9 +23,10 @@ atomicAdd(&v_opacs[id], grads); } #else - @group(0) @binding(7) var v_splats: array>; - @group(0) @binding(8) var v_opacs: array>; - @group(0) @binding(9) var v_refines: array>; + @group(0) @binding(6) var v_splats: array>; + @group(0) @binding(7) var v_opacs: array>; + @group(0) @binding(8) var v_refines: array>; + @group(0) @binding(9) var uniforms: helpers::RasterizeUniforms; fn add_bitcast(cur: u32, add: f32) -> u32 { return bitcast(bitcast(cur) + add); diff --git a/crates/brush-render-bwd/src/tests.rs b/crates/brush-render-bwd/src/tests.rs deleted file mode 100644 index 655239c7..00000000 --- a/crates/brush-render-bwd/src/tests.rs +++ /dev/null @@ -1,119 +0,0 @@ -use crate::burn_glue::SplatForwardDiff; -use assert_approx_eq::assert_approx_eq; -use brush_render::{camera::Camera, gaussian_splats::SplatRenderMode}; -use burn::{ - backend::Autodiff, - tensor::{Distribution, Tensor, TensorPrimitive}, -}; -use burn_wgpu::{CubeBackend, WgpuDevice, WgpuRuntime}; -use glam::Vec3; - -type InnerBackend = CubeBackend; -type TestBackend = Autodiff; - -#[test] -fn diffs_at_all() { - // Check if backward pass doesn't hard crash or anything. - // These are some zero-sized gaussians, so we know - // what the result should look like. - let cam = Camera::new( - glam::vec3(0.0, 0.0, 0.0), - glam::Quat::IDENTITY, - 0.5, - 0.5, - glam::vec2(0.5, 0.5), - ); - let img_size = glam::uvec2(32, 32); - let device = WgpuDevice::DefaultDevice; - let num_points = 8; - let means = Tensor::::zeros([num_points, 3], &device); - let log_scales = Tensor::::ones([num_points, 3], &device) * 2.0; - let quats: Tensor = - Tensor::::from_floats(glam::Quat::IDENTITY.to_array(), &device) - .unsqueeze_dim(0) - .repeat_dim(0, num_points); - let sh_coeffs = Tensor::::ones([num_points, 1, 3], &device); - let raw_opacity = Tensor::::zeros([num_points], &device); - - let result = >::render_splats( - &cam, - img_size, - means.into_primitive().tensor(), - log_scales.into_primitive().tensor(), - quats.into_primitive().tensor(), - sh_coeffs.into_primitive().tensor(), - raw_opacity.into_primitive().tensor(), - SplatRenderMode::Default, - Vec3::ZERO, - ); - result.aux.validate_values(); - - let output: Tensor = Tensor::from_primitive(TensorPrimitive::Float(result.img)); - let rgb = output.clone().slice([0..32, 0..32, 0..3]); - let alpha = output.slice([0..32, 0..32, 3..4]); - let rgb_mean = rgb.mean().to_data().as_slice::().expect("Wrong type")[0]; - let alpha_mean = alpha - .mean() - .to_data() - .as_slice::() - .expect("Wrong type")[0]; - assert_approx_eq!(rgb_mean, 0.0, 1e-5); - assert_approx_eq!(alpha_mean, 0.0); -} - -#[test] -fn diffs_many_splats() { - // Test backward pass with a ton of splats to verify 2D dispatch works correctly. - // This exceeds the 1D 65535 * 256 = 16.7M limit. - let num_points = 30_000_000; - let cam = Camera::new( - glam::vec3(0.0, 0.0, -5.0), - glam::Quat::IDENTITY, - 0.5, - 0.5, - glam::vec2(0.5, 0.5), - ); - let img_size = glam::uvec2(64, 64); - let device = WgpuDevice::DefaultDevice; - - // Create random gaussians spread in front of the camera - let means = Tensor::::random( - [num_points, 3], - Distribution::Uniform(-2.0, 2.0), - &device, - ); - // Small scales so they don't cover everything - let log_scales = Tensor::::random( - [num_points, 3], - Distribution::Uniform(-4.0, -2.0), - &device, - ); - // Random rotations (will be normalized) - let quats = Tensor::::random( - [num_points, 4], - Distribution::Uniform(-1.0, 1.0), - &device, - ); - // Simple SH coefficients (just base color) - let sh_coeffs = Tensor::::random( - [num_points, 1, 3], - Distribution::Uniform(0.0, 1.0), - &device, - ); - // Some visible, some not - let raw_opacity = - Tensor::::random([num_points], Distribution::Uniform(-2.0, 2.0), &device); - - let result = >::render_splats( - &cam, - img_size, - means.into_primitive().tensor(), - log_scales.into_primitive().tensor(), - quats.into_primitive().tensor(), - sh_coeffs.into_primitive().tensor(), - raw_opacity.into_primitive().tensor(), - SplatRenderMode::Default, - Vec3::ZERO, - ); - result.aux.validate_values(); -} diff --git a/crates/brush-render/Cargo.toml b/crates/brush-render/Cargo.toml index ee6691c0..fdac12f4 100644 --- a/crates/brush-render/Cargo.toml +++ b/crates/brush-render/Cargo.toml @@ -32,5 +32,8 @@ ignored = ["bytemuck"] [features] debug-validation = [] +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt"] } + [lints] workspace = true diff --git a/crates/brush-render/src/burn_glue.rs b/crates/brush-render/src/burn_glue.rs index 62bfb1a3..fce4399e 100644 --- a/crates/brush-render/src/burn_glue.rs +++ b/crates/brush-render/src/burn_glue.rs @@ -1,4 +1,7 @@ -use burn::tensor::{DType, Shape, ops::FloatTensor}; +use burn::tensor::{ + DType, Shape, + ops::{FloatTensor, IntTensor}, +}; use burn_cubecl::{BoolElement, fusion::FusionCubeRuntime}; use burn_fusion::{ Fusion, FusionHandle, @@ -9,16 +12,17 @@ use burn_wgpu::WgpuRuntime; use glam::Vec3; use crate::{ - MainBackendBase, SplatForward, + MainBackendBase, RenderAux, SplatOps, camera::Camera, gaussian_splats::SplatRenderMode, - render::{calc_tile_bounds, max_intersections}, - render_aux::RenderAux, - shaders, + render::calc_tile_bounds, + render_aux::ProjectOutput, + sh::sh_degree_from_coeffs, + shaders::{self, helpers::ProjectUniforms}, }; -impl SplatForward for Fusion { - fn render_splats( +impl SplatOps for Fusion { + fn project( cam: &Camera, img_size: glam::UVec2, means: FloatTensor, @@ -27,16 +31,12 @@ impl SplatForward for Fusion { sh_coeffs: FloatTensor, opacity: FloatTensor, render_mode: SplatRenderMode, - background: Vec3, - bwd_info: bool, - ) -> (FloatTensor, RenderAux) { + ) -> ProjectOutput { #[derive(Debug)] struct CustomOp { cam: Camera, img_size: glam::UVec2, render_mode: SplatRenderMode, - bwd_info: bool, - background: Vec3, desc: CustomOpIr, } @@ -49,19 +49,13 @@ impl SplatForward for Fusion { let [means, log_scales, quats, sh_coeffs, opacity] = inputs; let [ - // Img - out_img, - // Aux projected_splats, - uniforms_buffer, - num_intersections, - tile_offsets, - compact_gid_from_isect, + num_visible, global_from_compact_gid, - visible, + cum_tiles_hit, ] = outputs; - let (img, aux) = MainBackendBase::render_splats( + let aux = MainBackendBase::project( &self.cam, self.img_size, h.get_float_tensor::(means), @@ -70,110 +64,76 @@ impl SplatForward for Fusion { h.get_float_tensor::(sh_coeffs), h.get_float_tensor::(opacity), self.render_mode, - self.background, - self.bwd_info, ); - // Register output. - h.register_float_tensor::(&out_img.id, img); + // Register outputs (project_uniforms is stored on ProjectAux directly) h.register_float_tensor::( &projected_splats.id, aux.projected_splats, ); - h.register_int_tensor::(&uniforms_buffer.id, aux.uniforms_buffer); - h.register_int_tensor::( - &num_intersections.id, - aux.num_intersections, - ); - h.register_int_tensor::(&tile_offsets.id, aux.tile_offsets); - h.register_int_tensor::( - &compact_gid_from_isect.id, - aux.compact_gid_from_isect, - ); + h.register_int_tensor::(&num_visible.id, aux.num_visible); h.register_int_tensor::( &global_from_compact_gid.id, aux.global_from_compact_gid, ); - - h.register_float_tensor::(&visible.id, aux.visible); + h.register_int_tensor::(&cum_tiles_hit.id, aux.cum_tiles_hit); } } let client = means.client.clone(); - let num_points = means.shape[0]; - - let proj_size = size_of::() / 4; - let uniforms_size = size_of::() / 4; + let sh_degree = sh_degree_from_coeffs(sh_coeffs.shape[1] as u32); let tile_bounds = calc_tile_bounds(img_size); - let max_intersects = max_intersections(img_size, num_points as u32); - // If render_u32_buffer is true, we render a packed buffer of u32 values, otherwise - // render RGBA f32 values. - let channels = if bwd_info { 4 } else { 1 }; - - let out_img = TensorIr::uninit( - client.create_empty_handle(), - Shape::new([img_size.y as usize, img_size.x as usize, channels]), - if bwd_info { DType::F32 } else { DType::U32 }, - ); - - let visible_shape = if bwd_info { - Shape::new([num_points]) - } else { - Shape::new([1]) - }; + let proj_size = size_of::() / 4; let projected_splats = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points, proj_size]), DType::F32, ); - let uniforms_buffer = TensorIr::uninit( - client.create_empty_handle(), - Shape::new([uniforms_size]), - DType::U32, - ); - let num_intersections = + let num_visible = TensorIr::uninit(client.create_empty_handle(), Shape::new([1]), DType::U32); - let tile_offsets = TensorIr::uninit( - client.create_empty_handle(), - Shape::new([tile_bounds.y as usize, tile_bounds.x as usize, 2]), - DType::U32, - ); - let compact_gid_from_isect = TensorIr::uninit( + let global_from_compact_gid = TensorIr::uninit( client.create_empty_handle(), - Shape::new([max_intersects as usize]), + Shape::new([num_points]), DType::U32, ); - let global_from_compact_gid = TensorIr::uninit( + let cum_tiles_hit = TensorIr::uninit( client.create_empty_handle(), Shape::new([num_points]), DType::U32, ); - let visible = TensorIr::uninit(client.create_empty_handle(), visible_shape, DType::F32); + + // Create project_uniforms from camera and img_size (stored on ProjectAux directly) + let project_uniforms = ProjectUniforms { + viewmat: glam::Mat4::from(cam.world_to_local()).to_cols_array_2d(), + camera_position: [cam.position.x, cam.position.y, cam.position.z, 0.0], + focal: cam.focal(img_size).into(), + pixel_center: cam.center(img_size).into(), + img_size: img_size.into(), + tile_bounds: tile_bounds.into(), + sh_degree, + total_splats: num_points as u32, + pad_a: 0, + pad_b: 0, + }; let input_tensors = [means, log_scales, quats, sh_coeffs, opacity]; let stream = OperationStreams::with_inputs(&input_tensors); let desc = CustomOpIr::new( - "render_splats", + "project_prepare", &input_tensors.map(|t| t.into_ir()), &[ - out_img, projected_splats, - uniforms_buffer, - num_intersections, - tile_offsets, - compact_gid_from_isect, + num_visible, global_from_compact_gid, - visible, + cum_tiles_hit, ], ); let op = CustomOp { cam: cam.clone(), img_size, - bwd_info, - background, render_mode, desc: desc.clone(), }; @@ -183,30 +143,148 @@ impl SplatForward for Fusion { .outputs(); let [ - // Img - out_img, - // Aux projected_splats, - uniforms_buffer, - num_intersections, - tile_offsets, - compact_gid_from_isect, + num_visible, global_from_compact_gid, - visible, + cum_tiles_hit, ] = outputs; + ProjectOutput:: { + projected_splats, + project_uniforms, + num_visible, + global_from_compact_gid, + cum_tiles_hit, + img_size, + } + } + + fn rasterize( + project_output: &ProjectOutput, + num_intersections: u32, + background: Vec3, + bwd_info: bool, + ) -> (FloatTensor, RenderAux, IntTensor) { + #[derive(Debug)] + struct CustomOp { + img_size: glam::UVec2, + num_intersections: u32, + background: Vec3, + bwd_info: bool, + project_uniforms: ProjectUniforms, + desc: CustomOpIr, + } + + impl Operation> for CustomOp { + fn execute( + &self, + h: &mut HandleContainer>>, + ) { + let (inputs, outputs) = self.desc.as_fixed(); + + let [ + projected_splats, + num_visible, + global_from_compact_gid, + cum_tiles_hit, + ] = inputs; + let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs; + + let inner_output = ProjectOutput:: { + projected_splats: h.get_float_tensor::(projected_splats), + project_uniforms: self.project_uniforms, + num_visible: h.get_int_tensor::(num_visible), + global_from_compact_gid: h + .get_int_tensor::(global_from_compact_gid), + cum_tiles_hit: h.get_int_tensor::(cum_tiles_hit), + img_size: self.img_size, + }; + + let (img, aux, compact_gid) = MainBackendBase::rasterize( + &inner_output, + self.num_intersections, + self.background, + self.bwd_info, + ); + + // Register outputs + h.register_float_tensor::(&out_img.id, img); + h.register_int_tensor::(&tile_offsets.id, aux.tile_offsets); + h.register_int_tensor::(&compact_gid_from_isect.id, compact_gid); + h.register_float_tensor::(&visible.id, aux.visible); + } + } + + let client = project_output.projected_splats.client.clone(); + let img_size = project_output.img_size; + let tile_bounds = calc_tile_bounds(img_size); + + let num_points = project_output.projected_splats.shape[0]; + + let channels = if bwd_info { 4 } else { 1 }; + let out_img = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([img_size.y as usize, img_size.x as usize, channels]), + if bwd_info { DType::F32 } else { DType::U32 }, + ); + + let tile_offsets = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([tile_bounds.y as usize, tile_bounds.x as usize, 2]), + DType::U32, + ); + + // Use actual num_intersections for buffer size + let compact_gid_from_isect = TensorIr::uninit( + client.create_empty_handle(), + Shape::new([num_intersections.max(1) as usize]), + DType::U32, + ); + + let visible_shape = if bwd_info { + Shape::new([num_points]) + } else { + Shape::new([1]) + }; + let visible = TensorIr::uninit(client.create_empty_handle(), visible_shape, DType::F32); + + let input_tensors = [ + project_output.projected_splats.clone(), + project_output.num_visible.clone(), + project_output.global_from_compact_gid.clone(), + project_output.cum_tiles_hit.clone(), + ]; + let stream = OperationStreams::with_inputs(&input_tensors); + let desc = CustomOpIr::new( + "rasterize", + &input_tensors.map(|t| t.into_ir()), + &[out_img, tile_offsets, compact_gid_from_isect, visible], + ); + let op = CustomOp { + img_size, + num_intersections, + background, + bwd_info, + project_uniforms: project_output.project_uniforms, + desc: desc.clone(), + }; + + let outputs = client + .register(stream, OperationIr::Custom(desc), op) + .outputs(); + + let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs; + ( out_img, RenderAux:: { - projected_splats, - uniforms_buffer, + num_visible: project_output.num_visible.clone(), num_intersections, - tile_offsets, - compact_gid_from_isect, - global_from_compact_gid, visible, + tile_offsets, img_size, }, + compact_gid_from_isect, ) } } diff --git a/crates/brush-render/src/gaussian_splats.rs b/crates/brush-render/src/gaussian_splats.rs index 16085311..e9fa41c8 100644 --- a/crates/brush-render/src/gaussian_splats.rs +++ b/crates/brush-render/src/gaussian_splats.rs @@ -9,9 +9,8 @@ use glam::Vec3; use tracing::trace_span; use crate::{ - SplatForward, + RenderAux, SplatOps, camera::Camera, - render_aux::RenderAux, sh::{sh_coeffs_for_degree, sh_degree_from_coeffs}, }; @@ -24,6 +23,13 @@ pub enum SplatRenderMode { Mip, } +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +pub enum TextureMode { + Packed, + #[default] + Float, +} + #[derive(Module, Debug)] pub struct Splats { pub means: Param>, @@ -234,37 +240,48 @@ impl Splats { /// /// NB: This doesn't work on a differentiable backend. Use /// [`brush_render_bwd::render_splats`] for that. -pub fn render_splats>( - splats: &Splats, +/// +/// Takes ownership of the splats to avoid cloning internally. +pub async fn render_splats>( + splats: Splats, camera: &Camera, img_size: glam::UVec2, background: Vec3, splat_scale: Option, + texture_mode: TextureMode, ) -> (Tensor, RenderAux) { splats.validate_values(); - let mut scales = splats.log_scales.val(); + let mut scales = splats.log_scales.into_value(); - // Add in scaling if needed. if let Some(scale) = splat_scale { scales = scales + scale.ln(); }; - let (img, aux) = B::render_splats( + let project_output = B::project( camera, img_size, - splats.means.val().into_primitive().tensor(), + splats.means.into_value().into_primitive().tensor(), scales.into_primitive().tensor(), - splats.rotations.val().into_primitive().tensor(), - splats.sh_coeffs.val().into_primitive().tensor(), - splats.raw_opacities.val().into_primitive().tensor(), + splats.rotations.into_value().into_primitive().tensor(), + splats.sh_coeffs.into_value().into_primitive().tensor(), + splats.raw_opacities.into_value().into_primitive().tensor(), splats.render_mode, - background, - false, ); - let img = Tensor::from_primitive(TensorPrimitive::Float(img)); - aux.validate_values(); + project_output.validate(); + + // Async readback + let num_intersections = project_output.read_num_intersections().await; + + let use_float = matches!(texture_mode, TextureMode::Float); + let (out_img, render_aux, _) = + B::rasterize(&project_output, num_intersections, background, use_float); + + render_aux.validate(); - (img, aux) + ( + Tensor::from_primitive(TensorPrimitive::Float(out_img)), + render_aux, + ) } diff --git a/crates/brush-render/src/get_tile_offset.rs b/crates/brush-render/src/get_tile_offset.rs index d16561bf..6223f277 100644 --- a/crates/brush-render/src/get_tile_offset.rs +++ b/crates/brush-render/src/get_tile_offset.rs @@ -7,30 +7,6 @@ use burn_cubecl::cubecl::prelude::{ pub(crate) const CHECKS_PER_ITER: u32 = 8; -#[cube] -fn check_tile_boundary( - tile_id_from_isect: &Tensor, - tile_offsets: &mut Tensor, - isect_id: u32, - inter: u32, -) { - if isect_id < inter { - let prev_tid = tile_id_from_isect[isect_id - 1]; - let tid = tile_id_from_isect[isect_id]; - - if isect_id == inter - 1 { - // Write the end of the last tile. - tile_offsets[tid * 2 + 1] = isect_id + 1; - } - if tid != prev_tid { - // Write the end of the previous tile. - tile_offsets[prev_tid * 2 + 1] = isect_id; - // Write start of this tile. - tile_offsets[tid * 2] = isect_id; - } - } -} - #[cube(launch_unchecked)] pub fn get_tile_offsets( tile_id_from_isect: &Tensor, @@ -45,6 +21,22 @@ pub fn get_tile_offsets( #[unroll] for i in 0..CHECKS_PER_ITER { - check_tile_boundary(tile_id_from_isect, tile_offsets, base_id + i, inter); + let isect_id = base_id + i; + if isect_id < inter { + let prev_tid = tile_id_from_isect[isect_id - 1]; + let tid = tile_id_from_isect[isect_id]; + + if isect_id == inter - 1 { + // Write the end of the last tile. + tile_offsets[tid * 2 + 1] = isect_id + 1; + } + + if tid != prev_tid { + // Write the end of the previous tile. + tile_offsets[prev_tid * 2 + 1] = isect_id; + // Write start of this tile. + tile_offsets[tid * 2] = isect_id; + } + } } } diff --git a/crates/brush-render/src/lib.rs b/crates/brush-render/src/lib.rs index 7ef89d2c..7eda03da 100644 --- a/crates/brush-render/src/lib.rs +++ b/crates/brush-render/src/lib.rs @@ -1,17 +1,18 @@ #![recursion_limit = "256"] use burn::prelude::Backend; -use burn::tensor::ops::FloatTensor; +use burn::tensor::ops::{FloatTensor, IntTensor}; use burn_cubecl::CubeBackend; use burn_fusion::Fusion; use burn_wgpu::WgpuRuntime; use camera::Camera; use clap::ValueEnum; use glam::Vec3; -use render_aux::RenderAux; +use render_aux::ProjectOutput; use crate::gaussian_splats::SplatRenderMode; -pub use crate::gaussian_splats::render_splats; +pub use crate::gaussian_splats::{TextureMode, render_splats}; +pub use crate::render_aux::RenderAux; mod burn_glue; mod dim_check; @@ -33,28 +34,17 @@ pub mod validation; pub type MainBackendBase = CubeBackend; pub type MainBackend = Fusion; -#[derive(Debug, Clone)] -pub struct RenderStats { - pub num_visible: u32, - pub num_intersections: u32, -} - -// The maximum number of intersections that can be rendered. -// -// With 2D dispatch support, we can now handle more than the original 65535 workgroup limit. -// Doubled from the original 512 * 65535 to allow higher resolution rendering. -const INTERSECTS_UPPER_BOUND: u32 = 2 * 512 * 65535; - -pub trait SplatForward { - /// Render splats to a buffer. - /// - /// This projects the gaussians, sorts them, and rasterizes them to a buffer, in a - /// differentiable way. - /// The arguments are all passed as raw tensors. See [`Splats`] for a convenient Module that wraps this fun - /// The [`xy_grad_dummy`] variable is only used to carry screenspace xy gradients. - /// This function can optionally render a "u32" buffer, which is a packed RGBA (8 bits per channel) - /// buffer. This is useful when the results need to be displayed immediately. - fn render_splats( +/// Trait for the the gaussian splatting rendering pipeline. +/// +/// This trait provides two passes: +/// 1. `project`: Culling, depth sort, projection, intersection counting, prefix sum. +/// 2. `rasterize`: Intersection filling, tile sort, tile offsets, rasterization. +/// +/// The split allows for an explicit GPU sync point between passes to read back +/// the exact number of intersections needed for buffer allocation. +pub trait SplatOps { + /// First pass: project gaussians and count intersections. + fn project( camera: &Camera, img_size: glam::UVec2, means: FloatTensor, @@ -63,9 +53,15 @@ pub trait SplatForward { sh_coeffs: FloatTensor, raw_opacities: FloatTensor, render_mode: SplatRenderMode, + ) -> ProjectOutput; + + /// Second pass: rasterize using projection data. + fn rasterize( + project_output: &ProjectOutput, + num_intersections: u32, background: Vec3, bwd_info: bool, - ) -> (FloatTensor, RenderAux); + ) -> (FloatTensor, RenderAux, IntTensor); } #[derive( diff --git a/crates/brush-render/src/render.rs b/crates/brush-render/src/render.rs index 97dad971..b4894e0c 100644 --- a/crates/brush-render/src/render.rs +++ b/crates/brush-render/src/render.rs @@ -1,31 +1,29 @@ use crate::{ - INTERSECTS_UPPER_BOUND, MainBackendBase, SplatForward, + MainBackendBase, RenderAux, SplatOps, camera::Camera, dim_check::DimCheck, gaussian_splats::SplatRenderMode, get_tile_offset::{CHECKS_PER_ITER, get_tile_offsets}, - render_aux::RenderAux, + render_aux::ProjectOutput, sh::sh_degree_from_coeffs, shaders::{self, MapGaussiansToIntersect, ProjectSplats, ProjectVisible, Rasterize}, }; +use brush_kernel::bytemuck; use brush_kernel::create_dispatch_buffer_1d; +use brush_kernel::create_meta_binding; use brush_kernel::create_tensor; -use brush_kernel::create_uniform_buffer; use brush_kernel::{CubeCount, calc_cube_count_1d}; use brush_prefix_sum::prefix_sum; use brush_sort::radix_argsort; -use burn::tensor::{DType, IntDType, ops::FloatTensor}; +use burn::tensor::{DType, IntDType, Shape, ops::FloatTensor}; use burn::tensor::{ FloatDType, - ops::{FloatTensorOps, IntTensorOps}, + ops::{FloatTensorOps, IntTensor, IntTensorOps}, }; use burn_cubecl::cubecl::server::Bindings; - use burn_cubecl::kernel::into_contiguous; -use burn_wgpu::CubeDim; -use burn_wgpu::WgpuRuntime; +use burn_wgpu::{CubeDim, CubeTensor, WgpuRuntime}; use glam::{Vec3, uvec2}; -use std::mem::offset_of; pub(crate) fn calc_tile_bounds(img_size: glam::UVec2) -> glam::UVec2 { uvec2( @@ -34,26 +32,8 @@ pub(crate) fn calc_tile_bounds(img_size: glam::UVec2) -> glam::UVec2 { ) } -// On wasm, we cannot do a sync readback at all. -// Instead, can just estimate a max number of intersects. All the kernels only handle the actual -// number of intersects, and spin up empty threads for the rest atm. In the future, could use indirect -// dispatch to avoid this. -// Estimating the max number of intersects can be a bad hack though... The worst case scenario is so massive -// that it's easy to run out of memory... How do we actually properly deal with this :/ -pub fn max_intersections(img_size: glam::UVec2, num_splats: u32) -> u32 { - // Divide screen into tiles. - let tile_bounds = calc_tile_bounds(img_size); - // Assume on average each splat is maximally covering half x half the screen, - // and adjust for the variance such that we're fairly certain we have enough intersections. - let num_tiles = tile_bounds[0] * tile_bounds[1]; - let max_possible = num_tiles.saturating_mul(num_splats); - // clamp to max nr. of dispatches. - max_possible.min(INTERSECTS_UPPER_BOUND) -} - -// Implement forward functions for the inner wgpu backend. -impl SplatForward for MainBackendBase { - fn render_splats( +impl SplatOps for MainBackendBase { + fn project( camera: &Camera, img_size: glam::UVec2, means: FloatTensor, @@ -62,9 +42,7 @@ impl SplatForward for MainBackendBase { sh_coeffs: FloatTensor, raw_opacities: FloatTensor, render_mode: SplatRenderMode, - background: Vec3, - bwd_info: bool, - ) -> (FloatTensor, RenderAux) { + ) -> ProjectOutput { assert!( img_size[0] > 0 && img_size[1] > 0, "Can't render images with 0 size." @@ -78,9 +56,8 @@ impl SplatForward for MainBackendBase { let raw_opacities = into_contiguous(raw_opacities); let device = &means.device.clone(); - let client = means.client.clone(); - let _span = tracing::trace_span!("render_forward").entered(); + let _span = tracing::trace_span!("project_prepare").entered(); // Check whether input dimensions are valid. DimCheck::new() @@ -93,24 +70,11 @@ impl SplatForward for MainBackendBase { // Divide screen into tiles. let tile_bounds = calc_tile_bounds(img_size); - // A note on some confusing naming that'll be used throughout this function: - // Gaussians are stored in various states of buffers, eg. at the start they're all in one big buffer, - // then we sparsely store some results, then sort gaussian based on depths, etc. - // Overall this means there's lots of indices flying all over the place, and it's hard to keep track - // what is indexing what. So, for some sanity, try to match a few "gaussian ids" (gid) variable names. - // - Global Gaussian ID - global_gid - // - Compacted Gaussian ID - compact_gid - // - Per tile intersection depth sorted ID - tiled_gid - // - Sorted by tile per tile intersection depth sorted ID - sorted_tiled_gid - // Then, various buffers map between these, which are named x_from_y_gid, eg. - // global_from_compact_gid. - // Tile rendering setup. let sh_degree = sh_degree_from_coeffs(sh_coeffs.shape.dims[1] as u32); let total_splats = means.shape.dims[0]; - let max_intersects = max_intersections(img_size, total_splats as u32); - let uniforms = shaders::helpers::RenderUniforms { + let project_uniforms = shaders::helpers::ProjectUniforms { viewmat: glam::Mat4::from(camera.world_to_local()).to_cols_array_2d(), camera_position: [camera.position.x, camera.position.y, camera.position.z, 0.0], focal: camera.focal(img_size).into(), @@ -119,21 +83,18 @@ impl SplatForward for MainBackendBase { tile_bounds: tile_bounds.into(), sh_degree, total_splats: total_splats as u32, - max_intersects, - background: [background.x, background.y, background.z, 1.0], - // Nb: Bit of a hack as these aren't _really_ uniforms but are written to by the shaders. - num_visible: 0, + pad_a: 0, + pad_b: 0, }; - // Nb: This contains both static metadata and some dynamic data so can't pass this as metadata to execute. In the future - // should separate the two. - let uniforms_buffer = create_uniform_buffer(uniforms, device, &client); + // Separate buffer for num_visible (written atomically by ProjectSplats) + let num_visible_buffer = Self::int_zeros([1].into(), device, IntDType::U32); let client = &means.client.clone(); - let mip_splat = matches!(render_mode, SplatRenderMode::Mip); - let (global_from_compact_gid, num_visible) = { + // Step 1: ProjectSplats - culling pass + let global_from_compact_gid = { let global_from_presort_gid = Self::int_zeros([total_splats].into(), device, IntDType::U32); let depths = create_tensor([total_splats], device, DType::F32); @@ -144,204 +105,211 @@ impl SplatForward for MainBackendBase { client.launch_unchecked( ProjectSplats::task(mip_splat), calc_cube_count_1d(total_splats as u32, ProjectSplats::WORKGROUP_SIZE[0]), - Bindings::new().with_buffers( - vec![ - uniforms_buffer.handle.clone().binding(), - means.handle.clone().binding(), - quats.handle.clone().binding(), - log_scales.handle.clone().binding(), - raw_opacities.handle.clone().binding(), - global_from_presort_gid.handle.clone().binding(), - depths.handle.clone().binding(), - ]), + Bindings::new() + .with_buffers(vec![ + means.handle.clone().binding(), + quats.handle.clone().binding(), + log_scales.handle.clone().binding(), + raw_opacities.handle.clone().binding(), + global_from_presort_gid.handle.clone().binding(), + depths.handle.clone().binding(), + num_visible_buffer.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(project_uniforms)), ).expect("Failed to render splats"); }); - // Get just the number of visible splats from the uniforms buffer. - let num_vis_field_offset = - offset_of!(shaders::helpers::RenderUniforms, num_visible) / 4; - let num_visible = Self::int_slice( - uniforms_buffer.clone(), - &[(num_vis_field_offset..num_vis_field_offset + 1).into()], - ); - let (_, global_from_compact_gid) = tracing::trace_span!("DepthSort").in_scope(|| { - // Interpret the depth as a u32. This is fine for a radix sort, as long as the depth > 0.0, - // which we know to be the case given how we cull splats. - radix_argsort(depths, global_from_presort_gid, &num_visible, 32) + radix_argsort( + depths, + global_from_presort_gid, + 32, + Some(num_visible_buffer.clone()), + ) }); - (global_from_compact_gid, num_visible) + global_from_compact_gid }; - // Create a buffer of 'projected' splats, that is, - // project XY, projected conic, and converted color. let proj_size = size_of::() / size_of::(); let projected_splats = create_tensor([total_splats, proj_size], device, DType::F32); + let splat_intersect_counts = Self::int_zeros([total_splats].into(), device, IntDType::U32); - tracing::trace_span!("ProjectVisible").in_scope(|| { - // Create a buffer to determine how many threads to dispatch for all visible splats. - let num_vis_wg = - create_dispatch_buffer_1d(num_visible.clone(), ProjectVisible::WORKGROUP_SIZE[0]); + tracing::trace_span!("ProjectVisibleWithCounting").in_scope(|| { + let num_vis_wg = create_dispatch_buffer_1d( + num_visible_buffer.clone(), + ProjectVisible::WORKGROUP_SIZE[0], + ); // SAFETY: Kernel checked to have no OOB, bounded loops. unsafe { client .launch_unchecked( ProjectVisible::task(mip_splat), CubeCount::Dynamic(num_vis_wg.handle.binding()), - Bindings::new().with_buffers(vec![ - uniforms_buffer.clone().handle.binding(), - means.handle.binding(), - log_scales.handle.binding(), - quats.handle.binding(), - sh_coeffs.handle.binding(), - raw_opacities.handle.binding(), - global_from_compact_gid.handle.clone().binding(), - projected_splats.handle.clone().binding(), - ]), + Bindings::new() + .with_buffers(vec![ + num_visible_buffer.handle.clone().binding(), + means.handle.binding(), + log_scales.handle.binding(), + quats.handle.binding(), + sh_coeffs.handle.binding(), + raw_opacities.handle.binding(), + global_from_compact_gid.handle.clone().binding(), + projected_splats.handle.clone().binding(), + splat_intersect_counts.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(project_uniforms)), ) .expect("Failed to render splats"); } }); - // Each intersection maps to a gaussian. - let (tile_offsets, compact_gid_from_isect, num_intersections) = { - let num_tiles = tile_bounds.x * tile_bounds.y; - - let splat_intersect_counts = - Self::int_zeros([total_splats + 1].into(), device, IntDType::U32); - - let num_vis_map_wg = - create_dispatch_buffer_1d(num_visible, MapGaussiansToIntersect::WORKGROUP_SIZE[0]); - - // First do a prepass to compute the tile counts, then fill in intersection counts. - tracing::trace_span!("MapGaussiansToIntersectPrepass").in_scope(|| { - // SAFETY: Kernel checked to have no OOB, bounded loops. - unsafe { - client - .launch_unchecked( - MapGaussiansToIntersect::task(true), - CubeCount::Dynamic(num_vis_map_wg.handle.clone().binding()), - Bindings::new().with_buffers(vec![ - uniforms_buffer.handle.clone().binding(), - projected_splats.handle.clone().binding(), - splat_intersect_counts.handle.clone().binding(), - ]), - ) - .expect("Failed to render splats"); - } - }); + let cum_tiles_hit = tracing::trace_span!("PrefixSumGaussHits") + .in_scope(|| prefix_sum(splat_intersect_counts)); - // TODO: Only need to do this up to num_visible gaussians really. - let cum_tiles_hit = tracing::trace_span!("PrefixSumGaussHits") - .in_scope(|| prefix_sum(splat_intersect_counts)); - - let tile_id_from_isect = create_tensor([max_intersects as usize], device, DType::U32); - let compact_gid_from_isect = - create_tensor([max_intersects as usize], device, DType::U32); - - // Zero this out, as the kernel _might_ not run at all if no gaussians are visible. - let num_intersections = Self::int_zeros([1].into(), device, IntDType::U32); - - tracing::trace_span!("MapGaussiansToIntersect").in_scope(|| { - // SAFETY: Kernel checked to have no OOB, bounded loops. - unsafe { - client - .launch_unchecked( - MapGaussiansToIntersect::task(false), - CubeCount::Dynamic(num_vis_map_wg.handle.clone().binding()), - Bindings::new().with_buffers(vec![ - uniforms_buffer.handle.clone().binding(), - projected_splats.handle.clone().binding(), - cum_tiles_hit.handle.binding(), - tile_id_from_isect.handle.clone().binding(), - compact_gid_from_isect.handle.clone().binding(), - num_intersections.handle.clone().binding(), - ]), - ) - .expect("Failed to render splats"); - } - }); + ProjectOutput { + projected_splats, + project_uniforms, + num_visible: num_visible_buffer, + global_from_compact_gid, + cum_tiles_hit, + img_size, + } + } - // We're sorting by tile ID, but we know beforehand what the maximum value - // can be. We don't need to sort all the leading 0 bits! - let bits = u32::BITS - num_tiles.leading_zeros(); - - let (tile_id_from_isect, compact_gid_from_isect) = tracing::trace_span!("Tile sort") - .in_scope(|| { - radix_argsort( - tile_id_from_isect, - compact_gid_from_isect, - &num_intersections, - bits, - ) - }); - - let cube_dim = CubeDim::new_1d(256); - let num_vis_map_wg = - create_dispatch_buffer_1d(num_intersections.clone(), 256 * CHECKS_PER_ITER); - let cube_count = CubeCount::Dynamic(num_vis_map_wg.handle.binding()); - - // Tiles without splats will be written as having a range of [0, 0]. - let tile_offsets = Self::int_zeros( - [tile_bounds.y as usize, tile_bounds.x as usize, 2].into(), - device, - IntDType::U32, - ); + fn rasterize( + project_output: &ProjectOutput, + num_intersections: u32, + background: Vec3, + bwd_info: bool, + ) -> (FloatTensor, RenderAux, IntTensor) { + let _span = tracing::trace_span!("rasterize").entered(); + + let device = &project_output.projected_splats.device.clone(); + let client = project_output.projected_splats.client.clone(); + let img_size = project_output.img_size; + + // Divide screen into tiles. + let tile_bounds = calc_tile_bounds(img_size); + let num_tiles = tile_bounds.x * tile_bounds.y; + + let rasterize_uniforms = shaders::helpers::RasterizeUniforms { + tile_bounds: tile_bounds.into(), + img_size: img_size.into(), + background: [background.x, background.y, background.z, 1.0], + }; - // SAFETY: Safe kernel. + // Step 1: Allocate intersection buffers with exact size (minimum 1 to avoid zero-size allocation) + let buffer_size = (num_intersections as usize).max(1); + let tile_id_from_isect = create_tensor([buffer_size], device, DType::U32); + let compact_gid_from_isect = create_tensor([buffer_size], device, DType::U32); + + // Step 2: MapGaussiansToIntersect (fill pass) + let num_vis_map_wg = create_dispatch_buffer_1d( + project_output.num_visible.clone(), + MapGaussiansToIntersect::WORKGROUP_SIZE[0], + ); + + let map_uniforms = shaders::map_gaussians_to_intersect::Uniforms { + tile_bounds: tile_bounds.into(), + }; + + tracing::trace_span!("MapGaussiansToIntersect").in_scope(|| { + // SAFETY: Kernel checked to have no OOB, bounded loops. unsafe { - get_tile_offsets::launch_unchecked::( - client, - cube_count, - cube_dim, - tile_id_from_isect.as_tensor_arg(1), - tile_offsets.as_tensor_arg(1), - num_intersections.as_tensor_arg(1), - ) - .expect("Failed to render splats"); + client + .launch_unchecked( + MapGaussiansToIntersect::task(), + CubeCount::Dynamic(num_vis_map_wg.handle.clone().binding()), + Bindings::new() + .with_buffers(vec![ + project_output.num_visible.handle.clone().binding(), + project_output.projected_splats.handle.clone().binding(), + project_output.cum_tiles_hit.handle.clone().binding(), + tile_id_from_isect.handle.clone().binding(), + compact_gid_from_isect.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(map_uniforms)), + ) + .expect("Failed to render splats"); } + }); - (tile_offsets, compact_gid_from_isect, num_intersections) - }; + let bits = u32::BITS - num_tiles.leading_zeros(); + let (tile_id_from_isect, compact_gid_from_isect) = tracing::trace_span!("Tile sort") + .in_scope(|| radix_argsort(tile_id_from_isect, compact_gid_from_isect, bits, None)); - let _span = tracing::trace_span!("Rasterize").entered(); + let cube_dim = CubeDim::new_1d(256); + let tile_offsets = Self::int_zeros( + [tile_bounds.y as usize, tile_bounds.x as usize, 2].into(), + device, + IntDType::U32, + ); - let out_dim = if bwd_info { - 4 - } else { - // Channels are packed into 4 bytes, aka one float. - 1 + // Create a tensor for num_intersections + let num_inter_tensor = { + let data: [u32; 1] = [num_intersections]; + CubeTensor::new_contiguous( + client.clone(), + device.clone(), + Shape::new([1]), + client.create_from_slice(bytemuck::cast_slice(&data)), + DType::U32, + ) }; + // SAFETY: Safe kernel. + unsafe { + get_tile_offsets::launch_unchecked::( + &client, + calc_cube_count_1d(num_intersections, cube_dim.x * CHECKS_PER_ITER), + cube_dim, + tile_id_from_isect.as_tensor_arg(1), + tile_offsets.as_tensor_arg(1), + num_inter_tensor.as_tensor_arg(1), + ) + .expect("Failed to render splats"); + } + + let out_dim = if bwd_info { 4 } else { 1 }; let out_img = create_tensor( [img_size.y as usize, img_size.x as usize, out_dim], device, DType::F32, ); - let mut bindings = Bindings::new().with_buffers(vec![ - uniforms_buffer.handle.clone().binding(), - compact_gid_from_isect.handle.clone().binding(), - tile_offsets.handle.clone().binding(), - projected_splats.handle.clone().binding(), - out_img.handle.clone().binding(), - ]); + // Get total_splats from the shape of projected_splats + let total_splats = project_output.projected_splats.shape.dims[0]; - let visible = if bwd_info { + let (bindings, visible) = if bwd_info { let visible = Self::float_zeros([total_splats].into(), device, FloatDType::F32); - // Add the buffer to the bindings - bindings = bindings.with_buffers(vec![ - global_from_compact_gid.handle.clone().binding(), - visible.handle.clone().binding(), - ]); - visible + let bindings = Bindings::new() + .with_buffers(vec![ + compact_gid_from_isect.handle.clone().binding(), + tile_offsets.handle.clone().binding(), + project_output.projected_splats.handle.clone().binding(), + out_img.handle.clone().binding(), + project_output + .global_from_compact_gid + .handle + .clone() + .binding(), + visible.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(rasterize_uniforms)); + (bindings, visible) } else { - create_tensor([1], device, DType::F32) + let bindings = Bindings::new() + .with_buffers(vec![ + compact_gid_from_isect.handle.clone().binding(), + tile_offsets.handle.clone().binding(), + project_output.projected_splats.handle.clone().binding(), + out_img.handle.clone().binding(), + ]) + .with_metadata(create_meta_binding(rasterize_uniforms)); + (bindings, create_tensor([1], device, DType::F32)) }; - // Compile the kernel, including/excluding info for backwards pass. - // see the BWD_INFO define in the rasterize shader. let raster_task = Rasterize::task(bwd_info); // SAFETY: Kernel checked to have no OOB, bounded loops. @@ -355,41 +323,16 @@ impl SplatForward for MainBackendBase { .expect("Failed to render splats"); } - // Sanity check the buffers. - assert!( - uniforms_buffer.is_contiguous(), - "Uniforms must be contiguous" - ); - assert!( - tile_offsets.is_contiguous(), - "Tile offsets must be contiguous" - ); - assert!( - global_from_compact_gid.is_contiguous(), - "Global from compact gid must be contiguous" - ); - assert!(visible.is_contiguous(), "Visible must be contiguous"); - assert!( - projected_splats.is_contiguous(), - "Projected splats must be contiguous" - ); - assert!( - num_intersections.is_contiguous(), - "Num intersections must be contiguous" - ); - ( out_img, RenderAux { - uniforms_buffer, - tile_offsets, + num_visible: project_output.num_visible.clone(), num_intersections, - projected_splats, - compact_gid_from_isect, - global_from_compact_gid, visible, - img_size, + tile_offsets, + img_size: project_output.img_size, }, + compact_gid_from_isect, ) } } diff --git a/crates/brush-render/src/render_aux.rs b/crates/brush-render/src/render_aux.rs index 3a9d3588..4c879326 100644 --- a/crates/brush-render/src/render_aux.rs +++ b/crates/brush-render/src/render_aux.rs @@ -1,167 +1,156 @@ -use std::mem::offset_of; - +use burn::tensor::ElementConversion; use burn::{ Tensor, prelude::Backend, tensor::{ Int, ops::{FloatTensor, IntTensor}, - s, }, }; -use crate::shaders::{self, helpers::TILE_WIDTH}; +use crate::shaders::helpers::ProjectUniforms; +/// Output from the project pass. Consumed by rasterize. #[derive(Debug, Clone)] -pub struct RenderAux { - /// The packed projected splat information, see `ProjectedSplat` in helpers.wgsl +pub struct ProjectOutput { + pub project_uniforms: ProjectUniforms, pub projected_splats: FloatTensor, - pub uniforms_buffer: IntTensor, - pub num_intersections: IntTensor, - pub tile_offsets: IntTensor, - pub compact_gid_from_isect: IntTensor, + pub num_visible: IntTensor, pub global_from_compact_gid: IntTensor, - pub visible: FloatTensor, + pub cum_tiles_hit: IntTensor, pub img_size: glam::UVec2, } -impl RenderAux { - pub fn calc_tile_depth(&self) -> Tensor { - let tile_offsets: Tensor = Tensor::from_primitive(self.tile_offsets.clone()); - let max = tile_offsets.clone().slice(s![.., .., 1]); - let min = tile_offsets.slice(s![.., .., 0]); - let [w, h] = self.img_size.into(); - let [ty, tx] = [h.div_ceil(TILE_WIDTH), w.div_ceil(TILE_WIDTH)]; - (max - min).reshape([ty as usize, tx as usize]) - } - - pub fn num_intersections(&self) -> Tensor { - Tensor::from_primitive(self.num_intersections.clone()) - } - - pub fn num_visible(&self) -> Tensor { - let num_vis_field_offset = offset_of!(shaders::helpers::RenderUniforms, num_visible) / 4; - Tensor::from_primitive(self.uniforms_buffer.clone()).slice(s![num_vis_field_offset]) +impl ProjectOutput { + /// Get the total number of intersections. + pub async fn read_num_intersections(&self) -> u32 { + let cum_tiles_hit: Tensor = Tensor::from_primitive(self.cum_tiles_hit.clone()); + let total = self.project_uniforms.total_splats as usize; + if total > 0 { + cum_tiles_hit + .slice([total - 1..total]) + .into_scalar_async() + .await + .expect("Failed to read num_intersections") + .elem::() + } else { + 0 + } } - pub fn validate_values(&self) { + /// Validate project outputs. Call before consuming. + pub fn validate(&self) { #[cfg(any(test, feature = "debug-validation"))] { - use burn::tensor::{ElementConversion, TensorPrimitive}; - - use crate::{ - INTERSECTS_UPPER_BOUND, render::max_intersections, validation::validate_tensor_val, - }; - if std::env::args().any(|a| a == "--bench") { return; } - let num_intersects: Tensor = self.num_intersections(); - let compact_gid_from_isect: Tensor = - Tensor::from_primitive(self.compact_gid_from_isect.clone()); - let num_visible: Tensor = self.num_visible(); - - let num_intersections = num_intersects.into_scalar().elem::(); - let num_points = compact_gid_from_isect.dims()[0] as u32; - let num_visible = num_visible.into_scalar().elem::() as u32; - let img_size = self.img_size; + use crate::validation::validate_tensor_val; + use burn::tensor::{ElementConversion, TensorPrimitive, s}; - let max_intersects = max_intersections(img_size, num_points); + let num_visible_tensor: Tensor = + Tensor::from_primitive(self.num_visible.clone()); + let total_splats = self.project_uniforms.total_splats; + let num_visible = num_visible_tensor.into_scalar().elem::() as u32; assert!( - num_intersections < max_intersects as i32, - "Too many intersections, estimated too low of a number. {num_intersections} / {max_intersects}" + num_visible <= total_splats, + "num_visible ({num_visible}) > total_splats ({total_splats})" ); - assert!( - num_intersections < INTERSECTS_UPPER_BOUND as i32, - "Too many intersections, Brush currently can't handle this. {num_intersections} > {INTERSECTS_UPPER_BOUND}" - ); - - assert!( - num_visible <= num_points, - "Something went wrong when calculating the number of visible gaussians. {num_visible} > {num_points}" - ); - - // Projected splats is only valid up to num_visible and undefined for other values. - if num_visible > 0 { - use crate::validation::validate_tensor_val; - + if total_splats > 0 && num_visible > 0 { let projected_splats: Tensor = Tensor::from_primitive(TensorPrimitive::Float(self.projected_splats.clone())); - let projected_splats = projected_splats.slice(s![0..num_visible]); + let projected_splats = projected_splats.slice(s![0..num_visible, ..]); validate_tensor_val(&projected_splats, "projected_splats", None, None); - } - let visible: Tensor = - Tensor::from_primitive(TensorPrimitive::Float(self.visible.clone())); - validate_tensor_val(&visible, "visible", None, None); - - let tile_offsets: Tensor = Tensor::from_primitive(self.tile_offsets.clone()); - - let tile_offsets = tile_offsets - .into_data() - .into_vec::() - .expect("Failed to fetch tile offsets"); - for &offsets in &tile_offsets { - assert!( - offsets as i32 <= num_intersections, - "Tile offsets exceed bounds. Value: {offsets}, num_intersections: {num_intersections}" - ); - } + let global_from_compact_gid: Tensor = + Tensor::from_primitive(self.global_from_compact_gid.clone()); + let global_from_compact_gid = &global_from_compact_gid + .into_data() + .into_vec::() + .expect("Failed to fetch global_from_compact_gid") + [0..num_visible as usize]; - if num_intersections > 0 { - for i in 0..(tile_offsets.len() - 1) / 2 { - // Check pairs of start/end points. - let start = tile_offsets[i * 2] as i32; - let end = tile_offsets[i * 2 + 1] as i32; + for &global_gid in global_from_compact_gid { assert!( - start < num_intersections && end <= num_intersections, - "Invalid elements in tile offsets. Start {start} ending at {end}" - ); - assert!( - end >= start, - "Invalid elements in tile offsets. Start {start} ending at {end}" - ); - assert!( - end - start <= num_visible as i32, - "One tile has more hits than total visible splats. Start {start} ending at {end}" + global_gid < total_splats, + "Invalid gaussian ID in global_from_compact_gid: {global_gid} >= {total_splats}" ); } } + } + } +} - if num_intersections > 0 { - let compact_gid_from_isect = &compact_gid_from_isect - .slice([0..num_intersections as usize]) - .into_data() - .into_vec::() - .expect("Failed to fetch compact_gid_from_isect"); +/// Minimal output from rendering. Contains only what callers typically need. +#[derive(Debug, Clone)] +pub struct RenderAux { + pub num_visible: IntTensor, + pub num_intersections: u32, + pub visible: FloatTensor, + pub tile_offsets: IntTensor, + pub img_size: glam::UVec2, +} - for (i, &compact_gid) in compact_gid_from_isect.iter().enumerate() { - assert!( - compact_gid < num_visible, - "Invalid gaussian ID in intersection buffer. {compact_gid} out of {num_visible}. At {i} out of {num_intersections} intersections. \n +impl RenderAux { + /// Get `num_visible` as a tensor. + pub fn get_num_visible(&self) -> Tensor { + Tensor::from_primitive(self.num_visible.clone()) + } - {compact_gid_from_isect:?} + /// Calculate tile depth map for visualization. + pub fn calc_tile_depth(&self) -> Tensor { + use crate::shaders::helpers::TILE_WIDTH; + use burn::tensor::s; - \n\n\n" - ); - } + let tile_offsets: Tensor = Tensor::from_primitive(self.tile_offsets.clone()); + let max = tile_offsets.clone().slice(s![.., .., 1]); + let min = tile_offsets.slice(s![.., .., 0]); + let [w, h] = self.img_size.into(); + let [ty, tx] = [h.div_ceil(TILE_WIDTH), w.div_ceil(TILE_WIDTH)]; + (max - min).reshape([ty as usize, tx as usize]) + } + + /// Validate rasterize outputs. + pub fn validate(&self) { + #[cfg(any(test, feature = "debug-validation"))] + { + if std::env::args().any(|a| a == "--bench") { + return; } - // assert that every ID in global_from_compact_gid is valid. - let global_from_compact_gid: Tensor = - Tensor::from_primitive(self.global_from_compact_gid.clone()); - let global_from_compact_gid = &global_from_compact_gid + use crate::validation::validate_tensor_val; + use burn::tensor::{ElementConversion, TensorPrimitive}; + + let num_visible = Tensor::::from_primitive(self.num_visible.clone()) + .into_scalar() + .elem::(); + + let visible: Tensor = + Tensor::from_primitive(TensorPrimitive::Float(self.visible.clone())); + let visible_2d: Tensor = visible.unsqueeze_dim(1); + validate_tensor_val(&visible_2d, "visible", None, None); + + // Validate tile_offsets + let tile_offsets: Tensor = Tensor::from_primitive(self.tile_offsets.clone()); + let tile_offsets_data = tile_offsets .into_data() .into_vec::() - .expect("Failed to fetch global_from_compact_gid")[0..num_visible as usize]; + .expect("Failed to fetch tile offsets"); - for &global_gid in global_from_compact_gid { + for i in 0..(tile_offsets_data.len() / 2) { + let start = tile_offsets_data[i * 2]; + let end = tile_offsets_data[i * 2 + 1]; + assert!( + end >= start, + "Invalid tile offsets: start {start} > end {end}" + ); assert!( - global_gid < num_points, - "Invalid gaussian ID in global_from_compact_gid buffer. {global_gid} out of {num_points}" + end - start <= num_visible, + "Tile has more hits ({}) than visible splats ({num_visible})", + end - start ); } } diff --git a/crates/brush-render/src/shaders.rs b/crates/brush-render/src/shaders.rs index b06e6ef9..54757a0b 100644 --- a/crates/brush-render/src/shaders.rs +++ b/crates/brush-render/src/shaders.rs @@ -13,9 +13,7 @@ pub struct ProjectVisible { } #[wgsl_kernel(source = "src/shaders/map_gaussian_to_intersects.wgsl")] -pub struct MapGaussiansToIntersect { - pub prepass: bool, -} +pub struct MapGaussiansToIntersect; #[wgsl_kernel(source = "src/shaders/rasterize.wgsl")] pub struct Rasterize { @@ -26,8 +24,9 @@ pub struct Rasterize { pub mod helpers { // Types used by multiple shaders - available from project_visible pub use super::project_visible::PackedVec3; + pub use super::project_visible::ProjectUniforms; pub use super::project_visible::ProjectedSplat; - pub use super::project_visible::RenderUniforms; + pub use super::rasterize::RasterizeUniforms; // Constants are now associated with the kernel structs pub const COV_BLUR: f32 = super::ProjectVisible::COV_BLUR; diff --git a/crates/brush-render/src/shaders/helpers.wgsl b/crates/brush-render/src/shaders/helpers.wgsl index 7a610942..d840fcec 100644 --- a/crates/brush-render/src/shaders/helpers.wgsl +++ b/crates/brush-render/src/shaders/helpers.wgsl @@ -38,7 +38,8 @@ fn map_1d_to_2d(id: u32, tiles_per_row: u32) -> vec2 { return vec2u(tile_x * TILE_WIDTH, tile_y * TILE_WIDTH) + decode_morton_2d(within_tile_id); } -struct RenderUniforms { +// Uniforms for projection passes. +struct ProjectUniforms { // View matrix transform world to view position. viewmat: mat4x4f, @@ -57,19 +58,16 @@ struct RenderUniforms { // Degree of sh coefficients used. sh_degree: u32, -#ifdef UNIFORM_WRITE - // Number of visible gaussians, written by project_forward. - // This needs to be non-atomic for other kernels as you can't have - // read-only atomic data. - num_visible: atomic, -#else - // Number of visible gaussians. - num_visible: u32, -#endif - total_splats: u32, - max_intersects: u32, + pad_a: u32, + pad_b: u32, +} + +// Uniforms for rasterize pass. +struct RasterizeUniforms { + tile_bounds: vec2u, + img_size: vec2u, // Nb: Alpha is ignored atm. background: vec4f, } diff --git a/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl b/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl index 927f258b..557e2a5e 100644 --- a/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl +++ b/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl @@ -1,16 +1,16 @@ #import helpers; -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; +@group(0) @binding(0) var num_visible: u32; @group(0) @binding(1) var projected: array; +@group(0) @binding(2) var splat_cum_hit_counts: array; +@group(0) @binding(3) var tile_id_from_isect: array; +@group(0) @binding(4) var compact_gid_from_isect: array; -#ifdef PREPASS - @group(0) @binding(2) var splat_intersect_counts: array; -#else - @group(0) @binding(2) var splat_cum_hit_counts: array; - @group(0) @binding(3) var tile_id_from_isect: array; - @group(0) @binding(4) var compact_gid_from_isect: array; - @group(0) @binding(5) var num_intersections: array; -#endif +// Uniforms passed via with_metadata (always last binding) +struct Uniforms { + tile_bounds: vec2u, +} +@group(0) @binding(5) var uniforms: Uniforms; const WG_SIZE: u32 = 256u; @@ -23,13 +23,7 @@ fn main( ) { let compact_gid = helpers::get_global_id(wid, num_wgs, lid, WG_SIZE); -#ifndef PREPASS - if compact_gid == 0u { - num_intersections[0] = splat_cum_hit_counts[uniforms.num_visible]; - } -#endif - - if compact_gid >= uniforms.num_visible { + if compact_gid >= num_visible { return; } @@ -45,11 +39,10 @@ fn main( let tile_bbox_min = tile_bbox.xy; let tile_bbox_max = tile_bbox.zw; - var num_tiles_hit = 0u; + // With inclusive prefix sum, use cum[compact_gid - 1] as base (or 0 for first element) + let base_isect_id = select(splat_cum_hit_counts[compact_gid - 1u], 0u, compact_gid == 0u); - #ifndef PREPASS - let base_isect_id = splat_cum_hit_counts[compact_gid]; - #endif + var num_tiles_hit = 0u; // Nb: It's really really important here the two dispatches // of this kernel arrive at the exact same num_tiles_hit count. Otherwise @@ -67,20 +60,14 @@ fn main( if helpers::will_primitive_contribute(rect, mean2d, conic, power_threshold) { let tile_id = tx + ty * uniforms.tile_bounds.x; - #ifndef PREPASS let isect_id = base_isect_id + num_tiles_hit; // Nb: isect_id MIGHT be out of bounds here for degenerate cases. // These kernels should be launched with bounds checking, so that these // writes are ignored. This will skip these intersections. tile_id_from_isect[isect_id] = tile_id; compact_gid_from_isect[isect_id] = compact_gid; - #endif num_tiles_hit += 1u; } } - - #ifdef PREPASS - splat_intersect_counts[compact_gid + 1u] = num_tiles_hit; - #endif } diff --git a/crates/brush-render/src/shaders/project_forward.wgsl b/crates/brush-render/src/shaders/project_forward.wgsl index be604f56..0abe49b8 100644 --- a/crates/brush-render/src/shaders/project_forward.wgsl +++ b/crates/brush-render/src/shaders/project_forward.wgsl @@ -1,17 +1,13 @@ -#define UNIFORM_WRITE - #import helpers; -// Unfiroms contains the splat count which we're writing to. -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; - -@group(0) @binding(1) var means: array; -@group(0) @binding(2) var quats: array; -@group(0) @binding(3) var log_scales: array; -@group(0) @binding(4) var raw_opacities: array; - -@group(0) @binding(5) var global_from_compact_gid: array; -@group(0) @binding(6) var depths: array; +@group(0) @binding(0) var means: array; +@group(0) @binding(1) var quats: array; +@group(0) @binding(2) var log_scales: array; +@group(0) @binding(3) var raw_opacities: array; +@group(0) @binding(4) var global_from_compact_gid: array; +@group(0) @binding(5) var depths: array; +@group(0) @binding(6) var num_visible: atomic; +@group(0) @binding(7) var uniforms: helpers::ProjectUniforms; const WG_SIZE: u32 = 256u; @@ -77,7 +73,7 @@ fn main( return; } // Now write all the data to the buffers. - let write_id = atomicAdd(&uniforms.num_visible, 1u); + let write_id = atomicAdd(&num_visible, 1u); global_from_compact_gid[write_id] = global_gid; depths[write_id] = mean_c.z; } diff --git a/crates/brush-render/src/shaders/project_visible.wgsl b/crates/brush-render/src/shaders/project_visible.wgsl index 3853fed0..1d78a8db 100644 --- a/crates/brush-render/src/shaders/project_visible.wgsl +++ b/crates/brush-render/src/shaders/project_visible.wgsl @@ -5,8 +5,7 @@ struct IsectInfo { tile_id: u32, } -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; - +@group(0) @binding(0) var num_visible: u32; @group(0) @binding(1) var means: array; @group(0) @binding(2) var log_scales: array; @group(0) @binding(3) var quats: array; @@ -14,6 +13,8 @@ struct IsectInfo { @group(0) @binding(5) var raw_opacities: array; @group(0) @binding(6) var global_from_compact_gid: array; @group(0) @binding(7) var projected: array; +@group(0) @binding(8) var splat_intersect_counts: array; +@group(0) @binding(9) var uniforms: helpers::ProjectUniforms; struct ShCoeffs { b0_c0: vec3f, @@ -171,7 +172,7 @@ fn main( ) { let compact_gid = helpers::get_global_id(wid, num_wgs, lid, WG_SIZE); - if compact_gid >= uniforms.num_visible { + if compact_gid >= num_visible { return; } @@ -246,9 +247,34 @@ fn main( let viewdir = normalize(mean - uniforms.camera_position.xyz); var color = sh_coeffs_to_color(sh_degree, viewdir, sh) + vec3f(0.5); + let conic_packed = vec3f(conic[0][0], conic[0][1], conic[1][1]); + projected[compact_gid] = helpers::create_projected_splat( mean2d, - vec3f(conic[0][0], conic[0][1], conic[1][1]), + conic_packed, vec4f(color, opac) ); + + // Count intersections for this splat (merged from map_gaussian_to_intersects prepass) + let power_threshold = log(opac * 255.0); + let extent = helpers::compute_bbox_extent(cov2d, power_threshold); + let tile_bbox = helpers::get_tile_bbox(mean2d, extent, uniforms.tile_bounds); + let tile_bbox_min = tile_bbox.xy; + let tile_bbox_max = tile_bbox.zw; + + var num_tiles_hit = 0u; + let tile_bbox_width = tile_bbox_max.x - tile_bbox_min.x; + let num_tiles_bbox = (tile_bbox_max.y - tile_bbox_min.y) * tile_bbox_width; + + for (var tile_idx = 0u; tile_idx < num_tiles_bbox; tile_idx++) { + let tx = (tile_idx % tile_bbox_width) + tile_bbox_min.x; + let ty = (tile_idx / tile_bbox_width) + tile_bbox_min.y; + + let rect = helpers::tile_rect(vec2u(tx, ty)); + if helpers::will_primitive_contribute(rect, mean2d, conic_packed, power_threshold) { + num_tiles_hit += 1u; + } + } + + splat_intersect_counts[compact_gid] = num_tiles_hit; } diff --git a/crates/brush-render/src/shaders/rasterize.wgsl b/crates/brush-render/src/shaders/rasterize.wgsl index 29c3d437..8883feaf 100644 --- a/crates/brush-render/src/shaders/rasterize.wgsl +++ b/crates/brush-render/src/shaders/rasterize.wgsl @@ -1,16 +1,17 @@ #import helpers -@group(0) @binding(0) var uniforms: helpers::RenderUniforms; -@group(0) @binding(1) var compact_gid_from_isect: array; -@group(0) @binding(2) var tile_offsets: array; -@group(0) @binding(3) var projected: array; +@group(0) @binding(0) var compact_gid_from_isect: array; +@group(0) @binding(1) var tile_offsets: array; +@group(0) @binding(2) var projected: array; #ifdef BWD_INFO - @group(0) @binding(4) var out_img: array; - @group(0) @binding(5) var global_from_compact_gid: array; - @group(0) @binding(6) var visible: array; + @group(0) @binding(3) var out_img: array; + @group(0) @binding(4) var global_from_compact_gid: array; + @group(0) @binding(5) var visible: array; + @group(0) @binding(6) var uniforms: helpers::RasterizeUniforms; #else - @group(0) @binding(4) var out_img: array; + @group(0) @binding(3) var out_img: array; + @group(0) @binding(4) var uniforms: helpers::RasterizeUniforms; #endif var range_uniform: vec2u; diff --git a/crates/brush-render/src/tests/mod.rs b/crates/brush-render/src/tests/mod.rs index 8fc579a9..776f41bb 100644 --- a/crates/brush-render/src/tests/mod.rs +++ b/crates/brush-render/src/tests/mod.rs @@ -1,11 +1,15 @@ -use crate::{MainBackend, SplatForward, camera::Camera, gaussian_splats::SplatRenderMode}; +use crate::{ + MainBackend, TextureMode, + camera::Camera, + gaussian_splats::{SplatRenderMode, Splats, render_splats}, +}; use assert_approx_eq::assert_approx_eq; -use burn::tensor::{Distribution, Tensor, TensorPrimitive}; +use burn::tensor::{Distribution, Tensor}; use burn_wgpu::WgpuDevice; use glam::Vec3; -#[test] -fn renders_at_all() { +#[tokio::test] +async fn renders_at_all() { // Check if rendering doesn't hard crash or anything. // These are some zero-sized gaussians, so we know // what the result should look like. @@ -27,21 +31,18 @@ fn renders_at_all() { .repeat_dim(0, num_points); let sh_coeffs = Tensor::::ones([num_points, 1, 3], &device); let raw_opacity = Tensor::::zeros([num_points], &device); - let (output, aux) = >::render_splats( - &cam, - img_size, - means.into_primitive().tensor(), - log_scales.into_primitive().tensor(), - quats.into_primitive().tensor(), - sh_coeffs.into_primitive().tensor(), - raw_opacity.into_primitive().tensor(), + + let splats = Splats::from_tensor_data( + means, + quats, + log_scales, + sh_coeffs, + raw_opacity, SplatRenderMode::Default, - Vec3::ZERO, - true, ); - aux.validate_values(); + let (output, _render_aux) = + render_splats(splats, &cam, img_size, Vec3::ZERO, None, TextureMode::Float).await; - let output: Tensor = Tensor::from_primitive(TensorPrimitive::Float(output)); let rgb = output.clone().slice([0..32, 0..32, 0..3]); let alpha = output.slice([0..32, 0..32, 3..4]); let rgb_mean = rgb.mean().to_data().as_slice::().expect("Wrong type")[0]; @@ -54,8 +55,8 @@ fn renders_at_all() { assert_approx_eq!(alpha_mean, 0.0); } -#[test] -fn renders_many_splats() { +#[tokio::test] +async fn renders_many_splats() { // Test rendering with a ton of gaussians to verify 2D dispatch works correctly. // This exceeds the 1D 65535 * 256 = 16.7M limit. let num_splats = 30_000_000; @@ -97,17 +98,14 @@ fn renders_many_splats() { let raw_opacity = Tensor::::random([num_splats], Distribution::Uniform(-2.0, 2.0), &device); - let (_output, aux) = >::render_splats( - &cam, - img_size, - means.into_primitive().tensor(), - log_scales.into_primitive().tensor(), - quats.into_primitive().tensor(), - sh_coeffs.into_primitive().tensor(), - raw_opacity.into_primitive().tensor(), + let splats = Splats::from_tensor_data( + means, + quats, + log_scales, + sh_coeffs, + raw_opacity, SplatRenderMode::Default, - Vec3::ZERO, - true, ); - aux.validate_values(); + let (_output, _render_aux) = + render_splats(splats, &cam, img_size, Vec3::ZERO, None, TextureMode::Float).await; } diff --git a/crates/brush-rerun/src/visualize_tools.rs b/crates/brush-rerun/src/visualize_tools.rs index ce119049..79557af6 100644 --- a/crates/brush-rerun/src/visualize_tools.rs +++ b/crates/brush-rerun/src/visualize_tools.rs @@ -217,12 +217,13 @@ mod visualize_tools_impl { } #[allow(unused_variables)] - pub fn log_splat_stats(&self, iter: u32, splats: &Splats) -> Result<()> { + pub fn log_splat_stats(&self, iter: u32, num_splats: u32) -> Result<()> { if self.rec.is_enabled() { self.rec.set_time_sequence("iterations", iter); - let num = splats.num_splats(); - self.rec - .log("splats/num_splats", &rerun::Scalars::new(vec![num as f64]))?; + self.rec.log( + "splats/num_splats", + &rerun::Scalars::new(vec![num_splats as f64]), + )?; } Ok(()) } @@ -245,17 +246,6 @@ mod visualize_tools_impl { .log("lr/coeffs", &rerun::Scalars::new(vec![stats.lr_coeffs]))?; self.rec .log("lr/opac", &rerun::Scalars::new(vec![stats.lr_opac]))?; - - self.rec.log( - "splats/num_intersects", - &rerun::Scalars::new(vec![ - stats - .num_intersections - .into_scalar_async() - .await? - .elem::(), - ]), - )?; self.rec.log( "splats/splats_visible", &rerun::Scalars::new(vec![ @@ -370,7 +360,7 @@ mod visualize_tools_impl { #[allow(unused_variables)] #[allow(clippy::unnecessary_wraps, clippy::unused_self)] - pub fn log_splat_stats(&self, _iter: u32, _splats: &Splats) -> Result<()> { + pub fn log_splat_stats(&self, _iter: u32, _num_splats: u32) -> Result<()> { Ok(()) } diff --git a/crates/brush-sort/src/lib.rs b/crates/brush-sort/src/lib.rs index 4ef0e57c..cb63849f 100644 --- a/crates/brush-sort/src/lib.rs +++ b/crates/brush-sort/src/lib.rs @@ -1,4 +1,5 @@ use brush_kernel::CubeCount; +use brush_kernel::calc_cube_count_1d; use brush_kernel::create_dispatch_buffer_1d; use brush_kernel::create_tensor; use brush_kernel::create_uniform_buffer; @@ -33,17 +34,21 @@ use sort_count::Uniforms; const BLOCK_SIZE: u32 = SortCount::WG * SortCount::ELEMENTS_PER_THREAD; +/// Perform a radix argsort on the input keys and values. +/// +/// If `dynamic_count` is `Some(count_buffer)`, use that buffer as the actual number +/// of keys to sort (uses dynamic GPU dispatch). If `None`, use the full buffer length +/// with static CPU dispatch. pub fn radix_argsort( input_keys: CubeTensor, input_values: CubeTensor, - n_sort: &CubeTensor, sorting_bits: u32, + dynamic_count: Option>, ) -> (CubeTensor, CubeTensor) { assert_eq!( input_keys.shape.dims[0], input_values.shape.dims[0], "Input keys and values must have the same number of elements" ); - assert_eq!(n_sort.shape.dims[0], 1, "Sort count must have one element"); assert!(sorting_bits <= 32, "Can only sort up to 32 bits"); assert!( input_keys.is_contiguous(), @@ -64,11 +69,31 @@ pub fn radix_argsort( let max_needed_wgs = max_n.div_ceil(BLOCK_SIZE); - let num_wgs = create_dispatch_buffer_1d(n_sort.clone(), BLOCK_SIZE); - let num_reduce_wgs: Tensor, 1, Int> = - Tensor::from_primitive(create_dispatch_buffer_1d(num_wgs.clone(), BLOCK_SIZE)) - * Tensor::from_ints([SortCount::BIN_COUNT, 1, 1], device); - let num_reduce_wgs: CubeTensor = num_reduce_wgs.into_primitive(); + // Handle dynamic vs static dispatch + let (num_keys_buf, num_wgs, num_reduce_wgs) = if let Some(count_buf) = dynamic_count { + let num_wgs = create_dispatch_buffer_1d(count_buf.clone(), BLOCK_SIZE); + let num_reduce_wgs: Tensor, 1, Int> = + Tensor::from_primitive(create_dispatch_buffer_1d(num_wgs.clone(), BLOCK_SIZE)) + * Tensor::from_ints([SortCount::BIN_COUNT, 1, 1], device); + let num_reduce_wgs: CubeTensor = num_reduce_wgs.into_primitive(); + ( + count_buf, + CubeCount::Dynamic(num_wgs.handle.binding()), + CubeCount::Dynamic(num_reduce_wgs.handle.binding()), + ) + } else { + // Static dispatch: use full buffer size + let num_keys_buf = { + type Backend = CubeBackend; + Tensor::::from_ints([max_n as i32], device).into_primitive() + }; + // Calculate dispatch counts matching the original formula + let num_wgs_count = max_n.div_ceil(BLOCK_SIZE); + let num_reduce_wgs_count = num_wgs_count.div_ceil(BLOCK_SIZE) * SortCount::BIN_COUNT; + let num_wgs = calc_cube_count_1d(max_n, BLOCK_SIZE); + let num_reduce_wgs = calc_cube_count_1d(num_reduce_wgs_count, 1); + (num_keys_buf, num_wgs, num_reduce_wgs) + }; let mut cur_keys = input_keys; let mut cur_vals = input_values; @@ -79,14 +104,14 @@ pub fn radix_argsort( let count_buf = create_tensor([(max_needed_wgs as usize) * 16], device, DType::I32); - // use safe distpatch as dynamic work count isn't verified. + // use safe dispatch as dynamic work count isn't verified. client .launch( SortCount::task(), - CubeCount::Dynamic(num_wgs.clone().handle.binding()), + num_wgs.clone(), Bindings::new().with_buffers(vec![ uniforms_buffer.handle.clone().binding(), - n_sort.handle.clone().binding(), + num_keys_buf.handle.clone().binding(), cur_keys.handle.clone().binding(), count_buf.handle.clone().binding(), ]), @@ -99,9 +124,9 @@ pub fn radix_argsort( client .launch( SortReduce::task(), - CubeCount::Dynamic(num_reduce_wgs.handle.clone().binding()), + num_reduce_wgs.clone(), Bindings::new().with_buffers(vec![ - n_sort.handle.clone().binding(), + num_keys_buf.handle.clone().binding(), count_buf.handle.clone().binding(), reduced_buf.handle.clone().binding(), ]), @@ -115,7 +140,7 @@ pub fn radix_argsort( SortScan::task(), CubeCount::Static(1, 1, 1), Bindings::new().with_buffers(vec![ - n_sort.handle.clone().binding(), + num_keys_buf.handle.clone().binding(), reduced_buf.handle.clone().binding(), ]), ) @@ -125,9 +150,9 @@ pub fn radix_argsort( client .launch( SortScanAdd::task(), - CubeCount::Dynamic(num_reduce_wgs.handle.clone().binding()), + num_reduce_wgs.clone(), Bindings::new().with_buffers(vec![ - n_sort.handle.clone().binding(), + num_keys_buf.handle.clone().binding(), reduced_buf.handle.clone().binding(), count_buf.handle.clone().binding(), ]), @@ -141,10 +166,10 @@ pub fn radix_argsort( client .launch( SortScatter::task(), - CubeCount::Dynamic(num_wgs.handle.clone().binding()), + num_wgs.clone(), Bindings::new().with_buffers(vec![ uniforms_buffer.handle.clone().binding(), - n_sort.handle.clone().binding(), + num_keys_buf.handle.clone().binding(), cur_keys.handle.clone().binding(), cur_vals.handle.clone().binding(), count_buf.handle.clone().binding(), @@ -200,12 +225,9 @@ mod tests { let device = Default::default(); let keys = Tensor::::from_ints(keys_inp, &device).into_primitive(); - let values = Tensor::::from_ints(values_inp.as_slice(), &device) .into_primitive(); - let num_points = Tensor::::from_ints([keys_inp.len() as i32], &device) - .into_primitive(); - let (ret_keys, ret_values) = radix_argsort(keys, values, &num_points, 32); + let (ret_keys, ret_values) = radix_argsort(keys, values, 32, None); let ret_keys = Tensor::::from_primitive(ret_keys).into_data(); @@ -253,10 +275,7 @@ mod tests { Tensor::::from_ints(keys_inp.as_slice(), &device).into_primitive(); let values = Tensor::::from_ints(values_inp.as_slice(), &device).into_primitive(); - let num_points = - Tensor::::from_ints([keys_inp.len() as i32], &device).into_primitive(); - - let (ret_keys, ret_values) = radix_argsort(keys, values, &num_points, 32); + let (ret_keys, ret_values) = radix_argsort(keys, values, 32, None); let ret_keys = Tensor::::from_primitive(ret_keys).to_data(); let ret_values = Tensor::::from_primitive(ret_values).to_data(); @@ -296,10 +315,7 @@ mod tests { Tensor::::from_ints(keys_inp.as_slice(), &device).into_primitive(); let values = Tensor::::from_ints(values_inp.as_slice(), &device).into_primitive(); - let num_points = - Tensor::::from_ints([NUM_ELEMENTS as i32], &device).into_primitive(); - - let (ret_keys, ret_values) = radix_argsort(keys, values, &num_points, 32); + let (ret_keys, ret_values) = radix_argsort(keys, values, 32, None); let ret_keys = Tensor::::from_primitive(ret_keys).to_data(); let ret_values = Tensor::::from_primitive(ret_values).to_data(); diff --git a/crates/brush-train/src/eval.rs b/crates/brush-train/src/eval.rs index c3e6c97f..662b454b 100644 --- a/crates/brush-train/src/eval.rs +++ b/crates/brush-train/src/eval.rs @@ -5,10 +5,9 @@ use anyhow::Result; use brush_dataset::scene::{sample_to_tensor_data, view_to_sample_image}; use brush_render::camera::Camera; use brush_render::gaussian_splats::Splats; -use brush_render::render_aux::RenderAux; -use brush_render::{AlphaMode, SplatForward}; +use brush_render::{AlphaMode, RenderAux, SplatOps, TextureMode, render_splats}; use burn::prelude::Backend; -use burn::tensor::{Tensor, TensorPrimitive, s}; +use burn::tensor::{Tensor, s}; use glam::Vec3; use image::DynamicImage; @@ -19,11 +18,11 @@ pub struct EvalSample { pub rendered: Tensor, pub psnr: Tensor, pub ssim: Tensor, - pub aux: RenderAux, + pub render_aux: RenderAux, } -pub fn eval_stats>( - splats: &Splats, +pub async fn eval_stats>( + splats: Splats, gt_cam: &Camera, gt_img: DynamicImage, alpha_mode: AlphaMode, @@ -36,22 +35,9 @@ pub fn eval_stats>( let gt_tensor = Tensor::from_data(gt_tensor, device); let gt_rgb = gt_tensor.slice(s![.., .., 0..3]); - // Render on reference black background. - let (img, aux) = { - let (img, aux) = B::render_splats( - gt_cam, - res, - splats.means.val().into_primitive().tensor(), - splats.log_scales.val().into_primitive().tensor(), - splats.rotations.val().into_primitive().tensor(), - splats.sh_coeffs.val().into_primitive().tensor(), - splats.raw_opacities.val().into_primitive().tensor(), - splats.render_mode, - Vec3::ZERO, - true, - ); - (Tensor::from_primitive(TensorPrimitive::Float(img)), aux) - }; + // Render on reference black background - async readback + let (img, render_aux) = + render_splats(splats, gt_cam, res, Vec3::ZERO, None, TextureMode::Float).await; let render_rgb = img.slice(s![.., .., 0..3]); // Simulate an 8-bit roundtrip for fair comparison. @@ -68,7 +54,7 @@ pub fn eval_stats>( psnr, ssim, rendered: render_rgb, - aux, + render_aux, }) } diff --git a/crates/brush-train/src/msg.rs b/crates/brush-train/src/msg.rs index 780d8c37..a65d5a6d 100644 --- a/crates/brush-train/src/msg.rs +++ b/crates/brush-train/src/msg.rs @@ -7,13 +7,13 @@ use burn::{ pub struct RefineStats { pub num_added: u32, pub num_pruned: u32, + pub total_splats: u32, } #[derive(Clone)] pub struct TrainStepStats { pub pred_image: Tensor, - pub num_intersections: Tensor, pub num_visible: Tensor, pub loss: Tensor, diff --git a/crates/brush-train/src/train.rs b/crates/brush-train/src/train.rs index c7b80fb4..d6673337 100644 --- a/crates/brush-train/src/train.rs +++ b/crates/brush-train/src/train.rs @@ -34,7 +34,7 @@ use burn::{ use burn_cubecl::cubecl::Runtime; use glam::Vec3; use hashbrown::{HashMap, HashSet}; -use tracing::trace_span; +use tracing::{Instrument, trace_span}; pub const BOUND_PERCENTILE: f32 = 0.8; @@ -99,48 +99,41 @@ impl SplatTrainer { } } - pub fn step( + pub async fn step( &mut self, batch: SceneBatch, splats: Splats, ) -> (Splats, TrainStepStats) { - let _span = trace_span!("Train step").entered(); - let mut splats = splats; let [img_h, img_w, _] = batch.img_tensor.shape.clone().try_into().unwrap(); - let camera = &batch.camera; + let camera = batch.camera.clone(); // Upload tensor early. let device = splats.device(); let has_alpha = batch.has_alpha(); let gt_tensor = Tensor::from_data(batch.img_tensor, &device); - let (pred_image, aux, refine_weight_holder) = trace_span!("Forward").in_scope(|| { - // Could generate a random background color, but so far - // results just seem worse. - let background = Vec3::ZERO; + // Forward pass - render splats asynchronously. + // Clone splats to avoid holding references across the await. + let background = Vec3::ZERO; + let img_size = glam::uvec2(img_w as u32, img_h as u32); - let diff_out = render_splats( - &splats, - camera, - glam::uvec2(img_w as u32, img_h as u32), - background, - ); + let diff_out = render_splats(splats.clone(), &camera, img_size, background) + .instrument(trace_span!("Forward")) + .await; - let img = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); - - (img, diff_out.aux, diff_out.refine_weight_holder) - }); + let pred_image = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)); + let render_aux = diff_out.render_aux; + let refine_weight_holder = diff_out.refine_weight_holder; let median_scale = self.bounds.median_size(); - let num_visible = aux.num_visible().inner(); - let num_intersections = aux.num_intersections().inner(); + let num_visible = render_aux.get_num_visible().inner(); let pred_rgb = pred_image.clone().slice(s![.., .., 0..3]); let gt_rgb = gt_tensor.clone().slice(s![.., .., 0..3]); let visible: Tensor, 1> = - Tensor::from_primitive(TensorPrimitive::Float(aux.visible)); + Tensor::from_primitive(TensorPrimitive::Float(render_aux.visible)); let loss = trace_span!("Calculate losses").in_scope(|| { let l1_rgb = (pred_rgb.clone() - gt_rgb.clone()).abs(); @@ -273,7 +266,6 @@ impl SplatTrainer { let stats = TrainStepStats { pred_image: pred_image.inner(), num_visible, - num_intersections, loss: loss.inner(), lr_mean, lr_rotation, @@ -389,11 +381,14 @@ impl SplatTrainer { self.bounds = get_splat_bounds(splats.clone(), BOUND_PERCENTILE).await; client.memory_cleanup(); + let splat_count = splats.num_splats(); + ( splats, RefineStats { num_added: refine_count as u32, num_pruned: pruned_count, + total_splats: splat_count, }, ) } diff --git a/crates/brush-ui/src/app.rs b/crates/brush-ui/src/app.rs index 166fa4c4..eb0da69f 100644 --- a/crates/brush-ui/src/app.rs +++ b/crates/brush-ui/src/app.rs @@ -230,7 +230,7 @@ impl App { ); log::info!("Connecting context to Burn device & GUI context."); - let context = std::sync::Arc::new(UiProcess::new(burn_device.clone(), cc.egui_ctx.clone())); + let context = std::sync::Arc::new(UiProcess::new(burn_device, cc.egui_ctx.clone())); if let Some(process) = init_process { context.connect_to_process(process); @@ -249,13 +249,7 @@ impl App { // Initialize all panels with runtime state for (_, tile) in tree.tiles.iter_mut() { if let egui_tiles::Tile::Pane(pane) = tile { - pane.get_mut().as_pane_mut().init( - state.device.clone(), - state.queue.clone(), - state.renderer.clone(), - burn_device.clone(), - state.adapter.get_info(), - ); + pane.get_mut().as_pane_mut().init(state); } } diff --git a/crates/brush-ui/src/burn_texture.rs b/crates/brush-ui/src/burn_texture.rs deleted file mode 100644 index 89da1eb7..00000000 --- a/crates/brush-ui/src/burn_texture.rs +++ /dev/null @@ -1,179 +0,0 @@ -use std::sync::Arc; - -use brush_render::{MainBackend, MainBackendBase}; -use burn::tensor::{Tensor, TensorPrimitive}; -use burn_cubecl::cubecl::Runtime; -use burn_wgpu::WgpuRuntime; -use eframe::egui_wgpu::Renderer; -use egui::TextureId; -use egui::epaint::mutex::RwLock as EguiRwLock; -use wgpu::{CommandEncoderDescriptor, TexelCopyBufferLayout, TextureViewDescriptor}; - -struct TextureState { - texture: wgpu::Texture, - id: TextureId, -} - -pub struct BurnTexture { - state: Option, - device: wgpu::Device, - queue: wgpu::Queue, - renderer: Arc>, -} - -fn create_texture(size: glam::UVec2, device: &wgpu::Device) -> wgpu::Texture { - device.create_texture(&wgpu::TextureDescriptor { - label: Some("Splat backbuffer"), - size: wgpu::Extent3d { - width: size.x, - height: size.y, - depth_or_array_layers: 1, - }, - mip_level_count: 1, - sample_count: 1, - dimension: wgpu::TextureDimension::D2, - format: wgpu::TextureFormat::Rgba8Unorm, - usage: wgpu::TextureUsages::TEXTURE_BINDING - | wgpu::TextureUsages::COPY_DST - | wgpu::TextureUsages::RENDER_ATTACHMENT, - view_formats: &[wgpu::TextureFormat::Rgba8Unorm], - }) -} - -impl BurnTexture { - pub fn new( - renderer: Arc>, - device: wgpu::Device, - queue: wgpu::Queue, - ) -> Self { - Self { - state: None, - device, - queue, - renderer, - } - } - - pub fn update_texture(&mut self, img: Tensor) -> TextureId { - let [h, w, c] = img.shape().dims(); - assert!(c == 1, "texture should be u8 packed RGBA"); - let size = glam::uvec2(w as u32, h as u32); - - let dirty = if let Some(s) = self.state.as_ref() { - s.texture.width() != size.x || s.texture.height() != size.y - } else { - true - }; - - if dirty { - // Resizing has some really bad memory profiles, so cleanup memory when it's detected. - let client = WgpuRuntime::client(&img.device()); - client.memory_cleanup(); - - let texture = create_texture(glam::uvec2(w as u32, h as u32), &self.device); - - if let Some(s) = self.state.as_mut() { - s.texture = texture; - - self.renderer.write().update_egui_texture_from_wgpu_texture( - &self.device, - &s.texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - s.id, - ); - } else { - let id = self.renderer.write().register_native_texture( - &self.device, - &texture.create_view(&TextureViewDescriptor::default()), - wgpu::FilterMode::Linear, - ); - self.state = Some(TextureState { texture, id }); - } - } - - let Some(s) = self.state.as_ref() else { - unreachable!("Somehow failed to initialize") - }; - let texture: &wgpu::Texture = &s.texture; - let [height, width, c] = img.dims(); - - let mut encoder = self - .device - .create_command_encoder(&CommandEncoderDescriptor { - label: Some("viewer encoder"), - }); - - let padded_shape = vec![height, width.div_ceil(64) * 64, c]; - - let img_prim = img.into_primitive().tensor(); - let fusion_client = img_prim.client.clone(); - let img = fusion_client.resolve_tensor_float::(img_prim); - let img: Tensor = Tensor::from_primitive(TensorPrimitive::Float(img)); - - // Create padded tensor if needed. The bytes_per_row needs to be divisible - // by 256 in WebGPU, so 4 bytes per pixel means width needs to be divisible by 64. - let img = if width % 64 != 0 { - let padded: Tensor = Tensor::zeros(&padded_shape, &img.device()); - padded.slice_assign([0..height, 0..width], img) - } else { - img - }; - - let img = img.into_primitive().tensor(); - - // Get a hold of the Burn resource. - let client = &img.client; - let img_res_handle = client.get_resource(img.handle.clone().binding()); - - // Now flush commands to make sure the resource is fully ready. - client.flush(); - - // Put compute passes in encoder before copying the buffer. - let bytes_per_row = Some(4 * padded_shape[1] as u32); - - // Now copy the buffer to the texture. - encoder.copy_buffer_to_texture( - wgpu::TexelCopyBufferInfo { - buffer: &img_res_handle.resource().buffer, - layout: TexelCopyBufferLayout { - offset: img_res_handle.resource().offset, - bytes_per_row, - rows_per_image: None, - }, - }, - wgpu::TexelCopyTextureInfo { - texture, - mip_level: 0, - origin: wgpu::Origin3d { x: 0, y: 0, z: 0 }, - aspect: wgpu::TextureAspect::All, - }, - wgpu::Extent3d { - width: width as u32, - height: height as u32, - depth_or_array_layers: 1, - }, - ); - - self.queue.submit([encoder.finish()]); - - s.id - } - - pub fn id(&self) -> Option { - self.state.as_ref().map(|s| s.id) - } - - pub fn reset(&mut self) { - self.state = None; - } - - /// Get the underlying texture for additional rendering - pub fn texture(&self) -> Option<&wgpu::Texture> { - self.state.as_ref().map(|s| &s.texture) - } - - /// Get device and queue for additional rendering - pub fn device_queue(&self) -> (&wgpu::Device, &wgpu::Queue) { - (&self.device, &self.queue) - } -} diff --git a/crates/brush-ui/src/lib.rs b/crates/brush-ui/src/lib.rs index 5b6e4ee8..c3c0c949 100644 --- a/crates/brush-ui/src/lib.rs +++ b/crates/brush-ui/src/lib.rs @@ -1,13 +1,13 @@ #![recursion_limit = "256"] pub mod app; -pub mod burn_texture; pub mod camera_controls; pub mod ui_process; mod panels; mod scene; +pub mod splat_backbuffer; #[cfg(feature = "training")] mod stats; mod widget_3d; diff --git a/crates/brush-ui/src/panels.rs b/crates/brush-ui/src/panels.rs index 6b83e732..9d1285f0 100644 --- a/crates/brush-ui/src/panels.rs +++ b/crates/brush-ui/src/panels.rs @@ -1,8 +1,5 @@ -use std::sync::Arc; - use brush_process::message::ProcessMessage; -use eframe::egui_wgpu::Renderer; -use egui::mutex::RwLock; +use eframe::egui_wgpu::RenderState; use crate::ui_process::UiProcess; @@ -11,15 +8,7 @@ pub(crate) trait AppPane { /// Initialize runtime state after creation or deserialization. #[allow(unused_variables)] - fn init( - &mut self, - device: wgpu::Device, - queue: wgpu::Queue, - renderer: Arc>, - burn_device: burn_wgpu::WgpuDevice, - adapter_info: wgpu::AdapterInfo, - ) { - } + fn init(&mut self, state: &RenderState) {} /// Draw the pane's UI's content. fn ui(&mut self, ui: &mut egui::Ui, process: &UiProcess); diff --git a/crates/brush-ui/src/scene.rs b/crates/brush-ui/src/scene.rs index d207ef5e..a28fe541 100644 --- a/crates/brush-ui/src/scene.rs +++ b/crates/brush-ui/src/scene.rs @@ -1,47 +1,30 @@ #[cfg(feature = "training")] use crate::settings_popup::SettingsPopup; -#[cfg(feature = "training")] -use brush_process::message::TrainMessage; -#[cfg(feature = "training")] -use std::sync::Mutex; - +use crate::{splat_backbuffer::SplatBackbuffer, widget_3d::GridWidget}; use brush_process::{create_process, message::ProcessMessage}; +use brush_render::camera::{focal_to_fov, fov_to_focal}; use brush_vfs::DataSource; use core::f32; -use egui::{ - Align2, Button, Frame, RichText, containers::Popup, epaint::mutex::RwLock as EguiRwLock, -}; -use std::sync::Arc; - -use brush_render::{ - MainBackend, - camera::{Camera, focal_to_fov, fov_to_focal}, - gaussian_splats::Splats, - render_splats, -}; -use eframe::egui_wgpu::Renderer; +use eframe::egui_wgpu::RenderState; +use egui::{Align2, Button, Frame, RichText, containers::Popup}; use egui::{Color32, Rect, Slider}; -use glam::{UVec2, Vec3}; -use tracing::trace_span; -use web_time::Instant; - +use glam::Vec3; use serde::{Deserialize, Serialize}; +#[cfg(feature = "training")] +use std::sync::{Arc, Mutex}; +use web_time::Instant; /// Controls how often the viewport re-renders during training. #[derive(Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum RenderUpdateMode { - /// Don't re-render during training Off, - /// Re-render every 100 iterations Low, - /// Re-render every 5 iterations (default) #[default] Live, } impl RenderUpdateMode { /// Returns the iteration interval for this mode, or None if rendering is disabled. - #[cfg(feature = "training")] fn update_interval(&self) -> Option { match self { Self::Off => None, @@ -60,24 +43,11 @@ impl RenderUpdateMode { } use crate::{ - UiMode, - app::CameraSettings, - burn_texture::BurnTexture, - draw_checkerboard, + UiMode, draw_checkerboard, panels::AppPane, ui_process::{BackgroundStyle, UiProcess}, - widget_3d::Widget3D, }; -#[derive(Clone, PartialEq)] -struct RenderState { - size: UVec2, - cam: Camera, - settings: CameraSettings, - grid_opacity: f32, - frame: u32, -} - struct ErrorDisplay { headline: String, context: Vec, @@ -108,11 +78,15 @@ impl ErrorDisplay { #[derive(Default, Serialize, Deserialize)] pub struct ScenePanel { #[serde(skip)] - pub(crate) backbuffer: Option, + grid: Option, + #[serde(skip)] + backbuffer: Option, #[serde(skip)] pub(crate) last_draw: Option, #[serde(skip)] has_splats: bool, + #[serde(skip)] + splats_dirty: bool, /// Current frame for animated sequences (as float for smooth interpolation). #[serde(skip)] frame: f32, @@ -130,10 +104,6 @@ pub struct ScenePanel { #[serde(skip)] seen_warning_count: usize, #[serde(skip)] - last_state: Option, - #[serde(skip)] - widget_3d: Option, - #[serde(skip)] source_name: Option, #[serde(skip)] source_type: Option, @@ -250,6 +220,7 @@ impl ScenePanel { load_option } + #[allow(clippy::unused_self)] fn start_loading(&self, source: DataSource, process: &UiProcess) { process.connect_to_process(create_process( source, @@ -264,133 +235,6 @@ impl ScenePanel { )); } - pub(crate) fn draw_splats( - &mut self, - ui: &mut egui::Ui, - process: &UiProcess, - splats: Option>, - interactive: bool, - ) -> egui::Rect { - let size = ui.available_size(); - let size = glam::uvec2(size.x.round() as u32, size.y.round() as u32); - let (rect, response) = ui.allocate_exact_size( - egui::Vec2::new(size.x as f32, size.y as f32), - egui::Sense::drag(), - ); - if interactive { - process.tick_controls(&response, ui); - } - - // Get camera after modifying the controls. - let mut camera = process.current_camera(); - - let view_eff = (camera.world_to_local() * process.model_local_to_world()).inverse(); - let (_, rotation, position) = view_eff.to_scale_rotation_translation(); - camera.position = position; - camera.rotation = rotation; - - let settings = process.get_cam_settings(); - - // Adjust FOV so that the scene view shows at least what's visible in the dataset view. - // The camera has original fov_x and fov_y from the dataset. We need to ensure - // the viewport shows at least that much in both dimensions. - let camera_aspect = (camera.fov_x / 2.0).tan() / (camera.fov_y / 2.0).tan(); - let viewport_aspect = size.x as f64 / size.y as f64; - - if viewport_aspect > camera_aspect { - // Viewport is wider than camera - keep fov_y, expand fov_x - let focal_y = fov_to_focal(camera.fov_y, size.y); - camera.fov_x = focal_to_fov(focal_y, size.x); - } else { - // Viewport is taller than camera - keep fov_x, expand fov_y - let focal_x = fov_to_focal(camera.fov_x, size.x); - camera.fov_y = focal_to_fov(focal_x, size.y); - } - - let grid_opacity = process.get_grid_opacity(); - - let state = RenderState { - size, - cam: camera.clone(), - settings: settings.clone(), - grid_opacity, - frame: self.frame as u32, - }; - - let dirty = self.last_state != Some(state.clone()); - - if dirty { - self.last_state = Some(state); - // Check again next frame, as there might be more to animate. - ui.ctx().request_repaint(); - } - - if let Some(splats) = splats { - let pixel_size = glam::uvec2( - (size.x as f32 * ui.ctx().pixels_per_point().round()) as u32, - (size.y as f32 * ui.ctx().pixels_per_point().round()) as u32, - ); - // If this viewport is re-rendering. - if pixel_size.x > 8 && pixel_size.y > 8 && dirty { - let _span = trace_span!("Render splats").entered(); - // Could add an option for background color. - let (img, _) = render_splats( - &splats, - &camera, - pixel_size, - settings.background.unwrap_or(Vec3::ZERO), - settings.splat_scale, - ); - - if let Some(backbuffer) = &mut self.backbuffer { - backbuffer.update_texture(img); - } - - if let Some(widget_3d) = &mut self.widget_3d - && let Some(backbuffer) = &self.backbuffer - && let Some(texture) = backbuffer.texture() - { - widget_3d.render_to_texture( - &camera, - process.model_local_to_world(), - pixel_size, - texture, - grid_opacity, - ); - } - } - } - - ui.scope(|ui| { - // if training views have alpha, show a background checker. Masked images - // should still use a black background. - match process.background_style() { - BackgroundStyle::Checkerboard => { - draw_checkerboard(ui, rect, Color32::WHITE); - } - BackgroundStyle::Black => { - ui.painter().rect_filled(rect, 0.0, Color32::BLACK); - } - } - - if let Some(backbuffer) = &self.backbuffer - && let Some(id) = backbuffer.id() - { - ui.painter().image( - id, - rect, - Rect { - min: egui::pos2(0.0, 0.0), - max: egui::pos2(1.0, 1.0), - }, - Color32::WHITE, - ); - } - }); - - rect - } - fn draw_play_pause(&mut self, ui: &egui::Ui, rect: Rect) { // Only show play/pause if we have a multi-frame sequence that's fully loaded if self.frame_count > 1 { @@ -500,8 +344,8 @@ impl ScenePanel { impl ScenePanel { fn reset(&mut self) { self.last_draw = None; - self.last_state = None; self.has_splats = false; + self.splats_dirty = false; self.frame = 0.0; self.frame_count = 0; self.paused = false; @@ -838,30 +682,18 @@ impl AppPane for ScenePanel { let new_idx = idx.round() as usize; if new_idx != old_idx { - let old_mode = self.render_update_mode; self.render_update_mode = match new_idx { 0 => RenderUpdateMode::Off, 1 => RenderUpdateMode::Low, _ => RenderUpdateMode::Live, }; - // If enabling rendering from Off, force a redraw - if old_mode == RenderUpdateMode::Off { - self.last_state = None; - } } } } - fn init( - &mut self, - device: wgpu::Device, - queue: wgpu::Queue, - renderer: Arc>, - _burn_device: burn_wgpu::WgpuDevice, - _adapter_info: wgpu::AdapterInfo, - ) { - self.widget_3d = Some(Widget3D::new(device.clone(), queue.clone())); - self.backbuffer = Some(BurnTexture::new(renderer, device, queue)); + fn init(&mut self, state: &RenderState) { + self.grid = Some(GridWidget::new(state)); + self.backbuffer = Some(SplatBackbuffer::new(state)); // Create the settings popup now that we have the base_path #[cfg(feature = "training")] @@ -907,13 +739,14 @@ impl AppPane for ScenePanel { up_axis, frame, total_frames, + .. } => { self.has_splats = true; self.frame_count = *total_frames; // For non-training updates (e.g., loading), always redraw if !process.is_training() { - self.last_state = None; + self.splats_dirty = true; // When training, datasets handle this. if let Some(up_axis) = up_axis { @@ -924,16 +757,18 @@ impl AppPane for ScenePanel { if *total_frames <= 1 || *frame < *total_frames - 1 { self.frame = *frame as f32; } - } - } - #[cfg(feature = "training")] - ProcessMessage::TrainMessage(TrainMessage::TrainStep { iter, .. }) => { - // Check if we should redraw based on render update mode - if let Some(interval) = self.render_update_mode.update_interval() { - // Check if enough iterations have passed since last render - if *iter >= self.last_rendered_iter + interval || self.last_rendered_iter == 0 { - self.last_rendered_iter = *iter; - self.last_state = None; + } else { + // Check if we should redraw based on render update mode + if let Some(interval) = self.render_update_mode.update_interval() { + let iter = process.train_iter(); + + // Check if enough iterations have passed since last render + if iter >= self.last_rendered_iter + interval + || self.last_rendered_iter == 0 + { + self.last_rendered_iter = iter; + self.splats_dirty = true; + } } } } @@ -1047,15 +882,74 @@ impl AppPane for ScenePanel { ui.ctx().request_repaint(); } - // Get the splat for the current frame - let splats = process.current_splats().and_then(|sv| { - let frame_idx = self.frame as usize; - sv.get(frame_idx) - }); - let interactive = matches!(process.ui_mode(), UiMode::Default | UiMode::FullScreenSplat); - let rect = self.draw_splats(ui, process, splats, interactive); + + let size = ui.available_size(); + let size = glam::uvec2(size.x.round() as u32, size.y.round() as u32); + let (rect, response) = ui.allocate_exact_size( + egui::Vec2::new(size.x as f32, size.y as f32), + egui::Sense::drag(), + ); + if interactive { + process.tick_controls(&response, ui); + } + + // Get camera after modifying the controls. + let mut camera = process.current_camera(); + + let view_eff = (camera.world_to_local() * process.model_local_to_world()).inverse(); + let (_, rotation, position) = view_eff.to_scale_rotation_translation(); + camera.position = position; + camera.rotation = rotation; + + let settings = process.get_cam_settings(); + + // Adjust FOV so that the scene view shows at least what's visible in the dataset view. + let camera_aspect = (camera.fov_x / 2.0).tan() / (camera.fov_y / 2.0).tan(); + let viewport_aspect = size.x as f64 / size.y as f64; + + if viewport_aspect > camera_aspect { + let focal_y = fov_to_focal(camera.fov_y, size.y); + camera.fov_x = focal_to_fov(focal_y, size.x); + } else { + let focal_x = fov_to_focal(camera.fov_x, size.x); + camera.fov_y = focal_to_fov(focal_x, size.y); + } + + // Render the splats and grid + ui.scope(|ui| { + // if training views have alpha, show a background checker. Masked images + // should still use a black background. + match process.background_style() { + BackgroundStyle::Checkerboard => { + draw_checkerboard(ui, rect, Color32::WHITE); + } + BackgroundStyle::Black => { + ui.painter().rect_filled(rect, 0.0, Color32::BLACK); + } + } + + if let Some(backbuffer) = &mut self.backbuffer { + backbuffer.paint( + rect, + ui, + &process.current_splats(), + &camera, + self.frame as usize, + settings.background.unwrap_or(Vec3::ZERO), + settings.splat_scale, + self.splats_dirty, + ); + self.splats_dirty = false; + } + + if let Some(grid) = &mut self.grid { + let model_ltw = process.model_local_to_world(); + let grid_opacity = process.get_grid_opacity(); + grid.paint(rect, camera, model_ltw, grid_opacity, ui); + } + }); if interactive { self.draw_play_pause(ui, rect); diff --git a/crates/brush-ui/src/shaders/splat_backbuffer.wgsl b/crates/brush-ui/src/shaders/splat_backbuffer.wgsl new file mode 100644 index 00000000..c5fea9b2 --- /dev/null +++ b/crates/brush-ui/src/shaders/splat_backbuffer.wgsl @@ -0,0 +1,43 @@ +struct Uniforms { + img_width: u32, + img_height: u32, +} + +@group(0) @binding(0) var uniforms: Uniforms; +@group(0) @binding(1) var image_data: array; + +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) uv: vec2, +} + +@vertex +fn vs_main(@builtin(vertex_index) vertex_index: u32) -> VertexOutput { + // Fullscreen triangle using oversized triangle technique + var out: VertexOutput; + let x = f32((vertex_index << 1u) & 2u); // 0, 2, 0 for indices 0, 1, 2 + let y = f32(vertex_index & 2u); // 0, 0, 2 for indices 0, 1, 2 + out.position = vec4(x * 2.0 - 1.0, y * 2.0 - 1.0, 0.0, 1.0); + out.uv = vec2(x, 1.0 - y); + return out; +} + +@fragment +fn fs_main(in: VertexOutput) -> @location(0) vec4 { + let pixel_x = u32(in.uv.x * f32(uniforms.img_width)); + let pixel_y = u32(in.uv.y * f32(uniforms.img_height)); + + if (pixel_x >= uniforms.img_width || pixel_y >= uniforms.img_height) { + return vec4(0.0, 0.0, 0.0, 0.0); + } + + let idx = pixel_y * uniforms.img_width + pixel_x; + let packed = image_data[idx]; + + // Unpack RGBA8: R|(G<<8)|(B<<16)|(A<<24) + let r = f32(packed & 0xFFu) / 255.0; + let g = f32((packed >> 8u) & 0xFFu) / 255.0; + let b = f32((packed >> 16u) & 0xFFu) / 255.0; + let a = f32((packed >> 24u) & 0xFFu) / 255.0; + return vec4(r, g, b, a); +} diff --git a/crates/brush-ui/src/splat_backbuffer.rs b/crates/brush-ui/src/splat_backbuffer.rs new file mode 100644 index 00000000..580f81fd --- /dev/null +++ b/crates/brush-ui/src/splat_backbuffer.rs @@ -0,0 +1,342 @@ +use brush_process::slot::Slot; +use brush_render::{ + MainBackend, MainBackendBase, TextureMode, camera::Camera, gaussian_splats::Splats, + render_splats, +}; +use burn::tensor::Tensor; +use egui::Rect; +use glam::{UVec2, Vec3}; +use tokio::sync::mpsc; +use tokio_with_wasm::alias::task; + +use eframe::egui_wgpu::{self, CallbackTrait, wgpu}; + +#[derive(Clone)] +struct RenderRequest { + slot: Slot>, + ctx: egui::Context, + state: LastRenderState, +} + +#[derive(Clone, PartialEq)] +struct LastRenderState { + frame: usize, + camera: Camera, + background: Vec3, + splat_scale: Option, + img_size: UVec2, +} + +pub struct SplatBackbuffer { + req_send: mpsc::UnboundedSender, + img_rec: mpsc::Receiver>, + last_image: Option>, + last_state: Option, +} + +impl SplatBackbuffer { + pub fn new(state: &eframe::egui_wgpu::RenderState) -> Self { + // Create channel for render requests + let (req_send, req_rec) = mpsc::unbounded_channel(); + let (img_send, img_rec) = mpsc::channel(1); + + // Register splat backbuffer resources + state + .renderer + .write() + .callback_resources + .insert(SplatBackbufferResources::new( + &state.device, + state.target_format, + )); + + task::spawn(render_worker(req_rec, img_send)); + Self { + req_send, + img_rec, + last_image: None, + last_state: None, + } + } + + #[allow(clippy::too_many_arguments)] + pub fn paint( + &mut self, + rect: Rect, + ui: &egui::Ui, + slot: &Slot>, + camera: &Camera, + frame: usize, + background: Vec3, + splat_scale: Option, + splats_dirty: bool, + ) { + // Calculate pixel size for rendering + let ppp = ui.ctx().pixels_per_point(); + let img_size = UVec2::new( + (rect.width() * ppp).round() as u32, + (rect.height() * ppp).round() as u32, + ); + + // Check if we need to re-render + let current_state = LastRenderState { + frame, + camera: camera.clone(), + background, + splat_scale, + img_size, + }; + + let dirty = splats_dirty || self.last_state.as_ref() != Some(¤t_state); + + if dirty { + self.last_state = Some(current_state.clone()); + // Send request to worker (ignore send errors if channel closed) + let _ = self.req_send.send(RenderRequest { + slot: slot.clone(), + ctx: ui.ctx().clone(), + state: current_state, + }); + } + + while let Ok(img) = self.img_rec.try_recv() { + self.last_image = Some(img); + } + + if let Some(image) = &self.last_image { + let shape = image.shape(); + let img_height = shape.dims[0] as u32; + let img_width = shape.dims[1] as u32; + + ui.painter() + .add(eframe::egui_wgpu::Callback::new_paint_callback( + rect, + SplatBackbufferPainter { + last_img: image.clone(), + img_width, + img_height, + }, + )); + } + } +} + +#[repr(C)] +#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] +struct Uniforms { + img_width: u32, + img_height: u32, +} + +pub struct SplatBackbufferResources { + pipeline: wgpu::RenderPipeline, + uniform_buffer: wgpu::Buffer, + bind_group_layout: wgpu::BindGroupLayout, + // Per-frame bind group - created in prepare() with the current tensor buffer + bind_group: Option, +} + +impl SplatBackbufferResources { + pub fn new(device: &wgpu::Device, target_format: wgpu::TextureFormat) -> Self { + let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Splat Backbuffer Shader"), + source: wgpu::ShaderSource::Wgsl(include_str!("shaders/splat_backbuffer.wgsl").into()), + }); + let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Splat Backbuffer Uniform Buffer"), + size: std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("Splat Backbuffer Bind Group Layout"), + entries: &[ + // Uniform buffer for image dimensions + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::VERTEX | wgpu::ShaderStages::FRAGMENT, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Uniform, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // Storage buffer for image data (read-only) + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::FRAGMENT, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + }); + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Splat Backbuffer Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { + label: Some("Splat Backbuffer Pipeline"), + layout: Some(&pipeline_layout), + vertex: wgpu::VertexState { + module: &shader, + entry_point: Some("vs_main"), + buffers: &[], // No vertex buffers - using fullscreen triangle trick + compilation_options: wgpu::PipelineCompilationOptions::default(), + }, + fragment: Some(wgpu::FragmentState { + module: &shader, + entry_point: Some("fs_main"), + targets: &[Some(wgpu::ColorTargetState { + format: target_format, + blend: Some(wgpu::BlendState::ALPHA_BLENDING), + write_mask: wgpu::ColorWrites::ALL, + })], + compilation_options: wgpu::PipelineCompilationOptions::default(), + }), + primitive: wgpu::PrimitiveState { + topology: wgpu::PrimitiveTopology::TriangleList, + strip_index_format: None, + front_face: wgpu::FrontFace::Ccw, + cull_mode: None, + unclipped_depth: false, + polygon_mode: wgpu::PolygonMode::Fill, + conservative: false, + }, + depth_stencil: None, + multisample: wgpu::MultisampleState::default(), + multiview: None, + cache: None, + }); + + Self { + pipeline, + uniform_buffer, + bind_group_layout, + bind_group: None, + } + } +} + +struct SplatBackbufferPainter { + last_img: Tensor, + img_width: u32, + img_height: u32, +} + +impl CallbackTrait for SplatBackbufferPainter { + fn prepare( + &self, + device: &wgpu::Device, + queue: &wgpu::Queue, + _screen_descriptor: &egui_wgpu::ScreenDescriptor, + _egui_encoder: &mut wgpu::CommandEncoder, + resources: &mut egui_wgpu::CallbackResources, + ) -> Vec { + let Some(res) = resources.get_mut::() else { + return Vec::new(); + }; + + // Update uniform buffer with image dimensions + queue.write_buffer( + &res.uniform_buffer, + 0, + bytemuck::cast_slice(&[Uniforms { + img_width: self.img_width, + img_height: self.img_height, + }]), + ); + + // Extract the wgpu buffer from the Burn tensor + let last_img = self.last_img.clone().into_primitive().tensor(); + let prim_tensor = last_img + .client + .clone() + .resolve_tensor_int::(last_img); + let img_res_handle = prim_tensor + .client + .get_resource(prim_tensor.handle.binding()); + + // Create a new bind group with the current tensor buffer + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("Splat Backbuffer Bind Group"), + layout: &res.bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: res.uniform_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: img_res_handle.resource().buffer.as_entire_binding(), + }, + ], + }); + + res.bind_group = Some(bind_group); + Vec::new() + } + + fn paint( + &self, + _info: egui::PaintCallbackInfo, + render_pass: &mut wgpu::RenderPass<'static>, + callback_resources: &egui_wgpu::CallbackResources, + ) { + let Some(res) = callback_resources.get::() else { + return; + }; + + let Some(bind_group) = res.bind_group.as_ref() else { + return; + }; + + render_pass.set_pipeline(&res.pipeline); + render_pass.set_bind_group(0, bind_group, &[]); + render_pass.draw(0..3, 0..1); + } +} + +/// Async render worker that processes render requests. +async fn render_worker( + mut receiver: mpsc::UnboundedReceiver, + img_sender: mpsc::Sender>, +) { + loop { + // Wait for at least one request and get latest. + let Some(mut request) = receiver.recv().await else { + break; + }; + while let Ok(newer) = receiver.try_recv() { + request = newer; + } + + let image = request + .slot + .act(request.state.frame, async |splats| { + let (image, _) = render_splats( + splats.clone(), + &request.state.camera, + request.state.img_size, + request.state.background, + request.state.splat_scale, + TextureMode::Packed, + ) + .await; + (splats, image) + }) + .await; + + if let Some(image) = image { + let _ = img_sender.send(image).await; + } + + // Trigger egui repaint so the new texture gets picked up. + request.ctx.request_repaint(); + } +} diff --git a/crates/brush-ui/src/stats.rs b/crates/brush-ui/src/stats.rs index 2f389f51..5b9eb3b3 100644 --- a/crates/brush-ui/src/stats.rs +++ b/crates/brush-ui/src/stats.rs @@ -1,24 +1,22 @@ -use std::sync::Arc; - use crate::{UiMode, panels::AppPane, ui_process::UiProcess}; use brush_process::message::ProcessMessage; use brush_process::message::TrainMessage; use burn_cubecl::cubecl::Runtime; -use burn_wgpu::{WgpuDevice, WgpuRuntime}; -use eframe::egui_wgpu::Renderer; -use egui::mutex::RwLock; +use burn_wgpu::WgpuRuntime; +use eframe::egui_wgpu::RenderState; use web_time::Duration; use wgpu::AdapterInfo; #[derive(Default)] pub struct StatsPanel { - device: Option, last_eval: Option, frames: u32, adapter_info: Option, last_train_step: (Duration, u32), train_eval_views: (u32, u32), training_complete: bool, + num_splats: u32, + sh_degree: u32, } fn bytes_format(bytes: u64) -> String { @@ -79,16 +77,8 @@ impl AppPane for StatsPanel { "Stats".into() } - fn init( - &mut self, - _device: wgpu::Device, - _queue: wgpu::Queue, - _renderer: Arc>, - burn_device: burn_wgpu::WgpuDevice, - adapter_info: wgpu::AdapterInfo, - ) { - self.device = Some(burn_device); - self.adapter_info = Some(adapter_info); + fn init(&mut self, state: &RenderState) { + self.adapter_info = Some(state.adapter.get_info()); } fn is_visible(&self, process: &UiProcess) -> bool { @@ -103,11 +93,20 @@ impl AppPane for StatsPanel { self.last_train_step = (Duration::from_secs(0), 0); self.train_eval_views = (0, 0); self.training_complete = false; + self.num_splats = 0; + self.sh_degree = 0; } ProcessMessage::StartLoading { .. } => { self.last_eval = None; } - ProcessMessage::SplatsUpdated { .. } => {} + ProcessMessage::SplatsUpdated { + num_splats, + sh_degree, + .. + } => { + self.num_splats = *num_splats; + self.sh_degree = *sh_degree; + } ProcessMessage::TrainMessage(train) => match train { TrainMessage::TrainStep { iter, @@ -151,11 +150,8 @@ impl AppPane for StatsPanel { }); ui.separator(); - let (num_splats, sh_degree) = process - .current_splats() - .and_then(|sv| sv.get_main()) - .map_or((0, 0), |spl| (spl.num_splats(), spl.sh_degree())); - + let num_splats = self.num_splats; + let sh_degree = self.sh_degree; let frames = self.frames; stats_grid(ui, "model_stats_grid", |ui, v| { stat_row(ui, "Splats", format!("{num_splats}"), v); @@ -189,40 +185,39 @@ impl AppPane for StatsPanel { }); } - if let Some(device) = &self.device { - ui.add_space(10.0); - ui.heading("GPU"); - ui.separator(); + let device = process.burn_device(); + let client = WgpuRuntime::client(&device); + let memory = client.memory_usage(); - let client = WgpuRuntime::client(device); - let memory = client.memory_usage(); + ui.add_space(10.0); + ui.heading("GPU"); + ui.separator(); - stats_grid(ui, "memory_stats_grid", |ui, v| { - stat_row(ui, "Bytes in use", bytes_format(memory.bytes_in_use), v); - stat_row(ui, "Bytes reserved", bytes_format(memory.bytes_reserved), v); + stats_grid(ui, "memory_stats_grid", |ui, v| { + stat_row(ui, "Bytes in use", bytes_format(memory.bytes_in_use), v); + stat_row(ui, "Bytes reserved", bytes_format(memory.bytes_reserved), v); + stat_row( + ui, + "Active allocations", + format!("{}", memory.number_allocs), + v, + ); + }); + + // On WASM, adapter info is mostly private, not worth showing. + if !cfg!(target_family = "wasm") + && let Some(adapter_info) = &self.adapter_info + { + stats_grid(ui, "gpu_info_grid", |ui, v| { + stat_row(ui, "Name", &adapter_info.name, v); + stat_row(ui, "Type", format!("{:?}", adapter_info.device_type), v); stat_row( ui, - "Active allocations", - format!("{}", memory.number_allocs), + "Driver", + format!("{}, {}", adapter_info.driver, adapter_info.driver_info), v, ); }); - - // On WASM, adapter info is mostly private, not worth showing. - if !cfg!(target_family = "wasm") - && let Some(adapter_info) = &self.adapter_info - { - stats_grid(ui, "gpu_info_grid", |ui, v| { - stat_row(ui, "Name", &adapter_info.name, v); - stat_row(ui, "Type", format!("{:?}", adapter_info.device_type), v); - stat_row( - ui, - "Driver", - format!("{}, {}", adapter_info.driver, adapter_info.driver_info), - v, - ); - }); - } } }); } diff --git a/crates/brush-ui/src/training_panel.rs b/crates/brush-ui/src/training_panel.rs index 21ac1a1c..7ead7817 100644 --- a/crates/brush-ui/src/training_panel.rs +++ b/crates/brush-ui/src/training_panel.rs @@ -244,9 +244,7 @@ impl AppPane for TrainingPanel { } } - if let Some(slot) = process.current_splats() - && process.is_training() - { + if process.is_training() { // Right-align export button ui.with_layout(egui::Layout::right_to_left(egui::Align::Center), |ui| { // Make export button more prominent when training is complete @@ -278,9 +276,10 @@ impl AppPane for TrainingPanel { } let sender = self.export_channel.0.clone(); let ctx = ui.ctx().clone(); + let slot = process.current_splats(); task::spawn(async move { - let Some(splats) = slot.get_main() else { + let Some(splats) = slot.clone_main().await else { return; }; diff --git a/crates/brush-ui/src/ui_process.rs b/crates/brush-ui/src/ui_process.rs index e5f9cf8e..5cf49e6b 100644 --- a/crates/brush-ui/src/ui_process.rs +++ b/crates/brush-ui/src/ui_process.rs @@ -65,11 +65,11 @@ impl UiProcess { self.write().background_style = style; } - pub(crate) fn current_splats(&self) -> Option>> { + pub(crate) fn current_splats(&self) -> Slot> { self.read() .process_handle .as_ref() - .map(|s| s.splat_view.clone()) + .map_or(Slot::default(), |s| s.splat_view.clone()) } pub fn is_loading(&self) -> bool { @@ -108,6 +108,10 @@ impl UiProcess { self.read().train_paused } + pub(crate) fn train_iter(&self) -> u32 { + self.read().train_iter + } + pub fn get_cam_settings(&self) -> CameraSettings { self.read().controls.settings.clone() } @@ -235,10 +239,17 @@ impl UiProcess { Ok(ProcessMessage::StartLoading { training, .. }) => { inner.is_training = *training; inner.is_loading = true; + inner.train_iter = 0; } Ok(ProcessMessage::DoneLoading) => { inner.is_loading = false; } + #[cfg(feature = "training")] + Ok(ProcessMessage::TrainMessage( + brush_process::message::TrainMessage::TrainStep { iter, .. }, + )) => { + inner.train_iter = *iter; + } Err(_) => { inner.is_loading = false; inner.is_training = false; @@ -281,6 +292,10 @@ impl UiProcess { inner.session_reset_requested = false; requested } + + pub fn burn_device(&self) -> WgpuDevice { + self.read().burn_device.clone() + } } struct UiProcessInner { @@ -293,6 +308,7 @@ struct UiProcessInner { ui_mode: UiMode, background_style: BackgroundStyle, train_paused: bool, + train_iter: u32, reset_layout_requested: bool, session_reset_requested: bool, ui_ctx: egui::Context, @@ -313,6 +329,7 @@ impl UiProcessInner { splat_scale: None, is_loading: false, is_training: false, + train_iter: 0, process_handle: None, ui_mode: UiMode::Default, background_style: BackgroundStyle::Black, diff --git a/crates/brush-ui/src/widget_3d.rs b/crates/brush-ui/src/widget_3d.rs index 2e6e3d76..7dfb22af 100644 --- a/crates/brush-ui/src/widget_3d.rs +++ b/crates/brush-ui/src/widget_3d.rs @@ -1,3 +1,6 @@ +use brush_render::camera::Camera; +use eframe::egui_wgpu::{self, RenderState, wgpu}; +use egui::Rect; use glam::{Mat4, Vec3}; use wgpu::util::DeviceExt; @@ -13,7 +16,7 @@ struct Vertex { struct Uniforms { view_proj: [[f32; 4]; 4], grid_opacity: f32, - _padding: [f32; 3], // Padding for alignment + _padding: [f32; 3], } impl Vertex { @@ -29,9 +32,42 @@ impl Vertex { } } -pub struct Widget3D { - device: wgpu::Device, - queue: wgpu::Queue, +pub struct GridWidget {} + +impl GridWidget { + pub fn new(state: &RenderState) -> Self { + state + .renderer + .write() + .callback_resources + .insert(GridWidgetResources::new(&state.device, state.target_format)); + Self {} + } + + #[expect(clippy::unused_self)] + pub fn paint( + &self, // Not used atm,but, in the future the widget might have some state. + rect: Rect, + camera: Camera, + model_transform: glam::Affine3A, + grid_opacity: f32, + ui: &egui::Ui, + ) { + if grid_opacity > 0.0 { + ui.painter() + .add(eframe::egui_wgpu::Callback::new_paint_callback( + rect, + GridWidgetPainter { + camera, + model_transform, + grid_opacity, + }, + )); + } + } +} + +struct GridWidgetResources { pipeline: wgpu::RenderPipeline, uniform_buffer: wgpu::Buffer, uniform_bind_group: wgpu::BindGroup, @@ -41,14 +77,13 @@ pub struct Widget3D { up_axis_vertex_count: u32, } -impl Widget3D { - pub fn new(device: wgpu::Device, queue: wgpu::Queue) -> Self { +impl GridWidgetResources { + pub fn new(device: &wgpu::Device, target_format: wgpu::TextureFormat) -> Self { let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { label: Some("Widget 3D Shader"), source: wgpu::ShaderSource::Wgsl(include_str!("shaders/widget_3d.wgsl").into()), }); - // Create uniform buffer let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: Some("Widget 3D Uniform Buffer"), size: std::mem::size_of::() as u64, @@ -56,12 +91,11 @@ impl Widget3D { mapped_at_creation: false, }); - // Create bind group layout and bind group let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { label: Some("Widget 3D Bind Group Layout"), entries: &[wgpu::BindGroupLayoutEntry { binding: 0, - visibility: wgpu::ShaderStages::VERTEX | wgpu::ShaderStages::FRAGMENT, // Fragment needs access for grid_opacity + visibility: wgpu::ShaderStages::VERTEX | wgpu::ShaderStages::FRAGMENT, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, @@ -80,13 +114,13 @@ impl Widget3D { }], }); - // Create render pipeline let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: Some("Widget 3D Pipeline Layout"), bind_group_layouts: &[&bind_group_layout], push_constant_ranges: &[], }); + // Pipeline without depth stencil - draws on top of egui content let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { label: Some("Widget 3D Pipeline"), layout: Some(&pipeline_layout), @@ -100,7 +134,7 @@ impl Widget3D { module: &shader, entry_point: Some("fs_main"), targets: &[Some(wgpu::ColorTargetState { - format: wgpu::TextureFormat::Rgba8Unorm, + format: target_format, blend: Some(wgpu::BlendState::ALPHA_BLENDING), write_mask: wgpu::ColorWrites::ALL, })], @@ -115,19 +149,12 @@ impl Widget3D { polygon_mode: wgpu::PolygonMode::Fill, conservative: false, }, - depth_stencil: Some(wgpu::DepthStencilState { - format: wgpu::TextureFormat::Depth32Float, - depth_write_enabled: true, - depth_compare: wgpu::CompareFunction::Less, - stencil: wgpu::StencilState::default(), - bias: wgpu::DepthBiasState::default(), - }), + depth_stencil: None, // No depth buffer - draw on top multisample: wgpu::MultisampleState::default(), multiview: None, cache: None, }); - // Create geometry let (grid_vertices, grid_vertex_count) = Self::create_grid_geometry(); let grid_vertex_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { label: Some("Grid Vertex Buffer"), @@ -143,8 +170,6 @@ impl Widget3D { }); Self { - device, - queue, pipeline, uniform_buffer, uniform_bind_group, @@ -159,13 +184,10 @@ impl Widget3D { let mut vertices = Vec::new(); let size = 10.0; let step = 1.0; - let color = [0.3, 0.3, 0.3, 0.8]; // Semi-transparent gray + let color = [0.3, 0.3, 0.3, 0.8]; - // Create grid lines in XZ plane (Y=0) for OpenCV coordinates - // This creates a ground plane since Y is down in OpenCV let mut i = -size; while i <= size { - // Lines parallel to X axis vertices.push(Vertex { position: [-size, 0.0, i], color, @@ -174,8 +196,6 @@ impl Widget3D { position: [size, 0.0, i], color, }); - - // Lines parallel to Z axis vertices.push(Vertex { position: [i, 0.0, -size], color, @@ -184,7 +204,6 @@ impl Widget3D { position: [i, 0.0, size], color, }); - i += step; } @@ -192,117 +211,75 @@ impl Widget3D { } fn create_up_axis_geometry() -> (Vec, u32) { - let mut vertices = Vec::new(); - let length = 1.5; - - // Single blue line pointing up (negative Y in OpenCV coordinates) - vertices.push(Vertex { - position: [0.0, 0.0, 0.0], - color: [0.0, 0.5, 1.0, 1.0], // Light blue - }); - vertices.push(Vertex { - position: [0.0, -length, 0.0], // Negative Y is up - color: [0.0, 0.5, 1.0, 1.0], // Light blue - }); - + let vertices = vec![ + Vertex { + position: [0.0, 0.0, 0.0], + color: [0.0, 0.5, 1.0, 1.0], + }, + Vertex { + position: [0.0, -1.5, 0.0], + color: [0.0, 0.5, 1.0, 1.0], + }, + ]; (vertices, 2) } +} - pub fn render_to_texture( - &self, - camera: &brush_render::camera::Camera, - model_transform: glam::Affine3A, - size: glam::UVec2, - target_texture: &wgpu::Texture, - grid_opacity: f32, - ) { - let output_view = target_texture.create_view(&wgpu::TextureViewDescriptor::default()); - - // Create depth texture - let depth_texture = self.device.create_texture(&wgpu::TextureDescriptor { - label: Some("Widget 3D Depth Texture"), - size: wgpu::Extent3d { - width: size.x, - height: size.y, - depth_or_array_layers: 1, - }, - mip_level_count: 1, - sample_count: 1, - dimension: wgpu::TextureDimension::D2, - format: wgpu::TextureFormat::Depth32Float, - usage: wgpu::TextureUsages::RENDER_ATTACHMENT, - view_formats: &[], - }); - let depth_view = depth_texture.create_view(&wgpu::TextureViewDescriptor::default()); +/// Callback for rendering the 3D widget overlay via egui's paint system. +struct GridWidgetPainter { + pub camera: Camera, + pub model_transform: glam::Affine3A, + pub grid_opacity: f32, +} - // Use perspective_lh since camera uses +Z as forward - // But flip Y since camera uses Y-down while perspective_lh uses Y-up - let aspect = size.x as f32 / size.y as f32; - let proj_matrix = Mat4::perspective_lh(camera.fov_y as f32, aspect, 0.1, 1000.0); +impl egui_wgpu::CallbackTrait for GridWidgetPainter { + fn prepare( + &self, + _device: &wgpu::Device, + queue: &wgpu::Queue, + screen_descriptor: &egui_wgpu::ScreenDescriptor, + _egui_encoder: &mut wgpu::CommandEncoder, + resources: &mut egui_wgpu::CallbackResources, + ) -> Vec { + let Some(resources) = resources.get::() else { + return Vec::new(); + }; - // Y-flip to convert from Y-up to Y-down + let aspect = + screen_descriptor.size_in_pixels[0] as f32 / screen_descriptor.size_in_pixels[1] as f32; + let proj_matrix = Mat4::perspective_lh(self.camera.fov_y as f32, aspect, 0.1, 1000.0); let y_flip = Mat4::from_scale(Vec3::new(1.0, -1.0, 1.0)); - - // The camera already has model transform baked in - // To get world-space view, we need to undo the model transform by applying its inverse - let view_matrix = camera.world_to_local(); - let world_view = Mat4::from(view_matrix) * Mat4::from(model_transform.inverse()); - - // Apply Y flip and combine with projection + let view_matrix = self.camera.world_to_local(); + let world_view = Mat4::from(view_matrix) * Mat4::from(self.model_transform.inverse()); let view_proj = proj_matrix * y_flip * world_view; let uniforms = Uniforms { view_proj: view_proj.to_cols_array_2d(), - grid_opacity, + grid_opacity: self.grid_opacity, _padding: [0.0; 3], }; + queue.write_buffer( + &resources.uniform_buffer, + 0, + bytemuck::cast_slice(&[uniforms]), + ); + Vec::new() + } - self.queue - .write_buffer(&self.uniform_buffer, 0, bytemuck::cast_slice(&[uniforms])); - - // Render - let mut encoder = self - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("Widget 3D Render Encoder"), - }); - - { - let mut render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { - label: Some("Widget 3D Render Pass"), - color_attachments: &[Some(wgpu::RenderPassColorAttachment { - view: &output_view, - resolve_target: None, - ops: wgpu::Operations { - load: wgpu::LoadOp::Load, // Load existing content instead of clearing - store: wgpu::StoreOp::Store, - }, - depth_slice: None, - })], - depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment { - view: &depth_view, - depth_ops: Some(wgpu::Operations { - load: wgpu::LoadOp::Clear(1.0), - store: wgpu::StoreOp::Store, - }), - stencil_ops: None, - }), - timestamp_writes: None, - occlusion_query_set: None, - }); - - render_pass.set_pipeline(&self.pipeline); - render_pass.set_bind_group(0, &self.uniform_bind_group, &[]); - - // Draw grid - render_pass.set_vertex_buffer(0, self.grid_vertex_buffer.slice(..)); - render_pass.draw(0..self.grid_vertex_count, 0..1); - - // Draw up axis - render_pass.set_vertex_buffer(0, self.up_axis_vertex_buffer.slice(..)); - render_pass.draw(0..self.up_axis_vertex_count, 0..1); - } - - self.queue.submit(std::iter::once(encoder.finish())); + fn paint( + &self, + _info: egui::PaintCallbackInfo, + render_pass: &mut wgpu::RenderPass<'static>, + resources: &egui_wgpu::CallbackResources, + ) { + let Some(resources) = resources.get::() else { + return; + }; + render_pass.set_pipeline(&resources.pipeline); + render_pass.set_bind_group(0, &resources.uniform_bind_group, &[]); + render_pass.set_vertex_buffer(0, resources.grid_vertex_buffer.slice(..)); + render_pass.draw(0..resources.grid_vertex_count, 0..1); + render_pass.set_vertex_buffer(0, resources.up_axis_vertex_buffer.slice(..)); + render_pass.draw(0..resources.up_axis_vertex_count, 0..1); } } diff --git a/examples/train-2d/Cargo.toml b/examples/train-2d/Cargo.toml index 5aadae07..c3f1e8c9 100644 --- a/examples/train-2d/Cargo.toml +++ b/examples/train-2d/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [dev-dependencies] burn.workspace = true +burn-cubecl.workspace = true brush-train.path = "../../crates/brush-train" brush-dataset.path = "../../crates/brush-dataset" brush-render.path = "../../crates/brush-render" diff --git a/examples/train-2d/examples/train-2d.rs b/examples/train-2d/examples/train-2d.rs index ed0ae44f..b465d4f4 100644 --- a/examples/train-2d/examples/train-2d.rs +++ b/examples/train-2d/examples/train-2d.rs @@ -2,18 +2,18 @@ #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release use brush_dataset::scene::{SceneBatch, sample_to_tensor_data}; +use brush_process::slot::Slot; use brush_render::{ AlphaMode, MainBackend, bounding_box::BoundingBox, camera::{Camera, focal_to_fov, fov_to_focal}, gaussian_splats::{SplatRenderMode, Splats}, - render_splats, }; use brush_train::{ RandomSplatsConfig, config::TrainConfig, create_random_splats, splats_into_autodiff, train::SplatTrainer, }; -use brush_ui::burn_texture::BurnTexture; +use brush_ui::splat_backbuffer::SplatBackbuffer; use burn::{backend::wgpu::WgpuDevice, module::AutodiffModule, prelude::Backend}; use egui::{ImageSource, TextureHandle, TextureOptions, load::SizedTexture}; use glam::{Quat, Vec2, Vec3}; @@ -22,8 +22,8 @@ use rand::SeedableRng; use tokio::sync::mpsc::{Receiver, Sender}; struct TrainStep { - splats: Splats, iter: u32, + num_splats: u32, } fn spawn_train_loop( @@ -33,8 +33,8 @@ fn spawn_train_loop( device: WgpuDevice, ctx: egui::Context, sender: Sender, + slot: Slot>, ) { - // Spawn a task that iterates over the training stream. tokio::spawn(async move { let seed = 42; @@ -57,7 +57,6 @@ fn spawn_train_loop( BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE), ); - // One batch of training data, it's the same every step so can just construct it once. let batch = SceneBatch { img_tensor: sample_to_tensor_data(image), alpha_mode: AlphaMode::Transparent, @@ -67,20 +66,16 @@ fn spawn_train_loop( let mut iter = 0; loop { - let (new_splats, _) = trainer.step(batch.clone(), splats); + let (new_splats, _) = trainer.step(batch.clone(), splats).await; let (new_splats, _) = trainer.refine(iter, new_splats.valid()).await; + let num_splats = new_splats.num_splats(); + slot.set(new_splats.clone()).await; + splats = splats_into_autodiff(new_splats); iter += 1; ctx.request_repaint(); - if sender - .send(TrainStep { - splats: splats.valid(), - iter, - }) - .await - .is_err() - { + if sender.send(TrainStep { iter, num_splats }).await.is_err() { break; } } @@ -91,9 +86,11 @@ struct App { image: image::DynamicImage, camera: Camera, tex_handle: TextureHandle, - backbuffer: BurnTexture, + backbuffer: SplatBackbuffer, + slot: Slot>, receiver: Receiver, last_step: Option, + splats_dirty: bool, } impl App { @@ -102,6 +99,7 @@ impl App { .wgpu_render_state .as_ref() .expect("No wgpu renderer enabled in egui"); + let device = brush_process::burn_init_device( state.adapter.clone(), state.device.clone(), @@ -132,6 +130,7 @@ impl App { let handle = cc.egui_ctx .load_texture("nearest_view_tex", color_img, TextureOptions::default()); + let slot = Slot::default(); let config = TrainConfig::default(); spawn_train_loop( @@ -141,22 +140,18 @@ impl App { device, cc.egui_ctx.clone(), sender, + slot.clone(), ); - let renderer = cc - .wgpu_render_state - .as_ref() - .expect("No wgpu renderer enabled in egui") - .renderer - .clone(); - Self { image, camera, tex_handle: handle, - backbuffer: BurnTexture::new(renderer, state.device.clone(), state.queue.clone()), + backbuffer: SplatBackbuffer::new(state), + slot, receiver, last_step: None, + splats_dirty: false, } } } @@ -165,34 +160,39 @@ impl eframe::App for App { fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { while let Ok(step) = self.receiver.try_recv() { self.last_step = Some(step); + self.splats_dirty = true; } egui::CentralPanel::default().show(ctx, |ui| { - let Some(msg) = self.last_step.as_ref() else { + let Some(step) = self.last_step.as_ref() else { + ui.label("Waiting for first training step..."); return; }; - let (img, _) = render_splats( - &msg.splats, - &self.camera, - glam::uvec2(self.image.width(), self.image.height()), - Vec3::ZERO, // Just render with a black background - None, - ); - let size = egui::vec2(self.image.width() as f32, self.image.height() as f32); ui.horizontal(|ui| { - let texture_id = self.backbuffer.update_texture(img); - ui.image(ImageSource::Texture(SizedTexture::new(texture_id, size))); + let (rect, _response) = ui.allocate_exact_size(size, egui::Sense::hover()); + self.backbuffer.paint( + rect, + ui, + &self.slot, + &self.camera, + 0, + Vec3::ZERO, + None, + self.splats_dirty, + ); + self.splats_dirty = false; + ui.image(ImageSource::Texture(SizedTexture::new( self.tex_handle.id(), size, ))); }); - ui.label(format!("Splats: {}", msg.splats.num_splats())); - ui.label(format!("Step: {}", msg.iter)); + ui.label(format!("Splats: {}", step.num_splats)); + ui.label(format!("Step: {}", step.iter)); }); } } @@ -200,7 +200,6 @@ impl eframe::App for App { #[tokio::main] async fn main() { let native_options = eframe::NativeOptions { - // Build app display. viewport: egui::ViewportBuilder::default() .with_inner_size(egui::Vec2::new(1100.0, 500.0)) .with_active(true),