Skip to content

Commit 87360f2

Browse files
committed
coop: rewire IR using native variables load/store
1 parent 321ddf4 commit 87360f2

21 files changed

+282
-238
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ By @cwfitzgerald in [#8162](https://github.com/gfx-rs/wgpu/pull/8162).
166166

167167
- Added support for external textures based on WebGPU's [`GPUExternalTexture`](https://www.w3.org/TR/webgpu/#gpuexternaltexture). These allow shaders to transparently operate on potentially multiplanar source texture data in either RGB or YCbCr formats via WGSL's `texture_external` type. This is gated behind the `Features::EXTERNAL_TEXTURE` feature, which is currently only supported on DX12. By @jamienicol in [#4386](https://github.com/gfx-rs/wgpu/issues/4386).
168168

169-
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
169+
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V,METAL, and WGSL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
170170

171171
### Changes
172172

naga/src/back/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,16 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
311311
}
312312
}
313313

314+
impl crate::TypeInner {
315+
/// Returns true if a variable of this type is a handle.
316+
pub const fn is_handle(&self) -> bool {
317+
match *self {
318+
Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure { .. } => true,
319+
_ => false,
320+
}
321+
}
322+
}
323+
314324
impl crate::Statement {
315325
/// Returns true if the statement directly terminates the current block.
316326
///

naga/src/back/msl/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,10 @@ pub enum Error {
228228
UnsupportedArrayOf(String),
229229
#[error("array of type '{0:?}' is not supported")]
230230
UnsupportedArrayOfType(Handle<crate::Type>),
231-
#[error("ray tracing is not supported prior to MSL 2.3")]
231+
#[error("ray tracing is not supported prior to MSL 2.4")]
232232
UnsupportedRayTracing,
233+
#[error("cooperative matrix is not supported prior to MSL 2.3")]
234+
UnsupportedCooperativeMatrix,
233235
#[error("overrides should not be present at this stage")]
234236
Override,
235237
#[error("bitcasting to {0:?} is not supported")]

naga/src/back/msl/writer.rs

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ impl Display for TypeContext<'_> {
236236
rows,
237237
scalar,
238238
} => put_numeric_type(out, scalar, &[rows, columns]),
239+
// Requires Metal-2.3
239240
crate::TypeInner::CooperativeMatrix {
240241
columns,
241242
rows,
@@ -244,8 +245,7 @@ impl Display for TypeContext<'_> {
244245
} => {
245246
write!(
246247
out,
247-
"{}::simdgroup_{}{}x{}",
248-
NAMESPACE,
248+
"{NAMESPACE}::simdgroup_{}{}x{}",
249249
scalar.to_msl_name(),
250250
columns as u32,
251251
rows as u32,
@@ -485,6 +485,7 @@ enum WrappedFunction {
485485
class: crate::ImageClass,
486486
},
487487
CooperativeMultiplyAdd {
488+
space: crate::AddressSpace,
488489
columns: crate::CooperativeSize,
489490
rows: crate::CooperativeSize,
490491
intermediate: crate::CooperativeSize,
@@ -2842,6 +2843,9 @@ impl<W: Write> Writer<W> {
28422843
write!(self.out, "}}")?;
28432844
}
28442845
crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2846+
if context.lang_version < (2, 3) {
2847+
return Err(Error::UnsupportedCooperativeMatrix);
2848+
}
28452849
write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
28462850
self.put_expression(a, context, true)?;
28472851
write!(self.out, ", ")?;
@@ -4239,10 +4243,14 @@ impl<W: Write> Writer<W> {
42394243
row_major,
42404244
} => {
42414245
let op_str = if store { "store" } else { "load" };
4242-
write!(self.out, "{level}{NAMESPACE}::simdgroup_{op_str}(")?;
4246+
write!(self.out, "{level}simdgroup_{op_str}(")?;
42434247
self.put_expression(target, &context.expression, true)?;
4244-
write!(self.out, ", ")?;
4245-
self.put_expression(pointer, &context.expression, true)?;
4248+
write!(self.out, ", &")?;
4249+
self.put_access_chain(
4250+
pointer,
4251+
context.expression.policies.index,
4252+
&context.expression,
4253+
)?;
42464254
write!(self.out, ", ")?;
42474255
self.put_expression(stride, &context.expression, true)?;
42484256
if row_major {
@@ -6312,6 +6320,7 @@ template <typename A>
63126320
&mut self,
63136321
module: &crate::Module,
63146322
func_ctx: &back::FunctionCtx,
6323+
space: crate::AddressSpace,
63156324
a: Handle<crate::Expression>,
63166325
b: Handle<crate::Expression>,
63176326
) -> BackendResult {
@@ -6329,6 +6338,7 @@ template <typename A>
63296338
_ => unreachable!(),
63306339
};
63316340
let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6341+
space,
63326342
columns: b_c,
63336343
rows: a_r,
63346344
intermediate: a_c,
@@ -6337,15 +6347,11 @@ template <typename A>
63376347
if !self.wrapped_functions.insert(wrapped) {
63386348
return Ok(());
63396349
}
6340-
let scalar_name = match scalar.width {
6341-
2 => "half",
6342-
4 => "float",
6343-
8 => "double",
6344-
_ => unreachable!(),
6345-
};
6350+
let space_name = space.to_msl_name().unwrap_or_default();
6351+
let scalar_name = scalar.to_msl_name();
63466352
writeln!(
63476353
self.out,
6348-
"{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6354+
"{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
63496355
b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
63506356
)?;
63516357
let l1 = back::Level(1);
@@ -6354,10 +6360,7 @@ template <typename A>
63546360
"{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
63556361
b_c as u32, a_r as u32
63566362
)?;
6357-
writeln!(
6358-
self.out,
6359-
"{l1}{NAMESPACE}::simdgroup_multiply_accumulate(d,a,b,c);"
6360-
)?;
6363+
writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
63616364
writeln!(self.out, "{l1}return d;")?;
63626365
writeln!(self.out, "}}")?;
63636366
writeln!(self.out)?;
@@ -6439,7 +6442,8 @@ template <typename A>
64396442
self.write_wrapped_image_query(module, func_ctx, image, query)?;
64406443
}
64416444
crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6442-
self.write_wrapped_cooperative_multiply_add(module, func_ctx, a, b)?;
6445+
let space = crate::AddressSpace::Private;
6446+
self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
64436447
}
64446448
_ => {}
64456449
}
@@ -6632,7 +6636,6 @@ template <typename A>
66326636
names: &self.names,
66336637
handle,
66346638
usage: fun_info[handle],
6635-
66366639
reference: true,
66376640
};
66386641
let separator =

naga/src/back/spv/block.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3726,9 +3726,20 @@ impl BlockContext<'_> {
37263726
layout_id,
37273727
self.cached[stride],
37283728
));
3729-
block
3730-
.body
3731-
.push(Instruction::store(self.cached[target], id, None));
3729+
match self.write_access_chain(
3730+
target,
3731+
&mut block,
3732+
AccessTypeAdjustment::None,
3733+
)? {
3734+
ExpressionPointer::Ready {
3735+
pointer_id: target_id,
3736+
} => {
3737+
block.body.push(Instruction::store(target_id, id, None));
3738+
}
3739+
ExpressionPointer::Conditional { .. } => {
3740+
unimplemented!()
3741+
}
3742+
};
37323743
}
37333744
}
37343745
}

naga/src/back/spv/writer.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -971,14 +971,13 @@ impl Writer {
971971
}
972972
}
973973

974-
// Handle globals are pre-emitted and should be loaded automatically.
975-
//
976-
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
977974
match ir_module.types[var.ty].inner {
975+
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
978976
crate::TypeInner::BindingArray { .. } => {
979977
gv.access_id = gv.var_id;
980978
}
981979
_ => {
980+
// Handle globals are pre-emitted and should be loaded automatically.
982981
if var.space == crate::AddressSpace::Handle {
983982
let var_type_id = self.get_handle_type_id(var.ty);
984983
let id = self.id_gen.next();
@@ -1064,6 +1063,7 @@ impl Writer {
10641063
}
10651064
}),
10661065
);
1066+
10671067
context
10681068
.function
10691069
.variables

naga/src/back/wgsl/writer.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -993,13 +993,13 @@ impl<W: Write> Writer<W> {
993993
} => {
994994
let op_str = if store { "Store" } else { "Load" };
995995
let suffix = if row_major { "T" } else { "" };
996-
write!(self.out, "coop{op_str}{suffix}(")?;
996+
write!(self.out, "{level}coop{op_str}{suffix}(")?;
997997
self.write_expr(module, target, func_ctx)?;
998998
write!(self.out, ", ")?;
999999
self.write_expr(module, pointer, func_ctx)?;
10001000
write!(self.out, ", ")?;
10011001
self.write_expr(module, stride, func_ctx)?;
1002-
write!(self.out, ")")?
1002+
writeln!(self.out, ");")?
10031003
}
10041004
}
10051005

@@ -1118,6 +1118,13 @@ impl<W: Write> Writer<W> {
11181118
// If the plain form of the expression is not what we need, emit the
11191119
// operator necessary to correct that.
11201120
let plain = self.plain_form_indirection(expr, module, func_ctx);
1121+
log::trace!(
1122+
"expression {:?}={:?} is {:?}, expected {:?}",
1123+
expr,
1124+
func_ctx.expressions[expr],
1125+
plain,
1126+
requested,
1127+
);
11211128
match (requested, plain) {
11221129
(Indirection::Ordinary, Indirection::Reference) => {
11231130
write!(self.out, "(&")?;

naga/src/front/wgsl/lower/mod.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
524524
span: Span,
525525
) -> Result<'source, Handle<ir::Expression>> {
526526
let mut eval = self.as_const_evaluator();
527+
log::debug!("appending {expr:?}");
527528
eval.try_eval_and_append(expr, span)
528529
.map_err(|e| Box::new(Error::ConstantEvaluatorError(e.into(), span)))
529530
}
@@ -846,6 +847,15 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
846847
fn ensure_type_exists(&mut self, inner: ir::TypeInner) -> Handle<ir::Type> {
847848
self.as_global().ensure_type_exists(None, inner)
848849
}
850+
851+
fn _get_runtime_expression(&self, expr: Handle<ir::Expression>) -> &ir::Expression {
852+
match self.expr_type {
853+
ExpressionContextType::Runtime(ref ctx) => &ctx.function.expressions[expr],
854+
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
855+
unreachable!()
856+
}
857+
}
858+
}
849859
}
850860

851861
struct ArgumentContext<'ctx, 'source> {
@@ -955,6 +965,13 @@ impl<T> Typed<T> {
955965
Self::Plain(expr) => Typed::Plain(f(expr)?),
956966
})
957967
}
968+
969+
fn ref_or<E>(self, error: E) -> core::result::Result<T, E> {
970+
match self {
971+
Self::Reference(v) => Ok(v),
972+
Self::Plain(_) => Err(error),
973+
}
974+
}
958975
}
959976

960977
/// A single vector component or swizzle.
@@ -1677,12 +1694,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
16771694
.as_expression(block, &mut emitter)
16781695
.interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?;
16791696
block.extend(emitter.finish(&ctx.function.expressions));
1680-
let typed = if ctx.module.types[ty].inner.is_handle() {
1681-
Typed::Plain(handle)
1682-
} else {
1683-
Typed::Reference(handle)
1684-
};
1685-
ctx.local_table.insert(v.handle, Declared::Runtime(typed));
1697+
ctx.local_table
1698+
.insert(v.handle, Declared::Runtime(Typed::Reference(handle)));
16861699

16871700
match initializer {
16881701
Some(initializer) => ir::Statement::Store {
@@ -1977,12 +1990,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
19771990
let value_span = ctx.ast_expressions.get_span(value);
19781991
let target = self
19791992
.expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?;
1980-
let target_handle = match target {
1981-
Typed::Reference(handle) => handle,
1982-
Typed::Plain(_) => {
1983-
return Err(Box::new(Error::BadIncrDecrReferenceType(value_span)))
1984-
}
1985-
};
1993+
let target_handle = target.ref_or(Error::BadIncrDecrReferenceType(value_span))?;
19861994

19871995
let mut ectx = ctx.as_expression(block, &mut emitter);
19881996
let scalar = match *resolve_inner!(ectx, target_handle) {
@@ -2139,10 +2147,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
21392147
LoweredGlobalDecl::Var(handle) => {
21402148
let expr = ir::Expression::GlobalVariable(handle);
21412149
let v = &ctx.module.global_variables[handle];
2142-
let force_value = ctx.module.types[v.ty].inner.is_handle();
21432150
match v.space {
21442151
ir::AddressSpace::Handle => Typed::Plain(expr),
2145-
_ if force_value => Typed::Plain(expr),
21462152
_ => Typed::Reference(expr),
21472153
}
21482154
}

naga/src/proc/type_methods.rs

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,6 @@ impl crate::TypeInner {
191191
}
192192
}
193193

194-
/// Returns true if a variable of this type is a handle.
195-
pub const fn is_handle(&self) -> bool {
196-
match *self {
197-
Self::Image { .. }
198-
| Self::Sampler { .. }
199-
| Self::AccelerationStructure { .. }
200-
| Self::CooperativeMatrix { .. } => true,
201-
_ => false,
202-
}
203-
}
204-
205194
/// Attempt to calculate the size of this type. Returns `None` if the size
206195
/// exceeds the limit of [`crate::valid::MAX_TYPE_SIZE`].
207196
pub fn try_size(&self, gctx: super::GlobalCtx) -> Option<u32> {

naga/src/proc/typifier.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,7 @@ impl<'a> ResolveContext<'a> {
454454
}
455455
crate::Expression::GlobalVariable(h) => {
456456
let var = &self.global_vars[h];
457-
let ty = &types[var.ty].inner;
458-
if var.space == crate::AddressSpace::Handle || ty.is_handle() {
457+
if var.space == crate::AddressSpace::Handle {
459458
TypeResolution::Handle(var.ty)
460459
} else {
461460
TypeResolution::Value(Ti::Pointer {
@@ -466,15 +465,10 @@ impl<'a> ResolveContext<'a> {
466465
}
467466
crate::Expression::LocalVariable(h) => {
468467
let var = &self.local_vars[h];
469-
let ty = &types[var.ty].inner;
470-
if ty.is_handle() {
471-
TypeResolution::Handle(var.ty)
472-
} else {
473-
TypeResolution::Value(Ti::Pointer {
474-
base: var.ty,
475-
space: crate::AddressSpace::Function,
476-
})
477-
}
468+
TypeResolution::Value(Ti::Pointer {
469+
base: var.ty,
470+
space: crate::AddressSpace::Function,
471+
})
478472
}
479473
crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
480474
Ti::Pointer { base, space: _ } => {
@@ -493,7 +487,7 @@ impl<'a> ResolveContext<'a> {
493487
None => Ti::Scalar(scalar),
494488
}),
495489
ref other => {
496-
log::error!("Pointer type {other:?}");
490+
log::error!("Pointer {pointer:?} type {other:?}");
497491
return Err(ResolveError::InvalidPointer(pointer));
498492
}
499493
},

0 commit comments

Comments
 (0)