Skip to content

Commit 3daedd4

Browse files
committed
macros: move gen_sample_param_permutations macro to separate mod
1 parent d24e17a commit 3daedd4

File tree

2 files changed

+207
-198
lines changed

2 files changed

+207
-198
lines changed

crates/spirv-std/macros/src/lib.rs

Lines changed: 3 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@
7373

7474
mod debug_printf;
7575
mod image;
76+
mod sample_param_permutations;
7677

7778
use crate::debug_printf::{DebugPrintfInput, debug_printf_inner};
7879
use proc_macro::TokenStream;
79-
use proc_macro2::{Delimiter, Group, Ident, Span, TokenTree};
80+
use proc_macro2::{Delimiter, Group, Ident, TokenTree};
8081
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
8182
use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
82-
use syn::{ImplItemFn, visit_mut::VisitMut};
8383

8484
/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
8585
/// `spirv_std::image::Image<...>` type.
@@ -299,192 +299,6 @@ pub fn debug_printfln(input: TokenStream) -> TokenStream {
299299
debug_printf_inner(input)
300300
}
301301

302-
const SAMPLE_PARAM_COUNT: usize = 4;
303-
const SAMPLE_PARAM_GENERICS: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "G", "S"];
304-
const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "(G,G)", "S"];
305-
const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Grad", "Sample"];
306-
const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "grad", "sample_index"];
307-
const SAMPLE_PARAM_GRAD_INDEX: usize = 2; // Grad requires some special handling because it uses 2 arguments
308-
const SAMPLE_PARAM_EXPLICIT_LOD_MASK: usize = 0b0110; // which params require the use of ExplicitLod rather than ImplicitLod
309-
310-
fn is_grad(i: usize) -> bool {
311-
i == SAMPLE_PARAM_GRAD_INDEX
312-
}
313-
314-
struct SampleImplRewriter(usize, syn::Type);
315-
316-
impl SampleImplRewriter {
317-
pub fn rewrite(mask: usize, f: &syn::ItemImpl) -> syn::ItemImpl {
318-
let mut new_impl = f.clone();
319-
let mut ty_str = String::from("SampleParams<");
320-
321-
// based on the mask, form a `SampleParams` type string and add the generic parameters to the `impl<>` generics
322-
// example type string: `"SampleParams<SomeTy<B>, NoneTy, NoneTy>"`
323-
for i in 0..SAMPLE_PARAM_COUNT {
324-
if mask & (1 << i) != 0 {
325-
new_impl.generics.params.push(syn::GenericParam::Type(
326-
syn::Ident::new(SAMPLE_PARAM_GENERICS[i], Span::call_site()).into(),
327-
));
328-
ty_str.push_str("SomeTy<");
329-
ty_str.push_str(SAMPLE_PARAM_TYPES[i]);
330-
ty_str.push('>');
331-
} else {
332-
ty_str.push_str("NoneTy");
333-
}
334-
ty_str.push(',');
335-
}
336-
ty_str.push('>');
337-
let ty: syn::Type = syn::parse(ty_str.parse().unwrap()).unwrap();
338-
339-
// use the type to insert it into the generic argument of the trait we're implementing
340-
// e.g., `ImageWithMethods<Dummy>` becomes `ImageWithMethods<SampleParams<SomeTy<B>, NoneTy, NoneTy>>`
341-
if let Some(t) = &mut new_impl.trait_
342-
&& let syn::PathArguments::AngleBracketed(a) =
343-
&mut t.1.segments.last_mut().unwrap().arguments
344-
&& let Some(syn::GenericArgument::Type(t)) = a.args.last_mut()
345-
{
346-
*t = ty.clone();
347-
}
348-
349-
// rewrite the implemented functions
350-
SampleImplRewriter(mask, ty).visit_item_impl_mut(&mut new_impl);
351-
new_impl
352-
}
353-
354-
// generates an operands string for use in the assembly, e.g. "Bias %bias Lod %lod", based on the mask
355-
#[allow(clippy::needless_range_loop)]
356-
fn get_operands(&self) -> String {
357-
let mut op = String::new();
358-
for i in 0..SAMPLE_PARAM_COUNT {
359-
if self.0 & (1 << i) != 0 {
360-
if is_grad(i) {
361-
op.push_str("Grad %grad_x %grad_y ");
362-
} else {
363-
op.push_str(SAMPLE_PARAM_OPERANDS[i]);
364-
op.push_str(" %");
365-
op.push_str(SAMPLE_PARAM_NAMES[i]);
366-
op.push(' ');
367-
}
368-
}
369-
}
370-
op
371-
}
372-
373-
// generates list of assembly loads for the data, e.g. "%bias = OpLoad _ {bias}", etc.
374-
#[allow(clippy::needless_range_loop)]
375-
fn add_loads(&self, t: &mut Vec<TokenTree>) {
376-
for i in 0..SAMPLE_PARAM_COUNT {
377-
if self.0 & (1 << i) != 0 {
378-
if is_grad(i) {
379-
t.push(TokenTree::Literal(proc_macro2::Literal::string(
380-
"%grad_x = OpLoad _ {grad_x}",
381-
)));
382-
t.push(TokenTree::Punct(proc_macro2::Punct::new(
383-
',',
384-
proc_macro2::Spacing::Alone,
385-
)));
386-
t.push(TokenTree::Literal(proc_macro2::Literal::string(
387-
"%grad_y = OpLoad _ {grad_y}",
388-
)));
389-
t.push(TokenTree::Punct(proc_macro2::Punct::new(
390-
',',
391-
proc_macro2::Spacing::Alone,
392-
)));
393-
} else {
394-
let s = format!("%{0} = OpLoad _ {{{0}}}", SAMPLE_PARAM_NAMES[i]);
395-
t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
396-
t.push(TokenTree::Punct(proc_macro2::Punct::new(
397-
',',
398-
proc_macro2::Spacing::Alone,
399-
)));
400-
}
401-
}
402-
}
403-
}
404-
405-
// generates list of register specifications, e.g. `bias = in(reg) &params.bias.0, ...` as separate tokens
406-
#[allow(clippy::needless_range_loop)]
407-
fn add_regs(&self, t: &mut Vec<TokenTree>) {
408-
for i in 0..SAMPLE_PARAM_COUNT {
409-
if self.0 & (1 << i) != 0 {
410-
// HACK(eddyb) the extra `{...}` force the pointers to be to
411-
// fresh variables holding value copies, instead of the originals,
412-
// allowing `OpLoad _` inference to pick the appropriate type.
413-
let s = if is_grad(i) {
414-
"grad_x=in(reg) &{params.grad.0.0},grad_y=in(reg) &{params.grad.0.1},"
415-
.to_string()
416-
} else {
417-
format!("{0} = in(reg) &{{params.{0}.0}},", SAMPLE_PARAM_NAMES[i])
418-
};
419-
let ts: proc_macro2::TokenStream = s.parse().unwrap();
420-
t.extend(ts);
421-
}
422-
}
423-
}
424-
}
425-
426-
impl VisitMut for SampleImplRewriter {
427-
fn visit_impl_item_fn_mut(&mut self, item: &mut ImplItemFn) {
428-
// rewrite the last parameter of this method to be of type `SampleParams<...>` we generated earlier
429-
if let Some(syn::FnArg::Typed(p)) = item.sig.inputs.last_mut() {
430-
*p.ty.as_mut() = self.1.clone();
431-
}
432-
syn::visit_mut::visit_impl_item_fn_mut(self, item);
433-
}
434-
435-
fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
436-
if m.path.is_ident("asm") {
437-
// this is where the asm! block is manipulated
438-
let t = m.tokens.clone();
439-
let mut new_t = Vec::new();
440-
let mut altered = false;
441-
442-
for tt in t {
443-
match tt {
444-
TokenTree::Literal(l) => {
445-
if let Ok(l) = syn::parse::<syn::LitStr>(l.to_token_stream().into()) {
446-
// found a string literal
447-
let s = l.value();
448-
if s.contains("$PARAMS") {
449-
altered = true;
450-
// add load instructions before the sampling instruction
451-
self.add_loads(&mut new_t);
452-
// and insert image operands
453-
let s = s.replace("$PARAMS", &self.get_operands());
454-
let lod_type = if self.0 & SAMPLE_PARAM_EXPLICIT_LOD_MASK != 0 {
455-
"ExplicitLod"
456-
} else {
457-
"ImplicitLod "
458-
};
459-
let s = s.replace("$LOD", lod_type);
460-
461-
new_t.push(TokenTree::Literal(proc_macro2::Literal::string(
462-
s.as_str(),
463-
)));
464-
} else {
465-
new_t.push(TokenTree::Literal(l.token()));
466-
}
467-
} else {
468-
new_t.push(TokenTree::Literal(l));
469-
}
470-
}
471-
_ => {
472-
new_t.push(tt);
473-
}
474-
}
475-
}
476-
477-
if altered {
478-
// finally, add register specs
479-
self.add_regs(&mut new_t);
480-
}
481-
482-
// replace all tokens within the asm! block with our new list
483-
m.tokens = new_t.into_iter().collect();
484-
}
485-
}
486-
}
487-
488302
/// Generates permutations of an `ImageWithMethods` implementation containing sampling functions
489303
/// that have asm instruction ending with a placeholder `$PARAMS` operand. The last parameter
490304
/// of each function must be named `params`, its type will be rewritten. Relevant generic
@@ -493,14 +307,5 @@ impl VisitMut for SampleImplRewriter {
493307
#[proc_macro_attribute]
494308
#[doc(hidden)]
495309
pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
496-
let item_impl = syn::parse_macro_input!(item as syn::ItemImpl);
497-
let mut fns = Vec::new();
498-
499-
for m in 1..(1 << SAMPLE_PARAM_COUNT) {
500-
fns.push(SampleImplRewriter::rewrite(m, &item_impl));
501-
}
502-
503-
// uncomment to output generated tokenstream to stdout
504-
//println!("{}", quote! { #(#fns)* }.to_string());
505-
quote! { #(#fns)* }.into()
310+
sample_param_permutations::gen_sample_param_permutations(item)
506311
}

0 commit comments

Comments
 (0)