@@ -439,47 +439,59 @@ function checkCompute(
439439 numInvs : number ,
440440 filter : ( id : number , size : number ) => boolean
441441) : Error | undefined {
442- const mapping = new Map < number , number [ ] > ( ) ;
442+ const sub_unique_id_to_inv_idx = new Map < number , number [ ] > ( ) ;
443443 const empty = [ ...iterRange ( 128 , x => - 1 ) ] ;
444- for ( let i = 0 ; i < numInvs ; i ++ ) {
445- const id = metadata [ i ] ;
446- const subgroup_id = metadata [ i + numInvs ] ;
447- const v = mapping . get ( subgroup_id ) ?? Array . from ( empty ) ;
448- v [ id ] = i ;
449- mapping . set ( subgroup_id , v ) ;
444+ for ( let inv = 0 ; inv < numInvs ; inv ++ ) {
445+ const id = metadata [ inv ] ;
446+ const subgroup_unique_id = metadata [ inv + numInvs ] ;
447+ const v = sub_unique_id_to_inv_idx . get ( subgroup_unique_id ) ?? Array . from ( empty ) ;
448+ v [ id ] = inv ;
449+ sub_unique_id_to_inv_idx . set ( subgroup_unique_id , v ) ;
450450 }
451451
452- for ( let i = 0 ; i < numInvs ; i ++ ) {
453- const id = metadata [ i ] ;
454- const subgroup_id = metadata [ i + numInvs ] ;
455-
456- const subgroupMapping = mapping . get ( subgroup_id ) ?? empty ;
452+ for ( let inv = 0 ; inv < numInvs ; inv ++ ) {
453+ const id = metadata [ inv ] ;
454+ const subgroup_unique_id = metadata [ inv + numInvs ] ;
455+ const sub_inv_id_to_inv_idx = sub_unique_id_to_inv_idx . get ( subgroup_unique_id ) ?? empty ;
457456
458- const res = output [ i ] ;
459- const size = output [ i + numInvs ] ;
457+ const res = output [ inv ] ;
458+ const size = output [ inv + numInvs ] ;
460459
460+ // subgroup id predicated in shader
461461 if ( ! filter ( id , size ) ) {
462462 continue ;
463463 }
464464
465- let inputValue = input [ i ] ;
465+ let inputValue = input [ inv ] ;
466466 if ( op !== 'subgroupShuffle' ) {
467- inputValue = input [ subgroupMapping [ 0 ] ] ;
467+ // Because we use 'subgroupBroadcastFirst' without predication.
468+ const first_subgroup_inv_id = 0 ;
469+ inputValue = input [ sub_inv_id_to_inv_idx [ first_subgroup_inv_id ] ] ;
468470 }
469471
470- const index = getShuffledId ( id , inputValue , op ) ;
471- if ( index < 0 || index >= 128 || subgroupMapping [ index ] === - 1 ) {
472+ const shuffled_target_id = getShuffledId ( id , inputValue , op ) ;
473+ if (
474+ shuffled_target_id < 0 ||
475+ shuffled_target_id >= 128 ||
476+ sub_inv_id_to_inv_idx [ shuffled_target_id ] === - 1
477+ ) {
472478 continue ;
473479 }
474480
475- if ( ! filter ( index , size ) ) {
481+ // subgroup id predicated in shader
482+ if ( ! filter ( shuffled_target_id , size ) ) {
476483 continue ;
477484 }
478485
479- if ( res !== subgroupMapping [ index ] ) {
480- return new Error ( `Invocation ${ i } : unexpected result
481- - expected: ${ subgroupMapping [ index ] }
482- - got: ${ res } ` ) ;
486+ if ( res !== sub_inv_id_to_inv_idx [ shuffled_target_id ] ) {
487+ return new Error ( `Invocation ${ inv } : unexpected result
488+ - expected: ${ sub_inv_id_to_inv_idx [ shuffled_target_id ] }
489+ - got: ${ res }
490+ - id = ${ id }
491+ - size = ${ size }
492+ - inputValue = ${ inputValue }
493+ - shuffled_target_id = ${ shuffled_target_id }
494+ - subgroup_unique_id = ${ subgroup_unique_id } ` ) ;
483495 }
484496 }
485497
@@ -510,7 +522,7 @@ enable subgroups;
510522diagnostic(off, subgroup_uniformity);
511523
512524@group(0) @binding(0)
513- var<storage> input : array<u32>;
525+ var<storage> input : array<u32, ${ wgThreads } >;
514526
515527struct Output {
516528 res : array<u32, ${ wgThreads } >,
@@ -578,17 +590,17 @@ g.test('compute,split')
578590 const testcase = kPredicateCases [ t . params . predicate ] ;
579591 const wgThreads = t . params . wgSize [ 0 ] * t . params . wgSize [ 1 ] * t . params . wgSize [ 2 ] ;
580592
581- let value = `input[lid ]` ;
593+ let value = `input[lidx ]` ;
582594 if ( t . params . op !== 'subgroupShuffle' ) {
583- value = `subgroupBroadcastFirst(input[lid ])` ;
595+ value = `subgroupBroadcastFirst(input[lidx ])` ;
584596 }
585597
586598 const wgsl = `
587599enable subgroups;
588600diagnostic(off, subgroup_uniformity);
589601
590602@group(0) @binding(0)
591- var<storage> input : array<u32>;
603+ var<storage> input : array<u32, ${ wgThreads } >;
592604
593605struct Output {
594606 res : array<u32, ${ wgThreads } >,
@@ -608,24 +620,26 @@ var<storage, read_write> metadata : Metadata;
608620
609621@compute @workgroup_size(${ t . params . wgSize [ 0 ] } , ${ t . params . wgSize [ 1 ] } , ${ t . params . wgSize [ 2 ] } )
610622fn main(
611- @builtin(local_invocation_index) lid : u32,
623+ @builtin(local_invocation_index) lidx : u32,
612624 @builtin(subgroup_invocation_id) id : u32,
613625 @builtin(subgroup_size) subgroupSize : u32,
614626) {
615627 _ = input[0];
616- metadata.id[lid] = id;
617- metadata.subgroup_id[lid] = subgroupBroadcastFirst(lid + 1); // avoid 0
628+ metadata.id[lidx] = id;
629+ // Made from lidx but not lidx to avoid value confusion.
630+ var fake_unique_id = lidx + 1000;
631+ metadata.subgroup_id[lidx] = subgroupBroadcastFirst(fake_unique_id);
618632
619- output.size[lid ] = subgroupSize;
633+ output.size[lidx ] = subgroupSize;
620634 let value = ${ value } ;
621635 if ${ testcase . cond } {
622- output.res[lid ] = ${ t . params . op } (lid , value);
636+ output.res[lidx ] = ${ t . params . op } (lidx , value);
623637 } else {
624638 return;
625639 }
626640}` ;
627641
628- const inputArray = new Uint32Array ( [ ...iterRange ( 128 , x => x ) ] ) ;
642+ const inputArray = new Uint32Array ( [ ...iterRange ( wgThreads , x => x % 128 ) ] ) ;
629643 const numUintsPerOutput = 2 ;
630644 await runComputeTest (
631645 t ,
0 commit comments