Skip to content

Commit 61cf26e

Browse files
committed
tool: add spirv_recursive_for_testing to fix compiletest invalid-target
1 parent 28c2691 commit 61cf26e

File tree

1 file changed

+58
-12
lines changed
  • crates/spirv-std/macros/src

1 file changed

+58
-12
lines changed

crates/spirv-std/macros/src/lib.rs

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
mod image;
7575

7676
use proc_macro::TokenStream;
77-
use proc_macro2::{Delimiter, Group, Span, TokenTree};
77+
use proc_macro2::{Delimiter, Group, Ident, Span, TokenTree};
7878

7979
use syn::{ImplItemFn, visit_mut::VisitMut};
8080

@@ -146,11 +146,10 @@ pub fn Image(item: TokenStream) -> TokenStream {
146146
#[proc_macro_attribute]
147147
pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
148148
let spirv = format_ident!("{}", &spirv_attr_with_version());
149-
let mut tokens: Vec<TokenTree> = Vec::new();
150149

151150
// prepend with #[rust_gpu::spirv(..)]
152151
let attr: proc_macro2::TokenStream = attr.into();
153-
tokens.extend(quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] });
152+
let mut tokens = quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] };
154153

155154
let item: proc_macro2::TokenStream = item.into();
156155
for tt in item {
@@ -182,18 +181,65 @@ pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
182181
}
183182
last_token_hashtag = is_token_hashtag;
184183
}
185-
tokens.push(TokenTree::from(Group::new(
186-
Delimiter::Parenthesis,
187-
group_tokens,
188-
)));
184+
let mut out = Group::new(Delimiter::Parenthesis, group_tokens);
185+
out.set_span(group.span());
186+
tokens.append(out);
189187
}
190-
_ => tokens.push(tt),
188+
_ => tokens.append(tt),
191189
}
192190
}
193-
tokens
194-
.into_iter()
195-
.collect::<proc_macro2::TokenStream>()
196-
.into()
191+
tokens.into()
192+
}
193+
194+
/// For testing only! Is not reexported in `spirv-std`, but reachable via
195+
/// `spirv_std::macros::spirv_recursive_for_testing`.
196+
///
197+
/// May be more expensive than plain `spirv`, since we're checking a lot more symbols. So I've opted to
198+
/// have this be a separate macro, instead of modifying the standard `spirv` one.
199+
#[proc_macro_attribute]
200+
pub fn spirv_recursive_for_testing(attr: TokenStream, item: TokenStream) -> TokenStream {
201+
fn recurse(spirv: &Ident, stream: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
202+
let mut last_token_hashtag = false;
203+
stream.into_iter().map(|tt| {
204+
let mut is_token_hashtag = false;
205+
let out = match tt {
206+
TokenTree::Group(group)
207+
if group.delimiter() == Delimiter::Bracket
208+
&& last_token_hashtag
209+
&& matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
210+
{
211+
// group matches [spirv ...]
212+
// group stream doesn't include the brackets
213+
let inner = group
214+
.stream()
215+
.into_iter()
216+
.skip(1)
217+
.collect::<proc_macro2::TokenStream>();
218+
quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] }
219+
},
220+
TokenTree::Group(group) => {
221+
let mut out = Group::new(group.delimiter(), recurse(spirv, group.stream()));
222+
out.set_span(group.span());
223+
TokenTree::Group(out).into()
224+
},
225+
TokenTree::Punct(punct) => {
226+
is_token_hashtag = punct.as_char() == '#';
227+
TokenTree::Punct(punct).into()
228+
}
229+
tt => tt.into(),
230+
};
231+
last_token_hashtag = is_token_hashtag;
232+
out
233+
}).collect()
234+
}
235+
236+
let attr: proc_macro2::TokenStream = attr.into();
237+
let item: proc_macro2::TokenStream = item.into();
238+
239+
// prepend with #[rust_gpu::spirv(..)]
240+
let spirv = format_ident!("{}", &spirv_attr_with_version());
241+
let inner = recurse(&spirv, item);
242+
quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] #inner }.into()
197243
}
198244

199245
/// Marks a function as runnable only on the GPU, and will panic on

0 commit comments

Comments
 (0)