Skip to content

Commit 3851457

Browse files
author
Bernhard Kerbl
committed
changes
1 parent 8a219fd commit 3851457

File tree

5 files changed

+37
-23
lines changed

5 files changed

+37
-23
lines changed

cuda_rasterizer/rasterizer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ namespace CudaRasterizer
7474
virtual ~Rasterizer() {};
7575

7676
static Rasterizer* make(int resizeMultipliyer = 2);
77+
78+
static void kill(Rasterizer* rasterizer);
7779
};
7880
};
7981

cuda_rasterizer/rasterizer_impl.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplie
137137
return new CudaRasterizer::RasterizerImpl(resizeMultiplier);
138138
}
139139

140+
void CudaRasterizer::Rasterizer::kill(Rasterizer* rasterizer)
141+
{
142+
delete rasterizer;
143+
}
144+
140145
// Mark Gaussians as visible/invisible, based on view frustum testing
141146
void CudaRasterizer::RasterizerImpl::markVisible(
142147
int P,

ext.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
88
m.def("mark_visible", &markVisible);
99
m.def("create_rasterizer_state", &createRasterizerState);
1010
m.def("delete_rasterizer_state", &deleteRasterizerState);
11+
m.def("create_rasterizer", &createRasterizer);
12+
m.def("delete_rasterizer", &deleteRasterizer);
1113
}

rasterize_points.cu

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,29 @@
1515
#include <fstream>
1616
#include <string>
1717

18-
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
18+
void* createRasterizer()
19+
{
20+
return (void*)CudaRasterizer::Rasterizer::make();
21+
}
1922

20-
void* createRasterizerState()
23+
void deleteRasterizer(void* rasterizer)
2124
{
22-
if (cudaRenderer == nullptr)
23-
{
24-
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
25-
}
26-
return (void*)cudaRenderer->createInternalState();
25+
CudaRasterizer::Rasterizer::kill((CudaRasterizer::Rasterizer*)rasterizer);
26+
}
27+
28+
void* createRasterizerState(void* rasterizer)
29+
{
30+
return (void*)((CudaRasterizer::Rasterizer*)rasterizer)->createInternalState();
2731
}
2832

29-
void deleteRasterizerState(void* state)
33+
void deleteRasterizerState(void* rasterizer, void* state)
3034
{
31-
cudaRenderer->killInternalState((CudaRasterizer::InternalState*)state);
35+
((CudaRasterizer::Rasterizer*)rasterizer)->killInternalState((CudaRasterizer::InternalState*)state);
3236
}
3337

3438
std::tuple<torch::Tensor, torch::Tensor>
3539
RasterizeGaussiansCUDA(
40+
void* rasterizer,
3641
const torch::Tensor& background,
3742
const torch::Tensor& means3D,
3843
const torch::Tensor& colors,
@@ -58,11 +63,6 @@ RasterizeGaussiansCUDA(
5863
AT_ERROR("means3D must have dimensions (num_points, 3)");
5964
}
6065

61-
if (cudaRenderer == nullptr)
62-
{
63-
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
64-
}
65-
6666
const int P = means3D.size(0);
6767
const int N = 1; // batch size hard-coded
6868
const int H = image_height;
@@ -82,7 +82,7 @@ RasterizeGaussiansCUDA(
8282
M = sh.size(1);
8383
}
8484

85-
cudaRenderer->forward(P, degree, M,
85+
((CudaRasterizer::Rasterizer*)rasterizer)->forward(P, degree, M,
8686
background.contiguous().data<float>(),
8787
W, H,
8888
means3D.contiguous().data<float>(),
@@ -108,6 +108,7 @@ RasterizeGaussiansCUDA(
108108

109109
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
110110
RasterizeGaussiansBackwardCUDA(
111+
void* rasterizer,
111112
const void* internalState,
112113
const torch::Tensor& background,
113114
const torch::Tensor& means3D,
@@ -148,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
148149

149150
if(P != 0)
150151
{
151-
cudaRenderer->backward(
152+
((CudaRasterizer::Rasterizer*)rasterizer)->backward(
152153
radii.contiguous().data<int>(),
153154
(CudaRasterizer::InternalState*)internalState,
154155
P, degree, M,
@@ -182,22 +183,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
182183
}
183184

184185
torch::Tensor markVisible(
186+
void* rasterizer,
185187
torch::Tensor& means3D,
186188
torch::Tensor& viewmatrix,
187189
torch::Tensor& projmatrix)
188190
{
189-
if (cudaRenderer == nullptr)
190-
{
191-
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
192-
}
193191

194192
const int P = means3D.size(0);
195193

196194
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
197195

198196
if(P != 0)
199197
{
200-
cudaRenderer->markVisible(P,
198+
((CudaRasterizer::Rasterizer*)rasterizer)->markVisible(P,
201199
means3D.contiguous().data<float>(),
202200
viewmatrix.contiguous().data<float>(),
203201
projmatrix.contiguous().data<float>(),

rasterize_points.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
std::tuple<torch::Tensor, torch::Tensor>
1010
RasterizeGaussiansCUDA(
11+
void* rasterizer,
1112
const torch::Tensor& background,
1213
const torch::Tensor& means3D,
1314
const torch::Tensor& colors,
@@ -30,6 +31,7 @@ RasterizeGaussiansCUDA(
3031

3132
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
3233
RasterizeGaussiansBackwardCUDA(
34+
void* rasterizer,
3335
const void* internalState,
3436
const torch::Tensor& background,
3537
const torch::Tensor& means3D,
@@ -48,11 +50,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
4850
const int degree,
4951
const torch::Tensor& campos);
5052

51-
void* createRasterizerState();
53+
void* createRasterizerState(void* rasterizer);
5254

53-
void deleteRasterizerState(void* state);
55+
void deleteRasterizerState(void* rasterizer, void* state);
56+
57+
void* createRasterizer();
58+
59+
void deleteRasterizer(void* rasterizer);
5460

5561
torch::Tensor markVisible(
62+
void* rasterizer,
5663
torch::Tensor& means3D,
5764
torch::Tensor& viewmatrix,
5865
torch::Tensor& projmatrix);

0 commit comments

Comments
 (0)