Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions crates/brush-process/src/train_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use brush_render::{
use brush_rerun::visualize_tools::VisualizeTools;
use brush_train::{
RandomSplatsConfig, create_random_splats,
density_control::compute_gaussian_score,
eval::eval_stats,
msg::RefineStats,
splats_into_autodiff, to_init_splats,
Expand Down Expand Up @@ -194,15 +195,37 @@ pub(crate) async fn train_stream(
.await
.unwrap();

let train_t =
(iter as f32 / train_stream_config.train_config.total_steps as f32).clamp(0.0, 1.0);

let refine = if iter > 0
&& iter.is_multiple_of(train_stream_config.train_config.refine_every)
&& train_t <= 0.95
&& iter < train_stream_config.train_config.growth_stop_iter
{
splat_slot
.act(0, async |splats| trainer.refine(iter, splats).await)
.act(0, async |splats| {
let gaussian_scores = compute_gaussian_score(
&mut dataloader,
splats.clone(),
train_stream_config.train_config.n_views,
train_stream_config.train_config.high_error_threshold,
)
.await;
trainer.refine(iter, splats.clone(), gaussian_scores).await
})
.await
.unwrap()
} else if iter > train_stream_config.train_config.growth_stop_iter
&& iter % train_stream_config.train_config.refine_every_final == 0
{
splat_slot
.act(0, async |splats| {
let gaussian_scores = compute_gaussian_score(
&mut dataloader,
splats.clone(),
train_stream_config.train_config.n_views,
train_stream_config.train_config.high_error_threshold,
)
.await;
trainer.refine_final(splats.clone(), gaussian_scores).await
})
.await
.unwrap()
} else {
Expand Down
11 changes: 9 additions & 2 deletions crates/brush-render-bwd/src/burn_glue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,22 @@ where
// Async readback
let num_intersections = project_output.read_num_intersections().await;

let (out_img, render_aux, compact_gid_from_isect) =
<B as SplatOps<B>>::rasterize(&project_output, num_intersections, background, true);
let (out_img, render_aux, compact_gid_from_isect) = <B as SplatOps<B>>::rasterize(
&project_output,
num_intersections,
background,
true,
false,
None,
);

let wrapped_render_aux = RenderAux::<Autodiff<B, C>> {
num_visible: render_aux.num_visible.clone(),
num_intersections: render_aux.num_intersections,
visible: <Autodiff<B, C> as AutodiffBackend>::from_inner(render_aux.visible.clone()),
tile_offsets: render_aux.tile_offsets.clone(),
img_size: render_aux.img_size,
high_error_count: render_aux.high_error_count,
};

let sh_degree = sh_degree_from_coeffs(sh_coeffs_dims[1] as u32);
Expand Down
67 changes: 63 additions & 4 deletions crates/brush-render/src/burn_glue.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn::tensor::{
DType, Shape,
DType, Shape, Tensor,
ops::{FloatTensor, IntTensor},
};
use burn_cubecl::{BoolElement, fusion::FusionCubeRuntime};
Expand Down Expand Up @@ -164,13 +164,16 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
num_intersections: u32,
background: Vec3,
bwd_info: bool,
high_error_info: bool,
high_error_mask: Option<&FloatTensor<Self>>,
) -> (FloatTensor<Self>, RenderAux<Self>, IntTensor<Self>) {
#[derive(Debug)]
struct CustomOp {
img_size: glam::UVec2,
num_intersections: u32,
background: Vec3,
bwd_info: bool,
high_error_info: bool,
project_uniforms: ProjectUniforms,
desc: CustomOpIr,
}
Expand All @@ -187,8 +190,15 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
num_visible,
global_from_compact_gid,
cum_tiles_hit,
high_error_mask,
] = inputs;
let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs;
let [
out_img,
tile_offsets,
compact_gid_from_isect,
visible,
high_error_count,
] = outputs;

let inner_output = ProjectOutput::<MainBackendBase> {
projected_splats: h.get_float_tensor::<MainBackendBase>(projected_splats),
Expand All @@ -200,22 +210,35 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
img_size: self.img_size,
};

let high_error_mask = if self.high_error_info {
Some(&h.get_float_tensor::<MainBackendBase>(high_error_mask))
} else {
None
};

let (img, aux, compact_gid) = MainBackendBase::rasterize(
&inner_output,
self.num_intersections,
self.background,
self.bwd_info,
self.high_error_info,
high_error_mask,
);

// Register outputs
h.register_float_tensor::<MainBackendBase>(&out_img.id, img);
h.register_int_tensor::<MainBackendBase>(&tile_offsets.id, aux.tile_offsets);
h.register_int_tensor::<MainBackendBase>(&compact_gid_from_isect.id, compact_gid);
h.register_float_tensor::<MainBackendBase>(&visible.id, aux.visible);
h.register_int_tensor::<MainBackendBase>(
&high_error_count.id,
aux.high_error_count,
);
}
}

let client = project_output.projected_splats.client.clone();
let device = project_output.projected_splats.client.device();
let img_size = project_output.img_size;
let tile_bounds = calc_tile_bounds(img_size);

Expand Down Expand Up @@ -248,23 +271,52 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
};
let visible = TensorIr::uninit(client.create_empty_handle(), visible_shape, DType::F32);

let high_error_count_shape = if bwd_info && high_error_info {
Shape::new([num_points])
} else {
Shape::new([1])
};
let high_error_count = TensorIr::uninit(
client.create_empty_handle(),
high_error_count_shape,
DType::U32,
);

let high_error_mask = if bwd_info && high_error_info {
high_error_mask
.expect("Provide high error mask if high error info is required")
.clone()
} else {
Tensor::<Self, 2>::zeros([1, 1], device)
.into_primitive()
.tensor()
};

let input_tensors = [
project_output.projected_splats.clone(),
project_output.num_visible.clone(),
project_output.global_from_compact_gid.clone(),
project_output.cum_tiles_hit.clone(),
high_error_mask,
];
let stream = OperationStreams::with_inputs(&input_tensors);
let desc = CustomOpIr::new(
"rasterize",
&input_tensors.map(|t| t.into_ir()),
&[out_img, tile_offsets, compact_gid_from_isect, visible],
&[
out_img,
tile_offsets,
compact_gid_from_isect,
visible,
high_error_count,
],
);
let op = CustomOp {
img_size,
num_intersections,
background,
bwd_info,
high_error_info,
project_uniforms: project_output.project_uniforms,
desc: desc.clone(),
};
Expand All @@ -273,7 +325,13 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
.register(stream, OperationIr::Custom(desc), op)
.outputs();

let [out_img, tile_offsets, compact_gid_from_isect, visible] = outputs;
let [
out_img,
tile_offsets,
compact_gid_from_isect,
visible,
high_error_count,
] = outputs;

(
out_img,
Expand All @@ -283,6 +341,7 @@ impl SplatOps<Self> for Fusion<MainBackendBase> {
visible,
tile_offsets,
img_size,
high_error_count,
},
compact_gid_from_isect,
)
Expand Down
10 changes: 8 additions & 2 deletions crates/brush-render/src/gaussian_splats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,14 @@ pub async fn render_splats<B: Backend + SplatOps<B>>(
let num_intersections = project_output.read_num_intersections().await;

let use_float = matches!(texture_mode, TextureMode::Float);
let (out_img, render_aux, _) =
B::rasterize(&project_output, num_intersections, background, use_float);
let (out_img, render_aux, _) = B::rasterize(
&project_output,
num_intersections,
background,
use_float,
false,
None,
);

render_aux.validate();

Expand Down
2 changes: 2 additions & 0 deletions crates/brush-render/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ pub trait SplatOps<B: Backend> {
num_intersections: u32,
background: Vec3,
bwd_info: bool,
high_error_info: bool,
high_error_mask: Option<&FloatTensor<B>>,
) -> (FloatTensor<B>, RenderAux<B>, IntTensor<B>);
}

Expand Down
67 changes: 49 additions & 18 deletions crates/brush-render/src/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ impl SplatOps<Self> for MainBackendBase {
num_intersections: u32,
background: Vec3,
bwd_info: bool,
high_error_info: bool,
high_error_mask: Option<&FloatTensor<Self>>,
) -> (FloatTensor<Self>, RenderAux<Self>, IntTensor<Self>) {
let _span = tracing::trace_span!("rasterize").entered();

Expand Down Expand Up @@ -281,23 +283,47 @@ impl SplatOps<Self> for MainBackendBase {
// Get total_splats from the shape of projected_splats
let total_splats = project_output.projected_splats.shape.dims[0];

let (bindings, visible) = if bwd_info {
let (bindings, visible, high_error_count) = if bwd_info {
let visible = Self::float_zeros([total_splats].into(), device, FloatDType::F32);
let bindings = Bindings::new()
.with_buffers(vec![
compact_gid_from_isect.handle.clone().binding(),
tile_offsets.handle.clone().binding(),
project_output.projected_splats.handle.clone().binding(),
out_img.handle.clone().binding(),
project_output
.global_from_compact_gid
.handle
.clone()
.binding(),
visible.handle.clone().binding(),
])
.with_metadata(create_meta_binding(rasterize_uniforms));
(bindings, visible)
if high_error_info {
let high_error_count =
Self::int_zeros([total_splats].into(), device, IntDType::U32);
let high_error_mask = high_error_mask
.expect("Provide high error mask if high error info is required");
let bindings = Bindings::new()
.with_buffers(vec![
compact_gid_from_isect.handle.clone().binding(),
tile_offsets.handle.clone().binding(),
project_output.projected_splats.handle.clone().binding(),
out_img.handle.clone().binding(),
project_output
.global_from_compact_gid
.handle
.clone()
.binding(),
visible.handle.clone().binding(),
high_error_mask.handle.clone().binding(),
high_error_count.handle.clone().binding(),
])
.with_metadata(create_meta_binding(rasterize_uniforms));
(bindings, visible, high_error_count)
} else {
let bindings = Bindings::new()
.with_buffers(vec![
compact_gid_from_isect.handle.clone().binding(),
tile_offsets.handle.clone().binding(),
project_output.projected_splats.handle.clone().binding(),
out_img.handle.clone().binding(),
project_output
.global_from_compact_gid
.handle
.clone()
.binding(),
visible.handle.clone().binding(),
])
.with_metadata(create_meta_binding(rasterize_uniforms));
(bindings, visible, create_tensor([1], device, DType::U32))
}
} else {
let bindings = Bindings::new()
.with_buffers(vec![
Expand All @@ -307,10 +333,14 @@ impl SplatOps<Self> for MainBackendBase {
out_img.handle.clone().binding(),
])
.with_metadata(create_meta_binding(rasterize_uniforms));
(bindings, create_tensor([1], device, DType::F32))
(
bindings,
create_tensor([1], device, DType::F32),
create_tensor([1], device, DType::U32),
)
};

let raster_task = Rasterize::task(bwd_info);
let raster_task = Rasterize::task(bwd_info, high_error_info);

// SAFETY: Kernel checked to have no OOB, bounded loops.
unsafe {
Expand All @@ -331,6 +361,7 @@ impl SplatOps<Self> for MainBackendBase {
visible,
tile_offsets,
img_size: project_output.img_size,
high_error_count,
},
compact_gid_from_isect,
)
Expand Down
1 change: 1 addition & 0 deletions crates/brush-render/src/render_aux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub struct RenderAux<B: Backend> {
pub visible: FloatTensor<B>,
pub tile_offsets: IntTensor<B>,
pub img_size: glam::UVec2,
pub high_error_count: IntTensor<B>,
}

impl<B: Backend> RenderAux<B> {
Expand Down
1 change: 1 addition & 0 deletions crates/brush-render/src/shaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct MapGaussiansToIntersect;
#[wgsl_kernel(source = "src/shaders/rasterize.wgsl")]
pub struct Rasterize {
pub bwd_info: bool,
pub high_error_info: bool,
}

// Re-export helper types and constants from the kernel modules that use them
Expand Down
Loading