Skip to content

Commit e41a365

Browse files
author
Bernhard Kerbl
committed
working with instances
1 parent 3851457 commit e41a365

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

diff_gaussian_rasterization/rasterizer.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from . import _C
55

66
def rasterize_gaussians(
7+
instance,
78
means3D,
89
means2D,
910
sh,
@@ -16,6 +17,7 @@ def rasterize_gaussians(
1617
rasterizer_state
1718
):
1819
return _RasterizeGaussians.apply(
20+
instance,
1921
means3D,
2022
means2D,
2123
sh,
@@ -32,6 +34,7 @@ class _RasterizeGaussians(torch.autograd.Function):
3234
@staticmethod
3335
def forward(
3436
ctx,
37+
instance,
3538
means3D,
3639
means2D,
3740
sh,
@@ -46,6 +49,7 @@ def forward(
4649

4750
# Restructure arguments the way that the C++ lib expects them
4851
args = (
52+
instance,
4953
raster_settings.bg,
5054
means3D,
5155
colors_precomp,
@@ -72,6 +76,7 @@ def forward(
7276

7377
# Keep relevant tensors for backward
7478
ctx.raster_settings = raster_settings
79+
ctx.instance = instance
7580
ctx.rasterizer_state = rasterizer_state
7681
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh)
7782
return color, radii
@@ -80,12 +85,14 @@ def forward(
8085
def backward(ctx, grad_out_color, _):
8186

8287
# Restore necessary values from context
88+
instance = ctx.instance
8389
rasterizer_state = ctx.rasterizer_state
8490
raster_settings = ctx.raster_settings
8591
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors
8692

8793
# Restructure args as C++ method expects them
88-
args = (rasterizer_state,
94+
args = (instance,
95+
rasterizer_state,
8996
raster_settings.bg,
9097
means3D,
9198
radii,
@@ -107,6 +114,7 @@ def backward(ctx, grad_out_color, _):
107114
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
108115

109116
grads = (
117+
None,
110118
grad_means3D,
111119
grad_means2D,
112120
grad_sh,
@@ -134,33 +142,32 @@ class GaussianRasterizationSettings(NamedTuple):
134142
campos : torch.Tensor
135143
prefiltered : bool
136144

137-
def createRasterizerState():
138-
return _C.create_rasterizer_state()
139-
140-
def deleteRasterizerState(state):
141-
return _C.delete_rasterize_state(state)
142-
143145
class GaussianRasterizer(nn.Module):
144-
def __init__(self, raster_settings, rasterizer_state):
146+
def __init__(self):
145147
super().__init__()
146-
self.raster_settings = raster_settings
147-
self.rasterizer_state = rasterizer_state
148+
self.instance = _C.create_rasterizer()
149+
150+
def __del__(self):
151+
_C.delete_rasterizer(self.instance)
148152

149-
def markVisible(self, positions):
153+
def createRasterizerState(self):
154+
return _C.create_rasterizer_state(self.instance)
155+
156+
def deleteRasterizerState(self, state):
157+
_C.delete_rasterizer_state(self.instance, state)
158+
159+
def markVisible(self, raster_settings, positions):
150160
# Mark visible points (based on frustum culling for camera) with a boolean
151161
with torch.no_grad():
152-
raster_settings = self.raster_settings
153162
visible = _C.mark_visible(
163+
self.instance,
154164
positions,
155165
raster_settings.viewmatrix,
156166
raster_settings.projmatrix)
157167

158168
return visible
159169

160-
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
161-
raster_settings = self.raster_settings
162-
rasterize_state = self.rasterizer_state
163-
170+
def forward(self, rasterizer_state, raster_settings, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
164171
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
165172
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
166173

@@ -181,6 +188,7 @@ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None
181188

182189
# Invoke C++/CUDA rasterization routine
183190
return rasterize_gaussians(
191+
self.instance,
184192
means3D,
185193
means2D,
186194
shs,
@@ -190,6 +198,6 @@ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None
190198
rotations,
191199
cov3D_precomp,
192200
raster_settings,
193-
rasterize_state
201+
rasterizer_state
194202
)
195203

0 commit comments

Comments
 (0)