Skip to content

Commit adce7fa

Browse files
petermcneeleychromiumPeter McNeeley
andauthored
Avoid buffer overrun (#4422)
Co-authored-by: Peter McNeeley <[email protected]>
1 parent 68676e3 commit adce7fa

File tree

1 file changed

+47
-33
lines changed

1 file changed

+47
-33
lines changed

src/webgpu/shader/execution/expression/call/builtin/subgroupShuffle.spec.ts

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -439,47 +439,59 @@ function checkCompute(
439439
numInvs: number,
440440
filter: (id: number, size: number) => boolean
441441
): Error | undefined {
442-
const mapping = new Map<number, number[]>();
442+
const sub_unique_id_to_inv_idx = new Map<number, number[]>();
443443
const empty = [...iterRange(128, x => -1)];
444-
for (let i = 0; i < numInvs; i++) {
445-
const id = metadata[i];
446-
const subgroup_id = metadata[i + numInvs];
447-
const v = mapping.get(subgroup_id) ?? Array.from(empty);
448-
v[id] = i;
449-
mapping.set(subgroup_id, v);
444+
for (let inv = 0; inv < numInvs; inv++) {
445+
const id = metadata[inv];
446+
const subgroup_unique_id = metadata[inv + numInvs];
447+
const v = sub_unique_id_to_inv_idx.get(subgroup_unique_id) ?? Array.from(empty);
448+
v[id] = inv;
449+
sub_unique_id_to_inv_idx.set(subgroup_unique_id, v);
450450
}
451451

452-
for (let i = 0; i < numInvs; i++) {
453-
const id = metadata[i];
454-
const subgroup_id = metadata[i + numInvs];
455-
456-
const subgroupMapping = mapping.get(subgroup_id) ?? empty;
452+
for (let inv = 0; inv < numInvs; inv++) {
453+
const id = metadata[inv];
454+
const subgroup_unique_id = metadata[inv + numInvs];
455+
const sub_inv_id_to_inv_idx = sub_unique_id_to_inv_idx.get(subgroup_unique_id) ?? empty;
457456

458-
const res = output[i];
459-
const size = output[i + numInvs];
457+
const res = output[inv];
458+
const size = output[inv + numInvs];
460459

460+
// subgroup id predicated in shader
461461
if (!filter(id, size)) {
462462
continue;
463463
}
464464

465-
let inputValue = input[i];
465+
let inputValue = input[inv];
466466
if (op !== 'subgroupShuffle') {
467-
inputValue = input[subgroupMapping[0]];
467+
// Because we use 'subgroupBroadcastFirst' without predication.
468+
const first_subgroup_inv_id = 0;
469+
inputValue = input[sub_inv_id_to_inv_idx[first_subgroup_inv_id]];
468470
}
469471

470-
const index = getShuffledId(id, inputValue, op);
471-
if (index < 0 || index >= 128 || subgroupMapping[index] === -1) {
472+
const shuffled_target_id = getShuffledId(id, inputValue, op);
473+
if (
474+
shuffled_target_id < 0 ||
475+
shuffled_target_id >= 128 ||
476+
sub_inv_id_to_inv_idx[shuffled_target_id] === -1
477+
) {
472478
continue;
473479
}
474480

475-
if (!filter(index, size)) {
481+
// subgroup id predicated in shader
482+
if (!filter(shuffled_target_id, size)) {
476483
continue;
477484
}
478485

479-
if (res !== subgroupMapping[index]) {
480-
return new Error(`Invocation ${i}: unexpected result
481-
- expected: ${subgroupMapping[index]}
482-
- got: ${res}`);
486+
if (res !== sub_inv_id_to_inv_idx[shuffled_target_id]) {
487+
return new Error(`Invocation ${inv}: unexpected result
488+
- expected: ${sub_inv_id_to_inv_idx[shuffled_target_id]}
489+
- got: ${res}
490+
- id = ${id}
491+
- size = ${size}
492+
- inputValue = ${inputValue}
493+
- shuffled_target_id = ${shuffled_target_id}
494+
- subgroup_unique_id = ${subgroup_unique_id}`);
483495
}
484496
}
485497

@@ -510,7 +522,7 @@ enable subgroups;
510522
diagnostic(off, subgroup_uniformity);
511523
512524
@group(0) @binding(0)
513-
var<storage> input : array<u32>;
525+
var<storage> input : array<u32, ${wgThreads}>;
514526
515527
struct Output {
516528
res : array<u32, ${wgThreads}>,
@@ -578,17 +590,17 @@ g.test('compute,split')
578590
const testcase = kPredicateCases[t.params.predicate];
579591
const wgThreads = t.params.wgSize[0] * t.params.wgSize[1] * t.params.wgSize[2];
580592

581-
let value = `input[lid]`;
593+
let value = `input[lidx]`;
582594
if (t.params.op !== 'subgroupShuffle') {
583-
value = `subgroupBroadcastFirst(input[lid])`;
595+
value = `subgroupBroadcastFirst(input[lidx])`;
584596
}
585597

586598
const wgsl = `
587599
enable subgroups;
588600
diagnostic(off, subgroup_uniformity);
589601
590602
@group(0) @binding(0)
591-
var<storage> input : array<u32>;
603+
var<storage> input : array<u32, ${wgThreads}>;
592604
593605
struct Output {
594606
res : array<u32, ${wgThreads}>,
@@ -608,24 +620,26 @@ var<storage, read_write> metadata : Metadata;
608620
609621
@compute @workgroup_size(${t.params.wgSize[0]}, ${t.params.wgSize[1]}, ${t.params.wgSize[2]})
610622
fn main(
611-
@builtin(local_invocation_index) lid : u32,
623+
@builtin(local_invocation_index) lidx : u32,
612624
@builtin(subgroup_invocation_id) id : u32,
613625
@builtin(subgroup_size) subgroupSize : u32,
614626
) {
615627
_ = input[0];
616-
metadata.id[lid] = id;
617-
metadata.subgroup_id[lid] = subgroupBroadcastFirst(lid + 1); // avoid 0
628+
metadata.id[lidx] = id;
629+
// Made from lidx but not lidx to avoid value confusion.
630+
var fake_unique_id = lidx + 1000;
631+
metadata.subgroup_id[lidx] = subgroupBroadcastFirst(fake_unique_id);
618632
619-
output.size[lid] = subgroupSize;
633+
output.size[lidx] = subgroupSize;
620634
let value = ${value};
621635
if ${testcase.cond} {
622-
output.res[lid] = ${t.params.op}(lid, value);
636+
output.res[lidx] = ${t.params.op}(lidx, value);
623637
} else {
624638
return;
625639
}
626640
}`;
627641

628-
const inputArray = new Uint32Array([...iterRange(128, x => x)]);
642+
const inputArray = new Uint32Array([...iterRange(wgThreads, x => x % 128)]);
629643
const numUintsPerOutput = 2;
630644
await runComputeTest(
631645
t,

0 commit comments

Comments
 (0)