@@ -182,6 +182,8 @@ impl SplatOps<Self> for MainBackendBase {
182182 num_intersections : u32 ,
183183 background : Vec3 ,
184184 bwd_info : bool ,
185+ high_error_info : bool ,
186+ high_error_mask : Option < & FloatTensor < Self > > ,
185187 ) -> ( FloatTensor < Self > , RenderAux < Self > , IntTensor < Self > ) {
186188 let _span = tracing:: trace_span!( "rasterize" ) . entered ( ) ;
187189
@@ -281,23 +283,47 @@ impl SplatOps<Self> for MainBackendBase {
281283 // Get total_splats from the shape of projected_splats
282284 let total_splats = project_output. projected_splats . shape . dims [ 0 ] ;
283285
284- let ( bindings, visible) = if bwd_info {
286+ let ( bindings, visible, high_error_count ) = if bwd_info {
285287 let visible = Self :: float_zeros ( [ total_splats] . into ( ) , device, FloatDType :: F32 ) ;
286- let bindings = Bindings :: new ( )
287- . with_buffers ( vec ! [
288- compact_gid_from_isect. handle. clone( ) . binding( ) ,
289- tile_offsets. handle. clone( ) . binding( ) ,
290- project_output. projected_splats. handle. clone( ) . binding( ) ,
291- out_img. handle. clone( ) . binding( ) ,
292- project_output
293- . global_from_compact_gid
294- . handle
295- . clone( )
296- . binding( ) ,
297- visible. handle. clone( ) . binding( ) ,
298- ] )
299- . with_metadata ( create_meta_binding ( rasterize_uniforms) ) ;
300- ( bindings, visible)
288+ if high_error_info {
289+ let high_error_count =
290+ MainBackendBase :: int_zeros ( [ total_splats] . into ( ) , device, IntDType :: U32 ) ;
291+ let high_error_mask = high_error_mask
292+ . expect ( "Provide high error mask if high error info is required" ) ;
293+ let bindings = Bindings :: new ( )
294+ . with_buffers ( vec ! [
295+ compact_gid_from_isect. handle. clone( ) . binding( ) ,
296+ tile_offsets. handle. clone( ) . binding( ) ,
297+ project_output. projected_splats. handle. clone( ) . binding( ) ,
298+ out_img. handle. clone( ) . binding( ) ,
299+ project_output
300+ . global_from_compact_gid
301+ . handle
302+ . clone( )
303+ . binding( ) ,
304+ visible. handle. clone( ) . binding( ) ,
305+ high_error_mask. handle. clone( ) . binding( ) ,
306+ high_error_count. handle. clone( ) . binding( ) ,
307+ ] )
308+ . with_metadata ( create_meta_binding ( rasterize_uniforms) ) ;
309+ ( bindings, visible, high_error_count)
310+ } else {
311+ let bindings = Bindings :: new ( )
312+ . with_buffers ( vec ! [
313+ compact_gid_from_isect. handle. clone( ) . binding( ) ,
314+ tile_offsets. handle. clone( ) . binding( ) ,
315+ project_output. projected_splats. handle. clone( ) . binding( ) ,
316+ out_img. handle. clone( ) . binding( ) ,
317+ project_output
318+ . global_from_compact_gid
319+ . handle
320+ . clone( )
321+ . binding( ) ,
322+ visible. handle. clone( ) . binding( ) ,
323+ ] )
324+ . with_metadata ( create_meta_binding ( rasterize_uniforms) ) ;
325+ ( bindings, visible, create_tensor ( [ 1 ] , device, DType :: U32 ) )
326+ }
301327 } else {
302328 let bindings = Bindings :: new ( )
303329 . with_buffers ( vec ! [
@@ -307,10 +333,14 @@ impl SplatOps<Self> for MainBackendBase {
307333 out_img. handle. clone( ) . binding( ) ,
308334 ] )
309335 . with_metadata ( create_meta_binding ( rasterize_uniforms) ) ;
310- ( bindings, create_tensor ( [ 1 ] , device, DType :: F32 ) )
336+ (
337+ bindings,
338+ create_tensor ( [ 1 ] , device, DType :: F32 ) ,
339+ create_tensor ( [ 1 ] , device, DType :: U32 ) ,
340+ )
311341 } ;
312342
313- let raster_task = Rasterize :: task ( bwd_info) ;
343+ let raster_task = Rasterize :: task ( bwd_info, high_error_info ) ;
314344
315345 // SAFETY: Kernel checked to have no OOB, bounded loops.
316346 unsafe {
@@ -331,6 +361,7 @@ impl SplatOps<Self> for MainBackendBase {
331361 visible,
332362 tile_offsets,
333363 img_size : project_output. img_size ,
364+ high_error_count,
334365 } ,
335366 compact_gid_from_isect,
336367 )
0 commit comments