Skip to content

Commit 2360c2c

Browse files
committed
add depth rasterization
for now, there's only one buffer and this only serves visualization purpose, a better implementation is to use two color attachments and do the blending in one pass, which should be much faster
1 parent 66c25a4 commit 2360c2c

File tree

5 files changed

+282
-7
lines changed

5 files changed

+282
-7
lines changed

fast_gauss/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ class GaussianRasterizationSettings(NamedTuple):
1616
scale_modifier: float
1717
viewmatrix: torch.Tensor
1818
projmatrix: torch.Tensor
19-
sh_degree: int
2019
campos: torch.Tensor
21-
prefiltered: bool
22-
debug: bool
20+
sh_degree: int = 3
21+
prefiltered: bool = True
22+
debug: bool = False
23+
use_depth: bool = False
2324

2425

2526
class GaussianRasterizer:

fast_gauss/gsplat_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ def opengl_options(self):
8787
def compile_shaders(self):
8888
try:
8989
self.gsplat_program = shaders.compileProgram(
90-
shaders.compileShader(load_shader_source('gsplat.vert'), gl.GL_VERTEX_SHADER),
91-
shaders.compileShader(load_shader_source('gsplat.geom'), gl.GL_GEOMETRY_SHADER),
92-
shaders.compileShader(load_shader_source('gsplat.frag'), gl.GL_FRAGMENT_SHADER)
90+
shaders.compileShader(load_shader_source('dsplat.vert'), gl.GL_VERTEX_SHADER),
91+
shaders.compileShader(load_shader_source('dsplat.geom'), gl.GL_GEOMETRY_SHADER),
92+
shaders.compileShader(load_shader_source('dsplat.frag'), gl.GL_FRAGMENT_SHADER)
9393
)
9494
except Exception as e:
9595
print(str(e).encode('utf-8').decode('unicode_escape'))
@@ -102,6 +102,7 @@ def use_gl_program(self, program: shaders.ShaderProgram):
102102
self.uniforms.focal = gl.glGetUniformLocation(program, "focal")
103103
self.uniforms.principal = gl.glGetUniformLocation(program, "principal")
104104
self.uniforms.basisViewport = gl.glGetUniformLocation(program, "basisViewport")
105+
self.uniforms.useDepth = gl.glGetUniformLocation(program, "useDepth")
105106

106107
def upload_gl_uniforms(self, raster_settings: 'GaussianRasterizationSettings'):
107108
# FIXME: Possible nasty synchronization issue: raster_settings might contain cuda tensors
@@ -137,6 +138,7 @@ def upload_gl_uniforms(self, raster_settings: 'GaussianRasterizationSettings'):
137138
gl.glUniform2f(self.uniforms.focal, 0.5 * raster_settings.image_width / raster_settings.tanfovx, 0.5 * raster_settings.image_height / raster_settings.tanfovy) # focal in pixel space
138139
gl.glUniform2f(self.uniforms.principal, cx, cy) # focal
139140
gl.glUniform2f(self.uniforms.basisViewport, 1 / raster_settings.image_width, 1 / raster_settings.image_height) # focal
141+
gl.glUniform1i(self.uniforms.useDepth, 1 if raster_settings.use_depth else 0)
140142

141143
def init_gl_buffers(self, v: int = 0):
142144
from cuda import cudart
@@ -398,6 +400,6 @@ def rasterize_gaussians(
398400
image, alpha = image.permute(2, 0, 1), alpha.permute(2, 0, 1)
399401

400402
# FIXME: Alpha channel seems to be bugged
401-
return image.float(), torch.ones_like(alpha).float()
403+
return image.float(), alpha.float()
402404
else:
403405
return None, None

fast_gauss/shaders/dsplat.frag

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#version 330
2+
#pragma vscode_glsllint_stage : frag
3+
4+
/**
5+
Almost empty fragment shader for computing the final output color.
6+
Note that the geometry should've been emitted in the order they want to be rendered (back to front) and blended normally
7+
*/
8+
9+
in vec2 vPosition;
10+
flat in vec4 vColor;
11+
flat in float vDepth;
12+
13+
uniform bool useDepth = false;
14+
uniform float eight = 8;
15+
uniform float minAlpha = 1 / 255;
16+
uniform float maxAlpha = 0.99;
17+
18+
layout(location = 0) out vec4 write_color;
19+
// layout(location = 1) out vec4 write_depth;
20+
21+
void main() {
22+
// Compute the positional squared distance from the center of the splat to the current fragment.
23+
float A = dot(vPosition, vPosition);
24+
25+
// Since the positional data in vPosition has been scaled by sqrt(8), the squared result will be
26+
// scaled by a factor of 8. If the squared result is larger than 8, it means it is outside the ellipse
27+
// defined by the rectangle formed by vPosition. It also means it's farther
28+
// away than sqrt(8) standard deviations from the mean.
29+
if (A > eight) discard;
30+
float power = -0.5 * A;
31+
// if (power > 0.0f)
32+
// discard;
33+
34+
// Since the rendered splat is scaled by sqrt(8), the inverse covariance matrix that is part of
35+
// the gaussian formula becomes the identity matrix. We're then left with (X - mean) * (X - mean),
36+
// and since 'mean' is zero, we have X * X, which is the same as A:
37+
float opacity = exp(power) * vColor.a;
38+
// float opacity = exp(-0.5 * A) * vColor.a;
39+
if (opacity < minAlpha)
40+
discard;
41+
// opacity = min(maxAlpha, opacity);
42+
43+
if (useDepth)
44+
write_color = vec4(vec3(vDepth), opacity);
45+
else
46+
write_color = vec4(vColor.rgb, opacity);
47+
}

fast_gauss/shaders/dsplat.geom

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#version 330
2+
#pragma vscode_glsllint_stage : geom
3+
4+
/**
5+
Given center point and computed basis vectors, emit the four quad positions
6+
This computes the virtual position to be used by fragment shader to compute the gaussian fall off
7+
The RGBA color of the GS is also passed as is to the geometry shader
8+
*/
9+
10+
uniform vec2 basisViewport;
11+
uniform float discardAlpha = 0.0001;
12+
uniform float sqrt8 = sqrt(8);
13+
// uniform float sqrt8 = 3;
14+
15+
layout(points) in;
16+
layout(triangle_strip, max_vertices = 4) out;
17+
18+
in vec2 basisVector0[];
19+
in vec2 basisVector1[];
20+
in vec4 gColor[];
21+
in float gDepth[];
22+
23+
out vec2 vPosition;
24+
flat out vec4 vColor; // pass through
25+
flat out float vDepth; // pass through
26+
27+
vec2 computeNDCOffset(vec2 basisVector0, vec2 basisVector1, vec2 vPosition) {
28+
return (vPosition.x * basisVector0 + vPosition.y * basisVector1) * basisViewport * 2.0;
29+
}
30+
31+
void main() {
32+
if (gColor[0].a < discardAlpha) {
33+
return; // will not emit any quad for later rendering
34+
}
35+
36+
vec2 ndcOffset;
37+
vec3 ndcCenter = gl_in[0].gl_Position.xyz;
38+
39+
vPosition.x = -1;
40+
vPosition.y = -1;
41+
ndcOffset = computeNDCOffset(basisVector0[0], basisVector1[0], vPosition); // compute offset
42+
gl_Position = vec4(ndcCenter.xy + ndcOffset, ndcCenter.z, 1.0); // store output
43+
vPosition *= sqrt8; // store output
44+
vColor = gColor[0]; // pass through
45+
vDepth = gDepth[0]; // pass through
46+
EmitVertex();
47+
48+
vPosition.x = -1;
49+
vPosition.y = 1;
50+
ndcOffset = computeNDCOffset(basisVector0[0], basisVector1[0], vPosition); // compute offset
51+
gl_Position = vec4(ndcCenter.xy + ndcOffset, ndcCenter.z, 1.0); // store output
52+
vPosition *= sqrt8; // store output
53+
vColor = gColor[0]; // pass through
54+
vDepth = gDepth[0]; // pass through
55+
EmitVertex();
56+
57+
vPosition.x = 1;
58+
vPosition.y = -1;
59+
ndcOffset = computeNDCOffset(basisVector0[0], basisVector1[0], vPosition); // compute offset
60+
gl_Position = vec4(ndcCenter.xy + ndcOffset, ndcCenter.z, 1.0); // store output
61+
vPosition *= sqrt8; // store output
62+
vColor = gColor[0]; // pass through
63+
vDepth = gDepth[0]; // pass through
64+
EmitVertex();
65+
66+
vPosition.x = 1;
67+
vPosition.y = 1;
68+
ndcOffset = computeNDCOffset(basisVector0[0], basisVector1[0], vPosition); // compute offset
69+
gl_Position = vec4(ndcCenter.xy + ndcOffset, ndcCenter.z, 1.0); // store output
70+
vPosition *= sqrt8; // store output
71+
vColor = gColor[0]; // pass through
72+
vDepth = gDepth[0]; // pass through
73+
EmitVertex();
74+
75+
EndPrimitive();
76+
}

fast_gauss/shaders/dsplat.vert

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#version 330
2+
#pragma vscode_glsllint_stage : vert
3+
4+
/**
5+
Compute covariance relatec computation & projection etc
6+
This emits the 2 2D basis vector for geometry shader's quad construction
7+
The RGBA color of the GS is also passed as is to the geometry shader
8+
*/
9+
10+
// uniform int H;
11+
// uniform int W;
12+
// uniform float n;
13+
// uniform float f;
14+
// uniform mat3x3 K;
15+
16+
uniform mat4x4 P;
17+
uniform mat4x4 VM;
18+
uniform vec2 focal;
19+
uniform vec2 principal;
20+
uniform vec2 basisViewport;
21+
22+
// uniform float discardAlpha = 0.1;
23+
uniform float discardAlpha = 0.0001;
24+
// uniform float discardAlpha = 0.15;
25+
uniform float maxScreenSpaceSplatSize = 2048.0;
26+
uniform float sqrt8 = sqrt(8);
27+
// uniform float sqrt8 = 3;
28+
29+
layout(location = 0) in vec3 aPos; // xyz
30+
layout(location = 1) in vec3 aCov0_3; // cov6
31+
layout(location = 2) in vec3 aCov3_6; // cov6
32+
layout(location = 3) in vec4 aColor; // rgba
33+
34+
out vec2 basisVector0;
35+
out vec2 basisVector1;
36+
out vec4 gColor; // pass through
37+
out float gDepth; // pass through
38+
39+
void main() {
40+
if (aColor.a < discardAlpha) {
41+
gColor.a = 0.0; // will not emit things in geometry shader
42+
return;
43+
}
44+
45+
// Compute the view and clip space coordinates of the center of the ellipse
46+
vec4 viewCenter = VM * vec4(aPos, 1.0);
47+
vec4 clipCenter = P * vec4(aPos, 1.0);
48+
clipCenter = clipCenter / clipCenter.w; // perspective division
49+
50+
// Construct the 3D covariance matrix
51+
mat3 Vrk = mat3(
52+
aCov0_3[0], aCov0_3[1], aCov0_3[2],
53+
aCov0_3[1], aCov3_6[0], aCov3_6[1],
54+
aCov0_3[2], aCov3_6[1], aCov3_6[2]);
55+
56+
// Construct the Jacobian of the affine approximation of the projection matrix. It will be used to transform the
57+
float width = 1 / basisViewport[0];
58+
float height = 1 / basisViewport[1];
59+
float fx = focal[0];
60+
float fy = focal[1];
61+
float cx = principal[0];
62+
float cy = principal[1];
63+
float x = viewCenter.x;
64+
float y = viewCenter.y;
65+
float z = viewCenter.z;
66+
67+
float tan_fovx = 0.5 * width / fx;
68+
float tan_fovy = 0.5 * height / fy;
69+
float lim_x_pos = (width - cx) / fx + 0.3 * tan_fovx;
70+
float lim_x_neg = cx / fx + 0.3 * tan_fovx;
71+
float lim_y_pos = (height - cy) / fy + 0.3 * tan_fovy;
72+
float lim_y_neg = cy / fy + 0.3 * tan_fovy;
73+
74+
float rz = 1.0 / z;
75+
float rz2 = rz * rz;
76+
float tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz));
77+
float ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz));
78+
79+
mat3 J = mat3(
80+
fx * rz, 0., -fx * tx * rz2,
81+
0., fy * rz, -fy * ty * rz2,
82+
0., 0., 0.);
83+
84+
// Concatenate the projection approximation with the model-view transformation
85+
mat3 W = transpose(mat3(VM));
86+
mat3 T = W * J;
87+
88+
// Transform the 3D covariance matrix (Vrk) to compute the 2D covariance matrix
89+
mat3 cov2Dm = transpose(T) * Vrk * T;
90+
cov2Dm[0][0] += 0.3;
91+
cov2Dm[1][1] += 0.3;
92+
93+
// We are interested in the upper-left 2x2 portion of the projected 3D covariance matrix because
94+
// we only care about the X and Y values. We want the X-diagonal, cov2Dm[0][0],
95+
// the Y-diagonal, cov2Dm[1][1], and the correlation between the two cov2Dm[0][1]. We don't
96+
// need cov2Dm[1][0] because it is a symetric matrix.
97+
vec3 cov2Dv = vec3(cov2Dm[0][0], cov2Dm[0][1], cov2Dm[1][1]);
98+
99+
// We now need to solve for the eigen-values and eigen vectors of the 2D covariance matrix
100+
// so that we can determine the 2D basis for the splat. This is done using the method described
101+
// here: https://people.math.harvard.edu/~knill/teaching/math21b2004/exhibits/2dmatrices/index.html
102+
// After calculating the eigen-values and eigen-vectors, we calculate the basis for rendering the splat
103+
// by normalizing the eigen-vectors and then multiplying them by (sqrt(8) * eigen-value), which is
104+
// equal to scaling them by sqrt(8) standard deviations.
105+
//
106+
// This is a different approach than in the original work at INRIA. In that work they compute the
107+
// max extents of the projected splat in screen space to form a screen-space aligned bounding rectangle
108+
// which forms the geometry that is actually rasterized. The dimensions of that bounding box are 3.0
109+
// times the maximum eigen-value, or 3 standard deviations. They then use the inverse 2D covariance
110+
// matrix (called 'conic') in the CUDA rendering thread to determine fragment opacity by calculating the
111+
// full gaussian: exp(-0.5 * (X - mean) * conic * (X - mean)) * splat opacity
112+
float a = cov2Dv.x;
113+
float d = cov2Dv.z;
114+
float b = cov2Dv.y;
115+
float D = a * d - b * b;
116+
117+
if (D <= 0.0 || cov2Dv.x <= 0.0 || cov2Dv.z <= 0.0) {
118+
// Illegal cov matrix, this point should be pruned with zero gradients
119+
gColor.a = 0.0; // will not emit things
120+
return;
121+
}
122+
123+
float trace = a + d;
124+
float traceOver2 = 0.5 * trace;
125+
float term2 = sqrt(max(0.1f, traceOver2 * traceOver2 - D));
126+
float eigenValue0 = traceOver2 + term2;
127+
float eigenValue1 = traceOver2 - term2;
128+
129+
// if (eigenValue1 < 0) {
130+
// gColor.a = 0.0; // will not emit things
131+
// return;
132+
// }
133+
if (eigenValue0 <= 0.01 || eigenValue1 <= 0.01 || eigenValue0 < eigenValue1 || (eigenValue0 / eigenValue1) > 10000.0) {
134+
gColor.a = 0.0; // will not emit things
135+
return;
136+
}
137+
138+
vec2 eigenVector0 = normalize(vec2(b, eigenValue0 - a));
139+
// since the eigen vectors are orthogonal, we derive the second one from the first
140+
vec2 eigenVector1 = vec2(eigenVector0.y, -eigenVector0.x);
141+
142+
// We use sqrt(8) standard deviations instead of 3 to eliminate more of the splat with a very low opacity.
143+
basisVector0 = eigenVector0 * min(sqrt8 * sqrt(eigenValue0), maxScreenSpaceSplatSize);
144+
basisVector1 = eigenVector1 * min(sqrt8 * sqrt(eigenValue1), maxScreenSpaceSplatSize);
145+
146+
gl_Position = clipCenter; // doing a perspective projection to ndc space
147+
gColor = aColor; // passing through
148+
gDepth = clipCenter.z; // passing through
149+
}

0 commit comments

Comments
 (0)