Skip to content

Commit 07f5941

Browse files
committed
plumb backgrounds up to python bindings
Signed-off-by: Mark Harris <mharris@nvidia.com>
1 parent a2021f9 commit 07f5941

File tree

5 files changed

+137
-65
lines changed

5 files changed

+137
-65
lines changed

src/fvdb/GaussianSplat3d.cpp

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,8 @@ GaussianSplat3d::renderCropFromProjectedGaussiansImpl(
391391
const ssize_t cropWidth,
392392
const ssize_t cropHeight,
393393
const ssize_t cropOriginW,
394-
const ssize_t cropOriginH) {
394+
const ssize_t cropOriginH,
395+
const std::optional<torch::Tensor> &backgrounds) {
395396
FVDB_FUNC_RANGE();
396397
// Negative values mean use the whole image, but all values must be negative
397398
if (cropWidth <= 0 || cropHeight <= 0 || cropOriginW < 0 || cropOriginH < 0) {
@@ -410,6 +411,11 @@ GaussianSplat3d::renderCropFromProjectedGaussiansImpl(
410411
// Rasterize projected Gaussians to pixels (differentiable)
411412
// NOTE: projectGaussians* performs input checking, we need to apply some further
412413
// checking before GaussianRasterizeToPixels
414+
std::optional<torch::Tensor> backgroundsOpt = std::nullopt;
415+
if (backgrounds.has_value()) {
416+
backgroundsOpt = backgrounds.value();
417+
}
418+
413419
auto outputs = detail::autograd::RasterizeGaussiansToPixels::apply(
414420
projectedGaussians.perGaussian2dMean,
415421
projectedGaussians.perGaussianConic,
@@ -422,7 +428,8 @@ GaussianSplat3d::renderCropFromProjectedGaussiansImpl(
422428
tileSize,
423429
projectedGaussians.tileOffsets,
424430
projectedGaussians.tileGaussianIds,
425-
false);
431+
false,
432+
backgroundsOpt);
426433
torch::Tensor renderedImage = outputs[0];
427434
torch::Tensor renderedAlphas = outputs[1];
428435

@@ -674,9 +681,10 @@ GaussianSplat3d::renderFromProjectedGaussians(
674681
const ssize_t cropHeight,
675682
const ssize_t cropOriginW,
676683
const ssize_t cropOriginH,
677-
const size_t tileSize) {
684+
const size_t tileSize,
685+
const std::optional<torch::Tensor> &backgrounds) {
678686
return renderCropFromProjectedGaussiansImpl(
679-
projectedGaussians, tileSize, cropWidth, cropHeight, cropOriginW, cropOriginH);
687+
projectedGaussians, tileSize, cropWidth, cropHeight, cropOriginW, cropOriginH, backgrounds);
680688
}
681689

682690
std::tuple<torch::Tensor, torch::Tensor>
@@ -691,7 +699,8 @@ GaussianSplat3d::renderImages(const torch::Tensor &worldToCameraMatrices,
691699
const size_t tileSize,
692700
const float minRadius2d,
693701
const float eps2d,
694-
const bool antialias) {
702+
const bool antialias,
703+
const std::optional<torch::Tensor> &backgrounds) {
695704
RenderSettings settings;
696705
settings.imageWidth = imageWidth;
697706
settings.imageHeight = imageHeight;
@@ -705,7 +714,10 @@ GaussianSplat3d::renderImages(const torch::Tensor &worldToCameraMatrices,
705714
settings.tileSize = tileSize;
706715
settings.renderMode = RenderSettings::RenderMode::RGB;
707716

708-
return renderImpl(worldToCameraMatrices, projectionMatrices, settings);
717+
const ProjectedGaussianSplats state =
718+
projectGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings);
719+
return renderCropFromProjectedGaussiansImpl(
720+
state, settings.tileSize, settings.imageWidth, settings.imageHeight, 0, 0, backgrounds);
709721
}
710722

711723
std::tuple<torch::Tensor, torch::Tensor>
@@ -719,7 +731,8 @@ GaussianSplat3d::renderDepths(const torch::Tensor &worldToCameraMatrices,
719731
const size_t tileSize,
720732
const float minRadius2d,
721733
const float eps2d,
722-
const bool antialias) {
734+
const bool antialias,
735+
const std::optional<torch::Tensor> &backgrounds) {
723736
RenderSettings settings;
724737
settings.imageWidth = imageWidth;
725738
settings.imageHeight = imageHeight;
@@ -732,7 +745,10 @@ GaussianSplat3d::renderDepths(const torch::Tensor &worldToCameraMatrices,
732745
settings.tileSize = tileSize;
733746
settings.renderMode = RenderSettings::RenderMode::DEPTH;
734747

735-
return renderImpl(worldToCameraMatrices, projectionMatrices, settings);
748+
const ProjectedGaussianSplats state =
749+
projectGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings);
750+
return renderCropFromProjectedGaussiansImpl(
751+
state, settings.tileSize, settings.imageWidth, settings.imageHeight, 0, 0, backgrounds);
736752
}
737753

738754
std::tuple<torch::Tensor, torch::Tensor>
@@ -864,7 +880,8 @@ GaussianSplat3d::renderImagesAndDepths(const torch::Tensor &worldToCameraMatrice
864880
const size_t tileSize,
865881
const float minRadius2d,
866882
const float eps2d,
867-
const bool antialias) {
883+
const bool antialias,
884+
const std::optional<torch::Tensor> &backgrounds) {
868885
RenderSettings settings;
869886
settings.imageWidth = imageWidth;
870887
settings.imageHeight = imageHeight;
@@ -878,7 +895,10 @@ GaussianSplat3d::renderImagesAndDepths(const torch::Tensor &worldToCameraMatrice
878895
settings.tileSize = tileSize;
879896
settings.renderMode = RenderSettings::RenderMode::RGBD;
880897

881-
return renderImpl(worldToCameraMatrices, projectionMatrices, settings);
898+
const ProjectedGaussianSplats state =
899+
projectGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings);
900+
return renderCropFromProjectedGaussiansImpl(
901+
state, settings.tileSize, settings.imageWidth, settings.imageHeight, 0, 0, backgrounds);
882902
}
883903

884904
GaussianSplat3d
@@ -1105,7 +1125,8 @@ gaussianRenderJagged(const JaggedTensor &means, // [N1 + N2 + ..., 3]
11051125
const bool render_depth_channel,
11061126
const bool return_debug_info,
11071127
const bool render_depth_only,
1108-
const bool ortho) {
1128+
const bool ortho,
1129+
const std::optional<torch::Tensor> &backgrounds) {
11091130
const int ccz = viewmats.rsize(0); // number of cameras
11101131
const int ggz = means.rsize(0); // number of gaussians
11111132
const int D = render_depth_only ? 1 : sh_coeffs.rsize(-1); // Dimension of output
@@ -1264,6 +1285,11 @@ gaussianRenderJagged(const JaggedTensor &means, // [N1 + N2 + ..., 3]
12641285
}
12651286

12661287
// Rasterize projected Gaussians to pixels [differentiable]
1288+
std::optional<torch::Tensor> backgroundsOpt = std::nullopt;
1289+
if (backgrounds.has_value()) {
1290+
backgroundsOpt = backgrounds.value();
1291+
}
1292+
12671293
auto outputs =
12681294
detail::autograd::RasterizeGaussiansToPixels::apply(means2d,
12691295
conics,
@@ -1276,7 +1302,8 @@ gaussianRenderJagged(const JaggedTensor &means, // [N1 + N2 + ..., 3]
12761302
tile_size,
12771303
tile_offsets,
12781304
tile_gaussian_ids,
1279-
false);
1305+
false,
1306+
backgroundsOpt);
12801307
torch::Tensor renderedImages = outputs[0];
12811308
torch::Tensor renderedAlphaImages = outputs[1];
12821309

src/fvdb/GaussianSplat3d.h

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -841,17 +841,19 @@ class GaussianSplat3d {
841841
/// @param cropOriginH Origin of the cropped image in the height dimension (use -1 for no
842842
/// cropping)
843843
/// @param tileSize Size of the tiles used for rendering
844+
/// @param backgrounds Optional [C, D] tensor of background colors for each camera
844845
/// @return Tuple of two tensors:
845846
/// images: A [C, H, W, D|1|D+1] tensor containing the the rendered image
846847
/// (or depth or image and depth) for each camera
847848
/// alphas: A [C, H, W, 1] tensor containing the alpha values of the rendered images
848849
std::tuple<torch::Tensor, torch::Tensor>
849850
renderFromProjectedGaussians(const GaussianSplat3d::ProjectedGaussianSplats &projectedGaussians,
850-
const ssize_t cropWidth = -1,
851-
const ssize_t cropHeight = -1,
852-
const ssize_t cropOriginW = -1,
853-
const ssize_t cropOriginH = -1,
854-
const size_t tileSize = 16);
851+
const ssize_t cropWidth = -1,
852+
const ssize_t cropHeight = -1,
853+
const ssize_t cropOriginW = -1,
854+
const ssize_t cropOriginH = -1,
855+
const size_t tileSize = 16,
856+
const std::optional<torch::Tensor> &backgrounds = std::nullopt);
855857

856858
/// @brief Render images of this Gaussian splat scene from the given camera matrices and
857859
/// projection matrices.
@@ -867,6 +869,7 @@ class GaussianSplat3d {
867869
/// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored
868870
/// @param eps2d Blur factor for antialiasing (only used if antialias is true)
869871
/// @param antialias Whether to antialias the image
872+
/// @param backgrounds Optional [C, D] tensor of background colors for each camera
870873
/// @return Tuple of two tensors:
871874
/// images: A [C, H, W, D] tensor containing the the rendered image for each camera
872875
/// alphas: A [C, H, W, 1] tensor containing the alpha values of the rendered images
@@ -877,12 +880,13 @@ class GaussianSplat3d {
877880
const size_t imageHeight,
878881
const float near,
879882
const float far,
880-
const ProjectionType projectionType = ProjectionType::PERSPECTIVE,
881-
const int64_t shDegreeToUse = -1,
882-
const size_t tileSize = 16,
883-
const float minRadius2d = 0.0,
884-
const float eps2d = 0.3,
885-
const bool antialias = false);
883+
const ProjectionType projectionType = ProjectionType::PERSPECTIVE,
884+
const int64_t shDegreeToUse = -1,
885+
const size_t tileSize = 16,
886+
const float minRadius2d = 0.0,
887+
const float eps2d = 0.3,
888+
const bool antialias = false,
889+
const std::optional<torch::Tensor> &backgrounds = std::nullopt);
886890

887891
/// @brief Render depths of this Gaussian splat scene from the given camera matrices and
888892
/// projection matrices.
@@ -897,6 +901,7 @@ class GaussianSplat3d {
897901
/// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored
898902
/// @param eps2d Blur factor for antialiasing (only used if antialias is true)
899903
/// @param antialias Whether to antialias the image
904+
/// @param backgrounds Optional [C, 1] tensor of background depths for each camera
900905
/// @return Tuple of two tensors:
901906
/// images: A [C, H, W, 1] tensor containing the the rendered depths for each camera
902907
/// alphas: A [C, H, W, 1] tensor containing the alpha values of the rendered depths
@@ -907,11 +912,12 @@ class GaussianSplat3d {
907912
const size_t imageHeight,
908913
const float near,
909914
const float far,
910-
const ProjectionType projectionType = ProjectionType::PERSPECTIVE,
911-
const size_t tileSize = 16,
912-
const float minRadius2d = 0.0,
913-
const float eps2d = 0.3,
914-
const bool antialias = false);
915+
const ProjectionType projectionType = ProjectionType::PERSPECTIVE,
916+
const size_t tileSize = 16,
917+
const float minRadius2d = 0.0,
918+
const float eps2d = 0.3,
919+
const bool antialias = false,
920+
const std::optional<torch::Tensor> &backgrounds = std::nullopt);
915921

916922
std::tuple<torch::Tensor, torch::Tensor>
917923
renderImagesAndDepths(const torch::Tensor &worldToCameraMatrices,
@@ -925,7 +931,8 @@ class GaussianSplat3d {
925931
const size_t tileSize = 16,
926932
const float minRadius2d = 0.0,
927933
const float eps2d = 0.3,
928-
const bool antialias = false);
934+
const bool antialias = false,
935+
const std::optional<torch::Tensor> &backgrounds = std::nullopt);
929936

930937
/// @brief Render the number of contributing Gaussians for each pixel in the image.
931938
/// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices
@@ -1132,13 +1139,14 @@ class GaussianSplat3d {
11321139
const torch::Tensor &projectionMatrices,
11331140
const fvdb::detail::ops::RenderSettings &settings);
11341141

1135-
std::tuple<torch::Tensor, torch::Tensor>
1136-
renderCropFromProjectedGaussiansImpl(const ProjectedGaussianSplats &state,
1137-
const size_t tileSize,
1138-
const ssize_t cropWidth,
1139-
const ssize_t cropHeight,
1140-
const ssize_t cropOriginW,
1141-
const ssize_t cropOriginH);
1142+
std::tuple<torch::Tensor, torch::Tensor> renderCropFromProjectedGaussiansImpl(
1143+
const ProjectedGaussianSplats &state,
1144+
const size_t tileSize,
1145+
const ssize_t cropWidth,
1146+
const ssize_t cropHeight,
1147+
const ssize_t cropOriginW,
1148+
const ssize_t cropOriginH,
1149+
const std::optional<torch::Tensor> &backgrounds = std::nullopt);
11421150

11431151
/// @brief Implements index set with a tensor of booleans or integer indices
11441152
/// @param indexOrMask A 1D tensor of indices in the range [0, numGaussians-1] or a boolean mask
@@ -1260,17 +1268,18 @@ gaussianRenderJagged(const JaggedTensor &means, // [N1 + N2 + ..., 3]
12601268
const JaggedTensor &Ks, // [C1 + C2 + ..., 3, 3]
12611269
const uint32_t image_width,
12621270
const uint32_t image_height,
1263-
const float near_plane = 0.01,
1264-
const float far_plane = 1e10,
1265-
const int sh_degree_to_use = -1,
1266-
const int tile_size = 16,
1267-
const float radius_clip = 0.0,
1268-
const float eps2d = 0.3,
1269-
const bool antialias = false,
1270-
const bool render_depth_channel = false,
1271-
const bool return_debug_info = false,
1272-
const bool render_depth_only = false,
1273-
const bool ortho = false);
1271+
const float near_plane = 0.01,
1272+
const float far_plane = 1e10,
1273+
const int sh_degree_to_use = -1,
1274+
const int tile_size = 16,
1275+
const float radius_clip = 0.0,
1276+
const float eps2d = 0.3,
1277+
const bool antialias = false,
1278+
const bool render_depth_channel = false,
1279+
const bool return_debug_info = false,
1280+
const bool render_depth_only = false,
1281+
const bool ortho = false,
1282+
const std::optional<torch::Tensor> &backgrounds = std::nullopt);
12741283

12751284
} // namespace fvdb
12761285

0 commit comments

Comments
 (0)