@@ -6,10 +6,11 @@ use spirv::Word;
6
6
use super :: {
7
7
block:: DebugInfoInner ,
8
8
helpers:: { contains_builtin, global_needs_wrapper, map_storage_class} ,
9
- Block , BlockContext , CachedConstant , CachedExpressions , DebugInfo , EntryPointContext , Error ,
10
- Function , FunctionArgument , GlobalVariable , IdGenerator , Instruction , LocalImageType ,
11
- LocalType , LocalVariable , LogicalLayout , LookupFunctionType , LookupType , NumericType , Options ,
12
- PhysicalLayout , PipelineOptions , ResultMember , Writer , WriterFlags , BITS_PER_BYTE ,
9
+ Block , BlockContext , CachedConstant , CachedExpressions , CooperativeType , DebugInfo ,
10
+ EntryPointContext , Error , Function , FunctionArgument , GlobalVariable , IdGenerator , Instruction ,
11
+ LocalImageType , LocalType , LocalVariable , LogicalLayout , LookupFunctionType , LookupType ,
12
+ NumericType , Options , PhysicalLayout , PipelineOptions , ResultMember , Writer , WriterFlags ,
13
+ BITS_PER_BYTE ,
13
14
} ;
14
15
use crate :: {
15
16
arena:: { Handle , HandleVec , UniqueArena } ,
@@ -375,6 +376,12 @@ impl Writer {
375
376
} )
376
377
}
377
378
379
+ pub ( super ) fn get_cooperative_type_id ( & mut self , scalar : crate :: CooperativeScalar ) -> Word {
380
+ match scalar {
381
+ crate :: CooperativeScalar :: F32 => self . get_f32_type_id ( ) ,
382
+ }
383
+ }
384
+
378
385
pub ( super ) fn get_f32_pointer_type_id ( & mut self , class : spirv:: StorageClass ) -> Word {
379
386
let f32_id = self . get_f32_type_id ( ) ;
380
387
self . get_pointer_type_id ( f32_id, class)
@@ -436,7 +443,9 @@ impl Writer {
436
443
// these cases, so unwrap.
437
444
LocalType :: Numeric ( NumericType :: from_inner ( inner) . unwrap ( ) )
438
445
}
439
- crate :: TypeInner :: CooperativeMatrix { .. } => return None ,
446
+ crate :: TypeInner :: CooperativeMatrix { .. } => {
447
+ LocalType :: Cooperative ( CooperativeType :: from_inner ( inner) . unwrap ( ) )
448
+ }
440
449
crate :: TypeInner :: Pointer { base, space } => {
441
450
let base_type_id = self . get_handle_type_id ( base) ;
442
451
LocalType :: Pointer {
@@ -1353,6 +1362,14 @@ impl Writer {
1353
1362
self . require_any ( "16 bit floating-point" , & [ spirv:: Capability :: Float16 ] ) ?;
1354
1363
self . use_extension ( "SPV_KHR_16bit_storage" ) ;
1355
1364
}
1365
+ // Cooperative types and ops
1366
+ crate :: TypeInner :: CooperativeMatrix { .. } => {
1367
+ self . require_any (
1368
+ "cooperative matrix" ,
1369
+ & [ spirv:: Capability :: CooperativeMatrixKHR ] ,
1370
+ ) ?;
1371
+ self . use_extension ( "SPV_KHR_cooperative_matrix" ) ;
1372
+ }
1356
1373
_ => { }
1357
1374
}
1358
1375
Ok ( ( ) )
@@ -1379,12 +1396,31 @@ impl Writer {
1379
1396
instruction. to_words ( & mut self . logical_layout . declarations ) ;
1380
1397
}
1381
1398
1399
+ fn write_cooperative_type_declaration_local ( & mut self , id : Word , coop : CooperativeType ) {
1400
+ let instruction = match coop {
1401
+ CooperativeType :: Matrix {
1402
+ columns,
1403
+ rows,
1404
+ scalar,
1405
+ } => {
1406
+ let scalar_id = self . get_cooperative_type_id ( scalar) ;
1407
+ Instruction :: type_coop_matrix ( id, scalar_id, rows, columns)
1408
+ }
1409
+ } ;
1410
+
1411
+ instruction. to_words ( & mut self . logical_layout . declarations ) ;
1412
+ }
1413
+
1382
1414
fn write_type_declaration_local ( & mut self , id : Word , local_ty : LocalType ) {
1383
1415
let instruction = match local_ty {
1384
1416
LocalType :: Numeric ( numeric) => {
1385
1417
self . write_numeric_type_declaration_local ( id, numeric) ;
1386
1418
return ;
1387
1419
}
1420
+ LocalType :: Cooperative ( coop) => {
1421
+ self . write_cooperative_type_declaration_local ( id, coop) ;
1422
+ return ;
1423
+ }
1388
1424
LocalType :: Pointer { base, class } => Instruction :: type_pointer ( id, class, base) ,
1389
1425
LocalType :: Image ( image) => {
1390
1426
let local_type = LocalType :: Numeric ( NumericType :: Scalar ( image. sampled_type ) ) ;
0 commit comments