Skip to content

Commit 8a219fd

Browse files
author
Bernhard Kerbl
committed
changes:
1 parent 64f7c22 commit 8a219fd

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

diff_gaussian_rasterization/rasterizer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def rasterize_gaussians(
1313
rotations,
1414
cov3Ds_precomp,
1515
raster_settings,
16+
rasterizer_state
1617
):
1718
return _RasterizeGaussians.apply(
1819
means3D,
@@ -24,6 +25,7 @@ def rasterize_gaussians(
2425
rotations,
2526
cov3Ds_precomp,
2627
raster_settings,
28+
rasterizer_state
2729
)
2830

2931
class _RasterizeGaussians(torch.autograd.Function):
@@ -39,10 +41,9 @@ def forward(
3941
rotations,
4042
cov3Ds_precomp,
4143
raster_settings,
44+
rasterizer_state
4245
):
4346

44-
rasterizer_state = _C.create_rasterizer_state()
45-
4647
# Restructure arguments the way that the C++ lib expects them
4748
args = (
4849
raster_settings.bg,
@@ -115,10 +116,9 @@ def backward(ctx, grad_out_color, _):
115116
grad_rotations,
116117
grad_cov3Ds_precomp,
117118
None,
119+
None,
118120
)
119121

120-
_C.delete_rasterizer_state(rasterizer_state)
121-
122122
return grads
123123

124124
class GaussianRasterizationSettings(NamedTuple):
@@ -134,10 +134,17 @@ class GaussianRasterizationSettings(NamedTuple):
134134
campos : torch.Tensor
135135
prefiltered : bool
136136

137+
def createRasterizerState():
138+
return _C.create_rasterizer_state()
139+
140+
def deleteRasterizerState(state):
141+
return _C.delete_rasterize_state(state)
142+
137143
class GaussianRasterizer(nn.Module):
138-
def __init__(self, raster_settings):
144+
def __init__(self, raster_settings, rasterizer_state):
139145
super().__init__()
140146
self.raster_settings = raster_settings
147+
self.rasterizer_state = rasterizer_state
141148

142149
def markVisible(self, positions):
143150
# Mark visible points (based on frustum culling for camera) with a boolean
@@ -151,8 +158,8 @@ def markVisible(self, positions):
151158
return visible
152159

153160
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
154-
155161
raster_settings = self.raster_settings
162+
rasterize_state = self.rasterizer_state
156163

157164
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
158165
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
@@ -183,5 +190,6 @@ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None
183190
rotations,
184191
cov3D_precomp,
185192
raster_settings,
193+
rasterize_state
186194
)
187195

0 commit comments

Comments
 (0)