1515#include < fstream>
1616#include < string>
1717
18- static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr ;
18+ void * createRasterizer ()
19+ {
20+ return (void *)CudaRasterizer::Rasterizer::make ();
21+ }
1922
20- void * createRasterizerState ( )
23+ void deleteRasterizer ( void * rasterizer )
2124{
22- if (cudaRenderer == nullptr )
23- {
24- cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make ());
25- }
26- return (void *)cudaRenderer->createInternalState ();
25+ CudaRasterizer::Rasterizer::kill ((CudaRasterizer::Rasterizer*)rasterizer);
26+ }
27+
28+ void * createRasterizerState (void * rasterizer)
29+ {
30+ return (void *)((CudaRasterizer::Rasterizer*)rasterizer)->createInternalState ();
2731}
2832
29- void deleteRasterizerState (void * state)
33+ void deleteRasterizerState (void * rasterizer, void * state)
3034{
31- cudaRenderer ->killInternalState ((CudaRasterizer::InternalState*)state);
35+ ((CudaRasterizer::Rasterizer*)rasterizer) ->killInternalState ((CudaRasterizer::InternalState*)state);
3236}
3337
3438std::tuple<torch::Tensor, torch::Tensor>
3539RasterizeGaussiansCUDA (
40+ void * rasterizer,
3641 const torch::Tensor& background,
3742 const torch::Tensor& means3D,
3843 const torch::Tensor& colors,
@@ -58,11 +63,6 @@ RasterizeGaussiansCUDA(
5863 AT_ERROR (" means3D must have dimensions (num_points, 3)" );
5964 }
6065
61- if (cudaRenderer == nullptr )
62- {
63- cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make ());
64- }
65-
6666 const int P = means3D.size (0 );
6767 const int N = 1 ; // batch size hard-coded
6868 const int H = image_height;
@@ -82,7 +82,7 @@ RasterizeGaussiansCUDA(
8282 M = sh.size (1 );
8383 }
8484
85- cudaRenderer ->forward (P, degree, M,
85+ ((CudaRasterizer::Rasterizer*)rasterizer) ->forward (P, degree, M,
8686 background.contiguous ().data <float >(),
8787 W, H,
8888 means3D.contiguous ().data <float >(),
@@ -108,6 +108,7 @@ RasterizeGaussiansCUDA(
108108
109109std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
110110 RasterizeGaussiansBackwardCUDA (
111+ void * rasterizer,
111112 const void * internalState,
112113 const torch::Tensor& background,
113114 const torch::Tensor& means3D,
@@ -148,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
148149
149150 if (P != 0 )
150151 {
151- cudaRenderer ->backward (
152+ ((CudaRasterizer::Rasterizer*)rasterizer) ->backward (
152153 radii.contiguous ().data <int >(),
153154 (CudaRasterizer::InternalState*)internalState,
154155 P, degree, M,
@@ -182,22 +183,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
182183}
183184
184185torch::Tensor markVisible (
186+ void * rasterizer,
185187 torch::Tensor& means3D,
186188 torch::Tensor& viewmatrix,
187189 torch::Tensor& projmatrix)
188190{
189- if (cudaRenderer == nullptr )
190- {
191- cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make ());
192- }
193191
194192 const int P = means3D.size (0 );
195193
196194 torch::Tensor present = torch::full ({P}, false , means3D.options ().dtype (at::kBool ));
197195
198196 if (P != 0 )
199197 {
200- cudaRenderer ->markVisible (P,
198+ ((CudaRasterizer::Rasterizer*)rasterizer) ->markVisible (P,
201199 means3D.contiguous ().data <float >(),
202200 viewmatrix.contiguous ().data <float >(),
203201 projmatrix.contiguous ().data <float >(),
0 commit comments