Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ tracing = "0.1.41"
tracing-tracy = "0.11.3"
tracing-subscriber = "0.3.19"

winapi = "0.3"
winapi = { version = "0.3", features = ["wincon"] }

tokio = { version = "1.42.0", default-features = false }
tokio_with_wasm = "0.8.2"
Expand Down
131 changes: 80 additions & 51 deletions crates/brush-bench-test/src/benches.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use brush_dataset::scene::SceneBatch;
use brush_render::{
AlphaMode, MainBackend,
AlphaMode, MainBackend, TextureMode,
camera::Camera,
gaussian_splats::{SplatRenderMode, Splats},
render_splats,
Expand Down Expand Up @@ -144,8 +144,9 @@ fn generate_training_batch(resolution: (u32, u32), camera_pos: Vec3) -> SceneBat
mod forward_rendering {
use super::{
AutodiffModule, Backend, Camera, ITERS_PER_SYNC, MainBackend, Quat, RESOLUTIONS,
SPLAT_COUNTS, Vec3, WgpuDevice, gen_splats, render_splats,
SPLAT_COUNTS, TextureMode, Vec3, WgpuDevice, gen_splats, render_splats,
};
use burn_cubecl::cubecl::future::block_on;

#[divan::bench(args = SPLAT_COUNTS)]
fn render_1080p(bencher: divan::Bencher, splat_count: usize) {
Expand All @@ -160,10 +161,20 @@ mod forward_rendering {
);

bencher.bench_local(move || {
for _ in 0..ITERS_PER_SYNC {
let _ = render_splats(&splats, &camera, glam::uvec2(1920, 1080), Vec3::ZERO, None);
}
MainBackend::sync(&device).expect("Failed to sync");
block_on(async {
for _ in 0..ITERS_PER_SYNC {
let _ = render_splats(
splats.clone(),
&camera,
glam::uvec2(1920, 1080),
Vec3::ZERO,
None,
TextureMode::Float,
)
.await;
}
MainBackend::sync(&device).expect("Failed to sync");
});
});
}

Expand All @@ -180,16 +191,20 @@ mod forward_rendering {
);

bencher.bench_local(move || {
for _ in 0..ITERS_PER_SYNC {
let _ = render_splats(
&splats,
&camera,
glam::uvec2(width, height),
Vec3::ZERO,
None,
);
}
MainBackend::sync(&device).expect("Failed to sync");
block_on(async {
for _ in 0..ITERS_PER_SYNC {
let _ = render_splats(
splats.clone(),
&camera,
glam::uvec2(width, height),
Vec3::ZERO,
None,
TextureMode::Float,
)
.await;
}
MainBackend::sync(&device).expect("Failed to sync");
});
});
}
}
Expand All @@ -200,6 +215,7 @@ mod backward_rendering {
Backend, Camera, DiffBackend, ITERS_PER_SYNC, MainBackend, Quat, RESOLUTIONS, Tensor,
TensorPrimitive, Vec3, WgpuDevice, gen_splats, render_splats_diff,
};
use burn_cubecl::cubecl::future::block_on;

#[divan::bench(args = [1_000_000, 2_000_000, 5_000_000])]
fn render_grad_1080p(bencher: divan::Bencher, splat_count: usize) {
Expand All @@ -214,14 +230,21 @@ mod backward_rendering {
);

bencher.bench_local(move || {
for _ in 0..ITERS_PER_SYNC {
let diff_out =
render_splats_diff(&splats, &camera, glam::uvec2(1920, 1080), Vec3::ZERO);
let img: Tensor<DiffBackend, 3> =
Tensor::from_primitive(TensorPrimitive::Float(diff_out.img));
let _ = img.mean().backward();
}
MainBackend::sync(&device).expect("Failed to sync");
block_on(async {
for _ in 0..ITERS_PER_SYNC {
let diff_out = render_splats_diff(
splats.clone(),
&camera,
glam::uvec2(1920, 1080),
Vec3::ZERO,
)
.await;
let img: Tensor<DiffBackend, 3> =
Tensor::from_primitive(TensorPrimitive::Float(diff_out.img));
let _ = img.mean().backward();
}
MainBackend::sync(&device).expect("Failed to sync");
});
});
}

Expand All @@ -237,21 +260,29 @@ mod backward_rendering {
glam::vec2(0.5, 0.5),
);
bencher.bench_local(move || {
for _ in 0..ITERS_PER_SYNC {
let diff_out =
render_splats_diff(&splats, &camera, glam::uvec2(width, height), Vec3::ZERO);
let img: Tensor<DiffBackend, 3> =
Tensor::from_primitive(TensorPrimitive::Float(diff_out.img));
let _ = img.mean().backward();
}
MainBackend::sync(&device).expect("Failed to sync");
block_on(async {
for _ in 0..ITERS_PER_SYNC {
let diff_out = render_splats_diff(
splats.clone(),
&camera,
glam::uvec2(width, height),
Vec3::ZERO,
)
.await;
let img: Tensor<DiffBackend, 3> =
Tensor::from_primitive(TensorPrimitive::Float(diff_out.img));
let _ = img.mean().backward();
}
MainBackend::sync(&device).expect("Failed to sync");
});
});
}
}

#[divan::bench_group(max_time = 4)]
mod training {
use brush_render::bounding_box::BoundingBox;
use burn_cubecl::cubecl::future::block_on;

use crate::benches::ITERS_PER_SYNC;

Expand All @@ -262,25 +293,23 @@ mod training {

#[divan::bench(args = SPLAT_COUNTS)]
fn train_steps(splat_count: usize) {
burn_cubecl::cubecl::future::block_on(async {
let device = WgpuDevice::default();
let batch1 = generate_training_batch((1920, 1080), Vec3::new(0.0, 0.0, 5.0));
let batch2 = generate_training_batch((1920, 1080), Vec3::new(2.0, 0.0, 5.0));
let batches = [batch1, batch2];
let config = TrainConfig::default();
let mut splats = gen_splats(&device, splat_count);
let mut trainer = SplatTrainer::new(
&config,
&device,
BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE),
);
for step in 0..ITERS_PER_SYNC {
let batch = batches[step as usize % batches.len()].clone();
let (new_splats, _) = trainer.step(batch, splats);
splats = new_splats;
}
MainBackend::sync(&device).expect("Failed to sync");
});
let device = WgpuDevice::default();
let batch1 = generate_training_batch((1920, 1080), Vec3::new(0.0, 0.0, 5.0));
let batch2 = generate_training_batch((1920, 1080), Vec3::new(2.0, 0.0, 5.0));
let batches = [batch1, batch2];
let config = TrainConfig::default();
let mut splats = gen_splats(&device, splat_count);
let mut trainer = SplatTrainer::new(
&config,
&device,
BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE),
);
for step in 0..ITERS_PER_SYNC {
let batch = batches[step as usize % batches.len()].clone();
let (new_splats, _) = block_on(trainer.step(batch, splats));
splats = new_splats;
}
MainBackend::sync(&device).expect("Failed to sync");
}
}

Expand Down
33 changes: 7 additions & 26 deletions crates/brush-bench-test/src/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use burn::{
Tensor,
backend::{Autodiff, wgpu::WgpuDevice},
prelude::Backend,
tensor::{Float, Int, TensorPrimitive},
tensor::TensorPrimitive,
};

use anyhow::{Context, Result};
Expand Down Expand Up @@ -127,16 +127,15 @@ async fn test_reference() -> Result<()> {
);

let diff_out = brush_render_bwd::render_splats(
&splats,
splats.clone(),
&cam,
glam::uvec2(w as u32, h as u32),
Vec3::ZERO,
);
)
.await;

let (out, aux) = (
Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)),
diff_out.aux,
);
let out: Tensor<DiffBack, 3> = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img));
let render_aux = diff_out.render_aux;

if let Some(rec) = rec.as_ref() {
rec.set_time_sequence("test case", i as i64);
Expand All @@ -148,28 +147,10 @@ async fn test_reference() -> Result<()> {
)?;
rec.log(
"images/tile_depth",
&aux.calc_tile_depth().into_rerun().await,
&render_aux.calc_tile_depth().into_rerun().await,
)?;
}

let num_visible: Tensor<DiffBack, 1, Int> = aux.num_visible();
let num_visible = num_visible.into_scalar_async().await.unwrap() as usize;
let global_from_compact_gid: Tensor<DiffBack, 1, Int> =
Tensor::from_primitive(aux.global_from_compact_gid.clone());
let gs_ids = global_from_compact_gid.clone().slice([0..num_visible]);
let projected_splats =
Tensor::from_primitive(TensorPrimitive::Float(aux.projected_splats.clone()));
let xys: Tensor<DiffBack, 2, Float> =
projected_splats.clone().slice([0..num_visible, 0..2]);
let xys_ref = safetensor_to_burn::<DiffBack, 2>(&tensors.tensor("xys")?, &device);
let xys_ref = xys_ref.select(0, gs_ids.clone());
compare("xy", xys, xys_ref, 1e-5, 2e-5);
let conics: Tensor<DiffBack, 2, Float> =
projected_splats.clone().slice([0..num_visible, 2..5]);
let conics_ref = safetensor_to_burn::<DiffBack, 2>(&tensors.tensor("conics")?, &device);
let conics_ref = conics_ref.select(0, gs_ids.clone());
compare("conics", conics, conics_ref, 1e-6, 2e-5);

// Check if images match.
compare("img", out.clone(), img_ref, 1e-5, 1e-5);

Expand Down
19 changes: 10 additions & 9 deletions crates/brush-bench-test/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ fn test_forward_rendering() {
assert!(means_data.iter().all(|&x| x.is_finite()));
}

#[test]
fn test_training_step() {
#[tokio::test]
async fn test_training_step() {
let device = WgpuDevice::default();
let batch = generate_test_batch((64, 64));
let splats = generate_test_splats(&device, 500);
Expand All @@ -172,7 +172,7 @@ fn test_training_step() {
&device,
BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE),
);
let (final_splats, stats) = trainer.step(batch, splats);
let (final_splats, stats) = trainer.step(batch, splats).await;

assert!(final_splats.num_splats() > 0);
let loss = stats.loss.into_scalar();
Expand All @@ -190,8 +190,8 @@ fn test_batch_generation() {
assert!(img_data.iter().all(|&x| (0.0..=1.1).contains(&x)));
}

#[test]
fn test_multi_step_training() {
#[tokio::test]
async fn test_multi_step_training() {
let device = WgpuDevice::default();
let batch = generate_test_batch((64, 64));
let config = TrainConfig::default();
Expand All @@ -205,7 +205,7 @@ fn test_multi_step_training() {

// Run a few training steps
for _ in 0..3 {
let (new_splats, stats) = trainer.step(batch.clone(), splats);
let (new_splats, stats) = trainer.step(batch.clone(), splats).await;
splats = new_splats;

let loss = stats.loss.into_scalar();
Expand All @@ -216,8 +216,8 @@ fn test_multi_step_training() {
assert!(splats.num_splats() > 0);
}

#[test]
fn test_gradient_validation() {
#[tokio::test]
async fn test_gradient_validation() {
let device = WgpuDevice::default();
let splats = generate_test_splats(&device, 100);

Expand All @@ -231,7 +231,8 @@ fn test_gradient_validation() {
);
let img_size = glam::uvec2(64, 64);

let result = render_splats(&splats, &camera, img_size, Vec3::ZERO);
// Clone splats since render_splats takes ownership and we need splats for gradient validation
let result = render_splats(splats.clone(), &camera, img_size, Vec3::ZERO).await;

let rendered: Tensor<DiffBackend, 3> =
Tensor::from_primitive(TensorPrimitive::Float(result.img));
Expand Down
36 changes: 0 additions & 36 deletions crates/brush-kernel/build.rs

This file was deleted.

Loading