Skip to content

Commit 82df04d

Browse files
committed
coop: first bits of Vulkan support for the type
1 parent 889a54c commit 82df04d

File tree

5 files changed

+88
-5
lines changed

5 files changed

+88
-5
lines changed

naga/src/back/spv/instructions.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,22 @@ impl super::Instruction {
281281
instruction
282282
}
283283

284+
pub(super) fn type_coop_matrix(
285+
id: Word,
286+
scalar_type_id: Word,
287+
row_count: crate::CooperativeVectorSize,
288+
column_count: crate::CooperativeVectorSize,
289+
) -> Self {
290+
let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR);
291+
instruction.set_result(id);
292+
instruction.add_operand(scalar_type_id);
293+
instruction.add_operand(spirv::Scope::Subgroup as u32);
294+
instruction.add_operand(column_count as u32);
295+
instruction.add_operand(row_count as u32);
296+
instruction.add_operand(spirv::CooperativeMatrixUse::MatrixAKHR as u32); //TODO: configure or expose
297+
instruction
298+
}
299+
284300
#[allow(clippy::too_many_arguments)]
285301
pub(super) fn type_image(
286302
id: Word,

naga/src/back/spv/mod.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,33 @@ impl NumericType {
338338
}
339339
}
340340

341+
/// A cooperative type, for use in [`LocalType`].
342+
#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
343+
enum CooperativeType {
344+
Matrix {
345+
columns: crate::CooperativeVectorSize,
346+
rows: crate::CooperativeVectorSize,
347+
scalar: crate::CooperativeScalar,
348+
},
349+
}
350+
351+
impl CooperativeType {
352+
const fn from_inner(inner: &crate::TypeInner) -> Option<Self> {
353+
match *inner {
354+
crate::TypeInner::CooperativeMatrix {
355+
columns,
356+
rows,
357+
scalar,
358+
} => Some(Self::Matrix {
359+
columns,
360+
rows,
361+
scalar,
362+
}),
363+
_ => None,
364+
}
365+
}
366+
}
367+
341368
/// A SPIR-V type constructed during code generation.
342369
///
343370
/// This is the variant of [`LookupType`] used to represent types that might not
@@ -387,6 +414,7 @@ impl NumericType {
387414
enum LocalType {
388415
/// A numeric type.
389416
Numeric(NumericType),
417+
Cooperative(CooperativeType),
390418
Pointer {
391419
base: Word,
392420
class: spirv::StorageClass,

naga/src/back/spv/writer.rs

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use spirv::Word;
66
use super::{
77
block::DebugInfoInner,
88
helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
9-
Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, EntryPointContext, Error,
10-
Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalImageType,
11-
LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, NumericType, Options,
12-
PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
9+
Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo,
10+
EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction,
11+
LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType,
12+
NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags,
13+
BITS_PER_BYTE,
1314
};
1415
use crate::{
1516
arena::{Handle, HandleVec, UniqueArena},
@@ -373,6 +374,12 @@ impl Writer {
373374
})
374375
}
375376

377+
pub(super) fn get_cooperative_type_id(&mut self, scalar: crate::CooperativeScalar) -> Word {
378+
match scalar {
379+
crate::CooperativeScalar::F32 => self.get_f32_type_id(),
380+
}
381+
}
382+
376383
pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
377384
let f32_id = self.get_f32_type_id();
378385
self.get_pointer_type_id(f32_id, class)
@@ -434,7 +441,9 @@ impl Writer {
434441
// these cases, so unwrap.
435442
LocalType::Numeric(NumericType::from_inner(inner).unwrap())
436443
}
437-
crate::TypeInner::CooperativeMatrix { .. } => return None,
444+
crate::TypeInner::CooperativeMatrix { .. } => {
445+
LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap())
446+
}
438447
crate::TypeInner::Pointer { base, space } => {
439448
let base_type_id = self.get_handle_type_id(base);
440449
LocalType::Pointer {
@@ -1331,6 +1340,14 @@ impl Writer {
13311340
self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?;
13321341
self.use_extension("SPV_KHR_16bit_storage");
13331342
}
1343+
// Cooperative types and ops
1344+
crate::TypeInner::CooperativeMatrix { .. } => {
1345+
self.require_any(
1346+
"cooperative matrix",
1347+
&[spirv::Capability::CooperativeMatrixKHR],
1348+
)?;
1349+
self.use_extension("SPV_KHR_cooperative_matrix");
1350+
}
13341351
_ => {}
13351352
}
13361353
Ok(())
@@ -1357,12 +1374,31 @@ impl Writer {
13571374
instruction.to_words(&mut self.logical_layout.declarations);
13581375
}
13591376

1377+
fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) {
1378+
let instruction = match coop {
1379+
CooperativeType::Matrix {
1380+
columns,
1381+
rows,
1382+
scalar,
1383+
} => {
1384+
let scalar_id = self.get_cooperative_type_id(scalar);
1385+
Instruction::type_coop_matrix(id, scalar_id, rows, columns)
1386+
}
1387+
};
1388+
1389+
instruction.to_words(&mut self.logical_layout.declarations);
1390+
}
1391+
13601392
fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
13611393
let instruction = match local_ty {
13621394
LocalType::Numeric(numeric) => {
13631395
self.write_numeric_type_declaration_local(id, numeric);
13641396
return;
13651397
}
1398+
LocalType::Cooperative(coop) => {
1399+
self.write_cooperative_type_declaration_local(id, coop);
1400+
return;
1401+
}
13661402
LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base),
13671403
LocalType::Image(image) => {
13681404
let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));

naga/src/valid/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ bitflags::bitflags! {
186186
/// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store
187187
/// `f16`-precision values in `f32`s.
188188
const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
189+
/// Support for cooperative matrix types and operations
190+
const COOPERATIVE_MATRIX = 1 << 29;
189191
}
190192
}
191193

naga/src/valid/type.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ impl super::Validator {
420420
rows: _,
421421
scalar,
422422
} => {
423+
self.require_type_capability(Capabilities::COOPERATIVE_MATRIX)?;
423424
if scalar != crate::CooperativeScalar::F32 {
424425
return Err(TypeError::MatrixElementNotFloat);
425426
}

0 commit comments

Comments
 (0)