Skip to content

Commit 706d734

Browse files
committed
tool: simplify spirv proc macro
1 parent 0421550 commit 706d734

File tree

1 file changed

+13
-8
lines changed
  • crates/spirv-std/macros/src

1 file changed

+13
-8
lines changed

crates/spirv-std/macros/src/lib.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ use proc_macro2::{Delimiter, Group, Span, TokenTree};
7878

7979
use syn::{ImplItemFn, visit_mut::VisitMut};
8080

81-
use quote::{ToTokens, quote};
81+
use quote::{ToTokens, TokenStreamExt, quote};
8282
use std::fmt::Write;
8383

8484
/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
@@ -153,26 +153,31 @@ pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
153153
for tt in item {
154154
match tt {
155155
TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
156-
let mut sub_tokens = Vec::new();
156+
let mut group_tokens = proc_macro2::TokenStream::new();
157+
let mut last_token_hashtag = false;
157158
for tt in group.stream() {
159+
let is_token_hashtag =
160+
matches!(&tt, TokenTree::Punct(punct) if punct.as_char() == '#');
158161
match tt {
159162
TokenTree::Group(group)
160163
if group.delimiter() == Delimiter::Bracket
161-
&& matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv")
162-
&& matches!(sub_tokens.last(), Some(TokenTree::Punct(p)) if p.as_char() == '#') =>
164+
&& last_token_hashtag
165+
&& matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
163166
{
164167
// group matches [spirv ...]
165-
let inner = group.stream(); // group stream doesn't include the brackets
166-
sub_tokens.extend(
168+
// group stream doesn't include the brackets
169+
let inner = group.stream();
170+
group_tokens.extend(
167171
quote! { [cfg_attr(target_arch="spirv", rust_gpu::#inner)] },
168172
);
169173
}
170-
_ => sub_tokens.push(tt),
174+
_ => group_tokens.append(tt),
171175
}
176+
last_token_hashtag = is_token_hashtag;
172177
}
173178
tokens.push(TokenTree::from(Group::new(
174179
Delimiter::Parenthesis,
175-
sub_tokens.into_iter().collect(),
180+
group_tokens,
176181
)));
177182
}
178183
_ => tokens.push(tt),

0 commit comments

Comments
 (0)