Skip to content

Commit 86807d8

Browse files
committed
Multi View consistent densification and pruning.
1 parent a6d6a20 commit 86807d8

File tree

15 files changed

+464
-162
lines changed

15 files changed

+464
-162
lines changed

crates/brush-process/src/train_stream.rs

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use brush_render::{
1313
use brush_rerun::visualize_tools::VisualizeTools;
1414
use brush_train::{
1515
RandomSplatsConfig, create_random_splats,
16+
density_control::compute_gaussian_score,
1617
eval::eval_stats,
1718
msg::RefineStats,
1819
splats_into_autodiff, to_init_splats,
@@ -194,15 +195,37 @@ pub(crate) async fn train_stream(
194195
.await
195196
.unwrap();
196197

197-
let train_t =
198-
(iter as f32 / train_stream_config.train_config.total_steps as f32).clamp(0.0, 1.0);
199-
200198
let refine = if iter > 0
201199
&& iter.is_multiple_of(train_stream_config.train_config.refine_every)
202-
&& train_t <= 0.95
200+
&& iter < train_stream_config.train_config.growth_stop_iter
203201
{
204202
splat_slot
205-
.act(0, async |splats| trainer.refine(iter, splats).await)
203+
.act(0, async |splats| {
204+
let gaussian_scores = compute_gaussian_score(
205+
&mut dataloader,
206+
splats.clone(),
207+
train_stream_config.train_config.n_views,
208+
train_stream_config.train_config.high_error_threshold,
209+
)
210+
.await;
211+
trainer.refine(iter, splats.clone(), gaussian_scores).await
212+
})
213+
.await
214+
.unwrap()
215+
} else if iter > train_stream_config.train_config.growth_stop_iter
216+
&& iter % train_stream_config.train_config.refine_every_final == 0
217+
{
218+
splat_slot
219+
.act(0, async |splats| {
220+
let gaussian_scores = compute_gaussian_score(
221+
&mut dataloader,
222+
splats.clone(),
223+
train_stream_config.train_config.n_views,
224+
train_stream_config.train_config.high_error_threshold,
225+
)
226+
.await;
227+
trainer.refine_final(splats.clone(), gaussian_scores).await
228+
})
206229
.await
207230
.unwrap()
208231
} else {

crates/brush-render-bwd/src/burn_glue.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,15 +279,22 @@ where
279279
// Async readback
280280
let num_intersections = project_output.read_num_intersections().await;
281281

282-
let (out_img, render_aux, compact_gid_from_isect) =
283-
<B as SplatOps<B>>::rasterize(&project_output, num_intersections, background, true);
282+
let (out_img, render_aux, compact_gid_from_isect) = <B as SplatOps<B>>::rasterize(
283+
&project_output,
284+
num_intersections,
285+
background,
286+
true,
287+
false,
288+
None,
289+
);
284290

285291
let wrapped_render_aux = RenderAux::<Autodiff<B, C>> {
286292
num_visible: render_aux.num_visible.clone(),
287293
num_intersections: render_aux.num_intersections,
288294
visible: <Autodiff<B, C> as AutodiffBackend>::from_inner(render_aux.visible.clone()),
289295
tile_offsets: render_aux.tile_offsets.clone(),
290296
img_size: render_aux.img_size,
297+
high_error_count: render_aux.high_error_count,
291298
};
292299

293300
let sh_degree = sh_degree_from_coeffs(sh_coeffs_dims[1] as u32);

crates/brush-render/src/burn_glue.rs

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use burn::tensor::{
2-
DType, Shape,
2+
DType, Shape, Tensor,
33
ops::{FloatTensor, IntTensor},
44
};
55
use burn_cubecl::{BoolElement, fusion::FusionCubeRuntime};
@@ -164,13 +164,16 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
164164
num_intersections: u32,
165165
background: Vec3,
166166
bwd_info: bool,
167+
high_error_info: bool,
168+
high_error_mask: Option<&FloatTensor<Self>>,
167169
) -> (FloatTensor<Self>, RenderAux<Self>, IntTensor<Self>) {
168170
#[derive(Debug)]
169171
struct CustomOp {
170172
img_size: glam::UVec2,
171173
num_intersections: u32,
172174
background: Vec3,
173175
bwd_info: bool,
176+
high_error_info: bool,
174177
project_uniforms: ProjectUniforms,
175178
desc: CustomOpIr,
176179
}
@@ -187,8 +190,15 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
187190
num_visible,
188191
global_from_compact_gid,
189192
cum_tiles_hit,
193+
high_error_mask,
190194
] = inputs;
191-
let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs;
195+
let [
196+
out_img,
197+
tile_offsets,
198+
compact_gid_from_isect,
199+
visible,
200+
high_error_count,
201+
] = outputs;
192202

193203
let inner_output = ProjectOutput::<MainBackendBase> {
194204
projected_splats: h.get_float_tensor::<MainBackendBase>(projected_splats),
@@ -200,22 +210,35 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
200210
img_size: self.img_size,
201211
};
202212

213+
let high_error_mask = if self.high_error_info {
214+
Some(&h.get_float_tensor::<MainBackendBase>(high_error_mask))
215+
} else {
216+
None
217+
};
218+
203219
let (img, aux, compact_gid) = MainBackendBase::rasterize(
204220
&inner_output,
205221
self.num_intersections,
206222
self.background,
207223
self.bwd_info,
224+
self.high_error_info,
225+
high_error_mask,
208226
);
209227

210228
// Register outputs
211229
h.register_float_tensor::<MainBackendBase>(&out_img.id, img);
212230
h.register_int_tensor::<MainBackendBase>(&tile_offsets.id, aux.tile_offsets);
213231
h.register_int_tensor::<MainBackendBase>(&compact_gid_from_isect.id, compact_gid);
214232
h.register_float_tensor::<MainBackendBase>(&visible.id, aux.visible);
233+
h.register_int_tensor::<MainBackendBase>(
234+
&high_error_count.id,
235+
aux.high_error_count,
236+
);
215237
}
216238
}
217239

218240
let client = project_output.projected_splats.client.clone();
241+
let device = project_output.projected_splats.client.device();
219242
let img_size = project_output.img_size;
220243
let tile_bounds = calc_tile_bounds(img_size);
221244

@@ -248,23 +271,52 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
248271
};
249272
let visible = TensorIr::uninit(client.create_empty_handle(), visible_shape, DType::F32);
250273

274+
let high_error_count_shape = if bwd_info && high_error_info {
275+
Shape::new([num_points])
276+
} else {
277+
Shape::new([1])
278+
};
279+
let high_error_count = TensorIr::uninit(
280+
client.create_empty_handle(),
281+
high_error_count_shape,
282+
DType::U32,
283+
);
284+
285+
let high_error_mask = if bwd_info && high_error_info {
286+
high_error_mask
287+
.expect("Provide high error mask if high error info is required")
288+
.clone()
289+
} else {
290+
Tensor::<Self, 2>::zeros([1, 1], &device)
291+
.into_primitive()
292+
.tensor()
293+
};
294+
251295
let input_tensors = [
252296
project_output.projected_splats.clone(),
253297
project_output.num_visible.clone(),
254298
project_output.global_from_compact_gid.clone(),
255299
project_output.cum_tiles_hit.clone(),
300+
high_error_mask.clone(),
256301
];
257302
let stream = OperationStreams::with_inputs(&input_tensors);
258303
let desc = CustomOpIr::new(
259304
"rasterize",
260305
&input_tensors.map(|t| t.into_ir()),
261-
&[out_img, tile_offsets, compact_gid_from_isect, visible],
306+
&[
307+
out_img,
308+
tile_offsets,
309+
compact_gid_from_isect,
310+
visible,
311+
high_error_count,
312+
],
262313
);
263314
let op = CustomOp {
264315
img_size,
265316
num_intersections,
266317
background,
267318
bwd_info,
319+
high_error_info,
268320
project_uniforms: project_output.project_uniforms,
269321
desc: desc.clone(),
270322
};
@@ -273,7 +325,13 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
273325
.register(stream, OperationIr::Custom(desc), op)
274326
.outputs();
275327

276-
let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs;
328+
let [
329+
out_img,
330+
tile_offsets,
331+
compact_gid_from_isect,
332+
visible,
333+
high_error_count,
334+
] = outputs;
277335

278336
(
279337
out_img,
@@ -283,6 +341,7 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
283341
visible,
284342
tile_offsets,
285343
img_size,
344+
high_error_count,
286345
},
287346
compact_gid_from_isect,
288347
)

crates/brush-render/src/gaussian_splats.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,14 @@ pub async fn render_splats<B: Backend + SplatOps<B>>(
275275
let num_intersections = project_output.read_num_intersections().await;
276276

277277
let use_float = matches!(texture_mode, TextureMode::Float);
278-
let (out_img, render_aux, _) =
279-
B::rasterize(&project_output, num_intersections, background, use_float);
278+
let (out_img, render_aux, _) = B::rasterize(
279+
&project_output,
280+
num_intersections,
281+
background,
282+
use_float,
283+
false,
284+
None,
285+
);
280286

281287
render_aux.validate();
282288

crates/brush-render/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ pub trait SplatOps<B: Backend> {
6161
num_intersections: u32,
6262
background: Vec3,
6363
bwd_info: bool,
64+
high_error_info: bool,
65+
high_error_mask: Option<&FloatTensor<B>>,
6466
) -> (FloatTensor<B>, RenderAux<B>, IntTensor<B>);
6567
}
6668

crates/brush-render/src/render.rs

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ impl SplatOps<Self> for MainBackendBase {
182182
num_intersections: u32,
183183
background: Vec3,
184184
bwd_info: bool,
185+
high_error_info: bool,
186+
high_error_mask: Option<&FloatTensor<Self>>,
185187
) -> (FloatTensor<Self>, RenderAux<Self>, IntTensor<Self>) {
186188
let _span = tracing::trace_span!("rasterize").entered();
187189

@@ -281,23 +283,47 @@ impl SplatOps<Self> for MainBackendBase {
281283
// Get total_splats from the shape of projected_splats
282284
let total_splats = project_output.projected_splats.shape.dims[0];
283285

284-
let (bindings, visible) = if bwd_info {
286+
let (bindings, visible, high_error_count) = if bwd_info {
285287
let visible = Self::float_zeros([total_splats].into(), device, FloatDType::F32);
286-
let bindings = Bindings::new()
287-
.with_buffers(vec![
288-
compact_gid_from_isect.handle.clone().binding(),
289-
tile_offsets.handle.clone().binding(),
290-
project_output.projected_splats.handle.clone().binding(),
291-
out_img.handle.clone().binding(),
292-
project_output
293-
.global_from_compact_gid
294-
.handle
295-
.clone()
296-
.binding(),
297-
visible.handle.clone().binding(),
298-
])
299-
.with_metadata(create_meta_binding(rasterize_uniforms));
300-
(bindings, visible)
288+
if high_error_info {
289+
let high_error_count =
290+
MainBackendBase::int_zeros([total_splats].into(), device, IntDType::U32);
291+
let high_error_mask = high_error_mask
292+
.expect("Provide high error mask if high error info is required");
293+
let bindings = Bindings::new()
294+
.with_buffers(vec![
295+
compact_gid_from_isect.handle.clone().binding(),
296+
tile_offsets.handle.clone().binding(),
297+
project_output.projected_splats.handle.clone().binding(),
298+
out_img.handle.clone().binding(),
299+
project_output
300+
.global_from_compact_gid
301+
.handle
302+
.clone()
303+
.binding(),
304+
visible.handle.clone().binding(),
305+
high_error_mask.handle.clone().binding(),
306+
high_error_count.handle.clone().binding(),
307+
])
308+
.with_metadata(create_meta_binding(rasterize_uniforms));
309+
(bindings, visible, high_error_count)
310+
} else {
311+
let bindings = Bindings::new()
312+
.with_buffers(vec![
313+
compact_gid_from_isect.handle.clone().binding(),
314+
tile_offsets.handle.clone().binding(),
315+
project_output.projected_splats.handle.clone().binding(),
316+
out_img.handle.clone().binding(),
317+
project_output
318+
.global_from_compact_gid
319+
.handle
320+
.clone()
321+
.binding(),
322+
visible.handle.clone().binding(),
323+
])
324+
.with_metadata(create_meta_binding(rasterize_uniforms));
325+
(bindings, visible, create_tensor([1], device, DType::U32))
326+
}
301327
} else {
302328
let bindings = Bindings::new()
303329
.with_buffers(vec![
@@ -307,10 +333,14 @@ impl SplatOps<Self> for MainBackendBase {
307333
out_img.handle.clone().binding(),
308334
])
309335
.with_metadata(create_meta_binding(rasterize_uniforms));
310-
(bindings, create_tensor([1], device, DType::F32))
336+
(
337+
bindings,
338+
create_tensor([1], device, DType::F32),
339+
create_tensor([1], device, DType::U32),
340+
)
311341
};
312342

313-
let raster_task = Rasterize::task(bwd_info);
343+
let raster_task = Rasterize::task(bwd_info, high_error_info);
314344

315345
// SAFETY: Kernel checked to have no OOB, bounded loops.
316346
unsafe {
@@ -331,6 +361,7 @@ impl SplatOps<Self> for MainBackendBase {
331361
visible,
332362
tile_offsets,
333363
img_size: project_output.img_size,
364+
high_error_count,
334365
},
335366
compact_gid_from_isect,
336367
)

crates/brush-render/src/render_aux.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ pub struct RenderAux<B: Backend> {
9292
pub visible: FloatTensor<B>,
9393
pub tile_offsets: IntTensor<B>,
9494
pub img_size: glam::UVec2,
95+
pub high_error_count: IntTensor<B>,
9596
}
9697

9798
impl<B: Backend> RenderAux<B> {

crates/brush-render/src/shaders.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub struct MapGaussiansToIntersect;
1818
#[wgsl_kernel(source = "src/shaders/rasterize.wgsl")]
1919
pub struct Rasterize {
2020
pub bwd_info: bool,
21+
pub high_error_info: bool,
2122
}
2223

2324
// Re-export helper types and constants from the kernel modules that use them

0 commit comments

Comments
 (0)