44from . import _C
55
66def rasterize_gaussians (
7+ instance ,
78 means3D ,
89 means2D ,
910 sh ,
@@ -16,6 +17,7 @@ def rasterize_gaussians(
1617 rasterizer_state
1718):
1819 return _RasterizeGaussians .apply (
20+ instance ,
1921 means3D ,
2022 means2D ,
2123 sh ,
@@ -32,6 +34,7 @@ class _RasterizeGaussians(torch.autograd.Function):
3234 @staticmethod
3335 def forward (
3436 ctx ,
37+ instance ,
3538 means3D ,
3639 means2D ,
3740 sh ,
@@ -46,6 +49,7 @@ def forward(
4649
4750 # Restructure arguments the way that the C++ lib expects them
4851 args = (
52+ instance ,
4953 raster_settings .bg ,
5054 means3D ,
5155 colors_precomp ,
@@ -72,6 +76,7 @@ def forward(
7276
7377 # Keep relevant tensors for backward
7478 ctx .raster_settings = raster_settings
79+ ctx .instance = instance
7580 ctx .rasterizer_state = rasterizer_state
7681 ctx .save_for_backward (colors_precomp , means3D , scales , rotations , cov3Ds_precomp , radii , sh )
7782 return color , radii
@@ -80,12 +85,14 @@ def forward(
8085 def backward (ctx , grad_out_color , _ ):
8186
8287 # Restore necessary values from context
88+ instance = ctx .instance
8389 rasterizer_state = ctx .rasterizer_state
8490 raster_settings = ctx .raster_settings
8591 colors_precomp , means3D , scales , rotations , cov3Ds_precomp , radii , sh = ctx .saved_tensors
8692
8793 # Restructure args as C++ method expects them
88- args = (rasterizer_state ,
94+ args = (instance ,
95+ rasterizer_state ,
8996 raster_settings .bg ,
9097 means3D ,
9198 radii ,
@@ -107,6 +114,7 @@ def backward(ctx, grad_out_color, _):
107114 grad_means2D , grad_colors_precomp , grad_opacities , grad_means3D , grad_cov3Ds_precomp , grad_sh , grad_scales , grad_rotations = _C .rasterize_gaussians_backward (* args )
108115
109116 grads = (
117+ None ,
110118 grad_means3D ,
111119 grad_means2D ,
112120 grad_sh ,
@@ -134,33 +142,32 @@ class GaussianRasterizationSettings(NamedTuple):
134142 campos : torch .Tensor
135143 prefiltered : bool
136144
137- def createRasterizerState ():
138- return _C .create_rasterizer_state ()
139-
140- def deleteRasterizerState (state ):
141- return _C .delete_rasterize_state (state )
142-
143145class GaussianRasterizer (nn .Module ):
144- def __init__ (self , raster_settings , rasterizer_state ):
146+ def __init__ (self ):
145147 super ().__init__ ()
146- self .raster_settings = raster_settings
147- self .rasterizer_state = rasterizer_state
148+ self .instance = _C .create_rasterizer ()
149+
150+ def __del__ (self ):
151+ _C .delete_rasterizer (self .instance )
148152
149- def markVisible (self , positions ):
153+ def createRasterizerState (self ):
154+ return _C .create_rasterizer_state (self .instance )
155+
156+ def deleteRasterizerState (self , state ):
157+ _C .delete_rasterizer_state (self .instance , state )
158+
159+ def markVisible (self , raster_settings , positions ):
150160 # Mark visible points (based on frustum culling for camera) with a boolean
151161 with torch .no_grad ():
152- raster_settings = self .raster_settings
153162 visible = _C .mark_visible (
163+ self .instance ,
154164 positions ,
155165 raster_settings .viewmatrix ,
156166 raster_settings .projmatrix )
157167
158168 return visible
159169
160- def forward (self , means3D , means2D , opacities , shs = None , colors_precomp = None , scales = None , rotations = None , cov3D_precomp = None ):
161- raster_settings = self .raster_settings
162- rasterize_state = self .rasterizer_state
163-
170+ def forward (self , rasterizer_state , raster_settings , means3D , means2D , opacities , shs = None , colors_precomp = None , scales = None , rotations = None , cov3D_precomp = None ):
164171 if (shs is None and colors_precomp is None ) or (shs is not None and colors_precomp is not None ):
165172 raise Exception ('Please provide excatly one of either SHs or precomputed colors!' )
166173
@@ -181,6 +188,7 @@ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None
181188
182189 # Invoke C++/CUDA rasterization routine
183190 return rasterize_gaussians (
191+ self .instance ,
184192 means3D ,
185193 means2D ,
186194 shs ,
@@ -190,6 +198,6 @@ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None
190198 rotations ,
191199 cov3D_precomp ,
192200 raster_settings ,
193- rasterize_state
201+ rasterizer_state
194202 )
195203
0 commit comments