Skip to content

Commit 9c023e5

Browse files
dzamkovvalaphee
andauthored
Implement subgroup quad ops (gfx-rs#7683)
* Rudimentary impl of quad ops, impl quad ops for spirv * Impl quad swap for hlsl, msl and wgsl, finish spv front * Cargo clippy & cargo fmt, impl valid for quad ops * Enable quad feature * Add missing feature to glsl * Simplifying code by making `SubgroupQuadSwap` an instance of `SubgroupGather` * Add `GroupNonUniformQuad` spv capability to Vulkan * Adding GPU tests for quad operations * Validate that broadcast operations use const invocation ids * Added changelog entry --------- Co-authored-by: valaphee <[email protected]>
1 parent 4cd8be5 commit 9c023e5

29 files changed

+429
-58
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Bottom level categories:
5454
#### Naga
5555

5656
- When emitting GLSL, Uniform and Storage Buffer memory layouts are now emitted even if no explicit binding is given. By @cloone8 in [#7579](https://github.com/gfx-rs/wgpu/pull/7579).
57+
- Add support for [quad operations](https://www.w3.org/TR/WGSL/#quad-builtin-functions) (requires `SUBGROUP` feature to be enabled). By @dzamkov and @valaphee in [#7683](https://github.com/gfx-rs/wgpu/pull/7683).
5758

5859
### Bug Fixes
5960

naga/src/back/dot/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,11 @@ impl StatementGraph {
379379
| crate::GatherMode::Shuffle(index)
380380
| crate::GatherMode::ShuffleDown(index)
381381
| crate::GatherMode::ShuffleUp(index)
382-
| crate::GatherMode::ShuffleXor(index) => {
382+
| crate::GatherMode::ShuffleXor(index)
383+
| crate::GatherMode::QuadBroadcast(index) => {
383384
self.dependencies.push((id, index, "index"))
384385
}
386+
crate::GatherMode::QuadSwap(_) => {}
385387
}
386388
self.dependencies.push((id, argument, "arg"));
387389
self.emits.push((id, result));
@@ -392,6 +394,12 @@ impl StatementGraph {
392394
crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
393395
crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
394396
crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
397+
crate::GatherMode::QuadBroadcast(_) => "SubgroupQuadBroadcast",
398+
crate::GatherMode::QuadSwap(direction) => match direction {
399+
crate::Direction::X => "SubgroupQuadSwapX",
400+
crate::Direction::Y => "SubgroupQuadSwapY",
401+
crate::Direction::Diagonal => "SubgroupQuadSwapDiagonal",
402+
},
395403
}
396404
}
397405
};

naga/src/back/glsl/features.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ impl FeaturesManager {
280280
out,
281281
"#extension GL_KHR_shader_subgroup_shuffle_relative : require"
282282
)?;
283+
writeln!(out, "#extension GL_KHR_shader_subgroup_quad : require")?;
283284
}
284285

285286
if self.0.contains(Features::TEXTURE_ATOMICS) {

naga/src/back/glsl/mod.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2717,6 +2717,20 @@ impl<'a, W: Write> Writer<'a, W> {
27172717
crate::GatherMode::ShuffleXor(_) => {
27182718
write!(self.out, "subgroupShuffleXor(")?;
27192719
}
2720+
crate::GatherMode::QuadBroadcast(_) => {
2721+
write!(self.out, "subgroupQuadBroadcast(")?;
2722+
}
2723+
crate::GatherMode::QuadSwap(direction) => match direction {
2724+
crate::Direction::X => {
2725+
write!(self.out, "subgroupQuadSwapHorizontal(")?;
2726+
}
2727+
crate::Direction::Y => {
2728+
write!(self.out, "subgroupQuadSwapVertical(")?;
2729+
}
2730+
crate::Direction::Diagonal => {
2731+
write!(self.out, "subgroupQuadSwapDiagonal(")?;
2732+
}
2733+
},
27202734
}
27212735
self.write_expr(argument, ctx)?;
27222736
match mode {
@@ -2725,10 +2739,12 @@ impl<'a, W: Write> Writer<'a, W> {
27252739
| crate::GatherMode::Shuffle(index)
27262740
| crate::GatherMode::ShuffleDown(index)
27272741
| crate::GatherMode::ShuffleUp(index)
2728-
| crate::GatherMode::ShuffleXor(index) => {
2742+
| crate::GatherMode::ShuffleXor(index)
2743+
| crate::GatherMode::QuadBroadcast(index) => {
27292744
write!(self.out, ", ")?;
27302745
self.write_expr(index, ctx)?;
27312746
}
2747+
crate::GatherMode::QuadSwap(_) => {}
27322748
}
27332749
writeln!(self.out, ");")?;
27342750
}

naga/src/back/hlsl/writer.rs

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,30 +2610,55 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
26102610
};
26112611
write!(self.out, " {name} = ")?;
26122612
self.named_expressions.insert(result, name);
2613-
2614-
if matches!(mode, crate::GatherMode::BroadcastFirst) {
2615-
write!(self.out, "WaveReadLaneFirst(")?;
2616-
self.write_expr(module, argument, func_ctx)?;
2617-
} else {
2618-
write!(self.out, "WaveReadLaneAt(")?;
2619-
self.write_expr(module, argument, func_ctx)?;
2620-
write!(self.out, ", ")?;
2621-
match mode {
2622-
crate::GatherMode::BroadcastFirst => unreachable!(),
2623-
crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
2624-
self.write_expr(module, index, func_ctx)?;
2625-
}
2626-
crate::GatherMode::ShuffleDown(index) => {
2627-
write!(self.out, "WaveGetLaneIndex() + ")?;
2628-
self.write_expr(module, index, func_ctx)?;
2629-
}
2630-
crate::GatherMode::ShuffleUp(index) => {
2631-
write!(self.out, "WaveGetLaneIndex() - ")?;
2632-
self.write_expr(module, index, func_ctx)?;
2613+
match mode {
2614+
crate::GatherMode::BroadcastFirst => {
2615+
write!(self.out, "WaveReadLaneFirst(")?;
2616+
self.write_expr(module, argument, func_ctx)?;
2617+
}
2618+
crate::GatherMode::QuadBroadcast(index) => {
2619+
write!(self.out, "QuadReadLaneAt(")?;
2620+
self.write_expr(module, argument, func_ctx)?;
2621+
write!(self.out, ", ")?;
2622+
self.write_expr(module, index, func_ctx)?;
2623+
}
2624+
crate::GatherMode::QuadSwap(direction) => {
2625+
match direction {
2626+
crate::Direction::X => {
2627+
write!(self.out, "QuadReadAcrossX(")?;
2628+
}
2629+
crate::Direction::Y => {
2630+
write!(self.out, "QuadReadAcrossY(")?;
2631+
}
2632+
crate::Direction::Diagonal => {
2633+
write!(self.out, "QuadReadAcrossDiagonal(")?;
2634+
}
26332635
}
2634-
crate::GatherMode::ShuffleXor(index) => {
2635-
write!(self.out, "WaveGetLaneIndex() ^ ")?;
2636-
self.write_expr(module, index, func_ctx)?;
2636+
self.write_expr(module, argument, func_ctx)?;
2637+
}
2638+
_ => {
2639+
write!(self.out, "WaveReadLaneAt(")?;
2640+
self.write_expr(module, argument, func_ctx)?;
2641+
write!(self.out, ", ")?;
2642+
match mode {
2643+
crate::GatherMode::BroadcastFirst => unreachable!(),
2644+
crate::GatherMode::Broadcast(index)
2645+
| crate::GatherMode::Shuffle(index) => {
2646+
self.write_expr(module, index, func_ctx)?;
2647+
}
2648+
crate::GatherMode::ShuffleDown(index) => {
2649+
write!(self.out, "WaveGetLaneIndex() + ")?;
2650+
self.write_expr(module, index, func_ctx)?;
2651+
}
2652+
crate::GatherMode::ShuffleUp(index) => {
2653+
write!(self.out, "WaveGetLaneIndex() - ")?;
2654+
self.write_expr(module, index, func_ctx)?;
2655+
}
2656+
crate::GatherMode::ShuffleXor(index) => {
2657+
write!(self.out, "WaveGetLaneIndex() ^ ")?;
2658+
self.write_expr(module, index, func_ctx)?;
2659+
}
2660+
crate::GatherMode::QuadBroadcast(_) => unreachable!(),
2661+
crate::GatherMode::QuadSwap(_) => unreachable!(),
26372662
}
26382663
}
26392664
}

naga/src/back/msl/writer.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4090,6 +4090,12 @@ impl<W: Write> Writer<W> {
40904090
crate::GatherMode::ShuffleXor(_) => {
40914091
write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
40924092
}
4093+
crate::GatherMode::QuadBroadcast(_) => {
4094+
write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4095+
}
4096+
crate::GatherMode::QuadSwap(_) => {
4097+
write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4098+
}
40934099
}
40944100
self.put_expression(argument, &context.expression, true)?;
40954101
match mode {
@@ -4098,10 +4104,25 @@ impl<W: Write> Writer<W> {
40984104
| crate::GatherMode::Shuffle(index)
40994105
| crate::GatherMode::ShuffleDown(index)
41004106
| crate::GatherMode::ShuffleUp(index)
4101-
| crate::GatherMode::ShuffleXor(index) => {
4107+
| crate::GatherMode::ShuffleXor(index)
4108+
| crate::GatherMode::QuadBroadcast(index) => {
41024109
write!(self.out, ", ")?;
41034110
self.put_expression(index, &context.expression, true)?;
41044111
}
4112+
crate::GatherMode::QuadSwap(direction) => {
4113+
write!(self.out, ", ")?;
4114+
match direction {
4115+
crate::Direction::X => {
4116+
write!(self.out, "1u")?;
4117+
}
4118+
crate::Direction::Y => {
4119+
write!(self.out, "2u")?;
4120+
}
4121+
crate::Direction::Diagonal => {
4122+
write!(self.out, "3u")?;
4123+
}
4124+
}
4125+
}
41054126
}
41064127
writeln!(self.out, ");")?;
41074128
}

naga/src/back/pipeline_constants.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,9 +759,11 @@ fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut S
759759
| crate::GatherMode::Shuffle(ref mut index)
760760
| crate::GatherMode::ShuffleDown(ref mut index)
761761
| crate::GatherMode::ShuffleUp(ref mut index)
762-
| crate::GatherMode::ShuffleXor(ref mut index) => {
762+
| crate::GatherMode::ShuffleXor(ref mut index)
763+
| crate::GatherMode::QuadBroadcast(ref mut index) => {
763764
adjust(index);
764765
}
766+
crate::GatherMode::QuadSwap(_) => {}
765767
}
766768
adjust(argument);
767769
adjust(result)

naga/src/back/spv/instructions.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,22 @@ impl super::Instruction {
12031203
}
12041204
instruction.add_operand(value);
12051205

1206+
instruction
1207+
}
1208+
pub(super) fn group_non_uniform_quad_swap(
1209+
result_type_id: Word,
1210+
id: Word,
1211+
exec_scope_id: Word,
1212+
value: Word,
1213+
direction: Word,
1214+
) -> Self {
1215+
let mut instruction = Self::new(Op::GroupNonUniformQuadSwap);
1216+
instruction.set_type(result_type_id);
1217+
instruction.set_result(id);
1218+
instruction.add_operand(exec_scope_id);
1219+
instruction.add_operand(value);
1220+
instruction.add_operand(direction);
1221+
12061222
instruction
12071223
}
12081224
}

naga/src/back/spv/subgroup.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,6 @@ impl BlockContext<'_> {
125125
result: Handle<crate::Expression>,
126126
block: &mut Block,
127127
) -> Result<(), Error> {
128-
self.writer.require_any(
129-
"GroupNonUniformBallot",
130-
&[spirv::Capability::GroupNonUniformBallot],
131-
)?;
132128
match *mode {
133129
crate::GatherMode::BroadcastFirst => {
134130
self.writer.require_any(
@@ -150,6 +146,12 @@ impl BlockContext<'_> {
150146
&[spirv::Capability::GroupNonUniformShuffleRelative],
151147
)?;
152148
}
149+
crate::GatherMode::QuadBroadcast(_) | crate::GatherMode::QuadSwap(_) => {
150+
self.writer.require_any(
151+
"GroupNonUniformQuad",
152+
&[spirv::Capability::GroupNonUniformQuad],
153+
)?;
154+
}
153155
}
154156

155157
let id = self.gen_id();
@@ -174,7 +176,8 @@ impl BlockContext<'_> {
174176
| crate::GatherMode::Shuffle(index)
175177
| crate::GatherMode::ShuffleDown(index)
176178
| crate::GatherMode::ShuffleUp(index)
177-
| crate::GatherMode::ShuffleXor(index) => {
179+
| crate::GatherMode::ShuffleXor(index)
180+
| crate::GatherMode::QuadBroadcast(index) => {
178181
let index_id = self.cached[index];
179182
let op = match *mode {
180183
crate::GatherMode::BroadcastFirst => unreachable!(),
@@ -187,6 +190,8 @@ impl BlockContext<'_> {
187190
crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
188191
crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
189192
crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
193+
crate::GatherMode::QuadBroadcast(_) => spirv::Op::GroupNonUniformQuadBroadcast,
194+
crate::GatherMode::QuadSwap(_) => unreachable!(),
190195
};
191196
block.body.push(Instruction::group_non_uniform_gather(
192197
op,
@@ -197,6 +202,20 @@ impl BlockContext<'_> {
197202
index_id,
198203
));
199204
}
205+
crate::GatherMode::QuadSwap(direction) => {
206+
let direction = self.get_index_constant(match direction {
207+
crate::Direction::X => 0,
208+
crate::Direction::Y => 1,
209+
crate::Direction::Diagonal => 2,
210+
});
211+
block.body.push(Instruction::group_non_uniform_quad_swap(
212+
result_type_id,
213+
id,
214+
exec_scope_id,
215+
arg_id,
216+
direction,
217+
));
218+
}
200219
}
201220
self.cached[result] = id;
202221
Ok(())

naga/src/back/wgsl/writer.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,20 @@ impl<W: Write> Writer<W> {
945945
crate::GatherMode::ShuffleXor(_) => {
946946
write!(self.out, "subgroupShuffleXor(")?;
947947
}
948+
crate::GatherMode::QuadBroadcast(_) => {
949+
write!(self.out, "quadBroadcast(")?;
950+
}
951+
crate::GatherMode::QuadSwap(direction) => match direction {
952+
crate::Direction::X => {
953+
write!(self.out, "quadSwapX(")?;
954+
}
955+
crate::Direction::Y => {
956+
write!(self.out, "quadSwapY(")?;
957+
}
958+
crate::Direction::Diagonal => {
959+
write!(self.out, "quadSwapDiagonal(")?;
960+
}
961+
},
948962
}
949963
self.write_expr(module, argument, func_ctx)?;
950964
match mode {
@@ -953,10 +967,12 @@ impl<W: Write> Writer<W> {
953967
| crate::GatherMode::Shuffle(index)
954968
| crate::GatherMode::ShuffleDown(index)
955969
| crate::GatherMode::ShuffleUp(index)
956-
| crate::GatherMode::ShuffleXor(index) => {
970+
| crate::GatherMode::ShuffleXor(index)
971+
| crate::GatherMode::QuadBroadcast(index) => {
957972
write!(self.out, ", ")?;
958973
self.write_expr(module, index, func_ctx)?;
959974
}
975+
crate::GatherMode::QuadSwap(_) => {}
960976
}
961977
writeln!(self.out, ");")?;
962978
}

0 commit comments

Comments
 (0)