Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,9 @@ impl<'tcx> CodegenCx<'tcx> {
Ok(StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant)
);
let mut assign_location = |var_id: Result<Word, &str>, explicit: Option<u32>| {
let storage_class = storage_class.unwrap();
let location = decoration_locations
.entry(storage_class.unwrap())
.entry(storage_class)
.or_insert_with(|| 0);
if let Some(explicit) = explicit {
*location = explicit;
Expand All @@ -733,7 +734,46 @@ impl<'tcx> CodegenCx<'tcx> {
Decoration::Location,
std::iter::once(Operand::LiteralBit32(*location)),
);
let spirv_type = self.lookup_type(value_spirv_type);
let mut spirv_type = self.lookup_type(value_spirv_type);

// These shader types and storage classes skip the outer array or pointer of the declaration when computing
// the location layout, see bug at https://github.com/Rust-GPU/rust-gpu/issues/500.
//
// The match statment follows the rules at:
// https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#interfaces-iointerfaces-matching
#[allow(clippy::match_same_arms)]
let can_skip_outer_array =
match (execution_model, storage_class, attrs.per_primitive_ext) {
// > if the input is declared in a tessellation control or geometry shader...
(
ExecutionModel::TessellationControl | ExecutionModel::Geometry,
StorageClass::Input,
_,
) => true,
// > if the maintenance4 feature is enabled, they are declared as OpTypeVector variables, and the
// > output has a Component Count value higher than that of the input but the same Component Type
// Irrelevant: This allows a vertex shader to output a Vec4 and a fragment shader to accept a vector
// type with fewer components, like Vec3, Vec2 (or f32?). Which has no influence on locations.
// > if the output is declared in a mesh shader...
(ExecutionModel::MeshEXT | ExecutionModel::MeshNV, StorageClass::Output, _) => {
true
}
// > if the input is decorated with PerVertexKHR, and is declared in a fragment shader...
(ExecutionModel::Fragment, StorageClass::Input, Some(_)) => true,
// > if in any other case...
(_, _, _) => false,
};
if can_skip_outer_array {
spirv_type = match spirv_type {
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element, .. }
| SpirvType::Pointer {
pointee: element, ..
} => self.lookup_type(element),
e => e,
};
}

if let Some(location_size) = spirv_type.location_size(self) {
*location += location_size;
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// build-pass
// compile-flags: -Ctarget-feature=+Geometry
// compile-flags: -C llvm-args=--disassemble-globals
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
// normalize-stderr-test "; .*\n" -> ""
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
// ignore-spv1.0
// ignore-spv1.1
// ignore-spv1.2
// ignore-spv1.3
// ignore-vulkan1.0
// ignore-vulkan1.1

use spirv_std::arch::{emit_vertex, end_primitive};
use spirv_std::glam::*;
use spirv_std::spirv;

pub struct Attr1 {
pub a: Vec4,
pub b: Vec2,
pub c: f32,
}

pub struct Attr2 {
pub d: f32,
}

#[spirv(geometry(input_points = 2, output_line_strip = 2))]
pub fn main(
// #[spirv(descriptor_set = 0, binding = 0, storage_buffer)]
#[spirv(position)] position_in: Vec4,
#[spirv(position)] position_out: &mut Vec4,
// location 0
attr1_in: [f32; 2],
// location 0
attr1_out: &mut f32,
// location 1
attr2_in: [u32; 2],
// location 1
attr2_out: &mut u32,
) {
unsafe {
*attr1_out = attr1_in[0];
*attr2_out = attr2_in[0];
*position_out = position_in + vec4(-0.1, 0.0, 0.0, 0.0);
emit_vertex();

*attr1_out = attr1_in[1];
*attr2_out = attr2_in[1];
*position_out = position_in + vec4(0.1, 0.0, 0.0, 0.0);
emit_vertex();

end_primitive();
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
OpCapability Shader
OpCapability Geometry
OpMemoryModel Logical Simple
OpEntryPoint Geometry %1 "main" %2 %3 %4 %5 %6 %7
OpExecutionMode %1 InputPoints
OpExecutionMode %1 OutputLineStrip
OpName %2 "position_in"
OpName %3 "attr1_in"
OpName %4 "attr2_in"
OpName %5 "attr1_out"
OpName %6 "attr2_out"
OpName %7 "position_out"
OpDecorate %2 BuiltIn Position
OpDecorate %3 Location 0
OpDecorate %12 ArrayStride 4
OpDecorate %4 Location 1
OpDecorate %13 ArrayStride 4
OpDecorate %5 Location 0
OpDecorate %6 Location 1
OpDecorate %7 BuiltIn Position
%14 = OpTypeFloat 32
%15 = OpTypeVector %14 4
%16 = OpTypePointer Input %15
%17 = OpTypePointer Output %15
%18 = OpTypeInt 32 0
%19 = OpConstant %18 2
%20 = OpTypeArray %14 %19
%21 = OpTypePointer Input %20
%22 = OpTypePointer Output %14
%23 = OpTypeArray %18 %19
%24 = OpTypePointer Input %23
%25 = OpTypePointer Output %18
%26 = OpTypeVoid
%27 = OpTypeFunction %26
%2 = OpVariable %16 Input
%3 = OpVariable %21 Input
%12 = OpTypeArray %14 %19
%4 = OpVariable %24 Input
%13 = OpTypeArray %18 %19
%5 = OpVariable %22 Output
%6 = OpVariable %25 Output
%28 = OpConstant %14 3184315597
%29 = OpConstant %14 0
%7 = OpVariable %17 Output
%30 = OpConstant %14 1036831949
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// build-pass
// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader
// compile-flags: -C llvm-args=--disassemble-globals
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
// normalize-stderr-test "; .*\n" -> ""
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
// ignore-spv1.0
// ignore-spv1.1
// ignore-spv1.2
// ignore-spv1.3
// ignore-vulkan1.0
// ignore-vulkan1.1

use spirv_std::arch::set_mesh_outputs_ext;
use spirv_std::glam::{UVec3, Vec4};
use spirv_std::spirv;

#[spirv(mesh_ext(
threads(1),
output_vertices = 9,
output_primitives_ext = 3,
output_triangles_ext
))]
pub fn main(
#[spirv(position)] positions: &mut [Vec4; 9],
#[spirv(primitive_triangle_indices_ext)] indices: &mut [UVec3; 3],
// location 0
out_per_vertex: &mut [u32; 9],
// location 1
out_per_vertex2: &mut [f32; 9],
// location 2
#[spirv(per_primitive_ext)] out_per_primitive: &mut [u32; 3],
// location 3
#[spirv(per_primitive_ext)] out_per_primitive2: &mut [f32; 3],
) {
unsafe {
set_mesh_outputs_ext(9, 3);
}

for i in 0..3 {
positions[i * 3 + 0] = Vec4::new(-0.5, 0.5, 0.0, 1.0);
positions[i * 3 + 1] = Vec4::new(0.5, 0.5, 0.0, 1.0);
positions[i * 3 + 2] = Vec4::new(0.0, -0.5, 0.0, 1.0);
}

for i in 0..9 {
out_per_vertex[i] = i as u32;
out_per_vertex2[i] = i as f32;
}

for i in 0..3 {
indices[i] = UVec3::new(0, 1, 2) + UVec3::splat(i as u32);
out_per_primitive[i] = 42;
out_per_primitive2[i] = 69.;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
OpCapability Shader
OpCapability MeshShadingEXT
OpExtension "SPV_EXT_mesh_shader"
OpMemoryModel Logical Simple
OpEntryPoint MeshEXT %1 "main" %2 %3 %4 %5 %6 %7
OpExecutionMode %1 LocalSize 1 1 1
OpExecutionMode %1 OutputVertices 9
OpExecutionMode %1 OutputPrimitivesNV 3
OpExecutionMode %1 OutputTrianglesNV
OpName %16 "core::ops::Range<usize>"
OpMemberName %16 0 "start"
OpMemberName %16 1 "end"
OpName %2 "positions"
OpName %3 "out_per_vertex"
OpName %4 "out_per_vertex2"
OpName %5 "indices"
OpName %6 "out_per_primitive"
OpName %7 "out_per_primitive2"
OpMemberDecorate %16 0 Offset 0
OpMemberDecorate %16 1 Offset 4
OpDecorate %2 BuiltIn Position
OpDecorate %3 Location 0
OpDecorate %4 Location 1
OpDecorate %5 BuiltIn PrimitiveTriangleIndicesEXT
OpDecorate %6 Location 2
OpDecorate %6 PerPrimitiveNV
OpDecorate %7 Location 3
OpDecorate %7 PerPrimitiveNV
%17 = OpTypeFloat 32
%18 = OpTypeVector %17 4
%19 = OpTypeInt 32 0
%20 = OpConstant %19 9
%21 = OpTypeArray %18 %20
%22 = OpTypePointer Output %21
%23 = OpTypeVector %19 3
%24 = OpConstant %19 3
%25 = OpTypeArray %23 %24
%26 = OpTypePointer Output %25
%27 = OpTypeArray %19 %20
%28 = OpTypePointer Output %27
%29 = OpTypeArray %17 %20
%30 = OpTypePointer Output %29
%31 = OpTypeArray %19 %24
%32 = OpTypePointer Output %31
%33 = OpTypeArray %17 %24
%34 = OpTypePointer Output %33
%35 = OpTypeVoid
%36 = OpTypeFunction %35
%16 = OpTypeStruct %19 %19
%37 = OpConstant %19 0
%38 = OpUndef %16
%39 = OpTypeBool
%40 = OpConstantFalse %39
%41 = OpConstant %19 1
%42 = OpTypeInt 32 1
%43 = OpConstant %42 0
%44 = OpConstant %17 3204448256
%45 = OpConstant %17 1056964608
%46 = OpConstant %17 0
%47 = OpConstant %17 1065353216
%48 = OpTypePointer Output %18
%2 = OpVariable %22 Output
%49 = OpConstant %19 2
%50 = OpTypePointer Output %19
%3 = OpVariable %28 Output
%51 = OpTypePointer Output %17
%4 = OpVariable %30 Output
%52 = OpTypePointer Output %23
%5 = OpVariable %26 Output
%6 = OpVariable %32 Output
%53 = OpConstant %19 42
%7 = OpVariable %34 Output
%54 = OpConstant %17 1116340224