Skip to content

Commit eb2d673

Browse files
mvaligurskyMartin Valigursky
andauthored
feat: add compute-based tiled GSplat renderer (WebGPU) (#8531)
Adds GSplatComputeGlobalRenderer, a WebGPU compute shader pipeline that renders Gaussian splats via per-tile sorting and rasterization. The pipeline takes depth-sorted splats and runs 7 compute passes: tile count, prefix sum, expand, prepare sort, radix sort, tile ranges, and rasterize. Each tile's splats are blended front-to-back with per-thread early-out on transmittance saturation. Half-precision colors in shared memory reduce register pressure in the rasterizer. Also adds Camera.beforePasses for scheduling compute work before the main render pass, and extends ComputeRadixSort with a sortedKeys getter and configurable skipLastPassKeyWrite optimization. Tile intersection uses a StopThePop closest-point test by default with a FlashGS exact conic-edge fallback. Distance-adaptive contribution culling and configurable minPixelSize match the quad-based renderer. Currently disabled (USE_COMPUTE_RENDERER = false) pending further integration testing. Made-with: Cursor Co-authored-by: Martin Valigursky <mvaligursky@snapchat.com>
1 parent 99919b6 commit eb2d673

13 files changed

+1487
-18
lines changed

src/scene/camera.js

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ class Camera {
6868
*/
6969
framePasses = [];
7070

71+
/**
72+
* Frame passes that execute before this camera's main rendering. These are added to the
73+
* frame graph at first camera use, before any render actions or framePasses.
74+
*
75+
* @type {FramePass[]}
76+
*/
77+
beforePasses = [];
78+
7179
/** @type {number} */
7280
jitter = 0;
7381

src/scene/graphics/compute-radix-sort.js

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,14 @@ class ComputeRadixSort {
199199
*/
200200
_hasInitialValues = false;
201201

202+
/**
203+
* Whether the last pass skips writing sorted keys (only values are written).
204+
* When true, `sortedKeys` will contain stale data after sorting.
205+
*
206+
* @type {boolean}
207+
*/
208+
_skipLastPassKeyWrite = false;
209+
202210
/**
203211
* Creates a new ComputeRadixSort instance.
204212
*
@@ -270,14 +278,26 @@ class ComputeRadixSort {
270278
}
271279

272280
/**
273-
* Gets the sorted indices buffer.
281+
* Gets the sorted indices (or values) buffer.
274282
*
275283
* @type {StorageBuffer|null}
276284
*/
277285
get sortedIndices() {
278286
return this._sortedIndices;
279287
}
280288

289+
/**
290+
* Gets the sorted keys buffer after the last sort operation. The keys end up
291+
* in one of the internal ping-pong buffers depending on the number of passes.
292+
*
293+
* @type {StorageBuffer|null}
294+
*/
295+
get sortedKeys() {
296+
if (!this._keys0) return null;
297+
const numPasses = this._numBits / BITS_PER_PASS;
298+
return (numPasses % 2 === 0) ? this._keys1 : this._keys0;
299+
}
300+
281301
/**
282302
* Ensures bind group formats exist for the given mode. Destroys and recreates
283303
* them if switching between direct and indirect modes.
@@ -326,9 +346,10 @@ class ComputeRadixSort {
326346
* @param {number} numBits - Number of bits to sort.
327347
* @param {boolean} indirect - Whether to create indirect sort passes.
328348
* @param {boolean} hasInitialValues - Whether pass 0 reads from caller-supplied initial values.
349+
* @param {boolean} skipLastPassKeyWrite - Whether the last pass skips writing keys.
329350
* @private
330351
*/
331-
_createPasses(numBits, indirect, hasInitialValues) {
352+
_createPasses(numBits, indirect, hasInitialValues, skipLastPassKeyWrite) {
332353
// Destroy old passes and their shaders
333354
this._destroyPasses();
334355
this._numBits = numBits;
@@ -338,14 +359,15 @@ class ComputeRadixSort {
338359
this._ensureBindGroupFormats(indirect);
339360
this._indirect = indirect;
340361
this._hasInitialValues = hasInitialValues;
362+
this._skipLastPassKeyWrite = skipLastPassKeyWrite;
341363

342364
const numPasses = numBits / BITS_PER_PASS;
343365
const suffix = indirect ? '-Indirect' : '';
344366

345367
for (let pass = 0; pass < numPasses; pass++) {
346368
const bitOffset = pass * BITS_PER_PASS;
347369
const isFirstPass = pass === 0 && !hasInitialValues;
348-
const isLastPass = pass === numPasses - 1;
370+
const isLastPass = skipLastPassKeyWrite && pass === numPasses - 1;
349371

350372
const histogramShader = this._createShader(
351373
`RadixSort4bit-Histogram${suffix}-${bitOffset}`,
@@ -381,16 +403,18 @@ class ComputeRadixSort {
381403
* @param {number} numBits - Number of bits to sort.
382404
* @param {boolean} indirect - Whether passes should use indirect dispatch.
383405
* @param {boolean} hasInitialValues - Whether pass 0 reads caller-supplied initial values.
406+
* @param {boolean} skipLastPassKeyWrite - Whether the last pass skips writing keys.
384407
* @private
385408
*/
386-
_allocateBuffers(elementCount, numBits, indirect, hasInitialValues) {
409+
_allocateBuffers(elementCount, numBits, indirect, hasInitialValues, skipLastPassKeyWrite) {
387410
const workgroupCount = Math.ceil(elementCount / ELEMENTS_PER_WORKGROUP);
388411

389412
// Only reallocate buffers if we need MORE capacity (grow-only policy)
390413
const buffersNeedRealloc = workgroupCount > this._allocatedWorkgroupCount || !this._keys0;
391414

392-
// Recreate passes when numBits, indirect mode, or initial-values mode changes
393-
const passesNeedRecreate = numBits !== this._numBits || indirect !== this._indirect || hasInitialValues !== this._hasInitialValues;
415+
// Recreate passes when numBits, indirect mode, initial-values mode, or key-write mode changes
416+
const passesNeedRecreate = numBits !== this._numBits || indirect !== this._indirect ||
417+
hasInitialValues !== this._hasInitialValues || skipLastPassKeyWrite !== this._skipLastPassKeyWrite;
394418

395419
if (buffersNeedRealloc) {
396420

@@ -428,7 +452,7 @@ class ComputeRadixSort {
428452
this._prefixSumKernel.resize(this._blockSums, BUCKET_COUNT * workgroupCount);
429453

430454
if (passesNeedRecreate) {
431-
this._createPasses(numBits, indirect, hasInitialValues);
455+
this._createPasses(numBits, indirect, hasInitialValues, skipLastPassKeyWrite);
432456
}
433457
}
434458

@@ -473,14 +497,20 @@ class ComputeRadixSort {
473497
* @param {StorageBuffer} keysBuffer - Input storage buffer containing u32 keys.
474498
* @param {number} elementCount - Number of elements to sort.
475499
* @param {number} [numBits] - Number of bits to sort (must be multiple of 4). Defaults to 16.
476-
* @returns {StorageBuffer} Storage buffer containing sorted indices.
500+
* @param {StorageBuffer} [initialValues] - Optional buffer of initial values for pass 0.
501+
* When provided, the sort produces output values derived from this buffer instead of
502+
* sequential indices. The buffer is only read, never modified.
503+
* @param {boolean} [skipLastPassKeyWrite] - When true, the last pass skips writing sorted
504+
* keys for a small performance gain. Only use when sorted keys are not needed after sorting.
505+
* @returns {StorageBuffer} Storage buffer containing sorted indices (or values if
506+
* initialValues was provided).
477507
*/
478-
sort(keysBuffer, elementCount, numBits = 16) {
508+
sort(keysBuffer, elementCount, numBits = 16, initialValues, skipLastPassKeyWrite = false) {
479509
Debug.assert(keysBuffer, 'ComputeRadixSort.sort: keysBuffer is required');
480510
Debug.assert(elementCount > 0, 'ComputeRadixSort.sort: elementCount must be > 0');
481511
Debug.assert(numBits % BITS_PER_PASS === 0, `ComputeRadixSort.sort: numBits must be a multiple of ${BITS_PER_PASS}`);
482512

483-
return this._execute(keysBuffer, elementCount, numBits, false, -1, null, undefined);
513+
return this._execute(keysBuffer, elementCount, numBits, false, -1, null, initialValues, skipLastPassKeyWrite);
484514
}
485515

486516
/**
@@ -495,15 +525,17 @@ class ComputeRadixSort {
495525
* @param {StorageBuffer} [initialValues] - Optional buffer of initial values for pass 0.
496526
* When provided, the sort produces output values derived from this buffer instead of
497527
* sequential indices. The buffer is only read, never modified.
528+
* @param {boolean} [skipLastPassKeyWrite] - When true, the last pass skips writing sorted
529+
* keys for a small performance gain. Only use when sorted keys are not needed after sorting.
498530
* @returns {StorageBuffer} Storage buffer containing sorted values.
499531
*/
500-
sortIndirect(keysBuffer, maxElementCount, numBits, dispatchSlot, sortElementCountBuffer, initialValues) {
532+
sortIndirect(keysBuffer, maxElementCount, numBits, dispatchSlot, sortElementCountBuffer, initialValues, skipLastPassKeyWrite = false) {
501533
Debug.assert(keysBuffer, 'ComputeRadixSort.sortIndirect: keysBuffer is required');
502534
Debug.assert(maxElementCount > 0, 'ComputeRadixSort.sortIndirect: maxElementCount must be > 0');
503535
Debug.assert(numBits % BITS_PER_PASS === 0, `ComputeRadixSort.sortIndirect: numBits must be a multiple of ${BITS_PER_PASS}`);
504536
Debug.assert(sortElementCountBuffer, 'ComputeRadixSort.sortIndirect: sortElementCountBuffer is required');
505537

506-
return this._execute(keysBuffer, maxElementCount, numBits, true, dispatchSlot, sortElementCountBuffer, initialValues);
538+
return this._execute(keysBuffer, maxElementCount, numBits, true, dispatchSlot, sortElementCountBuffer, initialValues, skipLastPassKeyWrite);
507539
}
508540

509541
/**
@@ -516,15 +548,17 @@ class ComputeRadixSort {
516548
* @param {number} dispatchSlot - Indirect dispatch slot index (-1 for direct).
517549
* @param {StorageBuffer|null} sortElementCountBuffer - GPU-written element count (null for direct).
518550
* @param {StorageBuffer} [initialValues] - Optional initial values buffer for pass 0.
551+
* @param {boolean} [skipLastPassKeyWrite] - When true, the last pass skips writing sorted
552+
* keys for a small performance gain. Only use when sorted keys are not needed after sorting.
519553
* @returns {StorageBuffer} Storage buffer containing sorted values.
520554
* @private
521555
*/
522-
_execute(keysBuffer, elementCount, numBits, indirect, dispatchSlot, sortElementCountBuffer, initialValues) {
556+
_execute(keysBuffer, elementCount, numBits, indirect, dispatchSlot, sortElementCountBuffer, initialValues, skipLastPassKeyWrite = false) {
523557
this._elementCount = elementCount;
524558
const hasInitialValues = !!initialValues;
525559

526560
// Allocate buffers and create passes if needed
527-
this._allocateBuffers(elementCount, numBits, indirect, hasInitialValues);
561+
this._allocateBuffers(elementCount, numBits, indirect, hasInitialValues, skipLastPassKeyWrite);
528562

529563
const device = this.device;
530564
const numPasses = numBits / BITS_PER_PASS;
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import { FramePass } from '../../platform/graphics/frame-pass.js';
2+
3+
/**
4+
* @import { GSplatComputeGlobalRenderer } from './gsplat-compute-global-renderer.js'
5+
*/
6+
7+
/**
8+
* A frame pass for the global tiled compute renderer. Registered as a camera beforePass so
9+
* it runs before the main render pass. On each frame it resizes the offscreen output texture
10+
* to match the camera's render target, then dispatches the full 7-pass compute pipeline
11+
* (count/prefix-sum/expand/sort/ranges/rasterize). The rasterized result is later composited
12+
* into the render target via a full-screen quad with premultiplied blending.
13+
*
14+
* @ignore
15+
*/
16+
class FramePassGSplatComputeGlobal extends FramePass {
17+
/** @type {GSplatComputeGlobalRenderer} */
18+
renderer;
19+
20+
/**
21+
* @param {GSplatComputeGlobalRenderer} renderer - The compute renderer that owns this pass.
22+
*/
23+
constructor(renderer) {
24+
super(renderer.device);
25+
this.renderer = renderer;
26+
this.name = 'FramePassGSplatComputeGlobal';
27+
}
28+
29+
frameUpdate() {
30+
const renderer = this.renderer;
31+
const camera = renderer.cameraNode.camera;
32+
const rt = camera.renderTarget;
33+
const rtWidth = rt ? rt.width : this.device.width;
34+
const rtHeight = rt ? rt.height : this.device.height;
35+
const rect = camera.rect;
36+
const width = Math.floor(rtWidth * rect.z);
37+
const height = Math.floor(rtHeight * rect.w);
38+
renderer.resizeOutputTexture(width, height);
39+
}
40+
41+
execute() {
42+
this.renderer.dispatch();
43+
}
44+
}
45+
46+
export { FramePassGSplatComputeGlobal };

0 commit comments

Comments
 (0)