Skip to content

Commit f7e2d28

Browse files
committed
coop: first bits of Vulkan support for the type
1 parent 0a3cba4 commit f7e2d28

File tree

5 files changed

+88
-5
lines changed

5 files changed

+88
-5
lines changed

naga/src/back/spv/instructions.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,22 @@ impl super::Instruction {
281281
instruction
282282
}
283283

284+
pub(super) fn type_coop_matrix(
285+
id: Word,
286+
scalar_type_id: Word,
287+
row_count: crate::CooperativeVectorSize,
288+
column_count: crate::CooperativeVectorSize,
289+
) -> Self {
290+
let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR);
291+
instruction.set_result(id);
292+
instruction.add_operand(scalar_type_id);
293+
instruction.add_operand(spirv::Scope::Subgroup as u32);
294+
instruction.add_operand(column_count as u32);
295+
instruction.add_operand(row_count as u32);
296+
instruction.add_operand(spirv::CooperativeMatrixUse::MatrixAKHR as u32); //TODO: configure or expose
297+
instruction
298+
}
299+
284300
#[allow(clippy::too_many_arguments)]
285301
pub(super) fn type_image(
286302
id: Word,

naga/src/back/spv/mod.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,33 @@ impl NumericType {
340340
}
341341
}
342342

343+
/// A cooperative type, for use in [`LocalType`].
344+
#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
345+
enum CooperativeType {
346+
Matrix {
347+
columns: crate::CooperativeVectorSize,
348+
rows: crate::CooperativeVectorSize,
349+
scalar: crate::CooperativeScalar,
350+
},
351+
}
352+
353+
impl CooperativeType {
354+
const fn from_inner(inner: &crate::TypeInner) -> Option<Self> {
355+
match *inner {
356+
crate::TypeInner::CooperativeMatrix {
357+
columns,
358+
rows,
359+
scalar,
360+
} => Some(Self::Matrix {
361+
columns,
362+
rows,
363+
scalar,
364+
}),
365+
_ => None,
366+
}
367+
}
368+
}
369+
343370
/// A SPIR-V type constructed during code generation.
344371
///
345372
/// This is the variant of [`LookupType`] used to represent types that might not
@@ -389,6 +416,7 @@ impl NumericType {
389416
enum LocalType {
390417
/// A numeric type.
391418
Numeric(NumericType),
419+
Cooperative(CooperativeType),
392420
Pointer {
393421
base: Word,
394422
class: spirv::StorageClass,

naga/src/back/spv/writer.rs

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use spirv::Word;
66
use super::{
77
block::DebugInfoInner,
88
helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
9-
Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, EntryPointContext, Error,
10-
Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalImageType,
11-
LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, NumericType, Options,
12-
PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
9+
Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo,
10+
EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction,
11+
LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType,
12+
NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags,
13+
BITS_PER_BYTE,
1314
};
1415
use crate::{
1516
arena::{Handle, HandleVec, UniqueArena},
@@ -375,6 +376,12 @@ impl Writer {
375376
})
376377
}
377378

379+
pub(super) fn get_cooperative_type_id(&mut self, scalar: crate::CooperativeScalar) -> Word {
380+
match scalar {
381+
crate::CooperativeScalar::F32 => self.get_f32_type_id(),
382+
}
383+
}
384+
378385
pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
379386
let f32_id = self.get_f32_type_id();
380387
self.get_pointer_type_id(f32_id, class)
@@ -436,7 +443,9 @@ impl Writer {
436443
// these cases, so unwrap.
437444
LocalType::Numeric(NumericType::from_inner(inner).unwrap())
438445
}
439-
crate::TypeInner::CooperativeMatrix { .. } => return None,
446+
crate::TypeInner::CooperativeMatrix { .. } => {
447+
LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap())
448+
}
440449
crate::TypeInner::Pointer { base, space } => {
441450
let base_type_id = self.get_handle_type_id(base);
442451
LocalType::Pointer {
@@ -1353,6 +1362,14 @@ impl Writer {
13531362
self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?;
13541363
self.use_extension("SPV_KHR_16bit_storage");
13551364
}
1365+
// Cooperative types and ops
1366+
crate::TypeInner::CooperativeMatrix { .. } => {
1367+
self.require_any(
1368+
"cooperative matrix",
1369+
&[spirv::Capability::CooperativeMatrixKHR],
1370+
)?;
1371+
self.use_extension("SPV_KHR_cooperative_matrix");
1372+
}
13561373
_ => {}
13571374
}
13581375
Ok(())
@@ -1379,12 +1396,31 @@ impl Writer {
13791396
instruction.to_words(&mut self.logical_layout.declarations);
13801397
}
13811398

1399+
fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) {
1400+
let instruction = match coop {
1401+
CooperativeType::Matrix {
1402+
columns,
1403+
rows,
1404+
scalar,
1405+
} => {
1406+
let scalar_id = self.get_cooperative_type_id(scalar);
1407+
Instruction::type_coop_matrix(id, scalar_id, rows, columns)
1408+
}
1409+
};
1410+
1411+
instruction.to_words(&mut self.logical_layout.declarations);
1412+
}
1413+
13821414
fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
13831415
let instruction = match local_ty {
13841416
LocalType::Numeric(numeric) => {
13851417
self.write_numeric_type_declaration_local(id, numeric);
13861418
return;
13871419
}
1420+
LocalType::Cooperative(coop) => {
1421+
self.write_cooperative_type_declaration_local(id, coop);
1422+
return;
1423+
}
13881424
LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base),
13891425
LocalType::Image(image) => {
13901426
let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));

naga/src/valid/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ bitflags::bitflags! {
186186
/// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store
187187
/// `f16`-precision values in `f32`s.
188188
const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28;
189+
/// Support for cooperative matrix types and operations
190+
const COOPERATIVE_MATRIX = 1 << 29;
189191
}
190192
}
191193

naga/src/valid/type.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ impl super::Validator {
420420
rows: _,
421421
scalar,
422422
} => {
423+
self.require_type_capability(Capabilities::COOPERATIVE_MATRIX)?;
423424
if scalar != crate::CooperativeScalar::F32 {
424425
return Err(TypeError::MatrixElementNotFloat);
425426
}

0 commit comments

Comments
 (0)