@@ -41,6 +41,8 @@ def forward(
4141 raster_settings ,
4242 ):
4343
44+ rasterizer_state = _C .create_rasterizer_state ()
45+
4446 # Restructure arguments the way that the C++ lib expects them
4547 args = (
4648 raster_settings .bg ,
@@ -61,25 +63,29 @@ def forward(
6163 raster_settings .sh_degree ,
6264 raster_settings .campos ,
6365 raster_settings .prefiltered ,
66+ rasterizer_state
6467 )
6568
6669 # Invoke C++/CUDA rasterizer
6770 color , radii = _C .rasterize_gaussians (* args )
6871
6972 # Keep relevant tensors for backward
7073 ctx .raster_settings = raster_settings
74+ ctx .rasterizer_state = rasterizer_state
7175 ctx .save_for_backward (colors_precomp , means3D , scales , rotations , cov3Ds_precomp , radii , sh )
7276 return color , radii
7377
7478 @staticmethod
7579 def backward (ctx , grad_out_color , _ ):
7680
7781 # Restore necessary values from context
82+ rasterizer_state = ctx .rasterizer_state
7883 raster_settings = ctx .raster_settings
7984 colors_precomp , means3D , scales , rotations , cov3Ds_precomp , radii , sh = ctx .saved_tensors
8085
8186 # Restructure args as C++ method expects them
82- args = (raster_settings .bg ,
87+ args = (rasterizer_state ,
88+ raster_settings .bg ,
8389 means3D ,
8490 radii ,
8591 colors_precomp ,
@@ -111,6 +117,8 @@ def backward(ctx, grad_out_color, _):
111117 None ,
112118 )
113119
120+ _C .delete_rasterizer_state (rasterizer_state )
121+
114122 return grads
115123
116124class GaussianRasterizationSettings (NamedTuple ):
0 commit comments