@@ -45,6 +45,8 @@ pub enum SpirvType<'tcx> {
45
45
element : Word ,
46
46
/// Note: vector count is literal.
47
47
count : u32 ,
48
+ size : Size ,
49
+ align : Align ,
48
50
} ,
49
51
Matrix {
50
52
element : Word ,
@@ -131,7 +133,9 @@ impl SpirvType<'_> {
131
133
}
132
134
result
133
135
}
134
- Self :: Vector { element, count } => cx. emit_global ( ) . type_vector_id ( id, element, count) ,
136
+ Self :: Vector { element, count, .. } => {
137
+ cx. emit_global ( ) . type_vector_id ( id, element, count)
138
+ }
135
139
Self :: Matrix { element, count } => cx. emit_global ( ) . type_matrix_id ( id, element, count) ,
136
140
Self :: Array { element, count } => {
137
141
let result = cx
@@ -280,9 +284,7 @@ impl SpirvType<'_> {
280
284
Self :: Bool => Size :: from_bytes ( 1 ) ,
281
285
Self :: Integer ( width, _) | Self :: Float ( width) => Size :: from_bits ( width) ,
282
286
Self :: Adt { size, .. } => size?,
283
- Self :: Vector { element, count } => {
284
- cx. lookup_type ( element) . sizeof ( cx) ? * count. next_power_of_two ( ) as u64
285
- }
287
+ Self :: Vector { size, .. } => size,
286
288
Self :: Matrix { element, count } => cx. lookup_type ( element) . sizeof ( cx) ? * count as u64 ,
287
289
Self :: Array { element, count } => {
288
290
cx. lookup_type ( element) . sizeof ( cx) ?
@@ -310,14 +312,7 @@ impl SpirvType<'_> {
310
312
311
313
Self :: Bool => Align :: from_bytes ( 1 ) . unwrap ( ) ,
312
314
Self :: Integer ( width, _) | Self :: Float ( width) => Align :: from_bits ( width as u64 ) . unwrap ( ) ,
313
- Self :: Adt { align, .. } => align,
314
- // Vectors have size==align
315
- Self :: Vector { .. } => Align :: from_bytes (
316
- self . sizeof ( cx)
317
- . expect ( "alignof: Vectors must be sized" )
318
- . bytes ( ) ,
319
- )
320
- . expect ( "alignof: Vectors must have power-of-2 size" ) ,
315
+ Self :: Adt { align, .. } | Self :: Vector { align, .. } => align,
321
316
Self :: Array { element, .. }
322
317
| Self :: RuntimeArray { element }
323
318
| Self :: Matrix { element, .. } => cx. lookup_type ( element) . alignof ( cx) ,
@@ -382,7 +377,17 @@ impl SpirvType<'_> {
382
377
SpirvType :: Bool => SpirvType :: Bool ,
383
378
SpirvType :: Integer ( width, signedness) => SpirvType :: Integer ( width, signedness) ,
384
379
SpirvType :: Float ( width) => SpirvType :: Float ( width) ,
385
- SpirvType :: Vector { element, count } => SpirvType :: Vector { element, count } ,
380
+ SpirvType :: Vector {
381
+ element,
382
+ count,
383
+ size,
384
+ align,
385
+ } => SpirvType :: Vector {
386
+ element,
387
+ count,
388
+ size,
389
+ align,
390
+ } ,
386
391
SpirvType :: Matrix { element, count } => SpirvType :: Matrix { element, count } ,
387
392
SpirvType :: Array { element, count } => SpirvType :: Array { element, count } ,
388
393
SpirvType :: RuntimeArray { element } => SpirvType :: RuntimeArray { element } ,
@@ -435,6 +440,33 @@ impl SpirvType<'_> {
435
440
} ,
436
441
}
437
442
}
443
+
444
+ pub fn simd_vector ( cx : & CodegenCx < ' _ > , span : Span , element : SpirvType < ' _ > , count : u32 ) -> Self {
445
+ Self :: Vector {
446
+ element : element. def ( span, cx) ,
447
+ count,
448
+ size : element. sizeof ( cx) . unwrap ( ) * count as u64 ,
449
+ align : element. alignof ( cx) ,
450
+ }
451
+ }
452
+
453
+ /// Now that we can have different `OpTypeVector` types with various sizes or alignments, having a statement in an
454
+ /// `asm!` block like `OpTypeVector %f32 3` doesn't correlate to a specific type anymore. It could be Vec3 or
455
+ /// `Vec3A`, nobody knows.
456
+ ///
457
+ /// FIXME(@firestar99) This is a giant hack that only works with base glam types, if any other type with
458
+ /// `#[spirv(vector)]` is used, this will generate a mismatched type and fail validation.
459
+ /// Also, `Vec3A` doesn't work.
460
+ pub fn glam_vector_asm_hack ( cx : & CodegenCx < ' _ > , span : Span , element : Word , count : u32 ) -> Word {
461
+ let spirv_type = cx. lookup_type ( element) ;
462
+ SpirvType :: Vector {
463
+ element,
464
+ count,
465
+ size : spirv_type. sizeof ( cx) . unwrap ( ) * count as u64 ,
466
+ align : spirv_type. alignof ( cx) ,
467
+ }
468
+ . def ( span, cx)
469
+ }
438
470
}
439
471
440
472
impl < ' a > SpirvType < ' a > {
@@ -501,11 +533,18 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
501
533
. field ( "field_names" , & field_names)
502
534
. finish ( )
503
535
}
504
- SpirvType :: Vector { element, count } => f
536
+ SpirvType :: Vector {
537
+ element,
538
+ count,
539
+ size,
540
+ align,
541
+ } => f
505
542
. debug_struct ( "Vector" )
506
543
. field ( "id" , & self . id )
507
544
. field ( "element" , & self . cx . debug_type ( element) )
508
545
. field ( "count" , & count)
546
+ . field ( "size" , & size)
547
+ . field ( "align" , & align)
509
548
. finish ( ) ,
510
549
SpirvType :: Matrix { element, count } => f
511
550
. debug_struct ( "Matrix" )
@@ -668,7 +707,7 @@ impl SpirvTypePrinter<'_, '_> {
668
707
}
669
708
f. write_str ( " }" )
670
709
}
671
- SpirvType :: Vector { element, count } | SpirvType :: Matrix { element, count } => {
710
+ SpirvType :: Vector { element, count, .. } | SpirvType :: Matrix { element, count } => {
672
711
ty ( self . cx , stack, f, element) ?;
673
712
write ! ( f, "x{count}" )
674
713
}
0 commit comments