Skip to content

Commit 61aca28

Browse files
committed
coop: fixes and changelog
1 parent e30213c commit 61aca28

File tree

8 files changed

+117
-46
lines changed

8 files changed

+117
-46
lines changed

CHANGELOG.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Bottom level categories:
4444

4545
#### Deferred command buffer actions: `map_buffer_on_submit` and `on_submitted_work_done`
4646

47-
You may schedule buffer mapping and a submission-complete callback to run automatically after you submit, directly from encoders, command buffers, and passes.
47+
You may schedule buffer mapping and a submission-complete callback to run automatically after you submit, directly from encoders, command buffers, and passes.
4848

4949
```rust
5050
// Record some GPU work so the submission isn't empty and touches `buffer`.
@@ -150,7 +150,7 @@ By @cwfitzgerald in [#8163](https://github.com/gfx-rs/wgpu/pull/8163).
150150

151151
#### Multi-draw indirect is now unconditionally supported when indirect draws are supported
152152

153-
We have removed `Features::MULTI_DRAW_INDIRECT` as it was unconditionally available on all platforms.
153+
We have removed `Features::MULTI_DRAW_INDIRECT` as it was unconditionally available on all platforms.
154154
`RenderPass::multi_draw_indirect` is now available if the device supports downlevel flag `DownlevelFlags::INDIRECT_EXECUTION`.
155155

156156
If you are using spirv-passthrough with multi-draw indirect and `gl_DrawID`, you can know if `MULTI_DRAW_INDIRECT` is being emulated
@@ -166,6 +166,8 @@ By @cwfitzgerald in [#8162](https://github.com/gfx-rs/wgpu/pull/8162).
166166

167167
- Added support for external textures based on WebGPU's [`GPUExternalTexture`](https://www.w3.org/TR/webgpu/#gpuexternaltexture). These allow shaders to transparently operate on potentially multiplanar source texture data in either RGB or YCbCr formats via WGSL's `texture_external` type. This is gated behind the `Features::EXTERNAL_TEXTURE` feature, which is currently only supported on DX12. By @jamienicol in [#4386](https://github.com/gfx-rs/wgpu/issues/4386).
168168

169+
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
170+
169171
### Changes
170172

171173
#### General

naga/src/back/spv/block.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3691,6 +3691,18 @@ impl BlockContext<'_> {
36913691
stride,
36923692
row_major,
36933693
} => {
3694+
let pointer_id = match self.write_access_chain(
3695+
pointer,
3696+
&mut block,
3697+
AccessTypeAdjustment::None,
3698+
)? {
3699+
ExpressionPointer::Ready { pointer_id } => pointer_id,
3700+
ExpressionPointer::Conditional { .. } => {
3701+
return Err(Error::FeatureNotImplemented(
3702+
"Copperative load/store out-of-bounds handling",
3703+
));
3704+
}
3705+
};
36943706
let layout = if row_major {
36953707
spirv::CooperativeMatrixLayout::RowMajorKHR
36963708
} else {
@@ -3701,7 +3713,7 @@ impl BlockContext<'_> {
37013713
if store {
37023714
block.body.push(Instruction::coop_store(
37033715
self.cached[target],
3704-
self.cached[pointer],
3716+
pointer_id,
37053717
layout_id,
37063718
stride_id,
37073719
));
@@ -3711,7 +3723,7 @@ impl BlockContext<'_> {
37113723
block.body.push(Instruction::coop_load(
37123724
result_type_id,
37133725
id,
3714-
self.cached[pointer],
3726+
pointer_id,
37153727
layout_id,
37163728
stride_id,
37173729
));

naga/src/valid/function.rs

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,32 +1641,20 @@ impl super::Validator {
16411641
}
16421642
};
16431643

1644-
let ty_inner =
1645-
context.resolve_type_inner(pointer, &self.valid_expression_set)?;
1644+
let ty_inner = context.resolve_pointer_type(pointer);
16461645
//TODO: validate stride
1647-
let (pty_array, space) = match *ty_inner {
1646+
let (pty_scalar, space) = match *ty_inner {
16481647
crate::TypeInner::Pointer { base, space } => (base, space),
16491648
_ => {
16501649
return Err(FunctionError::InvalidCooperativeDataPointer(pointer)
1651-
.with_span_handle(pointer, context.expressions))
1652-
}
1653-
};
1654-
let pty_scalar = match context.types[pty_array].inner {
1655-
crate::TypeInner::Array {
1656-
base,
1657-
size: _,
1658-
stride: _,
1659-
} => base,
1660-
_ => {
1661-
return Err(FunctionError::InvalidCooperativeDataPointer(pointer)
1662-
.with_span_handle(pointer, context.expressions))
1650+
.with_span_handle(pointer, context.expressions));
16631651
}
16641652
};
16651653
let space = match context.types[pty_scalar].inner {
16661654
crate::TypeInner::Scalar(s) if s == target_scalar => space,
16671655
_ => {
16681656
return Err(FunctionError::InvalidCooperativeDataPointer(pointer)
1669-
.with_span_handle(pointer, context.expressions))
1657+
.with_span_handle(pointer, context.expressions));
16701658
}
16711659
};
16721660

naga/tests/in/wgsl/cooperative-matrix.wgsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ var<storage, read_write> ext: array<f32>;
66
@compute @workgroup_size(8, 8, 1)
77
fn main() {
88
var c = coop_mat8x8<f32, C>();
9-
coopLoad(c, &ext);
9+
coopLoad(c, &ext[4]);
1010
var d = coopMultiplyAdd(a, b, c);
11-
coopStore(c, &ext);
11+
coopStore(c, &ext[0]);
1212
}

naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,40 +119,56 @@
119119
ZeroValue(4),
120120
LocalVariable(0),
121121
GlobalVariable(2),
122+
AccessIndex(
123+
base: 2,
124+
index: 4,
125+
),
122126
GlobalVariable(0),
123127
GlobalVariable(1),
124128
CooperativeMultiplyAdd(
125-
a: 3,
126-
b: 4,
129+
a: 4,
130+
b: 5,
127131
c: 1,
128132
),
129133
LocalVariable(1),
130134
GlobalVariable(2),
135+
AccessIndex(
136+
base: 8,
137+
index: 0,
138+
),
131139
],
132140
named_expressions: {},
133141
body: [
134142
CooperativeLoadStore(
135143
store: false,
136144
target: 1,
137-
pointer: 2,
145+
pointer: 3,
138146
stride: None,
139147
row_major: false,
140148
),
141149
Emit((
142-
start: 5,
143-
end: 6,
150+
start: 3,
151+
end: 4,
152+
)),
153+
Emit((
154+
start: 6,
155+
end: 7,
144156
)),
145157
Store(
146-
pointer: 6,
147-
value: 5,
158+
pointer: 7,
159+
value: 6,
148160
),
149161
CooperativeLoadStore(
150162
store: true,
151163
target: 1,
152-
pointer: 7,
164+
pointer: 9,
153165
stride: None,
154166
row_major: false,
155167
),
168+
Emit((
169+
start: 9,
170+
end: 10,
171+
)),
156172
Return(
157173
value: None,
158174
),

naga/tests/out/ir/wgsl-cooperative-matrix.ron

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,40 +119,56 @@
119119
ZeroValue(4),
120120
LocalVariable(0),
121121
GlobalVariable(2),
122+
AccessIndex(
123+
base: 2,
124+
index: 4,
125+
),
122126
GlobalVariable(0),
123127
GlobalVariable(1),
124128
CooperativeMultiplyAdd(
125-
a: 3,
126-
b: 4,
129+
a: 4,
130+
b: 5,
127131
c: 1,
128132
),
129133
LocalVariable(1),
130134
GlobalVariable(2),
135+
AccessIndex(
136+
base: 8,
137+
index: 0,
138+
),
131139
],
132140
named_expressions: {},
133141
body: [
134142
CooperativeLoadStore(
135143
store: false,
136144
target: 1,
137-
pointer: 2,
145+
pointer: 3,
138146
stride: None,
139147
row_major: false,
140148
),
141149
Emit((
142-
start: 5,
143-
end: 6,
150+
start: 3,
151+
end: 4,
152+
)),
153+
Emit((
154+
start: 6,
155+
end: 7,
144156
)),
145157
Store(
146-
pointer: 6,
147-
value: 5,
158+
pointer: 7,
159+
value: 6,
148160
),
149161
CooperativeLoadStore(
150162
store: true,
151163
target: 1,
152-
pointer: 7,
164+
pointer: 9,
153165
stride: None,
154166
row_major: false,
155167
),
168+
Emit((
169+
start: 9,
170+
end: 10,
171+
)),
156172
Return(
157173
value: None,
158174
),
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// language: metal1.0
2+
#include <metal_stdlib>
3+
#include <simd/simd.h>
4+
5+
using metal::uint;
6+
7+
struct _mslBufferSizes {
8+
uint size2;
9+
};
10+
11+
typedef float type_3[1];
12+
metal::simdgroup_float8x8 NagaCooperativeMultiplyAdd(const metal::simdgroup_float8x8& a, const metal::simdgroup_float8x8& b, const metal::simdgroup_float8x8& c) {
13+
metal::simdgroup_float8x8 d;
14+
metal::simdgroup_multiply_accumulate(d,a,b,c);
15+
return d;
16+
}
17+
18+
19+
kernel void main_(
20+
device type_3 const& ext [[user(fake0)]]
21+
, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]]
22+
) {
23+
metal::simdgroup_float8x8 a = {};
24+
metal::simdgroup_float8x8 b = {};
25+
metal::simdgroup_float8x8 c = metal::simdgroup_float8x8 {};
26+
metal::simdgroup_float8x8 d = {};
27+
metal::simdgroup_load(c, ext[4]);
28+
d = NagaCooperativeMultiplyAdd(a, b, c);
29+
metal::simdgroup_store(c, ext[0]);
30+
return;
31+
}

naga/tests/out/spv/wgsl-cooperative-matrix.spvasm

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; SPIR-V
22
; Version: 1.4
33
; Generator: rspirv
4-
; Bound: 37
4+
; Bound: 41
55
OpCapability Shader
66
OpCapability CooperativeMatrixKHR
77
OpCapability VulkanMemoryModel
@@ -20,9 +20,9 @@ var<storage, read_write> ext: array<f32>;
2020
@compute @workgroup_size(8, 8, 1)
2121
fn main() {
2222
var c = coop_mat8x8<f32, C>();
23-
coopLoad(c, &ext);
23+
coopLoad(c, &ext[4]);
2424
var d = coopMultiplyAdd(a, b, c);
25-
coopStore(c, &ext);
25+
coopStore(c, &ext[0]);
2626
}
2727
"
2828
OpName %15 "a"
@@ -62,6 +62,8 @@ OpMemberDecorate %22 0 Offset 0
6262
%29 = OpConstantNull %13
6363
%31 = OpTypePointer Function %13
6464
%33 = OpConstantNull %13
65+
%35 = OpTypePointer StorageBuffer %4
66+
%36 = OpConstant %7 4
6567
%25 = OpFunction %2 None %26
6668
%24 = OpLabel
6769
%30 = OpVariable %31 Function %29
@@ -70,13 +72,17 @@ OpMemberDecorate %22 0 Offset 0
7072
OpBranch %34
7173
%34 = OpLabel
7274
OpLine %3 9 5
73-
%35 = OpCooperativeMatrixLoadKHR %13 %28 %11
74-
OpStore %30 %35
75+
%37 = OpAccessChain %35 %28 %36
76+
%38 = OpCooperativeMatrixLoadKHR %13 %37 %11
77+
OpStore %30 %38
78+
OpLine %3 9 18
7579
OpLine %3 10 13
76-
%36 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30
80+
%39 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30
7781
OpLine %3 10 5
78-
OpStore %32 %36
82+
OpStore %32 %39
7983
OpLine %3 11 5
80-
OpCooperativeMatrixStoreKHR %28 %30 %11
84+
%40 = OpAccessChain %35 %28 %9
85+
OpCooperativeMatrixStoreKHR %40 %30 %11
86+
OpLine %3 11 19
8187
OpReturn
8288
OpFunctionEnd

0 commit comments

Comments
 (0)