@@ -19,6 +19,7 @@ fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
19
19
crate :: TypeInner :: Scalar ( _) => Dimension :: Scalar ,
20
20
crate :: TypeInner :: Vector { .. } => Dimension :: Vector ,
21
21
crate :: TypeInner :: Matrix { .. } => Dimension :: Matrix ,
22
+ crate :: TypeInner :: CooperativeMatrix { .. } => Dimension :: CooperativeMatrix ,
22
23
_ => unreachable ! ( ) ,
23
24
}
24
25
}
@@ -766,6 +767,7 @@ impl BlockContext<'_> {
766
767
rows,
767
768
scalar,
768
769
} => {
770
+ //TODO: why not just rely on `Fadd` for matrices?
769
771
self . write_matrix_matrix_column_op (
770
772
block,
771
773
id,
@@ -781,6 +783,7 @@ impl BlockContext<'_> {
781
783
self . cached [ expr_handle] = id;
782
784
return Ok ( ( ) ) ;
783
785
}
786
+ crate :: TypeInner :: CooperativeMatrix { .. } => spirv:: Op :: FAdd ,
784
787
_ => unimplemented ! ( ) ,
785
788
} ,
786
789
crate :: BinaryOperator :: Subtract => match * left_ty_inner {
@@ -809,6 +812,7 @@ impl BlockContext<'_> {
809
812
self . cached [ expr_handle] = id;
810
813
return Ok ( ( ) ) ;
811
814
}
815
+ crate :: TypeInner :: CooperativeMatrix { .. } => spirv:: Op :: FSub ,
812
816
_ => unimplemented ! ( ) ,
813
817
} ,
814
818
crate :: BinaryOperator :: Multiply => {
@@ -842,10 +846,12 @@ impl BlockContext<'_> {
842
846
( Dimension :: Vector , Dimension :: Matrix ) => {
843
847
spirv:: Op :: VectorTimesMatrix
844
848
}
845
- ( Dimension :: Matrix , Dimension :: Scalar ) => {
849
+ ( Dimension :: Matrix , Dimension :: Scalar )
850
+ | ( Dimension :: CooperativeMatrix , Dimension :: Scalar ) => {
846
851
spirv:: Op :: MatrixTimesScalar
847
852
}
848
- ( Dimension :: Scalar , Dimension :: Matrix ) => {
853
+ ( Dimension :: Scalar , Dimension :: Matrix )
854
+ | ( Dimension :: Scalar , Dimension :: CooperativeMatrix ) => {
849
855
reverse_operands = true ;
850
856
spirv:: Op :: MatrixTimesScalar
851
857
}
@@ -864,6 +870,12 @@ impl BlockContext<'_> {
864
870
}
865
871
( Dimension :: Vector , Dimension :: Vector )
866
872
| ( Dimension :: Scalar , Dimension :: Scalar ) => spirv:: Op :: IMul ,
873
+ ( Dimension :: CooperativeMatrix , Dimension :: CooperativeMatrix )
874
+ //Note: technically can do `FMul` but IR doesn't have matrix per-component multiplication
875
+ | ( Dimension :: CooperativeMatrix , _)
876
+ | ( _, Dimension :: CooperativeMatrix ) => {
877
+ unimplemented ! ( )
878
+ }
867
879
}
868
880
}
869
881
crate :: BinaryOperator :: Divide => match left_ty_inner. scalar_kind ( ) {
0 commit comments