Skip to content

Commit 373d40a

Browse files
Use readback for intersections (#350)
* Readback for intersections * Split pass * Cleanup split pass * Pt. 2 * More cleanup * Cleanup * Fmt * Cleanup * Cleanup * Fmt * Cleanup * Async rendering * splat backbuffer async * Move widget to backbuffer * Fixes for wasm * fmt * Cleanup * Less locking * WIP * Start on native backbuffer. * WIP * WIP * Render splats straight to UI * Missing shader * Fix live update * Fix no default features * Fix * Misc clean * Simplify some code
1 parent f65348f commit 373d40a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2177
-2028
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ tracing = "0.1.41"
5454
tracing-tracy = "0.11.3"
5555
tracing-subscriber = "0.3.19"
5656

57-
winapi = "0.3"
57+
winapi = { version = "0.3", features = ["wincon"] }
5858

5959
tokio = { version = "1.42.0", default-features = false }
6060
tokio_with_wasm = "0.8.2"

crates/brush-bench-test/src/benches.rs

Lines changed: 80 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use brush_dataset::scene::SceneBatch;
22
use brush_render::{
3-
AlphaMode, MainBackend,
3+
AlphaMode, MainBackend, TextureMode,
44
camera::Camera,
55
gaussian_splats::{SplatRenderMode, Splats},
66
render_splats,
@@ -144,8 +144,9 @@ fn generate_training_batch(resolution: (u32, u32), camera_pos: Vec3) -> SceneBat
144144
mod forward_rendering {
145145
use super::{
146146
AutodiffModule, Backend, Camera, ITERS_PER_SYNC, MainBackend, Quat, RESOLUTIONS,
147-
SPLAT_COUNTS, Vec3, WgpuDevice, gen_splats, render_splats,
147+
SPLAT_COUNTS, TextureMode, Vec3, WgpuDevice, gen_splats, render_splats,
148148
};
149+
use burn_cubecl::cubecl::future::block_on;
149150

150151
#[divan::bench(args = SPLAT_COUNTS)]
151152
fn render_1080p(bencher: divan::Bencher, splat_count: usize) {
@@ -160,10 +161,20 @@ mod forward_rendering {
160161
);
161162

162163
bencher.bench_local(move || {
163-
for _ in 0..ITERS_PER_SYNC {
164-
let _ = render_splats(&splats, &camera, glam::uvec2(1920, 1080), Vec3::ZERO, None);
165-
}
166-
MainBackend::sync(&device).expect("Failed to sync");
164+
block_on(async {
165+
for _ in 0..ITERS_PER_SYNC {
166+
let _ = render_splats(
167+
splats.clone(),
168+
&camera,
169+
glam::uvec2(1920, 1080),
170+
Vec3::ZERO,
171+
None,
172+
TextureMode::Float,
173+
)
174+
.await;
175+
}
176+
MainBackend::sync(&device).expect("Failed to sync");
177+
});
167178
});
168179
}
169180

@@ -180,16 +191,20 @@ mod forward_rendering {
180191
);
181192

182193
bencher.bench_local(move || {
183-
for _ in 0..ITERS_PER_SYNC {
184-
let _ = render_splats(
185-
&splats,
186-
&camera,
187-
glam::uvec2(width, height),
188-
Vec3::ZERO,
189-
None,
190-
);
191-
}
192-
MainBackend::sync(&device).expect("Failed to sync");
194+
block_on(async {
195+
for _ in 0..ITERS_PER_SYNC {
196+
let _ = render_splats(
197+
splats.clone(),
198+
&camera,
199+
glam::uvec2(width, height),
200+
Vec3::ZERO,
201+
None,
202+
TextureMode::Float,
203+
)
204+
.await;
205+
}
206+
MainBackend::sync(&device).expect("Failed to sync");
207+
});
193208
});
194209
}
195210
}
@@ -200,6 +215,7 @@ mod backward_rendering {
200215
Backend, Camera, DiffBackend, ITERS_PER_SYNC, MainBackend, Quat, RESOLUTIONS, Tensor,
201216
TensorPrimitive, Vec3, WgpuDevice, gen_splats, render_splats_diff,
202217
};
218+
use burn_cubecl::cubecl::future::block_on;
203219

204220
#[divan::bench(args = [1_000_000, 2_000_000, 5_000_000])]
205221
fn render_grad_1080p(bencher: divan::Bencher, splat_count: usize) {
@@ -214,14 +230,21 @@ mod backward_rendering {
214230
);
215231

216232
bencher.bench_local(move || {
217-
for _ in 0..ITERS_PER_SYNC {
218-
let diff_out =
219-
render_splats_diff(&splats, &camera, glam::uvec2(1920, 1080), Vec3::ZERO);
220-
let img: Tensor<DiffBackend, 3> =
221-
Tensor::from_primitive(TensorPrimitive::Float(diff_out.img));
222-
let _ = img.mean().backward();
223-
}
224-
MainBackend::sync(&device).expect("Failed to sync");
233+
block_on(async {
234+
for _ in 0..ITERS_PER_SYNC {
235+
let diff_out = render_splats_diff(
236+
splats.clone(),
237+
&camera,
238+
glam::uvec2(1920, 1080),
239+
Vec3::ZERO,
240+
)
241+
.await;
242+
let img: Tensor<DiffBackend, 3> =
243+
Tensor::from_primitive(TensorPrimitive::Float(diff_out.img));
244+
let _ = img.mean().backward();
245+
}
246+
MainBackend::sync(&device).expect("Failed to sync");
247+
});
225248
});
226249
}
227250

@@ -237,21 +260,29 @@ mod backward_rendering {
237260
glam::vec2(0.5, 0.5),
238261
);
239262
bencher.bench_local(move || {
240-
for _ in 0..ITERS_PER_SYNC {
241-
let diff_out =
242-
render_splats_diff(&splats, &camera, glam::uvec2(width, height), Vec3::ZERO);
243-
let img: Tensor<DiffBackend, 3> =
244-
Tensor::from_primitive(TensorPrimitive::Float(diff_out.img));
245-
let _ = img.mean().backward();
246-
}
247-
MainBackend::sync(&device).expect("Failed to sync");
263+
block_on(async {
264+
for _ in 0..ITERS_PER_SYNC {
265+
let diff_out = render_splats_diff(
266+
splats.clone(),
267+
&camera,
268+
glam::uvec2(width, height),
269+
Vec3::ZERO,
270+
)
271+
.await;
272+
let img: Tensor<DiffBackend, 3> =
273+
Tensor::from_primitive(TensorPrimitive::Float(diff_out.img));
274+
let _ = img.mean().backward();
275+
}
276+
MainBackend::sync(&device).expect("Failed to sync");
277+
});
248278
});
249279
}
250280
}
251281

252282
#[divan::bench_group(max_time = 4)]
253283
mod training {
254284
use brush_render::bounding_box::BoundingBox;
285+
use burn_cubecl::cubecl::future::block_on;
255286

256287
use crate::benches::ITERS_PER_SYNC;
257288

@@ -262,25 +293,23 @@ mod training {
262293

263294
#[divan::bench(args = SPLAT_COUNTS)]
264295
fn train_steps(splat_count: usize) {
265-
burn_cubecl::cubecl::future::block_on(async {
266-
let device = WgpuDevice::default();
267-
let batch1 = generate_training_batch((1920, 1080), Vec3::new(0.0, 0.0, 5.0));
268-
let batch2 = generate_training_batch((1920, 1080), Vec3::new(2.0, 0.0, 5.0));
269-
let batches = [batch1, batch2];
270-
let config = TrainConfig::default();
271-
let mut splats = gen_splats(&device, splat_count);
272-
let mut trainer = SplatTrainer::new(
273-
&config,
274-
&device,
275-
BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE),
276-
);
277-
for step in 0..ITERS_PER_SYNC {
278-
let batch = batches[step as usize % batches.len()].clone();
279-
let (new_splats, _) = trainer.step(batch, splats);
280-
splats = new_splats;
281-
}
282-
MainBackend::sync(&device).expect("Failed to sync");
283-
});
296+
let device = WgpuDevice::default();
297+
let batch1 = generate_training_batch((1920, 1080), Vec3::new(0.0, 0.0, 5.0));
298+
let batch2 = generate_training_batch((1920, 1080), Vec3::new(2.0, 0.0, 5.0));
299+
let batches = [batch1, batch2];
300+
let config = TrainConfig::default();
301+
let mut splats = gen_splats(&device, splat_count);
302+
let mut trainer = SplatTrainer::new(
303+
&config,
304+
&device,
305+
BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE),
306+
);
307+
for step in 0..ITERS_PER_SYNC {
308+
let batch = batches[step as usize % batches.len()].clone();
309+
let (new_splats, _) = block_on(trainer.step(batch, splats));
310+
splats = new_splats;
311+
}
312+
MainBackend::sync(&device).expect("Failed to sync");
284313
}
285314
}
286315

crates/brush-bench-test/src/reference.rs

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use burn::{
1010
Tensor,
1111
backend::{Autodiff, wgpu::WgpuDevice},
1212
prelude::Backend,
13-
tensor::{Float, Int, TensorPrimitive},
13+
tensor::TensorPrimitive,
1414
};
1515

1616
use anyhow::{Context, Result};
@@ -127,16 +127,15 @@ async fn test_reference() -> Result<()> {
127127
);
128128

129129
let diff_out = brush_render_bwd::render_splats(
130-
&splats,
130+
splats.clone(),
131131
&cam,
132132
glam::uvec2(w as u32, h as u32),
133133
Vec3::ZERO,
134-
);
134+
)
135+
.await;
135136

136-
let (out, aux) = (
137-
Tensor::from_primitive(TensorPrimitive::Float(diff_out.img)),
138-
diff_out.aux,
139-
);
137+
let out: Tensor<DiffBack, 3> = Tensor::from_primitive(TensorPrimitive::Float(diff_out.img));
138+
let render_aux = diff_out.render_aux;
140139

141140
if let Some(rec) = rec.as_ref() {
142141
rec.set_time_sequence("test case", i as i64);
@@ -148,28 +147,10 @@ async fn test_reference() -> Result<()> {
148147
)?;
149148
rec.log(
150149
"images/tile_depth",
151-
&aux.calc_tile_depth().into_rerun().await,
150+
&render_aux.calc_tile_depth().into_rerun().await,
152151
)?;
153152
}
154153

155-
let num_visible: Tensor<DiffBack, 1, Int> = aux.num_visible();
156-
let num_visible = num_visible.into_scalar_async().await.unwrap() as usize;
157-
let global_from_compact_gid: Tensor<DiffBack, 1, Int> =
158-
Tensor::from_primitive(aux.global_from_compact_gid.clone());
159-
let gs_ids = global_from_compact_gid.clone().slice([0..num_visible]);
160-
let projected_splats =
161-
Tensor::from_primitive(TensorPrimitive::Float(aux.projected_splats.clone()));
162-
let xys: Tensor<DiffBack, 2, Float> =
163-
projected_splats.clone().slice([0..num_visible, 0..2]);
164-
let xys_ref = safetensor_to_burn::<DiffBack, 2>(&tensors.tensor("xys")?, &device);
165-
let xys_ref = xys_ref.select(0, gs_ids.clone());
166-
compare("xy", xys, xys_ref, 1e-5, 2e-5);
167-
let conics: Tensor<DiffBack, 2, Float> =
168-
projected_splats.clone().slice([0..num_visible, 2..5]);
169-
let conics_ref = safetensor_to_burn::<DiffBack, 2>(&tensors.tensor("conics")?, &device);
170-
let conics_ref = conics_ref.select(0, gs_ids.clone());
171-
compare("conics", conics, conics_ref, 1e-6, 2e-5);
172-
173154
// Check if images match.
174155
compare("img", out.clone(), img_ref, 1e-5, 1e-5);
175156

crates/brush-bench-test/tests/integration.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ fn test_forward_rendering() {
161161
assert!(means_data.iter().all(|&x| x.is_finite()));
162162
}
163163

164-
#[test]
165-
fn test_training_step() {
164+
#[tokio::test]
165+
async fn test_training_step() {
166166
let device = WgpuDevice::default();
167167
let batch = generate_test_batch((64, 64));
168168
let splats = generate_test_splats(&device, 500);
@@ -172,7 +172,7 @@ fn test_training_step() {
172172
&device,
173173
BoundingBox::from_min_max(Vec3::ZERO, Vec3::ONE),
174174
);
175-
let (final_splats, stats) = trainer.step(batch, splats);
175+
let (final_splats, stats) = trainer.step(batch, splats).await;
176176

177177
assert!(final_splats.num_splats() > 0);
178178
let loss = stats.loss.into_scalar();
@@ -190,8 +190,8 @@ fn test_batch_generation() {
190190
assert!(img_data.iter().all(|&x| (0.0..=1.1).contains(&x)));
191191
}
192192

193-
#[test]
194-
fn test_multi_step_training() {
193+
#[tokio::test]
194+
async fn test_multi_step_training() {
195195
let device = WgpuDevice::default();
196196
let batch = generate_test_batch((64, 64));
197197
let config = TrainConfig::default();
@@ -205,7 +205,7 @@ fn test_multi_step_training() {
205205

206206
// Run a few training steps
207207
for _ in 0..3 {
208-
let (new_splats, stats) = trainer.step(batch.clone(), splats);
208+
let (new_splats, stats) = trainer.step(batch.clone(), splats).await;
209209
splats = new_splats;
210210

211211
let loss = stats.loss.into_scalar();
@@ -216,8 +216,8 @@ fn test_multi_step_training() {
216216
assert!(splats.num_splats() > 0);
217217
}
218218

219-
#[test]
220-
fn test_gradient_validation() {
219+
#[tokio::test]
220+
async fn test_gradient_validation() {
221221
let device = WgpuDevice::default();
222222
let splats = generate_test_splats(&device, 100);
223223

@@ -231,7 +231,8 @@ fn test_gradient_validation() {
231231
);
232232
let img_size = glam::uvec2(64, 64);
233233

234-
let result = render_splats(&splats, &camera, img_size, Vec3::ZERO);
234+
// Clone splats since render_splats takes ownership and we need splats for gradient validation
235+
let result = render_splats(splats.clone(), &camera, img_size, Vec3::ZERO).await;
235236

236237
let rendered: Tensor<DiffBackend, 3> =
237238
Tensor::from_primitive(TensorPrimitive::Float(result.img));

crates/brush-kernel/build.rs

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)