@@ -199,6 +199,14 @@ class ComputeRadixSort {
199199 */
200200 _hasInitialValues = false ;
201201
202+ /**
203+ * Whether the last pass skips writing sorted keys (only values are written).
204+ * When true, `sortedKeys` will contain stale data after sorting.
205+ *
206+ * @type {boolean }
207+ */
208+ _skipLastPassKeyWrite = false ;
209+
202210 /**
203211 * Creates a new ComputeRadixSort instance.
204212 *
@@ -270,14 +278,26 @@ class ComputeRadixSort {
270278 }
271279
272280 /**
273- * Gets the sorted indices buffer.
281+ * Gets the sorted indices (or values) buffer.
274282 *
275283 * @type {StorageBuffer|null }
276284 */
277285 get sortedIndices ( ) {
278286 return this . _sortedIndices ;
279287 }
280288
289+ /**
290+ * Gets the sorted keys buffer after the last sort operation. The keys end up
291+ * in one of the internal ping-pong buffers depending on the number of passes.
292+ *
293+ * @type {StorageBuffer|null }
294+ */
295+ get sortedKeys ( ) {
296+ if ( ! this . _keys0 ) return null ;
297+ const numPasses = this . _numBits / BITS_PER_PASS ;
298+ return ( numPasses % 2 === 0 ) ? this . _keys1 : this . _keys0 ;
299+ }
300+
281301 /**
282302 * Ensures bind group formats exist for the given mode. Destroys and recreates
283303 * them if switching between direct and indirect modes.
@@ -326,9 +346,10 @@ class ComputeRadixSort {
326346 * @param {number } numBits - Number of bits to sort.
327347 * @param {boolean } indirect - Whether to create indirect sort passes.
328348 * @param {boolean } hasInitialValues - Whether pass 0 reads from caller-supplied initial values.
349+ * @param {boolean } skipLastPassKeyWrite - Whether the last pass skips writing keys.
329350 * @private
330351 */
331- _createPasses ( numBits , indirect , hasInitialValues ) {
352+ _createPasses ( numBits , indirect , hasInitialValues , skipLastPassKeyWrite ) {
332353 // Destroy old passes and their shaders
333354 this . _destroyPasses ( ) ;
334355 this . _numBits = numBits ;
@@ -338,14 +359,15 @@ class ComputeRadixSort {
338359 this . _ensureBindGroupFormats ( indirect ) ;
339360 this . _indirect = indirect ;
340361 this . _hasInitialValues = hasInitialValues ;
362+ this . _skipLastPassKeyWrite = skipLastPassKeyWrite ;
341363
342364 const numPasses = numBits / BITS_PER_PASS ;
343365 const suffix = indirect ? '-Indirect' : '' ;
344366
345367 for ( let pass = 0 ; pass < numPasses ; pass ++ ) {
346368 const bitOffset = pass * BITS_PER_PASS ;
347369 const isFirstPass = pass === 0 && ! hasInitialValues ;
348- const isLastPass = pass === numPasses - 1 ;
370+ const isLastPass = skipLastPassKeyWrite && pass === numPasses - 1 ;
349371
350372 const histogramShader = this . _createShader (
351373 `RadixSort4bit-Histogram${ suffix } -${ bitOffset } ` ,
@@ -381,16 +403,18 @@ class ComputeRadixSort {
381403 * @param {number } numBits - Number of bits to sort.
382404 * @param {boolean } indirect - Whether passes should use indirect dispatch.
383405 * @param {boolean } hasInitialValues - Whether pass 0 reads caller-supplied initial values.
406+ * @param {boolean } skipLastPassKeyWrite - Whether the last pass skips writing keys.
384407 * @private
385408 */
386- _allocateBuffers ( elementCount , numBits , indirect , hasInitialValues ) {
409+ _allocateBuffers ( elementCount , numBits , indirect , hasInitialValues , skipLastPassKeyWrite ) {
387410 const workgroupCount = Math . ceil ( elementCount / ELEMENTS_PER_WORKGROUP ) ;
388411
389412 // Only reallocate buffers if we need MORE capacity (grow-only policy)
390413 const buffersNeedRealloc = workgroupCount > this . _allocatedWorkgroupCount || ! this . _keys0 ;
391414
392- // Recreate passes when numBits, indirect mode, or initial-values mode changes
393- const passesNeedRecreate = numBits !== this . _numBits || indirect !== this . _indirect || hasInitialValues !== this . _hasInitialValues ;
415+ // Recreate passes when numBits, indirect mode, initial-values mode, or key-write mode changes
416+ const passesNeedRecreate = numBits !== this . _numBits || indirect !== this . _indirect ||
417+ hasInitialValues !== this . _hasInitialValues || skipLastPassKeyWrite !== this . _skipLastPassKeyWrite ;
394418
395419 if ( buffersNeedRealloc ) {
396420
@@ -428,7 +452,7 @@ class ComputeRadixSort {
428452 this . _prefixSumKernel . resize ( this . _blockSums , BUCKET_COUNT * workgroupCount ) ;
429453
430454 if ( passesNeedRecreate ) {
431- this . _createPasses ( numBits , indirect , hasInitialValues ) ;
455+ this . _createPasses ( numBits , indirect , hasInitialValues , skipLastPassKeyWrite ) ;
432456 }
433457 }
434458
@@ -473,14 +497,20 @@ class ComputeRadixSort {
473497 * @param {StorageBuffer } keysBuffer - Input storage buffer containing u32 keys.
474498 * @param {number } elementCount - Number of elements to sort.
475499 * @param {number } [numBits] - Number of bits to sort (must be multiple of 4). Defaults to 16.
476- * @returns {StorageBuffer } Storage buffer containing sorted indices.
500+ * @param {StorageBuffer } [initialValues] - Optional buffer of initial values for pass 0.
501+ * When provided, the sort produces output values derived from this buffer instead of
502+ * sequential indices. The buffer is only read, never modified.
503+ * @param {boolean } [skipLastPassKeyWrite] - When true, the last pass skips writing sorted
504+ * keys for a small performance gain. Only use when sorted keys are not needed after sorting.
505+ * @returns {StorageBuffer } Storage buffer containing sorted indices (or values if
506+ * initialValues was provided).
477507 */
478- sort ( keysBuffer , elementCount , numBits = 16 ) {
508+ sort ( keysBuffer , elementCount , numBits = 16 , initialValues , skipLastPassKeyWrite = false ) {
479509 Debug . assert ( keysBuffer , 'ComputeRadixSort.sort: keysBuffer is required' ) ;
480510 Debug . assert ( elementCount > 0 , 'ComputeRadixSort.sort: elementCount must be > 0' ) ;
481511 Debug . assert ( numBits % BITS_PER_PASS === 0 , `ComputeRadixSort.sort: numBits must be a multiple of ${ BITS_PER_PASS } ` ) ;
482512
483- return this . _execute ( keysBuffer , elementCount , numBits , false , - 1 , null , undefined ) ;
513+ return this . _execute ( keysBuffer , elementCount , numBits , false , - 1 , null , initialValues , skipLastPassKeyWrite ) ;
484514 }
485515
486516 /**
@@ -495,15 +525,17 @@ class ComputeRadixSort {
495525 * @param {StorageBuffer } [initialValues] - Optional buffer of initial values for pass 0.
496526 * When provided, the sort produces output values derived from this buffer instead of
497527 * sequential indices. The buffer is only read, never modified.
528+ * @param {boolean } [skipLastPassKeyWrite] - When true, the last pass skips writing sorted
529+ * keys for a small performance gain. Only use when sorted keys are not needed after sorting.
498530 * @returns {StorageBuffer } Storage buffer containing sorted values.
499531 */
500- sortIndirect ( keysBuffer , maxElementCount , numBits , dispatchSlot , sortElementCountBuffer , initialValues ) {
532+ sortIndirect ( keysBuffer , maxElementCount , numBits , dispatchSlot , sortElementCountBuffer , initialValues , skipLastPassKeyWrite = false ) {
501533 Debug . assert ( keysBuffer , 'ComputeRadixSort.sortIndirect: keysBuffer is required' ) ;
502534 Debug . assert ( maxElementCount > 0 , 'ComputeRadixSort.sortIndirect: maxElementCount must be > 0' ) ;
503535 Debug . assert ( numBits % BITS_PER_PASS === 0 , `ComputeRadixSort.sortIndirect: numBits must be a multiple of ${ BITS_PER_PASS } ` ) ;
504536 Debug . assert ( sortElementCountBuffer , 'ComputeRadixSort.sortIndirect: sortElementCountBuffer is required' ) ;
505537
506- return this . _execute ( keysBuffer , maxElementCount , numBits , true , dispatchSlot , sortElementCountBuffer , initialValues ) ;
538+ return this . _execute ( keysBuffer , maxElementCount , numBits , true , dispatchSlot , sortElementCountBuffer , initialValues , skipLastPassKeyWrite ) ;
507539 }
508540
509541 /**
@@ -516,15 +548,17 @@ class ComputeRadixSort {
516548 * @param {number } dispatchSlot - Indirect dispatch slot index (-1 for direct).
517549 * @param {StorageBuffer|null } sortElementCountBuffer - GPU-written element count (null for direct).
518550 * @param {StorageBuffer } [initialValues] - Optional initial values buffer for pass 0.
551+ * @param {boolean } [skipLastPassKeyWrite] - When true, the last pass skips writing sorted
552+ * keys for a small performance gain. Only use when sorted keys are not needed after sorting.
519553 * @returns {StorageBuffer } Storage buffer containing sorted values.
520554 * @private
521555 */
522- _execute ( keysBuffer , elementCount , numBits , indirect , dispatchSlot , sortElementCountBuffer , initialValues ) {
556+ _execute ( keysBuffer , elementCount , numBits , indirect , dispatchSlot , sortElementCountBuffer , initialValues , skipLastPassKeyWrite = false ) {
523557 this . _elementCount = elementCount ;
524558 const hasInitialValues = ! ! initialValues ;
525559
526560 // Allocate buffers and create passes if needed
527- this . _allocateBuffers ( elementCount , numBits , indirect , hasInitialValues ) ;
561+ this . _allocateBuffers ( elementCount , numBits , indirect , hasInitialValues , skipLastPassKeyWrite ) ;
528562
529563 const device = this . device ;
530564 const numPasses = numBits / BITS_PER_PASS ;
0 commit comments