Skip to content

Commit 4a7a6a1

Browse files
committed
macros: move gen_sample_param_permutations macro to separate mod
1 parent 57694c6 commit 4a7a6a1

File tree

2 files changed

+208
-201
lines changed

2 files changed

+208
-201
lines changed

crates/spirv-std/macros/src/lib.rs

Lines changed: 4 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,15 @@
7373

7474
mod debug_printf;
7575
mod image;
76+
mod sample_param_permutations;
7677

7778
#[path = "../../../spirv_attr_version.rs"]
7879
mod spirv_attr_version;
7980

80-
use proc_macro::TokenStream;
81-
use proc_macro2::{Delimiter, Group, Span, TokenTree};
82-
83-
use syn::{ImplItemFn, visit_mut::VisitMut};
84-
8581
use crate::debug_printf::{DebugPrintfInput, debug_printf_inner};
8682
use crate::spirv_attr_version::spirv_attr_with_version;
83+
use proc_macro::TokenStream;
84+
use proc_macro2::{Delimiter, Group, TokenTree};
8785
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
8886

8987
/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
@@ -259,192 +257,6 @@ pub fn debug_printfln(input: TokenStream) -> TokenStream {
259257
debug_printf_inner(input)
260258
}
261259

262-
const SAMPLE_PARAM_COUNT: usize = 4;
263-
const SAMPLE_PARAM_GENERICS: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "G", "S"];
264-
const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "(G,G)", "S"];
265-
const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Grad", "Sample"];
266-
const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "grad", "sample_index"];
267-
const SAMPLE_PARAM_GRAD_INDEX: usize = 2; // Grad requires some special handling because it uses 2 arguments
268-
const SAMPLE_PARAM_EXPLICIT_LOD_MASK: usize = 0b0110; // which params require the use of ExplicitLod rather than ImplicitLod
269-
270-
fn is_grad(i: usize) -> bool {
271-
i == SAMPLE_PARAM_GRAD_INDEX
272-
}
273-
274-
struct SampleImplRewriter(usize, syn::Type);
275-
276-
impl SampleImplRewriter {
277-
pub fn rewrite(mask: usize, f: &syn::ItemImpl) -> syn::ItemImpl {
278-
let mut new_impl = f.clone();
279-
let mut ty_str = String::from("SampleParams<");
280-
281-
// based on the mask, form a `SampleParams` type string and add the generic parameters to the `impl<>` generics
282-
// example type string: `"SampleParams<SomeTy<B>, NoneTy, NoneTy>"`
283-
for i in 0..SAMPLE_PARAM_COUNT {
284-
if mask & (1 << i) != 0 {
285-
new_impl.generics.params.push(syn::GenericParam::Type(
286-
syn::Ident::new(SAMPLE_PARAM_GENERICS[i], Span::call_site()).into(),
287-
));
288-
ty_str.push_str("SomeTy<");
289-
ty_str.push_str(SAMPLE_PARAM_TYPES[i]);
290-
ty_str.push('>');
291-
} else {
292-
ty_str.push_str("NoneTy");
293-
}
294-
ty_str.push(',');
295-
}
296-
ty_str.push('>');
297-
let ty: syn::Type = syn::parse(ty_str.parse().unwrap()).unwrap();
298-
299-
// use the type to insert it into the generic argument of the trait we're implementing
300-
// e.g., `ImageWithMethods<Dummy>` becomes `ImageWithMethods<SampleParams<SomeTy<B>, NoneTy, NoneTy>>`
301-
if let Some(t) = &mut new_impl.trait_
302-
&& let syn::PathArguments::AngleBracketed(a) =
303-
&mut t.1.segments.last_mut().unwrap().arguments
304-
&& let Some(syn::GenericArgument::Type(t)) = a.args.last_mut()
305-
{
306-
*t = ty.clone();
307-
}
308-
309-
// rewrite the implemented functions
310-
SampleImplRewriter(mask, ty).visit_item_impl_mut(&mut new_impl);
311-
new_impl
312-
}
313-
314-
// generates an operands string for use in the assembly, e.g. "Bias %bias Lod %lod", based on the mask
315-
#[allow(clippy::needless_range_loop)]
316-
fn get_operands(&self) -> String {
317-
let mut op = String::new();
318-
for i in 0..SAMPLE_PARAM_COUNT {
319-
if self.0 & (1 << i) != 0 {
320-
if is_grad(i) {
321-
op.push_str("Grad %grad_x %grad_y ");
322-
} else {
323-
op.push_str(SAMPLE_PARAM_OPERANDS[i]);
324-
op.push_str(" %");
325-
op.push_str(SAMPLE_PARAM_NAMES[i]);
326-
op.push(' ');
327-
}
328-
}
329-
}
330-
op
331-
}
332-
333-
// generates list of assembly loads for the data, e.g. "%bias = OpLoad _ {bias}", etc.
334-
#[allow(clippy::needless_range_loop)]
335-
fn add_loads(&self, t: &mut Vec<TokenTree>) {
336-
for i in 0..SAMPLE_PARAM_COUNT {
337-
if self.0 & (1 << i) != 0 {
338-
if is_grad(i) {
339-
t.push(TokenTree::Literal(proc_macro2::Literal::string(
340-
"%grad_x = OpLoad _ {grad_x}",
341-
)));
342-
t.push(TokenTree::Punct(proc_macro2::Punct::new(
343-
',',
344-
proc_macro2::Spacing::Alone,
345-
)));
346-
t.push(TokenTree::Literal(proc_macro2::Literal::string(
347-
"%grad_y = OpLoad _ {grad_y}",
348-
)));
349-
t.push(TokenTree::Punct(proc_macro2::Punct::new(
350-
',',
351-
proc_macro2::Spacing::Alone,
352-
)));
353-
} else {
354-
let s = format!("%{0} = OpLoad _ {{{0}}}", SAMPLE_PARAM_NAMES[i]);
355-
t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
356-
t.push(TokenTree::Punct(proc_macro2::Punct::new(
357-
',',
358-
proc_macro2::Spacing::Alone,
359-
)));
360-
}
361-
}
362-
}
363-
}
364-
365-
// generates list of register specifications, e.g. `bias = in(reg) &params.bias.0, ...` as separate tokens
366-
#[allow(clippy::needless_range_loop)]
367-
fn add_regs(&self, t: &mut Vec<TokenTree>) {
368-
for i in 0..SAMPLE_PARAM_COUNT {
369-
if self.0 & (1 << i) != 0 {
370-
// HACK(eddyb) the extra `{...}` force the pointers to be to
371-
// fresh variables holding value copies, instead of the originals,
372-
// allowing `OpLoad _` inference to pick the appropriate type.
373-
let s = if is_grad(i) {
374-
"grad_x=in(reg) &{params.grad.0.0},grad_y=in(reg) &{params.grad.0.1},"
375-
.to_string()
376-
} else {
377-
format!("{0} = in(reg) &{{params.{0}.0}},", SAMPLE_PARAM_NAMES[i])
378-
};
379-
let ts: proc_macro2::TokenStream = s.parse().unwrap();
380-
t.extend(ts);
381-
}
382-
}
383-
}
384-
}
385-
386-
impl VisitMut for SampleImplRewriter {
387-
fn visit_impl_item_fn_mut(&mut self, item: &mut ImplItemFn) {
388-
// rewrite the last parameter of this method to be of type `SampleParams<...>` we generated earlier
389-
if let Some(syn::FnArg::Typed(p)) = item.sig.inputs.last_mut() {
390-
*p.ty.as_mut() = self.1.clone();
391-
}
392-
syn::visit_mut::visit_impl_item_fn_mut(self, item);
393-
}
394-
395-
fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
396-
if m.path.is_ident("asm") {
397-
// this is where the asm! block is manipulated
398-
let t = m.tokens.clone();
399-
let mut new_t = Vec::new();
400-
let mut altered = false;
401-
402-
for tt in t {
403-
match tt {
404-
TokenTree::Literal(l) => {
405-
if let Ok(l) = syn::parse::<syn::LitStr>(l.to_token_stream().into()) {
406-
// found a string literal
407-
let s = l.value();
408-
if s.contains("$PARAMS") {
409-
altered = true;
410-
// add load instructions before the sampling instruction
411-
self.add_loads(&mut new_t);
412-
// and insert image operands
413-
let s = s.replace("$PARAMS", &self.get_operands());
414-
let lod_type = if self.0 & SAMPLE_PARAM_EXPLICIT_LOD_MASK != 0 {
415-
"ExplicitLod"
416-
} else {
417-
"ImplicitLod "
418-
};
419-
let s = s.replace("$LOD", lod_type);
420-
421-
new_t.push(TokenTree::Literal(proc_macro2::Literal::string(
422-
s.as_str(),
423-
)));
424-
} else {
425-
new_t.push(TokenTree::Literal(l.token()));
426-
}
427-
} else {
428-
new_t.push(TokenTree::Literal(l));
429-
}
430-
}
431-
_ => {
432-
new_t.push(tt);
433-
}
434-
}
435-
}
436-
437-
if altered {
438-
// finally, add register specs
439-
self.add_regs(&mut new_t);
440-
}
441-
442-
// replace all tokens within the asm! block with our new list
443-
m.tokens = new_t.into_iter().collect();
444-
}
445-
}
446-
}
447-
448260
/// Generates permutations of an `ImageWithMethods` implementation containing sampling functions
449261
/// that have asm instruction ending with a placeholder `$PARAMS` operand. The last parameter
450262
/// of each function must be named `params`, its type will be rewritten. Relevant generic
@@ -453,14 +265,5 @@ impl VisitMut for SampleImplRewriter {
453265
#[proc_macro_attribute]
454266
#[doc(hidden)]
455267
pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
456-
let item_impl = syn::parse_macro_input!(item as syn::ItemImpl);
457-
let mut fns = Vec::new();
458-
459-
for m in 1..(1 << SAMPLE_PARAM_COUNT) {
460-
fns.push(SampleImplRewriter::rewrite(m, &item_impl));
461-
}
462-
463-
// uncomment to output generated tokenstream to stdout
464-
//println!("{}", quote! { #(#fns)* }.to_string());
465-
quote! { #(#fns)* }.into()
268+
sample_param_permutations::gen_sample_param_permutations(item)
466269
}

0 commit comments

Comments
 (0)