Skip to content

Commit cede40a

Browse files
committed
docs
1 parent 7411e0c commit cede40a

File tree

1 file changed

+140
-16
lines changed
  • src/webgpu/shader/execution/reconvergence

1 file changed

+140
-16
lines changed

src/webgpu/shader/execution/reconvergence/util.ts

Lines changed: 140 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ function all(value: bigint, size: number): boolean {
6868
return value === ((1n << BigInt(size)) - 1n);
6969
}
7070

71+
/**
72+
* Reconvergence style being tested.
73+
*/
7174
export enum Style {
7275
// Workgroup uniform control flow
7376
Workgroup = 0,
@@ -83,6 +86,9 @@ export enum Style {
8386
WGSLv1 = 3,
8487
};
8588

89+
/**
90+
* Instruction type
91+
*/
8692
export enum OpType {
8793
// Store a ballot.
8894
// During simulation, uniform is set to false if the
@@ -172,6 +178,7 @@ export enum OpType {
172178
MAX,
173179
}
174180

181+
/** @returns The stringified version of |op|. */
175182
function serializeOpType(op: OpType): string {
176183
// prettier-ignore
177184
switch (op) {
@@ -215,6 +222,9 @@ function serializeOpType(op: OpType): string {
215222
return '';
216223
}
217224

225+
/**
226+
* Different styles of if conditions
227+
*/
218228
enum IfType {
219229
// If the mask is 0, generates a random uniform comparison
220230
// Otherwise, tests subgroup_invocation_id against a mask
@@ -254,6 +264,18 @@ class Op {
254264
}
255265
};
256266

267+
/**
268+
* Main class for testcase generation.
269+
*
270+
* Major steps involved in a test:
271+
* 1. Generation (either generate() or a predefined case)
272+
* 2. Simulation
273+
* 3. Result comparison
274+
*
275+
* The interface of the program is fixed and invariant of the particular
276+
* program being tested.
277+
*
278+
*/
257279
export class Program {
258280
// Number of invocations in the program
259281
// Max supported is 128
@@ -482,6 +504,10 @@ export class Program {
482504
}
483505
}
484506
}
507+
case 10: {
508+
this.genElect(false);
509+
break;
510+
}
485511
default: {
486512
break;
487513
}
@@ -491,6 +517,13 @@ export class Program {
491517
}
492518
}
493519

520+
/**
521+
* Ballot generation
522+
*
523+
* Can insert ballots, stores, noise into the program.
524+
* For non-maximal styles, if a ballot is generated, a store always precedes
525+
* it.
526+
*/
494527
private genBallot() {
495528
// Optionally insert ballots, stores, and noise.
496529
// Ballots and stores are used to determine correctness.
@@ -526,6 +559,13 @@ export class Program {
526559
}
527560
}
528561

562+
/**
563+
* Generate an if based on |type|
564+
*
565+
* @param type The type of the if condition, see IfType
566+
*
567+
* Generates if/else structures.
568+
*/
529569
private genIf(type: IfType) {
530570
let maskIdx = this.getRandomUint(this.numMasks);
531571
if (type == IfType.Uniform)
@@ -578,6 +618,11 @@ export class Program {
578618
this.nesting--;
579619
}
580620

621+
/**
622+
* Generate a uniform for loop
623+
*
624+
* The number of iterations is randomly selected [1, 5].
625+
*/
581626
private genForUniform() {
582627
const n = this.getRandomUint(5) + 1; // [1, 5]
583628
this.ops.push(new Op(OpType.ForUniform, n));
@@ -592,6 +637,19 @@ export class Program {
592637
this.nesting--;
593638
}
594639

640+
/**
641+
* Generate an infinite for loop
642+
*
643+
* The loop will always include an elect based break to prevent a truly
644+
* infinite loop. The maximum number of iterations is the number of
645+
* invocations in the program, but it is scaled by the loop nesting. Inside
646+
* one loop the number of iterations is halved and inside two loops the
647+
* number of iterations in quartered. This scaling is used to reduce runtime
648+
* and memory.
649+
*
650+
* The for_update also performs a ballot.
651+
*
652+
*/
595653
private genForInf() {
596654
this.ops.push(new Op(OpType.ForInf, 0));
597655
this.nesting++;
@@ -618,6 +676,14 @@ export class Program {
618676
this.nesting--;
619677
}
620678

679+
/**
680+
* Generate a for loop with variable iterations per invocation
681+
*
682+
* The loop condition is based on subgroup_invocation_id + 1. So each
683+
* invocation executes a different number of iterations, though the this is
684+
* scaled by the amount of loop nesting the same as |generateForInf|.
685+
*
686+
*/
621687
private genForVar() {
622688
// op.value is the iteration reduction factor.
623689
const reduction = this.loopNesting === 0 ? 1 : this.loopNesting === 1 ? 2 : 4;
@@ -635,6 +701,11 @@ export class Program {
635701
this.nesting--;
636702
}
637703

704+
/**
705+
* Generate a loop construct with uniform iterations
706+
*
707+
* Same as |genForUniform|, but coded as a loop construct.
708+
*/
638709
private genLoopUniform() {
639710
const n = this.getRandomUint(5) + 1;
640711
this.ops.push(new Op(OpType.LoopUniform, n));
@@ -651,6 +722,11 @@ export class Program {
651722
this.nesting--;
652723
}
653724

725+
/**
726+
* Generate an infinite loop construct
727+
*
728+
* This is the same as |genForInf| but uses a loop construct.
729+
*/
654730
private genLoopInf() {
655731
const header = this.ops.length;
656732
this.ops.push(new Op(OpType.LoopInf, 0));
@@ -679,6 +755,13 @@ export class Program {
679755
this.nesting--;
680756
}
681757

758+
/**
759+
* Generates an if based on subgroupElect()
760+
*
761+
* @param forceBreak If true, forces the then statement to contain a break
762+
* @param reduction This generates extra breaks
763+
*
764+
*/
682765
private genElect(forceBreak: boolean, reduction: number = 1) {
683766
this.ops.push(new Op(OpType.Elect, 0));
684767
this.nesting++;
@@ -711,6 +794,13 @@ export class Program {
711794
}
712795
}
713796

797+
/**
798+
* Generate a break if in a loop.
799+
*
800+
* Only generates a break within a loop, but may break out of a switch and
801+
* not just a loop. Sometimes the break uses a non-uniform if/else to break.
802+
*
803+
*/
714804
private genBreak() {
715805
if (this.loopNestingThisFunction > 0) {
716806
// Sometimes put the break in a divergent if
@@ -728,6 +818,11 @@ export class Program {
728818
}
729819
}
730820

821+
/**
822+
* Generate a continue if in a loop
823+
*
824+
* Sometimes uses a non-uniform if/else to continue.
825+
*/
731826
private genContinue() {
732827
if (this.loopNestingThisFunction > 0 && !this.isLoopInf.get(this.loopNesting)) {
733828
// Sometimes put the continue in a divergent if
@@ -745,6 +840,10 @@ export class Program {
745840
}
746841
}
747842

843+
/**
844+
* Generates a function call.
845+
*
846+
*/
748847
private genCall() {
749848
this.ops.push(new Op(OpType.Call, 0));
750849
this.callNesting++;
@@ -761,6 +860,11 @@ export class Program {
761860
this.ops.push(new Op(OpType.EndCall, 0));
762861
}
763862

863+
/**
864+
* Generates a return
865+
*
866+
* Rarely, this will return from the main function
867+
*/
764868
private genReturn() {
765869
const r = this.getRandomFloat();
766870
if (this.nesting > 0 &&
@@ -781,20 +885,28 @@ export class Program {
781885
}
782886
}
783887

888+
/**
889+
* Generate a uniform switch.
890+
*
891+
* Some dead case constructs are also generated.
892+
*/
784893
private genSwitchUniform() {
785894
const r = this.getRandomUint(5);
786895
this.ops.push(new Op(OpType.SwitchUniform, r));
787896
this.nesting++;
788897
this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting);
789898

899+
// Never taken
790900
this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r+1)));
791901
this.pickOp(1);
792902
this.ops.push(new Op(OpType.EndCase, 0));
793903

904+
// Always taken
794905
this.ops.push(new Op(OpType.CaseMask, 0xf, 1 << r));
795906
this.pickOp(1);
796907
this.ops.push(new Op(OpType.EndCase, 0));
797908

909+
// Never taken
798910
this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r+2)));
799911
this.pickOp(1);
800912
this.ops.push(new Op(OpType.EndCase, 0));
@@ -803,6 +915,10 @@ export class Program {
803915
this.nesting--;
804916
}
805917

918+
/**
919+
* Generates a non-uniform switch based on subgroup_invocation_id
920+
*
921+
*/
806922
private genSwitchVar() {
807923
this.ops.push(new Op(OpType.SwitchVar, 0));
808924
this.nesting++;
@@ -828,6 +944,10 @@ export class Program {
828944
this.nesting--;
829945
}
830946

947+
/**
948+
* Generates switch based on an active loop induction variable.
949+
*
950+
*/
831951
private genSwitchLoopCount() {
832952
const r = this.getRandomUint(this.loopNesting);
833953
this.ops.push(new Op(OpType.SwitchLoopCount, r));
@@ -850,17 +970,20 @@ export class Program {
850970
this.nesting--;
851971
}
852972

853-
// switch (subgroup_invocation_id & 3) {
854-
// default { }
855-
// case 0x3: { ... }
856-
// case 0xc: { ... }
857-
// }
858-
//
859-
// This is not generated for maximal style cases because it is not clear what
860-
// convergence should be expected. There are multiple valid lowerings of a
861-
// switch that would lead to different convergence scenarios. To test this
862-
// properly would likely require a range of values which is difficult for
863-
// this infrastructure to produce.
973+
/**
974+
* switch (subgroup_invocation_id & 3) {
975+
* default { }
976+
* case 0x3: { ... }
977+
* case 0xc: { ... }
978+
* }
979+
*
980+
* This is not generated for maximal style cases because it is not clear what
981+
* convergence should be expected. There are multiple valid lowerings of a
982+
* switch that would lead to different convergence scenarios. To test this
983+
* properly would likely require a range of values which is difficult for
984+
* this infrastructure to produce.
985+
*
986+
*/
864987
private genSwitchMulticase() {
865988
this.ops.push(new Op(OpType.SwitchVar, 0));
866989
this.nesting++;
@@ -1381,7 +1504,6 @@ ${this.functions[i]}`;
13811504
* BigInt is not the fastest value to manipulate. Care should be taken to optimize it's use.
13821505
* TODO: would it be better to roll my own 128 bitvector?
13831506
*
1384-
* TODO: reconvergence guarantees in WGSL are not as strong as this simulation
13851507
*/
13861508
public simulate(countOnly: boolean, subgroupSize: number, debug: boolean = false): number {
13871509
class State {
@@ -1430,6 +1552,8 @@ ${this.functions[i]}`;
14301552
}
14311553

14321554
// Allocate the stack based on the maximum nesting in the program.
1555+
// Note: this has proven to be considerably more performant than pushing
1556+
// and popping from the array.
14331557
let stack: State[] = new Array(this.maxProgramNesting + 1);
14341558
for (let i = 0; i < stack.length; i++) {
14351559
stack[i] = new State();
@@ -1479,10 +1603,6 @@ ${this.functions[i]}`;
14791603
continue;
14801604
}
14811605
}
1482-
case OpType.EndCase:
1483-
case OpType.Noise:
1484-
// No work
1485-
break;
14861606
default:
14871607
break;
14881608
}
@@ -1919,6 +2039,7 @@ ${this.functions[i]}`;
19192039
}
19202040
case OpType.Noise:
19212041
case OpType.EndCase: {
2042+
// No work
19222043
break;
19232044
}
19242045
default: {
@@ -1940,6 +2061,9 @@ ${this.functions[i]}`;
19402061

19412062
/**
19422063
* @returns a mask formed from |masks[idx]|
2064+
*
2065+
* @param idx The index in |this.masks| to use.
2066+
*
19432067
*/
19442068
private getValueMask(idx: number): bigint {
19452069
const x = this.masks[4*idx];

0 commit comments

Comments
 (0)