Skip to content

Commit c089cb2

Browse files
committed
abi layout: give Vector a dynamic size and alignment
1 parent 3c5a806 commit c089cb2

File tree

8 files changed

+167
-67
lines changed

8 files changed

+167
-67
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 79 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,8 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
639639
SpirvType::Vector {
640640
element: elem_spirv,
641641
count: count as u32,
642+
size: self.size,
643+
align: self.align.abi,
642644
}
643645
.def(span, cx)
644646
}
@@ -1220,43 +1222,93 @@ fn trans_intrinsic_type<'tcx>(
12201222
}
12211223
}
12221224
IntrinsicType::Matrix => {
1223-
let span = def_id_for_spirv_type_adt(ty)
1224-
.map(|did| cx.tcx.def_span(did))
1225-
.expect("#[spirv(matrix)] must be added to a type which has DefId");
1226-
1227-
let field_types = (0..ty.fields.count())
1228-
.map(|i| ty.field(cx, i).spirv_type(span, cx))
1229-
.collect::<Vec<_>>();
1230-
if field_types.len() < 2 {
1231-
return Err(cx
1232-
.tcx
1233-
.dcx()
1234-
.span_err(span, "#[spirv(matrix)] type must have at least two fields"));
1235-
}
1236-
let elem_type = field_types[0];
1237-
if !field_types.iter().all(|&ty| ty == elem_type) {
1238-
return Err(cx.tcx.dcx().span_err(
1239-
span,
1240-
"#[spirv(matrix)] type fields must all be the same type",
1241-
));
1242-
}
1243-
match cx.lookup_type(elem_type) {
1225+
let (element, count) =
1226+
trans_glam_like_struct(cx, span, ty, args, "`#[spirv(matrix)]`")?;
1227+
match cx.lookup_type(element) {
12441228
SpirvType::Vector { .. } => (),
12451229
ty => {
12461230
return Err(cx
12471231
.tcx
12481232
.dcx()
1249-
.struct_span_err(span, "#[spirv(matrix)] type fields must all be vectors")
1250-
.with_note(format!("field type is {}", ty.debug(elem_type, cx)))
1233+
.struct_span_err(span, "`#[spirv(matrix)]` type fields must all be vectors")
1234+
.with_note(format!("field type is {}", ty.debug(element, cx)))
12511235
.emit());
12521236
}
12531237
}
1254-
1255-
Ok(SpirvType::Matrix {
1256-
element: elem_type,
1257-
count: field_types.len() as u32,
1238+
Ok(SpirvType::Matrix { element, count }.def(span, cx))
1239+
}
1240+
IntrinsicType::Vector => {
1241+
let (element, count) =
1242+
trans_glam_like_struct(cx, span, ty, args, "`#[spirv(vector)]`")?;
1243+
match cx.lookup_type(element) {
1244+
SpirvType::Float { .. } => (),
1245+
SpirvType::Integer { .. } => (),
1246+
ty => {
1247+
return Err(cx
1248+
.tcx
1249+
.dcx()
1250+
.struct_span_err(
1251+
span,
1252+
"`#[spirv(vector)]` type fields must all be floats or integers",
1253+
)
1254+
.with_note(format!("field type is {}", ty.debug(element, cx)))
1255+
.emit());
1256+
}
1257+
}
1258+
Ok(SpirvType::Vector {
1259+
element,
1260+
count,
1261+
size: ty.size,
1262+
align: ty.align.abi,
12581263
}
12591264
.def(span, cx))
12601265
}
12611266
}
12621267
}
1268+
1269+
/// A struct with multiple fields of the same kind
1270+
/// Used for `#[spirv(vector)]` and `#[spirv(matrix)]`
1271+
fn trans_glam_like_struct<'tcx>(
1272+
cx: &CodegenCx<'tcx>,
1273+
span: Span,
1274+
ty: TyAndLayout<'tcx>,
1275+
args: GenericArgsRef<'tcx>,
1276+
err_attr_name: &str,
1277+
) -> Result<(Word, u32), ErrorGuaranteed> {
1278+
let tcx = cx.tcx;
1279+
if let Some(adt) = ty.ty.ty_adt_def()
1280+
&& adt.is_struct()
1281+
{
1282+
let (count, element) = adt
1283+
.non_enum_variant()
1284+
.fields
1285+
.iter()
1286+
.map(|f| f.ty(tcx, args))
1287+
.dedup_with_count()
1288+
.exactly_one()
1289+
.map_err(|_| {
1290+
tcx.dcx().span_err(
1291+
span,
1292+
format!("{err_attr_name} member types must all be the same"),
1293+
)
1294+
})?;
1295+
1296+
let element = cx.layout_of(element);
1297+
let element_word = element.spirv_type(span, cx);
1298+
let count = u32::try_from(count)
1299+
.ok()
1300+
.filter(|count| *count >= 2)
1301+
.ok_or_else(|| {
1302+
tcx.dcx().span_err(
1303+
span,
1304+
format!("{err_attr_name} must have at least 2 members"),
1305+
)
1306+
})?;
1307+
1308+
Ok((element_word, count))
1309+
} else {
1310+
Err(tcx
1311+
.dcx()
1312+
.span_err(span, "#[spirv(vector)] type must be a struct"))
1313+
}
1314+
}

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ pub enum IntrinsicType {
6666
RuntimeArray,
6767
TypedBuffer,
6868
Matrix,
69+
Vector,
6970
}
7071

7172
#[derive(Copy, Clone, Debug, PartialEq, Eq)]

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,12 @@ fn memset_dynamic_scalar(
317317
byte_width: usize,
318318
is_float: bool,
319319
) -> Word {
320-
let composite_type = SpirvType::Vector {
321-
element: SpirvType::Integer(8, false).def(builder.span(), builder),
322-
count: byte_width as u32,
323-
}
320+
let composite_type = SpirvType::simd_vector(
321+
builder,
322+
builder.span(),
323+
SpirvType::Integer(8, false),
324+
byte_width as u32,
325+
)
324326
.def(builder.span(), builder);
325327
let composite = builder
326328
.emit()
@@ -417,7 +419,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
417419
_ => self.fatal(format!("memset on float width {width} not implemented yet")),
418420
},
419421
SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"),
420-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
422+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
421423
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
422424
self.constant_composite(
423425
ty.def(self.span(), self),
@@ -478,7 +480,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
478480
)
479481
.unwrap()
480482
}
481-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
483+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
482484
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
483485
self.emit()
484486
.composite_construct(
@@ -2966,11 +2968,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
29662968
}
29672969

29682970
fn vector_splat(&mut self, num_elts: usize, elt: Self::Value) -> Self::Value {
2969-
let result_type = SpirvType::Vector {
2970-
element: elt.ty,
2971-
count: num_elts as u32,
2972-
}
2973-
.def(self.span(), self);
2971+
let result_type =
2972+
SpirvType::simd_vector(self, self.span(), self.lookup_type(elt.ty), num_elts as u32)
2973+
.def(self.span(), self);
29742974
if self.builder.lookup_const(elt).is_some() {
29752975
self.constant_composite(result_type, iter::repeat_n(elt.def(self), num_elts))
29762976
} else {

crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
113113
let val = self.load_u32(array, dynamic_word_index, constant_word_offset);
114114
self.bitcast(val, result_type)
115115
}
116-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
116+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => self
117117
.load_vec_mat_arr(
118118
original_type,
119119
result_type,
@@ -312,7 +312,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
312312
let value_u32 = self.bitcast(value, u32_ty);
313313
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
314314
}
315-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
315+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => self
316316
.store_vec_mat_arr(
317317
original_type,
318318
value,

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,11 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
370370
self.err("OpTypeStruct in asm! is not supported yet");
371371
return;
372372
}
373-
Op::TypeVector => SpirvType::Vector {
374-
element: inst.operands[0].unwrap_id_ref(),
375-
count: inst.operands[1].unwrap_literal_bit32(),
373+
Op::TypeVector => {
374+
let element = inst.operands[0].unwrap_id_ref();
375+
let count = inst.operands[1].unwrap_literal_bit32();
376+
SpirvType::glam_vector_asm_hack(self.cx, self.span(), element, count)
376377
}
377-
.def(self.span(), self),
378378
Op::TypeMatrix => SpirvType::Matrix {
379379
element: inst.operands[0].unwrap_id_ref(),
380380
count: inst.operands[1].unwrap_literal_bit32(),
@@ -722,6 +722,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
722722
SpirvType::Vector {
723723
element: ty,
724724
count: 4,
725+
..
725726
},
726727
)
727728
| (
@@ -762,11 +763,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
762763
}
763764
.def(DUMMY_SP, cx),
764765

765-
TyPat::Vector4(pat) => SpirvType::Vector {
766-
element: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
767-
count: 4,
768-
}
769-
.def(DUMMY_SP, cx),
766+
TyPat::Vector4(pat) => SpirvType::glam_vector_asm_hack(
767+
cx,
768+
DUMMY_SP,
769+
subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
770+
4,
771+
),
770772

771773
TyPat::SampledImage(pat) => SpirvType::SampledImage {
772774
image_type: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,

crates/rustc_codegen_spirv/src/codegen_cx/constant.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,12 @@ impl ConstCodegenMethods for CodegenCx<'_> {
200200
self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)))
201201
}
202202
fn const_vector(&self, elts: &[Self::Value]) -> Self::Value {
203-
let vector_ty = SpirvType::Vector {
204-
element: elts[0].ty,
205-
count: elts.len() as u32,
206-
}
203+
let vector_ty = SpirvType::simd_vector(
204+
self,
205+
DUMMY_SP,
206+
self.lookup_type(elts[0].ty),
207+
elts.len() as u32,
208+
)
207209
.def(DUMMY_SP, self);
208210
self.constant_composite(vector_ty, elts.iter().map(|elt| elt.def_cx(self)))
209211
}

crates/rustc_codegen_spirv/src/spirv_type.rs

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ pub enum SpirvType<'tcx> {
4545
element: Word,
4646
/// Note: vector count is literal.
4747
count: u32,
48+
size: Size,
49+
align: Align,
4850
},
4951
Matrix {
5052
element: Word,
@@ -131,7 +133,9 @@ impl SpirvType<'_> {
131133
}
132134
result
133135
}
134-
Self::Vector { element, count } => cx.emit_global().type_vector_id(id, element, count),
136+
Self::Vector { element, count, .. } => {
137+
cx.emit_global().type_vector_id(id, element, count)
138+
}
135139
Self::Matrix { element, count } => cx.emit_global().type_matrix_id(id, element, count),
136140
Self::Array { element, count } => {
137141
let result = cx
@@ -280,9 +284,7 @@ impl SpirvType<'_> {
280284
Self::Bool => Size::from_bytes(1),
281285
Self::Integer(width, _) | Self::Float(width) => Size::from_bits(width),
282286
Self::Adt { size, .. } => size?,
283-
Self::Vector { element, count } => {
284-
cx.lookup_type(element).sizeof(cx)? * count.next_power_of_two() as u64
285-
}
287+
Self::Vector { size, .. } => size,
286288
Self::Matrix { element, count } => cx.lookup_type(element).sizeof(cx)? * count as u64,
287289
Self::Array { element, count } => {
288290
cx.lookup_type(element).sizeof(cx)?
@@ -311,13 +313,7 @@ impl SpirvType<'_> {
311313
Self::Bool => Align::from_bytes(1).unwrap(),
312314
Self::Integer(width, _) | Self::Float(width) => Align::from_bits(width as u64).unwrap(),
313315
Self::Adt { align, .. } => align,
314-
// Vectors have size==align
315-
Self::Vector { .. } => Align::from_bytes(
316-
self.sizeof(cx)
317-
.expect("alignof: Vectors must be sized")
318-
.bytes(),
319-
)
320-
.expect("alignof: Vectors must have power-of-2 size"),
316+
Self::Vector { align, .. } => align,
321317
Self::Array { element, .. }
322318
| Self::RuntimeArray { element }
323319
| Self::Matrix { element, .. } => cx.lookup_type(element).alignof(cx),
@@ -382,7 +378,17 @@ impl SpirvType<'_> {
382378
SpirvType::Bool => SpirvType::Bool,
383379
SpirvType::Integer(width, signedness) => SpirvType::Integer(width, signedness),
384380
SpirvType::Float(width) => SpirvType::Float(width),
385-
SpirvType::Vector { element, count } => SpirvType::Vector { element, count },
381+
SpirvType::Vector {
382+
element,
383+
count,
384+
size,
385+
align,
386+
} => SpirvType::Vector {
387+
element,
388+
count,
389+
size,
390+
align,
391+
},
386392
SpirvType::Matrix { element, count } => SpirvType::Matrix { element, count },
387393
SpirvType::Array { element, count } => SpirvType::Array { element, count },
388394
SpirvType::RuntimeArray { element } => SpirvType::RuntimeArray { element },
@@ -435,6 +441,32 @@ impl SpirvType<'_> {
435441
},
436442
}
437443
}
444+
445+
pub fn simd_vector(cx: &CodegenCx<'_>, span: Span, element: SpirvType<'_>, count: u32) -> Self {
446+
Self::Vector {
447+
element: element.def(span, cx),
448+
count,
449+
size: element.sizeof(cx).unwrap() * count as u64,
450+
align: element.alignof(cx),
451+
}
452+
}
453+
454+
/// Now that we can have different `OpTypeVector` types with various sizes or alignments, having a statement in an
455+
/// `asm!` block like `OpTypeVector %f32 3` doesn't correlate to a specific type anymore. It could be Vec3 or Vec3A,
456+
/// nobody knows.
457+
///
458+
/// FIXME(@firestar99) This is a giant hack that only works with base glam types, if any other type with
459+
/// `#[spirv(vector)]` is used, this will generate a mismatched type and fail validation. Also, Vec3A doesn't work.
460+
pub fn glam_vector_asm_hack(cx: &CodegenCx<'_>, span: Span, element: Word, count: u32) -> Word {
461+
let spirv_type = cx.lookup_type(element);
462+
SpirvType::Vector {
463+
element,
464+
count,
465+
size: spirv_type.sizeof(cx).unwrap() * count as u64,
466+
align: spirv_type.alignof(cx),
467+
}
468+
.def(span, cx)
469+
}
438470
}
439471

440472
impl<'a> SpirvType<'a> {
@@ -501,11 +533,18 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
501533
.field("field_names", &field_names)
502534
.finish()
503535
}
504-
SpirvType::Vector { element, count } => f
536+
SpirvType::Vector {
537+
element,
538+
count,
539+
size,
540+
align,
541+
} => f
505542
.debug_struct("Vector")
506543
.field("id", &self.id)
507544
.field("element", &self.cx.debug_type(element))
508545
.field("count", &count)
546+
.field("size", &size)
547+
.field("align", &align)
509548
.finish(),
510549
SpirvType::Matrix { element, count } => f
511550
.debug_struct("Matrix")
@@ -668,7 +707,7 @@ impl SpirvTypePrinter<'_, '_> {
668707
}
669708
f.write_str(" }")
670709
}
671-
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
710+
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
672711
ty(self.cx, stack, f, element)?;
673712
write!(f, "x{count}")
674713
}

0 commit comments

Comments
 (0)