@@ -45,6 +45,8 @@ pub enum SpirvType<'tcx> {
45
45
element : Word ,
46
46
/// Note: vector count is literal.
47
47
count : u32 ,
48
+ size : Size ,
49
+ align : Align ,
48
50
} ,
49
51
Matrix {
50
52
element : Word ,
@@ -131,7 +133,9 @@ impl SpirvType<'_> {
131
133
}
132
134
result
133
135
}
134
- Self :: Vector { element, count } => cx. emit_global ( ) . type_vector_id ( id, element, count) ,
136
+ Self :: Vector { element, count, .. } => {
137
+ cx. emit_global ( ) . type_vector_id ( id, element, count)
138
+ }
135
139
Self :: Matrix { element, count } => cx. emit_global ( ) . type_matrix_id ( id, element, count) ,
136
140
Self :: Array { element, count } => {
137
141
let result = cx
@@ -280,9 +284,7 @@ impl SpirvType<'_> {
280
284
Self :: Bool => Size :: from_bytes ( 1 ) ,
281
285
Self :: Integer ( width, _) | Self :: Float ( width) => Size :: from_bits ( width) ,
282
286
Self :: Adt { size, .. } => size?,
283
- Self :: Vector { element, count } => {
284
- cx. lookup_type ( element) . sizeof ( cx) ? * count. next_power_of_two ( ) as u64
285
- }
287
+ Self :: Vector { size, .. } => size,
286
288
Self :: Matrix { element, count } => cx. lookup_type ( element) . sizeof ( cx) ? * count as u64 ,
287
289
Self :: Array { element, count } => {
288
290
cx. lookup_type ( element) . sizeof ( cx) ?
@@ -311,13 +313,7 @@ impl SpirvType<'_> {
311
313
Self :: Bool => Align :: from_bytes ( 1 ) . unwrap ( ) ,
312
314
Self :: Integer ( width, _) | Self :: Float ( width) => Align :: from_bits ( width as u64 ) . unwrap ( ) ,
313
315
Self :: Adt { align, .. } => align,
314
- // Vectors have size==align
315
- Self :: Vector { .. } => Align :: from_bytes (
316
- self . sizeof ( cx)
317
- . expect ( "alignof: Vectors must be sized" )
318
- . bytes ( ) ,
319
- )
320
- . expect ( "alignof: Vectors must have power-of-2 size" ) ,
316
+ Self :: Vector { align, .. } => align,
321
317
Self :: Array { element, .. }
322
318
| Self :: RuntimeArray { element }
323
319
| Self :: Matrix { element, .. } => cx. lookup_type ( element) . alignof ( cx) ,
@@ -382,7 +378,17 @@ impl SpirvType<'_> {
382
378
SpirvType :: Bool => SpirvType :: Bool ,
383
379
SpirvType :: Integer ( width, signedness) => SpirvType :: Integer ( width, signedness) ,
384
380
SpirvType :: Float ( width) => SpirvType :: Float ( width) ,
385
- SpirvType :: Vector { element, count } => SpirvType :: Vector { element, count } ,
381
+ SpirvType :: Vector {
382
+ element,
383
+ count,
384
+ size,
385
+ align,
386
+ } => SpirvType :: Vector {
387
+ element,
388
+ count,
389
+ size,
390
+ align,
391
+ } ,
386
392
SpirvType :: Matrix { element, count } => SpirvType :: Matrix { element, count } ,
387
393
SpirvType :: Array { element, count } => SpirvType :: Array { element, count } ,
388
394
SpirvType :: RuntimeArray { element } => SpirvType :: RuntimeArray { element } ,
@@ -435,6 +441,32 @@ impl SpirvType<'_> {
435
441
} ,
436
442
}
437
443
}
444
+
445
+ pub fn simd_vector ( cx : & CodegenCx < ' _ > , span : Span , element : SpirvType < ' _ > , count : u32 ) -> Self {
446
+ Self :: Vector {
447
+ element : element. def ( span, cx) ,
448
+ count,
449
+ size : element. sizeof ( cx) . unwrap ( ) * count as u64 ,
450
+ align : element. alignof ( cx) ,
451
+ }
452
+ }
453
+
454
+ /// Now that we can have different `OpTypeVector` types with various sizes or alignments, having a statement in an
455
+ /// `asm!` block like `OpTypeVector %f32 3` doesn't correlate to a specific type anymore. It could be Vec3 or Vec3A,
456
+ /// nobody knows.
457
+ ///
458
+ /// FIXME(@firestar99) This is a giant hack that only works with base glam types, if any other type with
459
+ /// `#[spirv(vector)]` is used, this will generate a mismatched type and fail validation. Also, Vec3A doesn't work.
460
+ pub fn glam_vector_asm_hack ( cx : & CodegenCx < ' _ > , span : Span , element : Word , count : u32 ) -> Word {
461
+ let spirv_type = cx. lookup_type ( element) ;
462
+ SpirvType :: Vector {
463
+ element,
464
+ count,
465
+ size : spirv_type. sizeof ( cx) . unwrap ( ) * count as u64 ,
466
+ align : spirv_type. alignof ( cx) ,
467
+ }
468
+ . def ( span, cx)
469
+ }
438
470
}
439
471
440
472
impl < ' a > SpirvType < ' a > {
@@ -501,11 +533,18 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
501
533
. field ( "field_names" , & field_names)
502
534
. finish ( )
503
535
}
504
- SpirvType :: Vector { element, count } => f
536
+ SpirvType :: Vector {
537
+ element,
538
+ count,
539
+ size,
540
+ align,
541
+ } => f
505
542
. debug_struct ( "Vector" )
506
543
. field ( "id" , & self . id )
507
544
. field ( "element" , & self . cx . debug_type ( element) )
508
545
. field ( "count" , & count)
546
+ . field ( "size" , & size)
547
+ . field ( "align" , & align)
509
548
. finish ( ) ,
510
549
SpirvType :: Matrix { element, count } => f
511
550
. debug_struct ( "Matrix" )
@@ -668,7 +707,7 @@ impl SpirvTypePrinter<'_, '_> {
668
707
}
669
708
f. write_str ( " }" )
670
709
}
671
- SpirvType :: Vector { element, count } | SpirvType :: Matrix { element, count } => {
710
+ SpirvType :: Vector { element, count, .. } | SpirvType :: Matrix { element, count } => {
672
711
ty ( self . cx , stack, f, element) ?;
673
712
write ! ( f, "x{count}" )
674
713
}
0 commit comments