Skip to content

Commit 9ed3407

Browse files
committed
macros: add spirv_vector attribute macro for declaring spirv Vector types
1 parent d49d297 commit 9ed3407

11 files changed

+450
-6
lines changed

crates/spirv-std/macros/src/lib.rs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,11 @@ mod spirv_attr_version;
8181
use crate::debug_printf::{DebugPrintfInput, debug_printf_inner};
8282
use crate::spirv_attr_version::spirv_attr_with_version;
8383
use proc_macro::TokenStream;
84-
use proc_macro2::{Delimiter, Group, TokenTree};
84+
use proc_macro2::{Delimiter, Group, TokenStream as TokenStream2, TokenTree};
8585
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
86+
use syn::punctuated::Punctuated;
87+
use syn::spanned::Spanned;
88+
use syn::{GenericParam, Token};
8689

8790
/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
8891
/// `spirv_std::image::Image<...>` type.
@@ -267,3 +270,51 @@ pub fn debug_printfln(input: TokenStream) -> TokenStream {
267270
pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
268271
sample_param_permutations::gen_sample_param_permutations(item)
269272
}
273+
274+
#[proc_macro_attribute]
275+
pub fn spirv_vector(attr: TokenStream, item: TokenStream) -> TokenStream {
276+
spirv_vector_impl(attr.into(), item.into())
277+
.unwrap_or_else(syn::Error::into_compile_error)
278+
.into()
279+
}
280+
281+
fn spirv_vector_impl(attr: TokenStream2, item: TokenStream2) -> syn::Result<TokenStream2> {
282+
// Whenever we'll properly resolve the crate symbol, replace this.
283+
let spirv_std = quote!(spirv_std);
284+
285+
// Defer all validation to our codegen backend. Rather than erroring here, emit garbage.
286+
let item = syn::parse2::<syn::ItemStruct>(item)?;
287+
let ident = &item.ident;
288+
let gens = &item.generics.params;
289+
let gen_refs = &item
290+
.generics
291+
.params
292+
.iter()
293+
.map(|p| match p {
294+
GenericParam::Lifetime(p) => p.lifetime.to_token_stream(),
295+
GenericParam::Type(p) => p.ident.to_token_stream(),
296+
GenericParam::Const(p) => p.ident.to_token_stream(),
297+
})
298+
.collect::<Punctuated<_, Token![,]>>();
299+
let where_clause = &item.generics.where_clause;
300+
let element = item
301+
.fields
302+
.iter()
303+
.next()
304+
.ok_or_else(|| syn::Error::new(item.span(), "Vector ZST not allowed"))?
305+
.ty
306+
.to_token_stream();
307+
let count = item.fields.len();
308+
309+
Ok(quote! {
310+
#[cfg_attr(target_arch = "spirv", rust_gpu::vector::v1(#attr))]
311+
#item
312+
313+
unsafe impl<#gens> #spirv_std::vector::VectorOrScalar for #ident<#gen_refs> #where_clause {
314+
type Scalar = #element;
315+
const DIM: core::num::NonZeroUsize = core::num::NonZeroUsize::new(#count).unwrap();
316+
}
317+
318+
unsafe impl<#gens> #spirv_std::vector::Vector<#element, #count> for #ident<#gen_refs> #where_clause {}
319+
})
320+
}

crates/spirv-std/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
/// Public re-export of the `spirv-std-macros` crate.
8888
#[macro_use]
8989
pub extern crate spirv_std_macros as macros;
90-
pub use macros::spirv;
90+
pub use macros::{spirv, spirv_vector};
9191

9292
pub mod arch;
9393
pub mod byte_addressable_buffer;

crates/spirv-std/src/vector.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ pub unsafe trait VectorOrScalar: Copy + Default + Send + Sync + 'static {
1818

1919
/// Abstract trait representing a SPIR-V vector type.
2020
///
21-
/// To implement this trait, your struct must be marked with:
21+
/// To derive this trait, mark your struct with:
2222
/// ```no_run
23-
/// #[cfg_attr(target_arch = "spirv", rust_gpu::vector::v1)]
23+
/// #[spirv_std::spirv_vector]
24+
/// # #[derive(Copy, Clone, Default)]
2425
/// # struct Bla(f32, f32);
2526
/// ```
2627
///
@@ -43,8 +44,8 @@ pub unsafe trait VectorOrScalar: Copy + Default + Send + Sync + 'static {
4344
///
4445
/// # Example
4546
/// ```no_run
47+
/// #[spirv_std::spirv_vector]
4648
/// #[derive(Copy, Clone, Default)]
47-
/// #[cfg_attr(target_arch = "spirv", rust_gpu::vector::v1)]
4849
/// struct MyColor {
4950
/// r: f32,
5051
/// b: f32,
@@ -55,7 +56,8 @@ pub unsafe trait VectorOrScalar: Copy + Default + Send + Sync + 'static {
5556
///
5657
/// # Safety
5758
/// Must only be implemented on types that the spirv codegen emits as valid `OpTypeVector`. This includes all structs
58-
/// marked with `#[rust_gpu::vector::v1]`, like [`glam`]'s non-SIMD "scalar" vector types.
59+
/// marked with `#[rust_gpu::vector::v1]`, which `#[spirv_std::spirv_vector]` expands into or [`glam`]'s non-SIMD
60+
/// "scalar" vector types use directly.
5961
pub unsafe trait Vector<T: Scalar, const N: usize>: VectorOrScalar<Scalar = T> {}
6062

6163
macro_rules! impl_vector {
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// build-fail
2+
3+
use core::num::NonZeroU32;
4+
use spirv_std::glam::Vec2;
5+
use spirv_std::spirv;
6+
7+
#[spirv_std::spirv_vector]
8+
#[derive(Copy, Clone, Default)]
9+
pub struct FewerFields {
10+
_v: f32,
11+
}
12+
13+
#[spirv_std::spirv_vector]
14+
#[derive(Copy, Clone, Default)]
15+
pub struct TooManyFields {
16+
_x: f32,
17+
_y: f32,
18+
_z: f32,
19+
_w: f32,
20+
_v: f32,
21+
}
22+
23+
// wrong member types fails too early
24+
25+
#[spirv_std::spirv_vector]
26+
#[derive(Copy, Clone, Default)]
27+
pub struct DifferentTypes {
28+
_x: f32,
29+
_y: u32,
30+
}
31+
32+
#[spirv(fragment)]
33+
pub fn entry(_: FewerFields, _: TooManyFields, #[spirv(flat)] _: DifferentTypes) {}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
error: `#[spirv(vector)]` must have 2, 3 or 4 members
2+
--> $DIR/invalid_vector_type_macro.rs:9:1
3+
|
4+
9 | pub struct FewerFields {
5+
| ^^^^^^^^^^^^^^^^^^^^^^
6+
7+
error: `#[spirv(vector)]` must have 2, 3 or 4 members
8+
--> $DIR/invalid_vector_type_macro.rs:15:1
9+
|
10+
15 | pub struct TooManyFields {
11+
| ^^^^^^^^^^^^^^^^^^^^^^^^
12+
13+
error: `#[spirv(vector)]` member types must all be the same
14+
--> $DIR/invalid_vector_type_macro.rs:27:1
15+
|
16+
27 | pub struct DifferentTypes {
17+
| ^^^^^^^^^^^^^^^^^^^^^^^^^
18+
19+
error: aborting due to 3 previous errors
20+
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// build-fail
2+
// normalize-stderr-test "\S*/crates/spirv-std/src/" -> "$$SPIRV_STD_SRC/"
3+
4+
use core::num::NonZeroU32;
5+
use spirv_std::glam::Vec2;
6+
use spirv_std::spirv;
7+
8+
#[spirv_std::spirv_vector]
9+
#[derive(Copy, Clone, Default)]
10+
pub struct ZstVector;
11+
12+
#[spirv_std::spirv_vector]
13+
#[derive(Copy, Clone, Default)]
14+
pub struct NotVectorField {
15+
_x: Vec2,
16+
_y: Vec2,
17+
}
18+
19+
#[spirv_std::spirv_vector]
20+
#[derive(Copy, Clone)]
21+
pub struct NotVectorField2 {
22+
_x: NonZeroU32,
23+
_y: NonZeroU32,
24+
}
25+
26+
impl Default for NotVectorField2 {
27+
fn default() -> Self {
28+
Self {
29+
_x: NonZeroU32::new(1).unwrap(),
30+
_y: NonZeroU32::new(1).unwrap(),
31+
}
32+
}
33+
}
34+
35+
#[spirv(fragment)]
36+
pub fn entry(
37+
// workaround to ZST loading
38+
#[spirv(storage_class, descriptor_set = 0, binding = 0)] _: &(ZstVector, i32),
39+
_: NotVectorField,
40+
#[spirv(flat)] _: NotVectorField2,
41+
) {
42+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
error: Vector ZST not allowed
2+
--> $DIR/invalid_vector_type_macro2.rs:9:1
3+
|
4+
9 | / #[derive(Copy, Clone, Default)]
5+
10 | | pub struct ZstVector;
6+
| |_____________________^
7+
8+
error[E0412]: cannot find type `ZstVector` in this scope
9+
--> $DIR/invalid_vector_type_macro2.rs:38:67
10+
|
11+
38 | #[spirv(storage_class, descriptor_set = 0, binding = 0)] _: &(ZstVector, i32),
12+
| ^^^^^^^^^ not found in this scope
13+
14+
error: unknown argument to spirv attribute
15+
--> $DIR/invalid_vector_type_macro2.rs:38:13
16+
|
17+
38 | #[spirv(storage_class, descriptor_set = 0, binding = 0)] _: &(ZstVector, i32),
18+
| ^^^^^^^^^^^^^
19+
20+
error[E0277]: the trait bound `Vec2: Scalar` is not satisfied
21+
--> $DIR/invalid_vector_type_macro2.rs:15:9
22+
|
23+
15 | _x: Vec2,
24+
| ^^^^ the trait `Scalar` is not implemented for `Vec2`
25+
|
26+
= help: the following other types implement trait `Scalar`:
27+
bool
28+
f32
29+
f64
30+
i16
31+
i32
32+
i64
33+
i8
34+
u16
35+
and 3 others
36+
note: required by a bound in `spirv_std::vector::VectorOrScalar::Scalar`
37+
--> $SPIRV_STD_SRC/vector.rs:13:18
38+
|
39+
13 | type Scalar: Scalar;
40+
| ^^^^^^ required by this bound in `VectorOrScalar::Scalar`
41+
42+
error[E0277]: the trait bound `Vec2: Scalar` is not satisfied
43+
--> $DIR/invalid_vector_type_macro2.rs:12:1
44+
|
45+
12 | #[spirv_std::spirv_vector]
46+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Scalar` is not implemented for `Vec2`
47+
|
48+
= help: the following other types implement trait `Scalar`:
49+
bool
50+
f32
51+
f64
52+
i16
53+
i32
54+
i64
55+
i8
56+
u16
57+
and 3 others
58+
note: required by a bound in `Vector`
59+
--> $SPIRV_STD_SRC/vector.rs:61:28
60+
|
61+
61 | pub unsafe trait Vector<T: Scalar, const N: usize>: VectorOrScalar<Scalar = T> {}
62+
| ^^^^^^ required by this bound in `Vector`
63+
= note: this error originates in the attribute macro `spirv_std::spirv_vector` (in Nightly builds, run with -Z macro-backtrace for more info)
64+
65+
error[E0277]: the trait bound `NonZero<u32>: Scalar` is not satisfied
66+
--> $DIR/invalid_vector_type_macro2.rs:22:9
67+
|
68+
22 | _x: NonZeroU32,
69+
| ^^^^^^^^^^ the trait `Scalar` is not implemented for `NonZero<u32>`
70+
|
71+
= help: the following other types implement trait `Scalar`:
72+
bool
73+
f32
74+
f64
75+
i16
76+
i32
77+
i64
78+
i8
79+
u16
80+
and 3 others
81+
note: required by a bound in `spirv_std::vector::VectorOrScalar::Scalar`
82+
--> $SPIRV_STD_SRC/vector.rs:13:18
83+
|
84+
13 | type Scalar: Scalar;
85+
| ^^^^^^ required by this bound in `VectorOrScalar::Scalar`
86+
87+
error[E0277]: the trait bound `NonZero<u32>: Scalar` is not satisfied
88+
--> $DIR/invalid_vector_type_macro2.rs:19:1
89+
|
90+
19 | #[spirv_std::spirv_vector]
91+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Scalar` is not implemented for `NonZero<u32>`
92+
|
93+
= help: the following other types implement trait `Scalar`:
94+
bool
95+
f32
96+
f64
97+
i16
98+
i32
99+
i64
100+
i8
101+
u16
102+
and 3 others
103+
note: required by a bound in `Vector`
104+
--> $SPIRV_STD_SRC/vector.rs:61:28
105+
|
106+
61 | pub unsafe trait Vector<T: Scalar, const N: usize>: VectorOrScalar<Scalar = T> {}
107+
| ^^^^^^ required by this bound in `Vector`
108+
= note: this error originates in the attribute macro `spirv_std::spirv_vector` (in Nightly builds, run with -Z macro-backtrace for more info)
109+
110+
error: aborting due to 7 previous errors
111+
112+
Some errors have detailed explanations: E0277, E0412.
113+
For more information about an error, try `rustc --explain E0277`.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// build-pass
2+
// only-vulkan1.2
3+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformShuffleRelative,+ext:SPV_KHR_vulkan_memory_model
4+
// compile-flags: -C llvm-args=--disassemble
5+
// normalize-stderr-test "OpSource .*\n" -> ""
6+
// normalize-stderr-test "OpLine .*\n" -> ""
7+
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
8+
9+
use spirv_std::arch::subgroup_shuffle_up;
10+
use spirv_std::glam::Vec3;
11+
use spirv_std::spirv;
12+
13+
#[spirv_std::spirv_vector]
14+
#[derive(Copy, Clone, Default)]
15+
pub struct MyColor {
16+
pub r: f32,
17+
pub g: f32,
18+
pub b: f32,
19+
}
20+
21+
#[spirv(compute(threads(32)))]
22+
pub fn main(
23+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &Vec3,
24+
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut MyColor,
25+
) {
26+
let color = MyColor {
27+
r: input.x,
28+
g: input.y,
29+
b: input.z,
30+
};
31+
// any function that accepts a `VectorOrScalar` would do
32+
*output = subgroup_shuffle_up(color, 5);
33+
}

0 commit comments

Comments
 (0)