4646from nerfstudio .utils .spherical_harmonics import RGB2SH , SH2RGB , num_sh_bases
4747
4848
49- def resize_image (image : torch .Tensor , d : int ):
49+ def resize_image (image : torch .Tensor , d : int ) -> torch . Tensor :
5050 """
5151 Downscale images using the same 'area' method in opencv
5252
53- :param image shape [H, W, C]
53+ :param image shape [B, H, W, C]
5454 :param d downscale factor (must be 2, 4, 8, etc.)
5555
56- return downscaled image in shape [H//d, W//d, C]
56+ return downscaled image in shape [B, H//d, W//d, C]
5757 """
5858 import torch .nn .functional as tf
5959
60- image = image .to (torch .float32 )
6160 weight = (1.0 / (d * d )) * torch .ones ((1 , 1 , d , d ), dtype = torch .float32 , device = image .device )
62- return tf .conv2d (image .permute (2 , 0 , 1 )[:, None , ...], weight , stride = d ).squeeze (1 ).permute (1 , 2 , 0 )
61+
62+ B , H , W , C = image .shape
63+ image = image .permute (0 , 3 , 1 , 2 ) # [B, C, H, W]
64+ image = image .reshape (B * C , 1 , H , W ) # Combine batch and channel dimensions for Conv2D
65+
66+ downscaled = tf .conv2d (image , weight , stride = d )
67+ downscaled = downscaled .reshape (B , C , downscaled .shape [- 2 ], downscaled .shape [- 1 ])
68+ downscaled = downscaled .permute (0 , 2 , 3 , 1 ) # [B, H//d, W//d, C]
69+
70+ return downscaled
6371
6472
6573@torch_compile ()
@@ -482,32 +490,31 @@ def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idx: int, H: int, W: int)
482490 )
483491 return out ["rgb" ]
484492
485- def get_outputs (self , camera : Cameras ) -> Dict [str , Union [torch .Tensor , List ]]:
486- """Takes in a camera and returns a dictionary of outputs.
493+ def get_outputs (self , cameras : Cameras ) -> Dict [str , Union [torch .Tensor , List ]]:
494+ """Takes in cameras and returns a dictionary of outputs.
487495
488496 Args:
489- camera : The camera(s) for which output images are rendered. It should have
497+ cameras : The camera(s) for which output images are rendered. It should have
490498 all the needed information to compute the outputs.
491499
492500 Returns:
493501 Outputs of model. (ie. rendered colors)
494502 """
495- if not isinstance (camera , Cameras ):
503+ if not isinstance (cameras , Cameras ):
496504 print ("Called get_outputs with not a camera" )
497505 return {}
498506
499507 if self .training :
500- assert camera .shape [0 ] == 1 , "Only one camera at a time"
501- optimized_camera_to_world = self .camera_optimizer .apply_to_camera (camera )
508+ optimized_camera_to_world = self .camera_optimizer .apply_to_camera (cameras )
502509 else :
503- optimized_camera_to_world = camera .camera_to_worlds
510+ optimized_camera_to_world = cameras .camera_to_worlds
504511
505512 # cropping
506513 if self .crop_box is not None and not self .training :
507514 crop_ids = self .crop_box .within (self .means ).squeeze ()
508515 if crop_ids .sum () == 0 :
509516 return self .get_empty_outputs (
510- int (camera .width .item ()), int (camera .height .item ()), self .background_color
517+ int (cameras .width .item ()), int (cameras .height .item ()), self .background_color
511518 )
512519 else :
513520 crop_ids = None
@@ -530,12 +537,16 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
530537 colors_crop = torch .cat ((features_dc_crop [:, None , :], features_rest_crop ), dim = 1 )
531538
532539 camera_scale_fac = self ._get_downscale_factor ()
533- camera .rescale_output_resolution (1 / camera_scale_fac )
534- viewmat = get_viewmat (optimized_camera_to_world )
535- K = camera .get_intrinsics_matrices ().cuda ()
536- W , H = int (camera .width .item ()), int (camera .height .item ())
540+ cameras .rescale_output_resolution (1 / camera_scale_fac )
541+ viewmats = get_viewmat (optimized_camera_to_world )
542+ Ks = cameras .get_intrinsics_matrices ().cuda ()
543+
544+ W , H = (
545+ int (cameras .width [0 ]),
546+ int (cameras .height [0 ]),
547+ ) # assume all cameras have the same resolution
537548 self .last_size = (H , W )
538- camera .rescale_output_resolution (camera_scale_fac ) # type: ignore
549+ cameras .rescale_output_resolution (camera_scale_fac ) # type: ignore
539550
540551 # apply the compensation of screen space blurring to gaussians
541552 if self .config .rasterize_mode not in ["antialiased" , "classic" ]:
@@ -558,8 +569,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
558569 scales = torch .exp (scales_crop ),
559570 opacities = torch .sigmoid (opacities_crop ).squeeze (- 1 ),
560571 colors = colors_crop ,
561- viewmats = viewmat , # [1, 4, 4]
562- Ks = K , # [1, 3, 3]
572+ viewmats = viewmats , # [1, 4, 4]
573+ Ks = Ks , # [1, 3, 3]
563574 width = W ,
564575 height = H ,
565576 packed = False ,
@@ -585,24 +596,28 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
585596
586597 # apply bilateral grid
587598 if self .config .use_bilateral_grid and self .training :
588- if camera .metadata is not None and "cam_idx" in camera .metadata :
589- rgb = self ._apply_bilateral_grid (rgb , camera .metadata ["cam_idx" ], H , W )
599+ if cameras .metadata is not None and "cam_idx" in cameras .metadata :
600+ rgb = self ._apply_bilateral_grid (rgb , cameras .metadata ["cam_idx" ], H , W )
590601
591602 if render_mode == "RGB+ED" :
592603 depth_im = render [:, ..., 3 :4 ]
593- depth_im = torch .where (alpha > 0 , depth_im , depth_im .detach ().max ()). squeeze ( 0 )
604+ depth_im = torch .where (alpha > 0 , depth_im , depth_im .detach ().max ())
594605 else :
595606 depth_im = None
596607
597608 if background .shape [0 ] == 3 and not self .training :
598609 background = background .expand (H , W , 3 )
599610
600- return {
601- "rgb" : rgb .squeeze (0 ), # type: ignore
602- "depth" : depth_im , # type: ignore
603- "accumulation" : alpha .squeeze (0 ), # type: ignore
604- "background" : background , # type: ignore
605- } # type: ignore
611+ outputs = {
612+ "rgb" : rgb ,
613+ "depth" : depth_im ,
614+ "accumulation" : alpha ,
615+ "background" : background ,
616+ }
617+
618+ if self .training :
619+ return outputs
620+ return {k : v .squeeze (0 ) if k != "background" else v for k , v in outputs .items ()}
606621
607622 def get_gt_img (self , image : torch .Tensor ):
608623 """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose
@@ -622,7 +637,7 @@ def composite_with_background(self, image, background) -> torch.Tensor:
622637 image: the image to composite
623638 background: the background color
624639 """
625- if image .shape [2 ] == 4 :
640+ if image .shape [- 1 ] == 4 :
626641 alpha = image [..., - 1 ].unsqueeze (- 1 ).repeat ((1 , 1 , 3 ))
627642 return alpha * image [..., :3 ] + (1 - alpha ) * background
628643 else :
@@ -671,7 +686,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
671686 pred_img = pred_img * mask
672687
673688 Ll1 = torch .abs (gt_img - pred_img ).mean ()
674- simloss = 1 - self .ssim (gt_img .permute (2 , 0 , 1 )[ None , ...] , pred_img .permute (2 , 0 , 1 )[ None , ...] )
689+ simloss = 1 - self .ssim (gt_img .permute (0 , 3 , 1 , 2 ) , pred_img .permute (0 , 3 , 1 , 2 ) )
675690 if self .config .use_scale_regularization and self .step % 10 == 0 :
676691 scale_exp = torch .exp (self .scales )
677692 scale_reg = (
0 commit comments