Skip to content

Commit 7959016

Browse files
committed
coop: handle simple ops, end-to-end with SPIRV
1 parent 1e9597d commit 7959016

File tree

10 files changed

+120
-16
lines changed

10 files changed

+120
-16
lines changed

naga/src/back/spv/block.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
1919
crate::TypeInner::Scalar(_) => Dimension::Scalar,
2020
crate::TypeInner::Vector { .. } => Dimension::Vector,
2121
crate::TypeInner::Matrix { .. } => Dimension::Matrix,
22+
crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix,
2223
_ => unreachable!(),
2324
}
2425
}
@@ -766,6 +767,7 @@ impl BlockContext<'_> {
766767
rows,
767768
scalar,
768769
} => {
770+
//TODO: why not just rely on `Fadd` for matrices?
769771
self.write_matrix_matrix_column_op(
770772
block,
771773
id,
@@ -781,6 +783,7 @@ impl BlockContext<'_> {
781783
self.cached[expr_handle] = id;
782784
return Ok(());
783785
}
786+
crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd,
784787
_ => unimplemented!(),
785788
},
786789
crate::BinaryOperator::Subtract => match *left_ty_inner {
@@ -809,6 +812,7 @@ impl BlockContext<'_> {
809812
self.cached[expr_handle] = id;
810813
return Ok(());
811814
}
815+
crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub,
812816
_ => unimplemented!(),
813817
},
814818
crate::BinaryOperator::Multiply => {
@@ -842,10 +846,12 @@ impl BlockContext<'_> {
842846
(Dimension::Vector, Dimension::Matrix) => {
843847
spirv::Op::VectorTimesMatrix
844848
}
845-
(Dimension::Matrix, Dimension::Scalar) => {
849+
(Dimension::Matrix, Dimension::Scalar)
850+
| (Dimension::CooperativeMatrix, Dimension::Scalar) => {
846851
spirv::Op::MatrixTimesScalar
847852
}
848-
(Dimension::Scalar, Dimension::Matrix) => {
853+
(Dimension::Scalar, Dimension::Matrix)
854+
| (Dimension::Scalar, Dimension::CooperativeMatrix) => {
849855
reverse_operands = true;
850856
spirv::Op::MatrixTimesScalar
851857
}
@@ -864,6 +870,12 @@ impl BlockContext<'_> {
864870
}
865871
(Dimension::Vector, Dimension::Vector)
866872
| (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
873+
(Dimension::CooperativeMatrix, Dimension::CooperativeMatrix)
874+
//Note: technically can do `FMul` but IR doesn't have matrix per-component multiplication
875+
| (Dimension::CooperativeMatrix, _)
876+
| (_, Dimension::CooperativeMatrix) => {
877+
unimplemented!()
878+
}
867879
}
868880
}
869881
crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {

naga/src/back/spv/instructions.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -284,17 +284,18 @@ impl super::Instruction {
284284
pub(super) fn type_coop_matrix(
285285
id: Word,
286286
scalar_type_id: Word,
287-
row_count: crate::CooperativeSize,
288-
column_count: crate::CooperativeSize,
289-
role: spirv::CooperativeMatrixUse,
287+
scope_id: Word,
288+
row_count_id: Word,
289+
column_count_id: Word,
290+
matrix_use_id: Word,
290291
) -> Self {
291292
let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR);
292293
instruction.set_result(id);
293294
instruction.add_operand(scalar_type_id);
294-
instruction.add_operand(spirv::Scope::Subgroup as u32);
295-
instruction.add_operand(column_count as u32);
296-
instruction.add_operand(row_count as u32);
297-
instruction.add_operand(role as u32);
295+
instruction.add_operand(scope_id);
296+
instruction.add_operand(row_count_id);
297+
instruction.add_operand(column_count_id);
298+
instruction.add_operand(matrix_use_id);
298299
instruction
299300
}
300301

naga/src/back/spv/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ enum Dimension {
482482
Scalar,
483483
Vector,
484484
Matrix,
485+
CooperativeMatrix,
485486
}
486487

487488
/// Key used to look up an operation which we have wrapped in a helper

naga/src/back/spv/writer.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,7 +1368,9 @@ impl Writer {
13681368
"cooperative matrix",
13691369
&[spirv::Capability::CooperativeMatrixKHR],
13701370
)?;
1371+
self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?;
13711372
self.use_extension("SPV_KHR_cooperative_matrix");
1373+
self.use_extension("SPV_KHR_vulkan_memory_model");
13721374
}
13731375
_ => {}
13741376
}
@@ -1405,7 +1407,12 @@ impl Writer {
14051407
role,
14061408
} => {
14071409
let scalar_id = self.get_cooperative_type_id(scalar);
1408-
Instruction::type_coop_matrix(id, scalar_id, rows, columns, role.into())
1410+
let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
1411+
let columns_id = self.get_index_constant(columns as u32);
1412+
let rows_id = self.get_index_constant(rows as u32);
1413+
let role_id =
1414+
self.get_index_constant(spirv::CooperativeMatrixUse::from(role) as u32);
1415+
Instruction::type_coop_matrix(id, scalar_id, scope_id, rows_id, columns_id, role_id)
14091416
}
14101417
};
14111418

@@ -2669,7 +2676,14 @@ impl Writer {
26692676
}
26702677

26712678
let addressing_model = spirv::AddressingModel::Logical;
2672-
let memory_model = spirv::MemoryModel::GLSL450;
2679+
let memory_model = if self
2680+
.capabilities_used
2681+
.contains(&spirv::Capability::VulkanMemoryModel)
2682+
{
2683+
spirv::MemoryModel::Vulkan
2684+
} else {
2685+
spirv::MemoryModel::GLSL450
2686+
};
26732687
//self.check(addressing_model.required_capabilities())?;
26742688
//self.check(memory_model.required_capabilities())?;
26752689

naga/src/ir/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,15 @@ impl CooperativeScalar {
490490
Self::F32 => 4,
491491
}
492492
}
493+
494+
pub const fn to_scalar(&self) -> Scalar {
495+
match *self {
496+
Self::F32 => Scalar {
497+
kind: ScalarKind::Float,
498+
width: 4,
499+
},
500+
}
501+
}
493502
}
494503

495504
/// Role of a cooperative variable in the equation "A * B + C"

naga/src/proc/type_methods.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ impl crate::TypeInner {
115115
match *self {
116116
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar),
117117
Ti::Matrix { scalar, .. } => Some(scalar),
118+
Ti::CooperativeMatrix { scalar, .. } => Some(scalar.to_scalar()),
118119
_ => None,
119120
}
120121
}

naga/src/proc/typifier.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,17 @@ impl Clone for TypeResolution {
143143
columns,
144144
scalar,
145145
},
146+
Ti::CooperativeMatrix {
147+
columns,
148+
rows,
149+
scalar,
150+
role,
151+
} => Ti::CooperativeMatrix {
152+
columns,
153+
rows,
154+
scalar,
155+
role,
156+
},
146157
Ti::Pointer { base, space } => Ti::Pointer { base, space },
147158
Ti::ValuePointer {
148159
size,
@@ -587,6 +598,20 @@ impl<'a> ResolveContext<'a> {
587598
(&Ti::Scalar { .. }, _) => res_right.clone(),
588599
(_, &Ti::Scalar { .. }) => res_left.clone(),
589600
(&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
601+
(
602+
&Ti::CooperativeMatrix {
603+
columns: _,
604+
rows,
605+
scalar,
606+
role,
607+
},
608+
&Ti::CooperativeMatrix { columns, .. },
609+
) => TypeResolution::Value(Ti::CooperativeMatrix {
610+
columns,
611+
rows,
612+
scalar,
613+
role,
614+
}),
590615
(tl, tr) => {
591616
return Err(ResolveError::IncompatibleOperands(format!(
592617
"{tl:?} * {tr:?}"

naga/src/valid/expression.rs

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,9 @@ impl super::Validator {
788788
Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
789789
Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
790790
},
791-
Ti::Matrix { .. } => left_inner == right_inner,
791+
Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } => {
792+
left_inner == right_inner
793+
}
792794
_ => false,
793795
},
794796
Bo::Divide | Bo::Modulo => match *left_inner {
@@ -818,7 +820,7 @@ impl super::Validator {
818820
scalar: scalar2, ..
819821
},
820822
) => scalar1 == scalar2,
821-
// Scalar/matrix.
823+
// Scalar * matrix.
822824
(
823825
&Ti::Scalar(Sc {
824826
kind: Sk::Float, ..
@@ -831,7 +833,7 @@ impl super::Validator {
831833
kind: Sk::Float, ..
832834
}),
833835
) => true,
834-
// Vector/vector.
836+
// Vector * vector.
835837
(
836838
&Ti::Vector {
837839
size: size1,
@@ -864,9 +866,44 @@ impl super::Validator {
864866
},
865867
&Ti::Matrix { rows, .. },
866868
) => size == rows,
869+
// Matrix * matrix.
867870
(&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
868871
columns == rows
869872
}
873+
// Coop matrix * coop matrix.
874+
(
875+
&Ti::CooperativeMatrix {
876+
columns,
877+
scalar: scalar1,
878+
role: role1,
879+
..
880+
},
881+
&Ti::CooperativeMatrix {
882+
rows,
883+
scalar: scalar2,
884+
role: role2,
885+
..
886+
},
887+
) => columns == rows && scalar1 == scalar2 && role1 == role2,
888+
// Scalar * coop matrix.
889+
(
890+
&Ti::Scalar(Sc {
891+
kind: Sk::Float, ..
892+
}),
893+
&Ti::CooperativeMatrix {
894+
scalar: crate::CooperativeScalar::F32,
895+
..
896+
},
897+
)
898+
| (
899+
&Ti::CooperativeMatrix {
900+
scalar: crate::CooperativeScalar::F32,
901+
..
902+
},
903+
&Ti::Scalar(Sc {
904+
kind: Sk::Float, ..
905+
}),
906+
) => true,
870907
_ => false,
871908
};
872909
let left_width = left_inner.scalar_width().unwrap_or(0);
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
targets = "SPIRV"
22
god_mode = true
3+
4+
[spv]
5+
debug = true
6+
version = [1, 4]
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
var<private> a: coop_mat8x8<f32, A>;
2-
var<private> b: coop_mat8x8<f32, B>;
2+
//var<private> b: coop_mat8x8<f32, B>;
33

44
@compute @workgroup_size(8, 8, 1)
55
fn main() {
6-
//let c = a * b;
6+
let a2 = a + a;
77
}

0 commit comments

Comments
 (0)