Skip to content

Commit 5517547

Browse files
GS SH View Direction (#16804)
Change space for compressed PLY and flip SH direction. followup https://forum.babylonjs.com/t/gaussian-splatting-seems-to-not-work-since-version-8/57903/37
1 parent ed001a6 commit 5517547

File tree

6 files changed

+27
-9
lines changed

6 files changed

+27
-9
lines changed

packages/dev/core/src/Materials/GaussianSplatting/gaussianSplattingMaterial.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ export class GaussianSplattingMaterial extends PushMaterial {
201201
"focal",
202202
"eyePosition",
203203
"kernelSize",
204+
"viewDirectionFactor",
204205
];
205206
const samplers = ["covariancesATexture", "covariancesBTexture", "centersTexture", "colorsTexture", "shTexture0", "shTexture1", "shTexture2"];
206207
const uniformBuffers = ["Scene", "Mesh"];
@@ -292,6 +293,7 @@ export class GaussianSplattingMaterial extends PushMaterial {
292293
}
293294

294295
effect.setFloat2("focal", focal, focal);
296+
effect.setVector3("viewDirectionFactor", gsMesh.viewDirectionFactor);
295297
effect.setFloat("kernelSize", gsMaterial && gsMaterial.kernelSize ? gsMaterial.kernelSize : GaussianSplattingMaterial.KernelSize);
296298
scene.bindEyePosition(effect, "eyePosition", true);
297299

packages/dev/core/src/Materials/Node/Blocks/GaussianSplatting/gaussianSplattingBlock.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
133133
state._emitUniformFromString("invViewport", NodeMaterialBlockConnectionPointTypes.Vector2);
134134
state._emitUniformFromString("kernelSize", NodeMaterialBlockConnectionPointTypes.Float);
135135
state._emitUniformFromString("eyePosition", NodeMaterialBlockConnectionPointTypes.Vector3);
136+
state._emitUniformFromString("viewDirectionFactor", NodeMaterialBlockConnectionPointTypes.Vector3);
136137
state.attributes.push(VertexBuffer.PositionKind);
137138
state.sharedData.nodeMaterial.backFaceCulling = false;
138139

@@ -163,13 +164,14 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
163164
state.compilationString += `let worldRot: mat3x3f = mat3x3f(${world.associatedVariableName}[0].xyz, ${world.associatedVariableName}[1].xyz, ${world.associatedVariableName}[2].xyz);`;
164165
state.compilationString += `let normWorldRot: mat3x3f = inverseMat3(worldRot);`;
165166
state.compilationString += `var dir: vec3f = normalize(normWorldRot * (${splatPosition.associatedVariableName}.xyz - uniforms.eyePosition));\n`;
167+
state.compilationString += `dir *= uniforms.viewDirectionFactor;\n`;
166168
} else {
167169
state.compilationString += `mat3 worldRot = mat3(${world.associatedVariableName});`;
168170
state.compilationString += `mat3 normWorldRot = inverseMat3(worldRot);`;
169171
state.compilationString += `vec3 dir = normalize(normWorldRot * (${splatPosition.associatedVariableName}.xyz - eyePosition));\n`;
172+
state.compilationString += `dir *= viewDirectionFactor;\n`;
170173
}
171174

172-
state.compilationString += `dir *= vec3${addF}(1.,1.,-1.);\n`;
173175
state.compilationString += `${state._declareOutput(sh)} = computeSH(splat, dir);\n`;
174176
state.compilationString += `#else\n`;
175177
state.compilationString += `${state._declareOutput(sh)} = vec3${addF}(0.,0.,0.);\n`;

packages/dev/core/src/Meshes/GaussianSplatting/gaussianSplattingMesh.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,14 @@ export class GaussianSplattingMesh extends Mesh {
307307
// batch size between 2 yield calls during the PLY to splat conversion.
308308
private static _PlyConversionBatchSize = 32768;
309309
private _shDegree = 0;
310+
private _viewDirectionFactor = new Vector3(1, 1, -1);
311+
312+
/**
313+
* View direction factor used to compute the SH view direction in the shader.
314+
*/
315+
public get viewDirectionFactor() {
316+
return this._viewDirectionFactor;
317+
}
310318

311319
/**
312320
* SH degree. 0 = no sh (default). 1 = 3 parameters. 2 = 8 parameters. 3 = 15 parameters.
@@ -931,8 +939,8 @@ export class GaussianSplattingMesh extends Mesh {
931939
const compressedChunk = compressedChunks![chunkIndex];
932940
Unpack111011(value, temp3);
933941
position[0] = Scalar.Lerp(compressedChunk.min.x, compressedChunk.max.x, temp3.x);
934-
position[1] = -Scalar.Lerp(compressedChunk.min.y, compressedChunk.max.y, temp3.y);
935-
position[2] = -Scalar.Lerp(compressedChunk.min.z, compressedChunk.max.z, temp3.z);
942+
position[1] = Scalar.Lerp(compressedChunk.min.y, compressedChunk.max.y, temp3.y);
943+
position[2] = Scalar.Lerp(compressedChunk.min.z, compressedChunk.max.z, temp3.z);
936944
}
937945
break;
938946
case PLYValue.PACKED_ROTATION:
@@ -941,8 +949,8 @@ export class GaussianSplattingMesh extends Mesh {
941949

942950
r0 = q.x;
943951
r1 = q.y;
944-
r2 = -q.z;
945-
r3 = -q.w;
952+
r2 = q.z;
953+
r3 = q.w;
946954
}
947955
break;
948956
case PLYValue.PACKED_SCALE:

packages/dev/core/src/Shaders/gaussianSplatting.vertex.fx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ uniform vec2 dataTextureSize;
1919
uniform vec2 focal;
2020
uniform float kernelSize;
2121
uniform vec3 eyePosition;
22+
uniform vec3 viewDirectionFactor;
2223

2324
uniform sampler2D covariancesATexture;
2425
uniform sampler2D covariancesBTexture;
@@ -56,7 +57,7 @@ void main () {
5657
mat3 normWorldRot = inverseMat3(worldRot);
5758

5859
vec3 dir = normalize(normWorldRot * (worldPos.xyz - eyePosition));
59-
dir *= vec3(1.,1.,-1.); // convert to Babylon Space
60+
dir *= viewDirectionFactor;
6061
vColor.xyz = splat.color.xyz + computeSH(splat, dir);
6162
#endif
6263

packages/dev/core/src/ShadersWGSL/gaussianSplatting.vertex.fx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ uniform dataTextureSize: vec2f;
1616
uniform focal: vec2f;
1717
uniform kernelSize: f32;
1818
uniform eyePosition: vec3f;
19+
uniform viewDirectionFactor: vec3f;
1920

2021
// textures
2122
var covariancesATexture: texture_2d<f32>;
@@ -53,7 +54,7 @@ fn main(input : VertexInputs) -> FragmentInputs {
5354
let normWorldRot: mat3x3f = inverseMat3(worldRot);
5455

5556
var dir: vec3f = normalize(normWorldRot * (worldPos.xyz - uniforms.eyePosition.xyz));
56-
dir *= vec3f(1.,1.,-1.); // convert to Babylon Space
57+
dir *= viewDirectionFactor;
5758
vertexOutputs.vColor = vec4f(splat.color.xyz + computeSH(splat, dir), splat.color.w);
5859
#else
5960
vertexOutputs.vColor = splat.color;

packages/dev/loaders/src/SPLAT/splatFileLoader.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ interface IParsedPLY {
4747
hasVertexColors?: boolean;
4848
sh?: Uint8Array[];
4949
trainedWithAntialiasing?: boolean;
50+
compressed?: boolean;
5051
}
5152

5253
/**
@@ -380,6 +381,9 @@ export class SPLATFileLoader implements ISceneLoaderPluginAsync, ISceneLoaderPlu
380381
gaussianSplatting._parentContainer = this._assetContainer;
381382
babylonMeshesArray.push(gaussianSplatting);
382383
gaussianSplatting.updateData(parsedPLY.data, parsedPLY.sh);
384+
if (parsedPLY.compressed) {
385+
gaussianSplatting.viewDirectionFactor.set(-1, -1, 1);
386+
}
383387
}
384388
break;
385389
case Mode.PointCloud:
@@ -577,7 +581,7 @@ export class SPLATFileLoader implements ISceneLoaderPluginAsync, ISceneLoaderPlu
577581
// early exit for chunked/quantized ply
578582
if (chunkCount) {
579583
return await new Promise((resolve) => {
580-
resolve({ mode: Mode.Splat, data: splatsData.buffer, sh: splatsData.sh, faces: faces, hasVertexColors: false });
584+
resolve({ mode: Mode.Splat, data: splatsData.buffer, sh: splatsData.sh, faces: faces, hasVertexColors: false, compressed: true });
581585
});
582586
}
583587
// count available properties. if all necessary are present then it's a splat. Otherwise, it's a point cloud
@@ -599,7 +603,7 @@ export class SPLATFileLoader implements ISceneLoaderPluginAsync, ISceneLoaderPlu
599603
const currentMode = faceCount ? Mode.Mesh : hasMandatoryProperties ? Mode.Splat : Mode.PointCloud;
600604
// parsed ready ready to be used as a splat
601605
return await new Promise((resolve) => {
602-
resolve({ mode: currentMode, data: splatsData.buffer, sh: splatsData.sh, faces: faces, hasVertexColors: !!propertyColorCount });
606+
resolve({ mode: currentMode, data: splatsData.buffer, sh: splatsData.sh, faces: faces, hasVertexColors: !!propertyColorCount, compressed: false });
603607
});
604608
});
605609
}

0 commit comments

Comments
 (0)