diff --git a/fvdb/__init__.pyi b/fvdb/__init__.pyi index d189b4d3..5e588e25 100644 --- a/fvdb/__init__.pyi +++ b/fvdb/__init__.pyi @@ -101,6 +101,10 @@ __all__ = [ "GaussianSplat3d", "ProjectedGaussianSplats", "ConvolutionPlan", + "ProjectionType", + "ShOrderingMode", + "Grid", + # JaggedTensor operations # Concatenation of jagged tensors or grid/grid batches "jcat", "gcat", diff --git a/fvdb/_fvdb_cpp.pyi b/fvdb/_fvdb_cpp.pyi index 4eb22523..af1a8941 100644 --- a/fvdb/_fvdb_cpp.pyi +++ b/fvdb/_fvdb_cpp.pyi @@ -167,6 +167,21 @@ class GaussianSplat3d: antialias: bool = ..., backgrounds: Optional[torch.Tensor] = ..., ) -> tuple[torch.Tensor, torch.Tensor]: ... + def sparse_render_depths( + self, + pixels_to_render: JaggedTensor, + world_to_camera_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + near: float, + far: float, + projection_type: ProjectionType = ..., + tile_size: int = ..., + min_radius_2d: float = ..., + eps_2d: float = ..., + antialias: bool = ..., + ) -> tuple[JaggedTensor, JaggedTensor]: ... def render_from_projected_gaussians( self, projected_gaussians: ProjectedGaussianSplats, @@ -193,6 +208,22 @@ class GaussianSplat3d: antialias: bool = ..., backgrounds: Optional[torch.Tensor] = ..., ) -> tuple[torch.Tensor, torch.Tensor]: ... + def sparse_render_images( + self, + pixels_to_render: JaggedTensor, + world_to_camera_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + near: float, + far: float, + projection_type: ProjectionType = ..., + sh_degree_to_use: int = ..., + tile_size: int = ..., + min_radius_2d: float = ..., + eps_2d: float = ..., + antialias: bool = ..., + ) -> tuple[JaggedTensor, JaggedTensor]: ... def render_images_and_depths( self, world_to_camera_matrices: torch.Tensor, @@ -209,6 +240,22 @@ class GaussianSplat3d: antialias: bool = ..., backgrounds: Optional[torch.Tensor] = ..., ) -> tuple[torch.Tensor, torch.Tensor]: ... + def sparse_render_images_and_depths( + self, + pixels_to_render: JaggedTensor, + world_to_camera_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + near: float, + far: float, + projection_type: ProjectionType = ..., + sh_degree_to_use: int = ..., + tile_size: int = ..., + min_radius_2d: float = ..., + eps_2d: float = ..., + antialias: bool = ..., + ) -> tuple[JaggedTensor, JaggedTensor]: ... def render_num_contributing_gaussians( self, world_to_camera_matrices: torch.Tensor, diff --git a/fvdb/gaussian_splatting.py b/fvdb/gaussian_splatting.py index 429f9d47..c6c38486 100644 --- a/fvdb/gaussian_splatting.py +++ b/fvdb/gaussian_splatting.py @@ -2,20 +2,21 @@ # SPDX-License-Identifier: Apache-2.0 # import pathlib -from typing import Any, Mapping, Sequence, overload +from typing import Any, Mapping, Sequence, TypeVar, overload import torch from fvdb.enums import ProjectionType -from . import JaggedTensor as JaggedTensorCpp from ._fvdb_cpp import GaussianSplat3d as GaussianSplat3dCpp -from ._fvdb_cpp import JaggedTensor +from ._fvdb_cpp import JaggedTensor as JaggedTensorCpp from ._fvdb_cpp import ProjectedGaussianSplats as ProjectedGaussianSplatsCpp from .grid import Grid from .grid_batch import GridBatch from .jagged_tensor import JaggedTensor from .types import DeviceIdentifier, cast_check, resolve_device +JaggedTensorOrTensorT = TypeVar("JaggedTensorOrTensorT", JaggedTensor, torch.Tensor) + class ProjectedGaussianSplats: """ @@ -256,7 +257,7 @@ class GaussianSplat3d: - Rendering images with arbitrary channels using spherical harmonics for view-dependent color representation (:meth:`render_images`, :meth:`render_images_and_depths`). - Rendering depth maps (:meth:`render_depths`, :meth:`render_images_and_depths`). - - Rendering features at arbitrary sparse pixel locations (:meth:`sparse_render_features`). + - Rendering features at arbitrary sparse pixel locations (:meth:`sparse_render_images`, :meth:`sparse_render_images_and_depths`). - Rendering depths at arbitrary sparse pixel locations (:meth:`sparse_render_depths`). - Computing which gaussians contribute to each pixel in an image plane (:meth:`render_num_contributing_gaussians`, :meth:`render_contributing_gaussian_ids`). @@ -1699,6 +1700,99 @@ def render_depths( backgrounds=backgrounds, ) + def sparse_render_depths( + self, + pixels_to_render: JaggedTensorOrTensorT, + world_to_camera_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + near: float, + far: float, + projection_type=ProjectionType.PERSPECTIVE, + tile_size: int = 16, + min_radius_2d: float = 0.3, + eps_2d: float = 0.3, + antialias: bool = False, + ) -> tuple[JaggedTensorOrTensorT, JaggedTensorOrTensorT]: + """ + Render ``C`` collections of sparse depth values from this :class:`GaussianSplat3d` from ``C`` camera views + at the specified pixel locations. + + Example: + + .. code-block:: python + + # Assume gaussian_splat_3d is an instance of GaussianSplat3d + # pixels_to_render is a tensor of shape [C, P, 2] containing pixel coordinates to render + # Render sparse depth values from C camera views at specified pixel locations + # depth_values is a tensor of shape [C, P, 1] + # alpha_values is a tensor of shape [C, P, 1] + depth_values, alpha_values = gaussian_splat_3d.sparse_render_depths( + pixels_to_render, # tensor of shape [C, P, 2] + world_to_camera_matrices, # tensor of shape [C, 4, 4] + projection_matrices, # tensor of shape [C, 3, 3] + image_width, # width of the images + image_height, # height of the images + near, # near clipping plane + far) # far clipping plane + + true_depths = depth_values / alpha_values # Get true depth values by dividing by alpha + + Args: + pixels_to_render (torch.Tensor | JaggedTensor): A tensor of shape ``(C, P, 2)`` or a JaggedTensor where ``C`` is the number of camera views, + and ``P`` is the number of pixel coordinates to render per camera. Each pixel coordinate is represented as (y, x) (row, col). + world_to_camera_matrices (torch.Tensor): Tensor of shape ``(C, 4, 4)`` representing the + world-to-camera transformation matrices for C cameras. Each matrix transforms points + from world coordinates to camera coordinates. + projection_matrices (torch.Tensor): Tensor of shape ``(C, 3, 3)`` representing the projection matrices for ``C`` cameras. + Each matrix projects points in camera space into homogeneous pixel coordinates. + image_width (int): The width of the images to be rendered. Note these are the same for all images being rendered. + image_height (int): The height of the images to be rendered. Note these are the same for all images being rendered. + near (float): The near clipping plane distance for the projection. + far (float): The far clipping plane distance for the projection. + projection_type (ProjectionType): The type of projection to use. Default is :attr:`fvdb.ProjectionType.PERSPECTIVE`. + tile_size (int): The size of the tiles to use for rendering. Default is 16. You shouldn't set this parameter unless you really know what you are doing. + min_radius_2d (float): The minimum radius (in pixels) below which Gaussians are ignored during rendering. + eps_2d (float): A value used to pad Gaussians when projecting them onto the image plane, to avoid very projected Gaussians which create artifacts and + numerical issues. + antialias (bool): If ``True``, applies opacity correction to the projected Gaussians when using ``eps_2d > 0.0``. + + + Returns: + depth_values (torch.Tensor | JaggedTensor): A tensor of shape ``(C, P, 1)`` or a JaggedTensor where ``C`` is the number of camera views, + and ``P`` is the number of pixel coordinates rendered per camera. Each element represents the depth value at that pixel. + alpha_values (torch.Tensor | JaggedTensor): A tensor of shape ``(C, P, 1)`` or a JaggedTensor where ``C`` is the number of camera views, + and ``P`` is the number of pixel coordinates rendered per camera. Each element represents the alpha value (opacity) at that pixel such that ``0 <= alpha < 1``, + and 0 means the pixel is fully transparent, and 1 means the pixel is fully opaque. + """ + if isinstance(pixels_to_render, torch.Tensor): + pixels_to_render_impl = JaggedTensorCpp(pixels_to_render) + elif isinstance(pixels_to_render, JaggedTensor): + pixels_to_render_impl: JaggedTensorCpp = pixels_to_render._impl + else: + raise TypeError("pixels_to_render must be either a torch.Tensor or a fvdb.JaggedTensor") + + ret_depths, ret_alphas = self._impl.sparse_render_depths( + pixels_to_render=pixels_to_render_impl, + world_to_camera_matrices=world_to_camera_matrices, + projection_matrices=projection_matrices, + image_width=image_width, + image_height=image_height, + near=near, + far=far, + projection_type=self._proj_type_to_cpp(projection_type), + tile_size=tile_size, + min_radius_2d=min_radius_2d, + eps_2d=eps_2d, + antialias=antialias, + ) + + if isinstance(pixels_to_render, torch.Tensor): + return ret_depths.jdata, ret_alphas.jdata + else: + return JaggedTensor(impl=ret_depths), JaggedTensor(impl=ret_alphas) + def render_images( self, world_to_camera_matrices: torch.Tensor, @@ -1783,6 +1877,200 @@ def render_images( backgrounds=backgrounds, ) + def sparse_render_images( + self, + pixels_to_render: JaggedTensorOrTensorT, + world_to_camera_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + near: float, + far: float, + projection_type=ProjectionType.PERSPECTIVE, + sh_degree_to_use: int = -1, + tile_size: int = 16, + min_radius_2d: float = 0.0, + eps_2d: float = 0.3, + antialias: bool = False, + ) -> tuple[JaggedTensorOrTensorT, JaggedTensorOrTensorT]: + """ + Render ``C`` collections of multi-channel features (see :attr:`num_channels`) from this :class:`GaussianSplat3d` from ``C`` camera views + at the specified pixel locations. + + Example: + + .. code-block:: python + + # Assume gaussian_splat_3d is an instance of GaussianSplat3d + # pixels_to_render is a tensor of shape [C, P, 2] containing pixel coordinates to render + # Render sparse images from C camera views at specified pixel locations + # features is a tensor of shape [C, P, D] where D is the number of channels + # alphas is a tensor of shape [C, P, 1] + features, alphas = gaussian_splat_3d.sparse_render_images( + pixels_to_render, # tensor of shape [C, P, 2] + world_to_camera_matrices, # tensor of shape [C, 4, 4] + projection_matrices, # tensor of shape [C, 3, 3] + image_width, # width of the images + image_height, # height of the images + near, # near clipping plane + far) # far clipping plane + + Args: + pixels_to_render (torch.Tensor | JaggedTensor): A tensor of shape ``(C, P, 2)`` or a :class:`~fvdb.JaggedTensor` where ``C`` is the number of camera views, + and ``P`` is the number of pixel coordinates to render per camera. Each pixel coordinate is represented as (y, x) (row, col). + world_to_camera_matrices (torch.Tensor): Tensor of shape ``(C, 4, 4)`` representing the + world-to-camera transformation matrices for C cameras. Each matrix transforms points + from world coordinates to camera coordinates. + projection_matrices (torch.Tensor): Tensor of shape ``(C, 3, 3)`` representing the projection matrices for ``C`` cameras. + Each matrix projects points in camera space into homogeneous pixel coordinates. + image_width (int): The width of the images to be rendered. Note these are the same for all images being rendered. + image_height (int): The height of the images to be rendered. Note these are the same for all images being rendered. + near (float): The near clipping plane distance for the projection. + far (float): The far clipping plane distance for the projection. + projection_type (ProjectionType): The type of projection to use. Default is :attr:`fvdb.ProjectionType.PERSPECTIVE`. + sh_degree_to_use (int): The degree of spherical harmonics to use for rendering. -1 means use all available SH bases. + 0 means use only the first SH base (constant color). Note that you can't use more SH bases than available in the GaussianSplat3d instance. + Default is -1. + tile_size (int): The size of the tiles to use for rendering. Default is 16. You shouldn't set this parameter unless you really know what you are doing. + min_radius_2d (float): The minimum radius (in pixels) below which Gaussians are ignored during rendering. + eps_2d (float): A value used to pad Gaussians when projecting them onto the image plane, to avoid very projected Gaussians which create artifacts and + numerical issues. + antialias (bool): If ``True``, applies opacity correction to the projected Gaussians when using ``eps_2d > 0.0``. + + Returns: + features (torch.Tensor | JaggedTensor): A tensor of shape ``(C, P, D)`` or a + :class:`~fvdb.JaggedTensor` where ``C`` is the number of camera views, + ``P`` is the number of pixel coordinates rendered per camera, and ``D`` is the number of channels. + alpha_images (torch.Tensor | JaggedTensor): A tensor of shape ``(C, P, 1)`` or a :class:`~fvdb.JaggedTensor` + where ``C`` is the number of camera views, and ``P`` is the number of pixel coordinates rendered per camera. + Each element represents the alpha value (opacity) at that pixel such that ``0 <= alpha < 1``, + and 0 means the pixel is fully transparent, and 1 means the pixel is fully opaque. + """ + if isinstance(pixels_to_render, torch.Tensor): + pixels_to_render_impl = JaggedTensorCpp(pixels_to_render) + elif isinstance(pixels_to_render, JaggedTensor): + pixels_to_render_impl: JaggedTensorCpp = pixels_to_render._impl + else: + raise TypeError("pixels_to_render must be either a torch.Tensor or a fvdb.JaggedTensor") + + ret_features, ret_alphas = self._impl.sparse_render_images( + pixels_to_render=pixels_to_render_impl, + world_to_camera_matrices=world_to_camera_matrices, + projection_matrices=projection_matrices, + image_width=image_width, + image_height=image_height, + near=near, + far=far, + projection_type=self._proj_type_to_cpp(projection_type), + sh_degree_to_use=sh_degree_to_use, + tile_size=tile_size, + min_radius_2d=min_radius_2d, + eps_2d=eps_2d, + antialias=antialias, + ) + + if isinstance(pixels_to_render, torch.Tensor): + return ret_features.jdata, ret_alphas.jdata + else: + return JaggedTensor(impl=ret_features), JaggedTensor(impl=ret_alphas) + + def sparse_render_images_and_depths( + self, + pixels_to_render: JaggedTensorOrTensorT, + world_to_camera_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + near: float, + far: float, + projection_type=ProjectionType.PERSPECTIVE, + sh_degree_to_use: int = -1, + tile_size: int = 16, + min_radius_2d: float = 0.0, + eps_2d: float = 0.3, + antialias: bool = False, + ) -> tuple[JaggedTensorOrTensorT, JaggedTensorOrTensorT]: + """ + Render ``C`` collections of sparse multi-channel features (see :attr:`num_channels`) with depth as + the last channel from this :class:`GaussianSplat3d` from ``C`` camera views at the specified pixel locations. + + Example: + .. code-block:: python + + # Assume gaussian_splat_3d is an instance of GaussianSplat3d + # pixels_to_render is a tensor of shape [C, P, 2] containing pixel coordinates to render + # Render sparse images with depth from C camera views at specified pixel locations + # features is a tensor of shape [C, P, D + 1] where D is the number of channels + # alphas is a tensor of shape [C, P, 1] + features, alphas = gaussian_splat_3d.sparse_render_images_and_depths( + pixels_to_render, # tensor of shape [C, P, 2] + world_to_camera_matrices, # tensor of shape [C, 4, 4] + projection_matrices, # tensor of shape [C, 3, 3] + image_width, # width of the images + image_height, # height of the images + near, # near clipping plane + far) # far clipping plane + + Args: + pixels_to_render (torch.Tensor | JaggedTensor): A tensor of shape ``(C, P, 2)`` or a :class:`~fvdb.JaggedTensor` where ``C`` is the number of camera views, + and ``P`` is the number of pixel coordinates to render per camera. Each pixel coordinate is represented as (y, x) (row, col). + world_to_camera_matrices (torch.Tensor): Tensor of shape ``(C, 4, 4)`` representing the + world-to-camera transformation matrices for C cameras. Each matrix transforms points + from world coordinates to camera coordinates. + projection_matrices (torch.Tensor): Tensor of shape ``(C, 3, 3)`` representing the projection matrices for ``C`` cameras. + Each matrix projects points in camera space into homogeneous pixel coordinates. + image_width (int): The width of the images to be rendered. Note these are the same for all images being rendered. + image_height (int): The height of the images to be rendered. Note these are the same for all images being rendered. + near (float): The near clipping plane distance for the projection. + far (float): The far clipping plane distance for the projection. + projection_type (ProjectionType): The type of projection to use. Default is :attr:`fvdb.ProjectionType.PERSPECTIVE`. + sh_degree_to_use (int): The degree of spherical harmonics to use for rendering. -1 means use all available SH bases. + 0 means use only the first SH base (constant color). Note that you can't use more SH bases than available in the GaussianSplat3d instance. + Default is -1. + tile_size (int): The size of the tiles to use for rendering. Default is 16. You shouldn't set this parameter unless you really know what you are doing. + min_radius_2d (float): The minimum radius (in pixels) below which Gaussians are ignored during rendering. + eps_2d (float): A value used to pad Gaussians when projecting them onto the image plane, to avoid very projected Gaussians which create artifacts and + numerical issues. + antialias (bool): If ``True``, applies opacity correction to the projected Gaussians when using ``eps_2d > 0.0``. + + Returns: + features_with_depths (torch.Tensor | JaggedTensor): A tensor of shape ``(C, P, D + 1)`` or a + :class:`~fvdb.JaggedTensor` where ``C`` is the number of camera views, + ``P`` is the number of pixel coordinates rendered per camera, and ``D`` is the number of channels. The last channel + represents the depth value at that pixel. + alpha_images (torch.Tensor | JaggedTensor): A tensor of shape ``(C, P, 1)`` or a :class:`~fvdb.JaggedTensor` + where ``C`` is the number of camera views, and ``P`` is the number of pixel coordinates rendered per camera. + Each element represents the alpha value (opacity) at that pixel such that ``0 <= alpha < 1``, + and 0 means the pixel is fully transparent, and 1 means the pixel is fully opaque. + """ + if isinstance(pixels_to_render, torch.Tensor): + pixels_to_render_impl = JaggedTensorCpp(pixels_to_render) + elif isinstance(pixels_to_render, JaggedTensor): + pixels_to_render_impl: JaggedTensorCpp = pixels_to_render._impl + else: + raise TypeError("pixels_to_render must be either a torch.Tensor or a fvdb.JaggedTensor") + + ret_features, ret_alphas = self._impl.sparse_render_images_and_depths( + pixels_to_render=pixels_to_render_impl, + world_to_camera_matrices=world_to_camera_matrices, + projection_matrices=projection_matrices, + image_width=image_width, + image_height=image_height, + near=near, + far=far, + projection_type=self._proj_type_to_cpp(projection_type), + sh_degree_to_use=sh_degree_to_use, + tile_size=tile_size, + min_radius_2d=min_radius_2d, + eps_2d=eps_2d, + antialias=antialias, + ) + + if isinstance(pixels_to_render, torch.Tensor): + return ret_features.jdata, ret_alphas.jdata + else: + return JaggedTensor(impl=ret_features), JaggedTensor(impl=ret_alphas) + def render_images_and_depths( self, world_to_camera_matrices: torch.Tensor, @@ -1815,7 +2103,7 @@ def render_images_and_depths( # Render images with depth maps from C camera views. # images is a tensor of shape [C, H, W, D + 1] where D is the number of channels # alpha_images is a tensor of shape [C, H, W, 1] - images, alpha_images = gaussian_splat_3d.render_images( + images, alpha_images = gaussian_splat_3d.render_images_and_depths( world_to_camera_matrices, # tensor of shape [C, 4, 4] projection_matrices, # tensor of shape [C, 3, 3] image_width, # width of the images @@ -1901,7 +2189,7 @@ def render_num_contributing_gaussians( # Render images from C camera views. # images is a tensor of shape [C, H, W, D] where D is the number of channels # alpha_images is a tensor of shape [C, H, W, 1] - num_gaussians, alpha_images = gaussian_splat_3d.render_images( + num_gaussians, alpha_images = gaussian_splat_3d.render_num_contributing_gaussians( world_to_camera_matrices, # tensor of shape [C, 4, 4] projection_matrices, # tensor of shape [C, 3, 3] image_width, # width of the images diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 53fc097a..d9a6f0fc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -39,8 +39,10 @@ set(FVDB_CPP_FILES fvdb/Config.cpp fvdb/detail/autograd/AvgPoolGrid.cpp fvdb/detail/autograd/EvaluateSphericalHarmonics.cpp - fvdb/detail/autograd/GaussianRender.cpp fvdb/detail/autograd/Inject.cpp + fvdb/detail/autograd/GaussianProjection.cpp + fvdb/detail/autograd/GaussianRasterize.cpp + fvdb/detail/autograd/GaussianRasterizeSparse.cpp fvdb/detail/autograd/JaggedReduce.cpp fvdb/detail/autograd/MaxPoolGrid.cpp fvdb/detail/autograd/ReadFromDense.cpp diff --git a/src/fvdb/GaussianSplat3d.cpp b/src/fvdb/GaussianSplat3d.cpp index 87b55097..a2ce7825 100644 --- a/src/fvdb/GaussianSplat3d.cpp +++ b/src/fvdb/GaussianSplat3d.cpp @@ -2,12 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 // #include +#include #include #include // Autograd headers #include -#include +#include +#include // Ops headers #include @@ -443,6 +445,47 @@ GaussianSplat3d::renderImpl(const torch::Tensor &worldToCameraMatrices, state, settings.tileSize, settings.imageWidth, settings.imageHeight, 0, 0); } +std::tuple +GaussianSplat3d::sparseRenderImpl(const JaggedTensor &pixelsToRender, + const torch::Tensor &worldToCameraMatrices, + const torch::Tensor &projectionMatrices, + const fvdb::detail::ops::RenderSettings &settings) { + FVDB_FUNC_RANGE(); + + const ProjectedGaussianSplats &state = + projectGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings); + + const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = + fvdb::detail::ops::computeSparseInfo(settings.tileSize, + state.tileOffsets.size(2), + state.tileOffsets.size(1), + pixelsToRender); + + auto rasterizeResult = + detail::autograd::RasterizeGaussiansToPixelsSparse::apply(pixelsToRender, + state.perGaussian2dMean, + state.perGaussianConic, + state.perGaussianRenderQuantity, + state.perGaussianOpacity, + settings.imageWidth, + settings.imageHeight, + 0, + 0, + settings.tileSize, + state.tileOffsets, + state.tileGaussianIds, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + false); + auto renderedPixelsJData = rasterizeResult[0]; + + auto renderedAlphasJData = rasterizeResult[1]; + return {pixelsToRender.jagged_like(renderedPixelsJData), + pixelsToRender.jagged_like(renderedAlphasJData)}; +} + std::tuple GaussianSplat3d::renderNumContributingGaussiansImpl( const torch::Tensor &worldToCameraMatrices, @@ -778,6 +821,92 @@ GaussianSplat3d::renderNumContributingGaussians(const torch::Tensor &worldToCame return renderNumContributingGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings); } +std::tuple +GaussianSplat3d::sparseRenderDepths(const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &worldToCameraMatrices, + const torch::Tensor &projectionMatrices, + const size_t imageWidth, + const size_t imageHeight, + const float near, + const float far, + const ProjectionType projectionType, + const size_t tileSize, + const float minRadius2d, + const float eps2d, + const bool antialias) { + RenderSettings settings; + settings.imageWidth = imageWidth; + settings.imageHeight = imageHeight; + settings.nearPlane = near; + settings.farPlane = far; + settings.projectionType = projectionType; + settings.shDegreeToUse = 0; + settings.tileSize = tileSize; + settings.radiusClip = minRadius2d; + settings.eps2d = eps2d; + settings.renderMode = RenderSettings::RenderMode::DEPTH; + + return sparseRenderImpl(pixelsToRender, worldToCameraMatrices, projectionMatrices, settings); +} + +std::tuple +GaussianSplat3d::sparseRenderImages(const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &worldToCameraMatrices, + const torch::Tensor &projectionMatrices, + const size_t imageWidth, + const size_t imageHeight, + const float near, + const float far, + const ProjectionType projectionType, + const int64_t shDegreeToUse, + const size_t tileSize, + const float minRadius2d, + const float eps2d, + const bool antialias) { + RenderSettings settings; + settings.imageWidth = imageWidth; + settings.imageHeight = imageHeight; + settings.nearPlane = near; + settings.farPlane = far; + settings.projectionType = projectionType; + settings.shDegreeToUse = shDegreeToUse; + settings.tileSize = tileSize; + settings.radiusClip = minRadius2d; + settings.eps2d = eps2d; + settings.renderMode = RenderSettings::RenderMode::RGB; + + return sparseRenderImpl(pixelsToRender, worldToCameraMatrices, projectionMatrices, settings); +} + +std::tuple +GaussianSplat3d::sparseRenderImagesAndDepths(const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &worldToCameraMatrices, + const torch::Tensor &projectionMatrices, + const size_t imageWidth, + const size_t imageHeight, + const float near, + const float far, + const ProjectionType projectionType, + const int64_t shDegreeToUse, + const size_t tileSize, + const float minRadius2d, + const float eps2d, + const bool antialias) { + RenderSettings settings; + settings.imageWidth = imageWidth; + settings.imageHeight = imageHeight; + settings.nearPlane = near; + settings.farPlane = far; + settings.projectionType = projectionType; + settings.shDegreeToUse = shDegreeToUse; + settings.tileSize = tileSize; + settings.radiusClip = minRadius2d; + settings.eps2d = eps2d; + settings.renderMode = RenderSettings::RenderMode::RGBD; + + return sparseRenderImpl(pixelsToRender, worldToCameraMatrices, projectionMatrices, settings); +} + std::tuple GaussianSplat3d::sparseRenderNumContributingGaussians(const fvdb::JaggedTensor &pixelsToRender, const torch::Tensor &worldToCameraMatrices, diff --git a/src/fvdb/GaussianSplat3d.h b/src/fvdb/GaussianSplat3d.h index ec4780bb..8e8935ce 100644 --- a/src/fvdb/GaussianSplat3d.h +++ b/src/fvdb/GaussianSplat3d.h @@ -934,6 +934,50 @@ class GaussianSplat3d { const bool antialias = false, const std::optional &backgrounds = std::nullopt); + std::tuple + sparseRenderImages(const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &worldToCameraMatrices, + const torch::Tensor &projectionMatrices, + const size_t imageWidth, + const size_t imageHeight, + const float near, + const float far, + const ProjectionType projectionType = ProjectionType::PERSPECTIVE, + const int64_t shDegreeToUse = -1, + const size_t tileSize = 16, + const float minRadius2d = 0.0, + const float eps2d = 0.3, + const bool antialias = false); + + std::tuple + sparseRenderDepths(const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &worldToCameraMatrices, + const torch::Tensor &projectionMatrices, + const size_t imageWidth, + const size_t imageHeight, + const float near, + const float far, + const ProjectionType projectionType = ProjectionType::PERSPECTIVE, + const size_t tileSize = 16, + const float minRadius2d = 0.0, + const float eps2d = 0.3, + const bool antialias = false); + + std::tuple + sparseRenderImagesAndDepths(const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &worldToCameraMatrices, + const torch::Tensor &projectionMatrices, + const size_t imageWidth, + const size_t imageHeight, + const float near, + const float far, + const ProjectionType projectionType = ProjectionType::PERSPECTIVE, + const int64_t shDegreeToUse = -1, + const size_t tileSize = 16, + const float minRadius2d = 0.0, + const float eps2d = 0.3, + const bool antialias = false); + /// @brief Render the number of contributing Gaussians for each pixel in the image. /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices /// @param projectionMatrices [C, 4, 4] Projection matrices @@ -1241,9 +1285,9 @@ class GaussianSplat3d { /// The mask must have the same length as the number of Gaussians in this scene. GaussianSplat3d tensorIndexGetImpl(const torch::Tensor &indexOrMask) const; - /// @brief Render the gaussian splatting scene - /// This function returns a single render quantity (RGB, depth, RGB+D) and - /// single alpha value per pixel. + /// @brief Render the scene described by the Gaussian splats from the specified views. + /// This function returns a single render quantity (RGB, depth, RGB+D) and + /// single alpha value per pixel. /// @param worldToCameraMatrices [C, 4, 4] /// @param projectionMatrices [C, 3, 3] /// @param settings @@ -1253,6 +1297,20 @@ class GaussianSplat3d { const torch::Tensor &projectionMatrices, const fvdb::detail::ops::RenderSettings &settings); + /// @brief Render the scene described by the Gaussian splats at the specified pixels in the + /// specified views. This function returns a single render quantity (RGB, depth, RGB+D) + /// and single alpha value per pixel. + /// @param pixelsToRender [P1 + P2 + ..., 2] JaggedTensor of pixels per camera to render. + /// @param worldToCameraMatrices [C, 4, 4] + /// @param projectionMatrices [C, 3, 3] + /// @param settings + /// @return Tuple of (render quantity, alpha value) + std::tuple + sparseRenderImpl(const JaggedTensor &pixelsToRender, + const torch::Tensor &worldToCameraMatrices, + const torch::Tensor &projectionMatrices, + const fvdb::detail::ops::RenderSettings &settings); + /// @brief Render the number of contributing Gaussians for each pixel in the image. /// @param worldToCameraMatrices [C, 4, 4] /// @param projectionMatrices [C, 3, 3] diff --git a/src/fvdb/detail/autograd/GaussianRender.cpp b/src/fvdb/detail/autograd/GaussianProjection.cpp similarity index 68% rename from src/fvdb/detail/autograd/GaussianRender.cpp rename to src/fvdb/detail/autograd/GaussianProjection.cpp index 1aff1fb5..d2e84551 100644 --- a/src/fvdb/detail/autograd/GaussianRender.cpp +++ b/src/fvdb/detail/autograd/GaussianProjection.cpp @@ -1,7 +1,7 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include +#include #include #include #include @@ -11,9 +11,7 @@ #include #include -namespace fvdb { -namespace detail { -namespace autograd { +namespace fvdb::detail::autograd { ProjectGaussians::VariableList ProjectGaussians::forward(ProjectGaussians::AutogradContext *ctx, @@ -205,163 +203,6 @@ ProjectGaussians::backward(ProjectGaussians::AutogradContext *ctx, Variable()}; } -RasterizeGaussiansToPixels::VariableList -RasterizeGaussiansToPixels::forward( - RasterizeGaussiansToPixels::AutogradContext *ctx, - const RasterizeGaussiansToPixels::Variable &means2d, // [C, N, 2] - const RasterizeGaussiansToPixels::Variable &conics, // [C, N, 3] - const RasterizeGaussiansToPixels::Variable &colors, // [C, N, 3] - const RasterizeGaussiansToPixels::Variable &opacities, // [N] - const uint32_t imageWidth, - const uint32_t imageHeight, - const uint32_t imageOriginW, - const uint32_t imageOriginH, - const uint32_t tileSize, - const RasterizeGaussiansToPixels::Variable &tileOffsets, // [C, tile_height, tile_width] - const RasterizeGaussiansToPixels::Variable &tileGaussianIds, // [n_isects] - const bool absgrad, - std::optional backgrounds) { - FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixels::forward"); - // const int C = means2d.size(0); - // const int N = means2d.size(1); - - auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { - return ops::dispatchGaussianRasterizeForward(means2d, - conics, - colors, - opacities, - imageWidth, - imageHeight, - imageOriginW, - imageOriginH, - tileSize, - tileOffsets, - tileGaussianIds, - backgrounds); - }); - Variable renderedColors = std::get<0>(variables); - Variable renderedAlphas = std::get<1>(variables); - Variable lastIds = std::get<2>(variables); - - if (backgrounds.has_value()) { - ctx->save_for_backward({means2d, - conics, - colors, - opacities, - tileOffsets, - tileGaussianIds, - renderedAlphas, - lastIds, - backgrounds.value()}); - ctx->saved_data["has_backgrounds"] = true; - } else { - ctx->save_for_backward({means2d, - conics, - colors, - opacities, - tileOffsets, - tileGaussianIds, - renderedAlphas, - lastIds}); - ctx->saved_data["has_backgrounds"] = false; - } - ctx->saved_data["imageWidth"] = (int64_t)imageWidth; - ctx->saved_data["imageHeight"] = (int64_t)imageHeight; - ctx->saved_data["tileSize"] = (int64_t)tileSize; - ctx->saved_data["imageOriginW"] = (int64_t)imageOriginW; - ctx->saved_data["imageOriginH"] = (int64_t)imageOriginH; - ctx->saved_data["absgrad"] = absgrad; - - return {renderedColors, renderedAlphas}; -} - -RasterizeGaussiansToPixels::VariableList -RasterizeGaussiansToPixels::backward(RasterizeGaussiansToPixels::AutogradContext *ctx, - RasterizeGaussiansToPixels::VariableList gradOutput) { - FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixels::backward"); - Variable dLossDRenderedColors = gradOutput.at(0); - Variable dLossDRenderedAlphas = gradOutput.at(1); - - // ensure the gradients are contiguous if they are not None - if (dLossDRenderedColors.defined()) { - dLossDRenderedColors = dLossDRenderedColors.contiguous(); - } - if (dLossDRenderedAlphas.defined()) { - dLossDRenderedAlphas = dLossDRenderedAlphas.contiguous(); - } - - VariableList saved = ctx->get_saved_variables(); - Variable means2d = saved.at(0); - Variable conics = saved.at(1); - Variable colors = saved.at(2); - Variable opacities = saved.at(3); - Variable tileOffsets = saved.at(4); - Variable tileGaussianIds = saved.at(5); - Variable renderedAlphas = saved.at(6); - Variable lastIds = saved.at(7); - - const bool hasBackgrounds = ctx->saved_data["has_backgrounds"].toBool(); - std::optional backgrounds = std::nullopt; - if (hasBackgrounds) { - backgrounds = saved.at(8); - } - - const int imageWidth = (int)ctx->saved_data["imageWidth"].toInt(); - const int imageHeight = (int)ctx->saved_data["imageHeight"].toInt(); - const int tileSize = (int)ctx->saved_data["tileSize"].toInt(); - const int imageOriginW = (int)ctx->saved_data["imageOriginW"].toInt(); - const int imageOriginH = (int)ctx->saved_data["imageOriginH"].toInt(); - const bool absgrad = ctx->saved_data["absgrad"].toBool(); - - auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { - return ops::dispatchGaussianRasterizeBackward(means2d, - conics, - colors, - opacities, - imageWidth, - imageHeight, - imageOriginW, - imageOriginH, - tileSize, - tileOffsets, - tileGaussianIds, - renderedAlphas, - lastIds, - dLossDRenderedColors, - dLossDRenderedAlphas, - absgrad, - -1, - backgrounds); - }); - Variable dLossDMean2dAbs; - if (absgrad) { - dLossDMean2dAbs = std::get<0>(variables); - // means2d.absgrad = dLossDMean2dAbs; - } else { - dLossDMean2dAbs = Variable(); - } - Variable dLossDMeans2d = std::get<1>(variables); - Variable dLossDConics = std::get<2>(variables); - Variable dLossDColors = std::get<3>(variables); - Variable dLossDOpacities = std::get<4>(variables); - - return { - dLossDMeans2d, - dLossDConics, - dLossDColors, - dLossDOpacities, - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), // backgrounds gradient (not needed, so return empty) - }; -} - ProjectGaussiansJagged::VariableList ProjectGaussiansJagged::forward( ProjectGaussiansJagged::AutogradContext *ctx, @@ -494,6 +335,4 @@ ProjectGaussiansJagged::backward(ProjectGaussiansJagged::AutogradContext *ctx, Variable()}; } -} // namespace autograd -} // namespace detail -} // namespace fvdb +} // namespace fvdb::detail::autograd diff --git a/src/fvdb/detail/autograd/GaussianRender.h b/src/fvdb/detail/autograd/GaussianProjection.h similarity index 63% rename from src/fvdb/detail/autograd/GaussianRender.h rename to src/fvdb/detail/autograd/GaussianProjection.h index 2bced159..d0ed5c33 100644 --- a/src/fvdb/detail/autograd/GaussianRender.h +++ b/src/fvdb/detail/autograd/GaussianProjection.h @@ -1,14 +1,12 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_AUTOGRAD_GAUSSIANRENDER_H -#define FVDB_DETAIL_AUTOGRAD_GAUSSIANRENDER_H +#ifndef FVDB_DETAIL_AUTOGRAD_GAUSSIANPROJECTION_H +#define FVDB_DETAIL_AUTOGRAD_GAUSSIANPROJECTION_H #include -namespace fvdb { -namespace detail { -namespace autograd { +namespace fvdb::detail::autograd { struct ProjectGaussians : public torch::autograd::Function { using VariableList = torch::autograd::variable_list; @@ -37,29 +35,6 @@ struct ProjectGaussians : public torch::autograd::Function { static VariableList backward(AutogradContext *ctx, VariableList gradOutput); }; -struct RasterizeGaussiansToPixels : public torch::autograd::Function { - using VariableList = torch::autograd::variable_list; - using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - - static VariableList forward(AutogradContext *ctx, - const Variable &means2d, // [C, N, 2] - const Variable &conics, // [C, N, 3] - const Variable &colors, // [C, N, 3] - const Variable &opacities, // [N] - const uint32_t imageWidth, - const uint32_t imageHeight, - const uint32_t imageOriginW, - const uint32_t imageOriginH, - const uint32_t tileSize, - const Variable &tileOffsets, // [C, tile_height, tile_width] - const Variable &tileGaussianIds, // [n_isects] - const bool absgrad, - std::optional backgrounds = std::nullopt); // [C, D] - - static VariableList backward(AutogradContext *ctx, VariableList gradOutput); -}; - struct ProjectGaussiansJagged : public torch::autograd::Function { using VariableList = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; @@ -84,8 +59,6 @@ struct ProjectGaussiansJagged : public torch::autograd::Function +#include +#include +#include +#include + +namespace fvdb::detail::autograd { + +RasterizeGaussiansToPixels::VariableList +RasterizeGaussiansToPixels::forward( + RasterizeGaussiansToPixels::AutogradContext *ctx, + const RasterizeGaussiansToPixels::Variable &means2d, // [C, N, 2] + const RasterizeGaussiansToPixels::Variable &conics, // [C, N, 3] + const RasterizeGaussiansToPixels::Variable &colors, // [C, N, 3] + const RasterizeGaussiansToPixels::Variable &opacities, // [N] + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const RasterizeGaussiansToPixels::Variable &tileOffsets, // [C, tile_height, tile_width] + const RasterizeGaussiansToPixels::Variable &tileGaussianIds, // [n_isects] + const bool absgrad, + std::optional backgrounds) { + FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixels::forward"); + // const int C = means2d.size(0); + // const int N = means2d.size(1); + + auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { + return ops::dispatchGaussianRasterizeForward(means2d, + conics, + colors, + opacities, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + backgrounds); + }); + Variable renderedColors = std::get<0>(variables); + Variable renderedAlphas = std::get<1>(variables); + Variable lastIds = std::get<2>(variables); + + if (backgrounds.has_value()) { + ctx->save_for_backward({means2d, + conics, + colors, + opacities, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastIds, + backgrounds.value()}); + ctx->saved_data["has_backgrounds"] = true; + } else { + ctx->save_for_backward({means2d, + conics, + colors, + opacities, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastIds}); + ctx->saved_data["has_backgrounds"] = false; + } + ctx->saved_data["imageWidth"] = (int64_t)imageWidth; + ctx->saved_data["imageHeight"] = (int64_t)imageHeight; + ctx->saved_data["tileSize"] = (int64_t)tileSize; + ctx->saved_data["imageOriginW"] = (int64_t)imageOriginW; + ctx->saved_data["imageOriginH"] = (int64_t)imageOriginH; + ctx->saved_data["absgrad"] = absgrad; + + return {renderedColors, renderedAlphas}; +} + +RasterizeGaussiansToPixels::VariableList +RasterizeGaussiansToPixels::backward(RasterizeGaussiansToPixels::AutogradContext *ctx, + RasterizeGaussiansToPixels::VariableList gradOutput) { + FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixels::backward"); + Variable dLossDRenderedColors = gradOutput.at(0); + Variable dLossDRenderedAlphas = gradOutput.at(1); + + // ensure the gradients are contiguous if they are not None + if (dLossDRenderedColors.defined()) { + dLossDRenderedColors = dLossDRenderedColors.contiguous(); + } + if (dLossDRenderedAlphas.defined()) { + dLossDRenderedAlphas = dLossDRenderedAlphas.contiguous(); + } + + VariableList saved = ctx->get_saved_variables(); + Variable means2d = saved.at(0); + Variable conics = saved.at(1); + Variable colors = saved.at(2); + Variable opacities = saved.at(3); + Variable tileOffsets = saved.at(4); + Variable tileGaussianIds = saved.at(5); + Variable renderedAlphas = saved.at(6); + Variable lastIds = saved.at(7); + + const bool hasBackgrounds = ctx->saved_data["has_backgrounds"].toBool(); + std::optional backgrounds = std::nullopt; + if (hasBackgrounds) { + backgrounds = saved.at(8); + } + + const int imageWidth = (int)ctx->saved_data["imageWidth"].toInt(); + const int imageHeight = (int)ctx->saved_data["imageHeight"].toInt(); + const int tileSize = (int)ctx->saved_data["tileSize"].toInt(); + const int imageOriginW = (int)ctx->saved_data["imageOriginW"].toInt(); + const int imageOriginH = (int)ctx->saved_data["imageOriginH"].toInt(); + const bool absgrad = ctx->saved_data["absgrad"].toBool(); + + auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { + return ops::dispatchGaussianRasterizeBackward(means2d, + conics, + colors, + opacities, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastIds, + dLossDRenderedColors, + dLossDRenderedAlphas, + absgrad, + -1, + backgrounds); + }); + Variable dLossDMean2dAbs; + if (absgrad) { + dLossDMean2dAbs = std::get<0>(variables); + // means2d.absgrad = dLossDMean2dAbs; + } else { + dLossDMean2dAbs = Variable(); + } + Variable dLossDMeans2d = std::get<1>(variables); + Variable dLossDConics = std::get<2>(variables); + Variable dLossDColors = std::get<3>(variables); + Variable dLossDOpacities = std::get<4>(variables); + + return { + dLossDMeans2d, + dLossDConics, + dLossDColors, + dLossDOpacities, + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), // backgrounds gradient (not needed, so return empty) + }; +} + +} // namespace fvdb::detail::autograd diff --git a/src/fvdb/detail/autograd/GaussianRasterize.h b/src/fvdb/detail/autograd/GaussianRasterize.h new file mode 100644 index 00000000..f60fffca --- /dev/null +++ b/src/fvdb/detail/autograd/GaussianRasterize.h @@ -0,0 +1,36 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZE_H +#define FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZE_H + +#include + +namespace fvdb::detail::autograd { + +struct RasterizeGaussiansToPixels : public torch::autograd::Function { + using VariableList = torch::autograd::variable_list; + using AutogradContext = torch::autograd::AutogradContext; + using Variable = torch::autograd::Variable; + + static VariableList forward(AutogradContext *ctx, + const Variable &means2d, // [C, N, 2] + const Variable &conics, // [C, N, 3] + const Variable &colors, // [C, N, 3] + const Variable &opacities, // [N] + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const Variable &tileOffsets, // [C, tile_height, tile_width] + const Variable &tileGaussianIds, // [n_isects] + const bool absgrad, + std::optional backgrounds = std::nullopt); // [C, D] + + static VariableList backward(AutogradContext *ctx, VariableList gradOutput); +}; + +} // namespace fvdb::detail::autograd + +#endif // FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZE_H diff --git a/src/fvdb/detail/autograd/GaussianRasterizeSparse.cpp b/src/fvdb/detail/autograd/GaussianRasterizeSparse.cpp new file mode 100644 index 00000000..44865376 --- /dev/null +++ b/src/fvdb/detail/autograd/GaussianRasterizeSparse.cpp @@ -0,0 +1,198 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include + +namespace fvdb::detail::autograd { + +RasterizeGaussiansToPixelsSparse::VariableList +RasterizeGaussiansToPixelsSparse::forward( + RasterizeGaussiansToPixelsSparse::AutogradContext *ctx, + const JaggedTensor &pixelsToRender, // [C, num_pixels, 2] + const RasterizeGaussiansToPixelsSparse::Variable &means2d, // [C, N, 2] + const RasterizeGaussiansToPixelsSparse::Variable &conics, // [C, N, 3] + const RasterizeGaussiansToPixelsSparse::Variable &colors, // [C, N, 3] + const RasterizeGaussiansToPixelsSparse::Variable &opacities, // [N] + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const RasterizeGaussiansToPixelsSparse::Variable &tileOffsets, // [C, tile_height, tile_width] + const RasterizeGaussiansToPixelsSparse::Variable &tileGaussianIds, // [n_isects] + const RasterizeGaussiansToPixelsSparse::Variable &activeTiles, // [num_active_tiles] + const RasterizeGaussiansToPixelsSparse::Variable + &tilePixelMask, // [num_active_tiles, tileSize, tileSize] + const RasterizeGaussiansToPixelsSparse::Variable &tilePixelCumsum, // [num_active_tiles + 1] + const RasterizeGaussiansToPixelsSparse::Variable &pixelMap, // [num_pixels] + const bool absgrad) { + FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixelsSparse::forward"); + // const int C = means2d.size(0); + // const int N = means2d.size(1); + + auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { + return ops::dispatchGaussianSparseRasterizeForward(pixelsToRender, + means2d, + conics, + colors, + opacities, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap); + }); + JaggedTensor renderedColors = std::get<0>(variables); + JaggedTensor renderedAlphas = std::get<1>(variables); + JaggedTensor lastIds = std::get<2>(variables); + + const auto joffsets = pixelsToRender.joffsets(); + const auto jidx = pixelsToRender.jidx(); + const auto jlidx = pixelsToRender.jlidx(); + const auto numOuterLists = pixelsToRender.num_outer_lists(); + + ctx->save_for_backward({means2d, + conics, + colors, + opacities, + tileOffsets, + tileGaussianIds, + pixelsToRender.jdata(), + renderedColors.jdata(), + renderedAlphas.jdata(), + lastIds.jdata(), + joffsets, + jidx, + jlidx, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap}); + ctx->saved_data["imageWidth"] = (int64_t)imageWidth; + ctx->saved_data["imageHeight"] = (int64_t)imageHeight; + ctx->saved_data["tileSize"] = (int64_t)tileSize; + ctx->saved_data["imageOriginW"] = (int64_t)imageOriginW; + ctx->saved_data["imageOriginH"] = (int64_t)imageOriginH; + ctx->saved_data["numOuterLists"] = (int64_t)numOuterLists; + ctx->saved_data["absgrad"] = absgrad; + + return {renderedColors.jdata(), renderedAlphas.jdata()}; +} + +RasterizeGaussiansToPixelsSparse::VariableList +RasterizeGaussiansToPixelsSparse::backward( + RasterizeGaussiansToPixelsSparse::AutogradContext *ctx, + RasterizeGaussiansToPixelsSparse::VariableList gradOutput) { + FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixelsSparse::backward"); + Variable dLossDRenderedFeaturesJData = gradOutput.at(0); + Variable dLossDRenderedAlphasJData = gradOutput.at(1); + + // ensure the gradients are contiguous if they are not None + if (dLossDRenderedFeaturesJData.defined()) { + dLossDRenderedFeaturesJData = dLossDRenderedFeaturesJData.contiguous(); + } + if (dLossDRenderedAlphasJData.defined()) { + dLossDRenderedAlphasJData = dLossDRenderedAlphasJData.contiguous(); + } + + VariableList saved = ctx->get_saved_variables(); + Variable means2d = saved.at(0); + Variable conics = saved.at(1); + Variable features = saved.at(2); + Variable opacities = saved.at(3); + Variable tileOffsets = saved.at(4); + Variable tileGaussianIds = saved.at(5); + Variable pixelsToRenderJData = saved.at(6); + Variable renderedColorsJData = saved.at(7); + Variable renderedAlphasJData = saved.at(8); + Variable lastIdsJData = saved.at(9); + Variable joffsets = saved.at(10); + Variable jidx = saved.at(11); + Variable jlidx = saved.at(12); + Variable activeTiles = saved.at(13); + Variable tilePixelMask = saved.at(14); + Variable tilePixelCumsum = saved.at(15); + Variable pixelMap = saved.at(16); + + const int imageWidth = (int)ctx->saved_data["imageWidth"].toInt(); + const int imageHeight = (int)ctx->saved_data["imageHeight"].toInt(); + const int tileSize = (int)ctx->saved_data["tileSize"].toInt(); + const int imageOriginW = (int)ctx->saved_data["imageOriginW"].toInt(); + const int imageOriginH = (int)ctx->saved_data["imageOriginH"].toInt(); + const int64_t numOuterLists = ctx->saved_data["numOuterLists"].toInt(); + const bool absgrad = ctx->saved_data["absgrad"].toBool(); + + auto pixelsToRender = JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( + pixelsToRenderJData, joffsets, jidx, jlidx, numOuterLists); + auto renderedAlphas = pixelsToRender.jagged_like(renderedAlphasJData); + auto lastIds = pixelsToRender.jagged_like(lastIdsJData); + + auto dLossDRenderedFeatures = pixelsToRender.jagged_like(dLossDRenderedFeaturesJData); + auto dLossDRenderedAlphas = pixelsToRender.jagged_like(dLossDRenderedAlphasJData); + + auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { + return ops::dispatchGaussianSparseRasterizeBackward(pixelsToRender, + means2d, + conics, + features, + opacities, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastIds, + dLossDRenderedFeatures, + dLossDRenderedAlphas, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + absgrad); + }); + Variable dLossDMean2dAbs; + if (absgrad) { + dLossDMean2dAbs = std::get<0>(variables); + } else { + dLossDMean2dAbs = Variable(); + } + Variable dLossDMeans2d = std::get<1>(variables); + Variable dLossDConics = std::get<2>(variables); + Variable dLossDColors = std::get<3>(variables); + Variable dLossDOpacities = std::get<4>(variables); + + return { + Variable(), + dLossDMeans2d, + dLossDConics, + dLossDColors, + dLossDOpacities, + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + Variable(), + }; +} + +} // namespace fvdb::detail::autograd diff --git a/src/fvdb/detail/autograd/GaussianRasterizeSparse.h b/src/fvdb/detail/autograd/GaussianRasterizeSparse.h new file mode 100644 index 00000000..f8829b41 --- /dev/null +++ b/src/fvdb/detail/autograd/GaussianRasterizeSparse.h @@ -0,0 +1,44 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZESPARSE_H +#define FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZESPARSE_H + +#include + +#include + +namespace fvdb::detail::autograd { + +struct RasterizeGaussiansToPixelsSparse + : public torch::autograd::Function { + using VariableList = torch::autograd::variable_list; + using AutogradContext = torch::autograd::AutogradContext; + using Variable = torch::autograd::Variable; + + static VariableList + forward(AutogradContext *ctx, + const JaggedTensor &pixelsToRender, // [C, num_pixels, 2] + const Variable &means2d, // [C, N, 2] + const Variable &conics, // [C, N, 3] + const Variable &features, // [C, N, D] + const Variable &opacities, // [N] + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const Variable &tileOffsets, // [C, tile_height, tile_width] + const Variable &tileGaussianIds, // [n_isects] + const Variable &activeTiles, // [num_active_tiles] + const Variable &tilePixelMask, // [num_active_tiles, tileSize, tileSize] + const Variable &tilePixelCumsum, // [num_active_tiles + 1] + const Variable &pixelMap, // [num_pixels] + const bool absgrad); + + static VariableList backward(AutogradContext *ctx, VariableList gradOutput); +}; + +} // namespace fvdb::detail::autograd + +#endif // FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZESPARSE_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu b/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu index cc1103c4..3d247e91 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu +++ b/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -1711,6 +1712,35 @@ dispatchGaussianSparseRasterizeBackward( } } +template <> +std::tuple +dispatchGaussianSparseRasterizeBackward( + const fvdb::JaggedTensor &pixelsToRender, // [C, NumPixels, 2] + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &conics, // [C, N, 3] + const torch::Tensor &features, // [C, N, D] + const torch::Tensor &opacities, // [N] + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] + const torch::Tensor &tileGaussianIds, // [n_isects] + const fvdb::JaggedTensor &renderedAlphas, // [C lists: varying sizes, each element [1]] + const fvdb::JaggedTensor &lastIds, // [C lists: varying sizes] + const fvdb::JaggedTensor &dLossDRenderedFeatures, // [C lists: varying sizes, each element [D]] + const fvdb::JaggedTensor &dLossDRenderedAlphas, // [C lists: varying sizes, each element [1]] + const torch::Tensor &activeTiles, // [AT] + const torch::Tensor &tilePixelMask, // [AT, wordsPerTile] + const torch::Tensor &tilePixelCumsum, // [AT] + const torch::Tensor &pixelMap, // [AP] + const bool absGrad, + const int64_t numSharedChannelsOverride, + const at::optional &backgrounds) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "PrivateUse1 implementation not available"); +} + template <> std::tuple dispatchGaussianSparseRasterizeBackward( diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu b/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu index c3b8e296..ad30caf5 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu +++ b/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu @@ -13,6 +13,7 @@ #include +#include #include #include @@ -877,6 +878,31 @@ dispatchGaussianSparseRasterizeForward( } } +template <> +std::tuple +dispatchGaussianSparseRasterizeForward( + // sparse pixel coordinates + const fvdb::JaggedTensor &pixelsToRender, // [C, maxPixelsPerCamera, 2] + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &conics, // [C, N, 3] + const torch::Tensor &features, // [C, N, D] + const torch::Tensor &opacities, // [N] + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] + const torch::Tensor &tileGaussianIds, // [n_isects] + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + const at::optional &backgrounds) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "PrivateUse1 implementation not available"); +} + template <> std::tuple dispatchGaussianSparseRasterizeForward( diff --git a/src/python/GaussianSplatBinding.cpp b/src/python/GaussianSplatBinding.cpp index ddfb2005..94e38d72 100644 --- a/src/python/GaussianSplatBinding.cpp +++ b/src/python/GaussianSplatBinding.cpp @@ -1,6 +1,8 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // +#include + #include "TypeCasters.h" #include @@ -229,6 +231,53 @@ bind_gaussian_splat3d(py::module &m) { py::arg("antialias") = false, py::arg("backgrounds") = std::nullopt) + .def("sparse_render_images", + &fvdb::GaussianSplat3d::sparseRenderImages, + py::arg("pixels_to_render"), + py::arg("world_to_camera_matrices"), + py::arg("projection_matrices"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("near"), + py::arg("far"), + py::arg("projection_type") = fvdb::GaussianSplat3d::ProjectionType::PERSPECTIVE, + py::arg("sh_degree_to_use") = -1, + py::arg("tile_size") = 16, + py::arg("min_radius_2d") = 0.0, + py::arg("eps_2d") = 0.3, + py::arg("antialias") = false) + + .def("sparse_render_depths", + &fvdb::GaussianSplat3d::sparseRenderDepths, + py::arg("pixels_to_render"), + py::arg("world_to_camera_matrices"), + py::arg("projection_matrices"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("near"), + py::arg("far"), + py::arg("projection_type") = fvdb::GaussianSplat3d::ProjectionType::PERSPECTIVE, + py::arg("tile_size") = 16, + py::arg("min_radius_2d") = 0.0, + py::arg("eps_2d") = 0.3, + py::arg("antialias") = false) + + .def("sparse_render_images_and_depths", + &fvdb::GaussianSplat3d::sparseRenderImagesAndDepths, + py::arg("pixels_to_render"), + py::arg("world_to_camera_matrices"), + py::arg("projection_matrices"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("near"), + py::arg("far"), + py::arg("projection_type") = fvdb::GaussianSplat3d::ProjectionType::PERSPECTIVE, + py::arg("sh_degree_to_use") = -1, + py::arg("tile_size") = 16, + py::arg("min_radius_2d") = 0.0, + py::arg("eps_2d") = 0.3, + py::arg("antialias") = false) + .def("render_num_contributing_gaussians", &fvdb::GaussianSplat3d::renderNumContributingGaussians, py::arg("world_to_camera_matrices"), diff --git a/tests/unit/test_gsplat.py b/tests/unit/test_gaussian_splat_3d.py similarity index 86% rename from tests/unit/test_gsplat.py rename to tests/unit/test_gaussian_splat_3d.py index a1d5ec60..3af06fc0 100644 --- a/tests/unit/test_gsplat.py +++ b/tests/unit/test_gaussian_splat_3d.py @@ -1493,7 +1493,6 @@ def test_gaussian_render_jagged(self): False, # return debug info False, # ortho ) - torch.cuda.synchronize() pixels = self._tensors_to_pixel(render_colors, render_alphas) differ, cmp = compare_images(pixels, str(self.data_path / "regression_gaussian_render_jagged_result.png")) @@ -2122,7 +2121,7 @@ def test_gaussian_contributors_scene_sparse_render(self): # Select tensors from reference_ids at the specified pixel positions reference_ids_list = image_reference_ids.unbind() - selected_tensors = [reference_ids_list[idx.item()] for idx in pixel_indices] + selected_tensors = [reference_ids_list[idx.item()] for idx in pixel_indices] # type: ignore selected_reference_ids = JaggedTensor(selected_tensors) self.assertTrue(image_sparse_ids == selected_reference_ids) @@ -2140,7 +2139,7 @@ def test_gaussian_contributors_scene_sparse_render(self): # Select tensors from reference_weights at the specified pixel positions reference_weights_list = image_reference_weights.unbind() - selected_tensors = [reference_weights_list[idx.item()] for idx in pixel_indices] + selected_tensors = [reference_weights_list[idx.item()] for idx in pixel_indices] # type: ignore selected_reference_weights = JaggedTensor(selected_tensors) self.assertTrue(image_sparse_weights == selected_reference_weights) @@ -2229,7 +2228,7 @@ def test_gaussian_contributors_scene_dense_pixels_sparse_render(self): # Select tensors from reference_ids at the specified pixel positions reference_ids_list = image_reference_ids.unbind() - selected_tensors = [reference_ids_list[idx.item()] for idx in pixel_indices] + selected_tensors = [reference_ids_list[idx.item()] for idx in pixel_indices] # type: ignore selected_reference_ids = JaggedTensor(selected_tensors) self.assertTrue(image_sparse_ids == selected_reference_ids) @@ -2247,12 +2246,385 @@ def test_gaussian_contributors_scene_dense_pixels_sparse_render(self): # Select tensors from reference_weights at the specified pixel positions reference_weights_list = image_reference_weights.unbind() - selected_tensors = [reference_weights_list[idx.item()] for idx in pixel_indices] + selected_tensors = [reference_weights_list[idx.item()] for idx in pixel_indices] # type: ignore selected_reference_weights = JaggedTensor(selected_tensors) self.assertTrue(image_sparse_weights == selected_reference_weights) +class TestGaussianRenderSparse(BaseGaussianTestCase): + def setUp(self): + super().setUp() + + def test_gaussian_render_sparse_depth(self): + # Generate random pixel coordinates within image bounds + + idx = torch.randperm(self.width * self.height)[:5000] + x_coords = idx % self.width + y_coords = idx // self.width + pixels_to_render = JaggedTensor([torch.stack([y_coords, x_coords], 1)]).to(self.device) + + sparse_depth, sparse_alphas = self.gs3d.sparse_render_depths( + pixels_to_render, + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + dense_depth, dense_alphas = self.gs3d.render_depths( + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + dense_depth_pixels = dense_depth[0, y_coords, x_coords] + dense_alphas_pixels = dense_alphas[0, y_coords, x_coords] + + self.assertTrue( + torch.allclose(sparse_depth.jdata, dense_depth_pixels, atol=1e-5, rtol=1e-8), + "Sparse depth render does not match dense depth render at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_alphas.jdata, dense_alphas_pixels, atol=1e-5, rtol=1e-8), + "Sparse alpha render does not match dense alpha render at specified pixels", + ) + + def test_gaussian_render_sparse_depth_backward(self): + # Generate random pixel coordinates within image bounds + + idx = torch.randperm(self.width * self.height)[:5000] + x_coords = idx % self.width + y_coords = idx // self.width + pixels_to_render = JaggedTensor([torch.stack([y_coords, x_coords], 1)]).to(self.device) + + sparse_depth, sparse_alphas = self.gs3d.sparse_render_depths( + pixels_to_render, + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + l1 = torch.mean(sparse_depth.jdata) + sparse_alphas.jdata.sum() + l1.backward() + + assert self.gs3d.means.grad is not None, "Gradients not computed for means in sparse depth render" + assert self.gs3d.quats.grad is not None, "Gradients not computed for quats in sparse depth render" + assert self.gs3d.log_scales.grad is not None, "Gradients not computed for log_scales in sparse depth render" + assert ( + self.gs3d.logit_opacities.grad is not None + ), "Gradients not computed for logit_opacities in sparse depth render" + sparse_means_grad = self.gs3d.means.grad.clone() + sparse_quats_grad = self.gs3d.quats.grad.clone() + sparse_log_scales_grad = self.gs3d.log_scales.grad.clone() + sparse_logit_opacities_grad = self.gs3d.logit_opacities.grad + self.gs3d.means.grad.zero_() + self.gs3d.quats.grad.zero_() + self.gs3d.log_scales.grad.zero_() + self.gs3d.logit_opacities.grad.zero_() + + dense_depth, dense_alphas = self.gs3d.render_depths( + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + dense_depth_pixels = dense_depth[0, y_coords, x_coords] + dense_alphas_pixels = dense_alphas[0, y_coords, x_coords] + + l2 = torch.mean(dense_depth_pixels) + dense_alphas_pixels.sum() + l2.backward() + + dense_means_grad = self.gs3d.means.grad.clone() + dense_quats_grad = self.gs3d.quats.grad.clone() + dense_log_scales_grad = self.gs3d.log_scales.grad.clone() + dense_logit_opacities_grad = self.gs3d.logit_opacities.grad.clone() + + self.assertTrue( + torch.allclose(sparse_means_grad, dense_means_grad, atol=1e-4, rtol=1e-8), + "Sparse means grad does not match dense means grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_quats_grad, dense_quats_grad, atol=1e-4, rtol=1e-8), + "Sparse quats grad does not match dense quats grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_log_scales_grad, dense_log_scales_grad, atol=1e-4, rtol=1e-8), + "Sparse log scales grad does not match dense log scales grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_logit_opacities_grad, dense_logit_opacities_grad, atol=1e-4, rtol=1e-8), + "Sparse logit opacities grad does not match dense logit opacities grad at specified pixels", + ) + + def test_gaussian_render_sparse_features(self): + # Generate random pixel coordinates within image bounds + + idx = torch.randperm(self.width * self.height)[:5000] + x_coords = idx % self.width + y_coords = idx // self.width + pixels_to_render = JaggedTensor([torch.stack([y_coords, x_coords], 1)]).to(self.device) + + sparse_features, sparse_alphas = self.gs3d.sparse_render_images( + pixels_to_render, + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + dense_features, dense_alphas = self.gs3d.render_images( + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + dense_depth_pixels = dense_features[0, y_coords, x_coords] + dense_alphas_pixels = dense_alphas[0, y_coords, x_coords] + + self.assertTrue( + torch.allclose(sparse_features.jdata, dense_depth_pixels, atol=1e-5, rtol=1e-8), + "Sparse depth render does not match dense depth render at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_alphas.jdata, dense_alphas_pixels, atol=1e-5, rtol=1e-8), + "Sparse alpha render does not match dense alpha render at specified pixels", + ) + + def test_gaussian_render_sparse_features_backward(self): + # Generate random pixel coordinates within image bounds + + idx = torch.randperm(self.width * self.height)[:5000] + x_coords = idx % self.width + y_coords = idx // self.width + pixels_to_render = JaggedTensor([torch.stack([y_coords, x_coords], 1)]).to(self.device) + + sparse_features, sparse_alphas = self.gs3d.sparse_render_images( + pixels_to_render, + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + l1 = torch.mean(sparse_features.jdata) + sparse_alphas.jdata.sum() + l1.backward() + + assert self.gs3d.means.grad is not None, "Gradients not computed for means in sparse features render" + assert self.gs3d.quats.grad is not None, "Gradients not computed for quats in sparse features render" + assert self.gs3d.log_scales.grad is not None, "Gradients not computed for log_scales in sparse features render" + assert ( + self.gs3d.logit_opacities.grad is not None + ), "Gradients not computed for logit_opacities in sparse features render" + assert self.gs3d.sh0.grad is not None, "Gradients not computed for sh0 in sparse features render" + assert self.gs3d.shN.grad is not None, "Gradients not computed for shN in sparse features render" + sparse_means_grad = self.gs3d.means.grad.clone() + sparse_quats_grad = self.gs3d.quats.grad.clone() + sparse_log_scales_grad = self.gs3d.log_scales.grad.clone() + sparse_logit_opacities_grad = self.gs3d.logit_opacities.grad.clone() + sparse_sh0_grad = self.gs3d.sh0.grad.clone() + sparse_shN_grad = self.gs3d.shN.grad.clone() + self.gs3d.means.grad.zero_() + self.gs3d.quats.grad.zero_() + self.gs3d.log_scales.grad.zero_() + self.gs3d.logit_opacities.grad.zero_() + self.gs3d.sh0.grad.zero_() + self.gs3d.shN.grad.zero_() + + dense_features, dense_alphas = self.gs3d.render_images( + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + dense_features_pixels = dense_features[0, y_coords, x_coords] + dense_alphas_pixels = dense_alphas[0, y_coords, x_coords] + + l2 = torch.mean(dense_features_pixels) + dense_alphas_pixels.sum() + l2.backward() + + dense_means_grad = self.gs3d.means.grad.clone() + dense_quats_grad = self.gs3d.quats.grad.clone() + dense_log_scales_grad = self.gs3d.log_scales.grad.clone() + dense_logit_opacities_grad = self.gs3d.logit_opacities.grad.clone() + dense_sh0_grad = self.gs3d.sh0.grad.clone() + dense_shN_grad = self.gs3d.shN.grad.clone() + + self.assertTrue( + torch.allclose(sparse_means_grad, dense_means_grad, atol=1e-4, rtol=1e-8), + "Sparse means grad does not match dense means grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_quats_grad, dense_quats_grad, atol=1e-4, rtol=1e-8), + "Sparse quats grad does not match dense quats grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_log_scales_grad, dense_log_scales_grad, atol=1e-4, rtol=1e-8), + "Sparse log scales grad does not match dense log scales grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_logit_opacities_grad, dense_logit_opacities_grad, atol=1e-4, rtol=1e-8), + "Sparse logit opacities grad does not match dense logit opacities grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_sh0_grad, dense_sh0_grad, atol=1e-4, rtol=1e-8), + "Sparse sh0 grad does not match dense sh0 grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_shN_grad, dense_shN_grad, atol=1e-4, rtol=1e-8), + "Sparse shN grad does not match dense shN grad at specified pixels", + ) + + def test_gaussian_render_sparse_features_and_depths(self): + # Generate random pixel coordinates within image bounds + + idx = torch.randperm(self.width * self.height)[:5000] + x_coords = idx % self.width + y_coords = idx // self.width + pixels_to_render = JaggedTensor([torch.stack([y_coords, x_coords], 1)]).to(self.device) + + sparse_features, sparse_alphas = self.gs3d.sparse_render_images_and_depths( + pixels_to_render, + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + dense_features, dense_alphas = self.gs3d.render_images_and_depths( + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + dense_depth_pixels = dense_features[0, y_coords, x_coords] + dense_alphas_pixels = dense_alphas[0, y_coords, x_coords] + + self.assertTrue( + torch.allclose(sparse_features.jdata, dense_depth_pixels, atol=1e-5, rtol=1e-8), + "Sparse depth render does not match dense depth render at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_alphas.jdata, dense_alphas_pixels, atol=1e-5, rtol=1e-8), + "Sparse alpha render does not match dense alpha render at specified pixels", + ) + + def test_gaussian_render_sparse_features_and_depths_backward(self): + # Generate random pixel coordinates within image bounds + + idx = torch.randperm(self.width * self.height)[:5000] + x_coords = idx % self.width + y_coords = idx // self.width + pixels_to_render = JaggedTensor([torch.stack([y_coords, x_coords], 1)]).to(self.device) + + sparse_features, sparse_alphas = self.gs3d.sparse_render_images_and_depths( + pixels_to_render, + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + l1 = torch.mean(sparse_features.jdata) + sparse_alphas.jdata.sum() + l1.backward() + + assert self.gs3d.means.grad is not None, "Gradients not computed for means in sparse features render" + assert self.gs3d.quats.grad is not None, "Gradients not computed for quats in sparse features render" + assert self.gs3d.log_scales.grad is not None, "Gradients not computed for log_scales in sparse features render" + assert ( + self.gs3d.logit_opacities.grad is not None + ), "Gradients not computed for logit_opacities in sparse features render" + assert self.gs3d.sh0.grad is not None, "Gradients not computed for sh0 in sparse features render" + assert self.gs3d.shN.grad is not None, "Gradients not computed for shN in sparse features render" + sparse_means_grad = self.gs3d.means.grad.clone() + sparse_quats_grad = self.gs3d.quats.grad.clone() + sparse_log_scales_grad = self.gs3d.log_scales.grad.clone() + sparse_logit_opacities_grad = self.gs3d.logit_opacities.grad + sparse_sh0_grad = self.gs3d.sh0.grad.clone() + sparse_shN_grad = self.gs3d.shN.grad.clone() + self.gs3d.means.grad.zero_() + self.gs3d.quats.grad.zero_() + self.gs3d.log_scales.grad.zero_() + self.gs3d.logit_opacities.grad.zero_() + self.gs3d.sh0.grad.zero_() + self.gs3d.shN.grad.zero_() + + dense_features, dense_alphas = self.gs3d.render_images_and_depths( + self.cam_to_world_mats[0:1], + self.projection_mats[0:1], + self.width, + self.height, + self.near_plane, # near_plane + self.far_plane, # far_plane + ) + + dense_features_pixels = dense_features[0, y_coords, x_coords] + dense_alphas_pixels = dense_alphas[0, y_coords, x_coords] + + l2 = torch.mean(dense_features_pixels) + dense_alphas_pixels.sum() + l2.backward() + + dense_means_grad = self.gs3d.means.grad.clone() + dense_quats_grad = self.gs3d.quats.grad.clone() + dense_log_scales_grad = self.gs3d.log_scales.grad.clone() + dense_logit_opacities_grad = self.gs3d.logit_opacities.grad.clone() + dense_sh0_grad = self.gs3d.sh0.grad.clone() + dense_shN_grad = self.gs3d.shN.grad.clone() + + self.assertTrue( + torch.allclose(sparse_means_grad, dense_means_grad, atol=1e-4, rtol=1e-8), + "Sparse means grad does not match dense means grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_quats_grad, dense_quats_grad, atol=1e-4, rtol=1e-8), + "Sparse quats grad does not match dense quats grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_log_scales_grad, dense_log_scales_grad, atol=1e-4, rtol=1e-8), + "Sparse log scales grad does not match dense log scales grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_logit_opacities_grad, dense_logit_opacities_grad, atol=1e-4, rtol=1e-8), + "Sparse logit opacities grad does not match dense logit opacities grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_sh0_grad, dense_sh0_grad, atol=1e-4, rtol=1e-8), + "Sparse sh0 grad does not match dense sh0 grad at specified pixels", + ) + self.assertTrue( + torch.allclose(sparse_shN_grad, dense_shN_grad, atol=1e-4, rtol=1e-8), + "Sparse shN grad does not match dense shN grad at specified pixels", + ) + + class TestGaussianRenderBackgrounds(BaseGaussianTestCase): """Test background color support in Gaussian rendering"""