@@ -6,10 +6,11 @@ use spirv::Word;
6
6
use super :: {
7
7
block:: DebugInfoInner ,
8
8
helpers:: { contains_builtin, global_needs_wrapper, map_storage_class} ,
9
- Block , BlockContext , CachedConstant , CachedExpressions , DebugInfo , EntryPointContext , Error ,
10
- Function , FunctionArgument , GlobalVariable , IdGenerator , Instruction , LocalImageType ,
11
- LocalType , LocalVariable , LogicalLayout , LookupFunctionType , LookupType , NumericType , Options ,
12
- PhysicalLayout , PipelineOptions , ResultMember , Writer , WriterFlags , BITS_PER_BYTE ,
9
+ Block , BlockContext , CachedConstant , CachedExpressions , CooperativeType , DebugInfo ,
10
+ EntryPointContext , Error , Function , FunctionArgument , GlobalVariable , IdGenerator , Instruction ,
11
+ LocalImageType , LocalType , LocalVariable , LogicalLayout , LookupFunctionType , LookupType ,
12
+ NumericType , Options , PhysicalLayout , PipelineOptions , ResultMember , Writer , WriterFlags ,
13
+ BITS_PER_BYTE ,
13
14
} ;
14
15
use crate :: {
15
16
arena:: { Handle , HandleVec , UniqueArena } ,
@@ -373,6 +374,12 @@ impl Writer {
373
374
} )
374
375
}
375
376
377
+ pub ( super ) fn get_cooperative_type_id ( & mut self , scalar : crate :: CooperativeScalar ) -> Word {
378
+ match scalar {
379
+ crate :: CooperativeScalar :: F32 => self . get_f32_type_id ( ) ,
380
+ }
381
+ }
382
+
376
383
pub ( super ) fn get_f32_pointer_type_id ( & mut self , class : spirv:: StorageClass ) -> Word {
377
384
let f32_id = self . get_f32_type_id ( ) ;
378
385
self . get_pointer_type_id ( f32_id, class)
@@ -434,7 +441,9 @@ impl Writer {
434
441
// these cases, so unwrap.
435
442
LocalType :: Numeric ( NumericType :: from_inner ( inner) . unwrap ( ) )
436
443
}
437
- crate :: TypeInner :: CooperativeMatrix { .. } => return None ,
444
+ crate :: TypeInner :: CooperativeMatrix { .. } => {
445
+ LocalType :: Cooperative ( CooperativeType :: from_inner ( inner) . unwrap ( ) )
446
+ }
438
447
crate :: TypeInner :: Pointer { base, space } => {
439
448
let base_type_id = self . get_handle_type_id ( base) ;
440
449
LocalType :: Pointer {
@@ -1331,6 +1340,14 @@ impl Writer {
1331
1340
self . require_any ( "16 bit floating-point" , & [ spirv:: Capability :: Float16 ] ) ?;
1332
1341
self . use_extension ( "SPV_KHR_16bit_storage" ) ;
1333
1342
}
1343
+ // Cooperative types and ops
1344
+ crate :: TypeInner :: CooperativeMatrix { .. } => {
1345
+ self . require_any (
1346
+ "cooperative matrix" ,
1347
+ & [ spirv:: Capability :: CooperativeMatrixKHR ] ,
1348
+ ) ?;
1349
+ self . use_extension ( "SPV_KHR_cooperative_matrix" ) ;
1350
+ }
1334
1351
_ => { }
1335
1352
}
1336
1353
Ok ( ( ) )
@@ -1357,12 +1374,31 @@ impl Writer {
1357
1374
instruction. to_words ( & mut self . logical_layout . declarations ) ;
1358
1375
}
1359
1376
1377
+ fn write_cooperative_type_declaration_local ( & mut self , id : Word , coop : CooperativeType ) {
1378
+ let instruction = match coop {
1379
+ CooperativeType :: Matrix {
1380
+ columns,
1381
+ rows,
1382
+ scalar,
1383
+ } => {
1384
+ let scalar_id = self . get_cooperative_type_id ( scalar) ;
1385
+ Instruction :: type_coop_matrix ( id, scalar_id, rows, columns)
1386
+ }
1387
+ } ;
1388
+
1389
+ instruction. to_words ( & mut self . logical_layout . declarations ) ;
1390
+ }
1391
+
1360
1392
fn write_type_declaration_local ( & mut self , id : Word , local_ty : LocalType ) {
1361
1393
let instruction = match local_ty {
1362
1394
LocalType :: Numeric ( numeric) => {
1363
1395
self . write_numeric_type_declaration_local ( id, numeric) ;
1364
1396
return ;
1365
1397
}
1398
+ LocalType :: Cooperative ( coop) => {
1399
+ self . write_cooperative_type_declaration_local ( id, coop) ;
1400
+ return ;
1401
+ }
1366
1402
LocalType :: Pointer { base, class } => Instruction :: type_pointer ( id, class, base) ,
1367
1403
LocalType :: Image ( image) => {
1368
1404
let local_type = LocalType :: Numeric ( NumericType :: Scalar ( image. sampled_type ) ) ;
0 commit comments