|
| 1 | +use proc_macro::TokenStream; |
| 2 | +use proc_macro2::Span; |
| 3 | +use std::fmt::Write; |
| 4 | + |
| 5 | +pub struct DebugPrintfInput { |
| 6 | + pub span: Span, |
| 7 | + pub format_string: String, |
| 8 | + pub variables: Vec<syn::Expr>, |
| 9 | +} |
| 10 | + |
| 11 | +impl syn::parse::Parse for DebugPrintfInput { |
| 12 | + fn parse(input: syn::parse::ParseStream<'_>) -> syn::parse::Result<Self> { |
| 13 | + let span = input.span(); |
| 14 | + |
| 15 | + if input.is_empty() { |
| 16 | + return Ok(Self { |
| 17 | + span, |
| 18 | + format_string: Default::default(), |
| 19 | + variables: Default::default(), |
| 20 | + }); |
| 21 | + } |
| 22 | + |
| 23 | + let format_string = input.parse::<syn::LitStr>()?; |
| 24 | + if !input.is_empty() { |
| 25 | + input.parse::<syn::token::Comma>()?; |
| 26 | + } |
| 27 | + let variables = |
| 28 | + syn::punctuated::Punctuated::<syn::Expr, syn::token::Comma>::parse_terminated(input)?; |
| 29 | + |
| 30 | + Ok(Self { |
| 31 | + span, |
| 32 | + format_string: format_string.value(), |
| 33 | + variables: variables.into_iter().collect(), |
| 34 | + }) |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +fn parsing_error(message: &str, span: Span) -> TokenStream { |
| 39 | + syn::Error::new(span, message).to_compile_error().into() |
| 40 | +} |
| 41 | + |
| 42 | +enum FormatType { |
| 43 | + Scalar { |
| 44 | + ty: proc_macro2::TokenStream, |
| 45 | + }, |
| 46 | + Vector { |
| 47 | + ty: proc_macro2::TokenStream, |
| 48 | + width: usize, |
| 49 | + }, |
| 50 | +} |
| 51 | + |
| 52 | +pub fn debug_printf_inner(input: DebugPrintfInput) -> TokenStream { |
| 53 | + let DebugPrintfInput { |
| 54 | + format_string, |
| 55 | + variables, |
| 56 | + span, |
| 57 | + } = input; |
| 58 | + |
| 59 | + fn map_specifier_to_type( |
| 60 | + specifier: char, |
| 61 | + chars: &mut std::str::Chars<'_>, |
| 62 | + ) -> Option<proc_macro2::TokenStream> { |
| 63 | + let mut peekable = chars.peekable(); |
| 64 | + |
| 65 | + Some(match specifier { |
| 66 | + 'd' | 'i' => quote::quote! { i32 }, |
| 67 | + 'o' | 'x' | 'X' => quote::quote! { u32 }, |
| 68 | + 'a' | 'A' | 'e' | 'E' | 'f' | 'F' | 'g' | 'G' => quote::quote! { f32 }, |
| 69 | + 'u' => { |
| 70 | + if matches!(peekable.peek(), Some('l')) { |
| 71 | + chars.next(); |
| 72 | + quote::quote! { u64 } |
| 73 | + } else { |
| 74 | + quote::quote! { u32 } |
| 75 | + } |
| 76 | + } |
| 77 | + 'l' => { |
| 78 | + if matches!(peekable.peek(), Some('u' | 'x')) { |
| 79 | + chars.next(); |
| 80 | + quote::quote! { u64 } |
| 81 | + } else { |
| 82 | + return None; |
| 83 | + } |
| 84 | + } |
| 85 | + _ => return None, |
| 86 | + }) |
| 87 | + } |
| 88 | + |
| 89 | + let mut chars = format_string.chars(); |
| 90 | + let mut format_arguments = Vec::new(); |
| 91 | + |
| 92 | + while let Some(mut ch) = chars.next() { |
| 93 | + if ch == '%' { |
| 94 | + ch = match chars.next() { |
| 95 | + Some('%') => continue, |
| 96 | + None => return parsing_error("Unterminated format specifier", span), |
| 97 | + Some(ch) => ch, |
| 98 | + }; |
| 99 | + |
| 100 | + let mut has_precision = false; |
| 101 | + |
| 102 | + while ch.is_ascii_digit() { |
| 103 | + ch = match chars.next() { |
| 104 | + Some(ch) => ch, |
| 105 | + None => { |
| 106 | + return parsing_error( |
| 107 | + "Unterminated format specifier: missing type after precision", |
| 108 | + span, |
| 109 | + ); |
| 110 | + } |
| 111 | + }; |
| 112 | + |
| 113 | + has_precision = true; |
| 114 | + } |
| 115 | + |
| 116 | + if has_precision && ch == '.' { |
| 117 | + ch = match chars.next() { |
| 118 | + Some(ch) => ch, |
| 119 | + None => { |
| 120 | + return parsing_error( |
| 121 | + "Unterminated format specifier: missing type after decimal point", |
| 122 | + span, |
| 123 | + ); |
| 124 | + } |
| 125 | + }; |
| 126 | + |
| 127 | + while ch.is_ascii_digit() { |
| 128 | + ch = match chars.next() { |
| 129 | + Some(ch) => ch, |
| 130 | + None => { |
| 131 | + return parsing_error( |
| 132 | + "Unterminated format specifier: missing type after fraction precision", |
| 133 | + span, |
| 134 | + ); |
| 135 | + } |
| 136 | + }; |
| 137 | + } |
| 138 | + } |
| 139 | + |
| 140 | + if ch == 'v' { |
| 141 | + let width = match chars.next() { |
| 142 | + Some('2') => 2, |
| 143 | + Some('3') => 3, |
| 144 | + Some('4') => 4, |
| 145 | + Some(ch) => { |
| 146 | + return parsing_error(&format!("Invalid width for vector: {ch}"), span); |
| 147 | + } |
| 148 | + None => return parsing_error("Missing vector dimensions specifier", span), |
| 149 | + }; |
| 150 | + |
| 151 | + ch = match chars.next() { |
| 152 | + Some(ch) => ch, |
| 153 | + None => return parsing_error("Missing vector type specifier", span), |
| 154 | + }; |
| 155 | + |
| 156 | + let ty = match map_specifier_to_type(ch, &mut chars) { |
| 157 | + Some(ty) => ty, |
| 158 | + _ => { |
| 159 | + return parsing_error( |
| 160 | + &format!("Unrecognised vector type specifier: '{ch}'"), |
| 161 | + span, |
| 162 | + ); |
| 163 | + } |
| 164 | + }; |
| 165 | + |
| 166 | + format_arguments.push(FormatType::Vector { ty, width }); |
| 167 | + } else { |
| 168 | + let ty = match map_specifier_to_type(ch, &mut chars) { |
| 169 | + Some(ty) => ty, |
| 170 | + _ => { |
| 171 | + return parsing_error( |
| 172 | + &format!("Unrecognised format specifier: '{ch}'"), |
| 173 | + span, |
| 174 | + ); |
| 175 | + } |
| 176 | + }; |
| 177 | + |
| 178 | + format_arguments.push(FormatType::Scalar { ty }); |
| 179 | + } |
| 180 | + } |
| 181 | + } |
| 182 | + |
| 183 | + if format_arguments.len() != variables.len() { |
| 184 | + return syn::Error::new( |
| 185 | + span, |
| 186 | + format!( |
| 187 | + "{} % arguments were found, but {} variables were given", |
| 188 | + format_arguments.len(), |
| 189 | + variables.len() |
| 190 | + ), |
| 191 | + ) |
| 192 | + .to_compile_error() |
| 193 | + .into(); |
| 194 | + } |
| 195 | + |
| 196 | + let mut variable_idents = String::new(); |
| 197 | + let mut input_registers = Vec::new(); |
| 198 | + let mut op_loads = Vec::new(); |
| 199 | + |
| 200 | + for (i, (variable, format_argument)) in variables.into_iter().zip(format_arguments).enumerate() |
| 201 | + { |
| 202 | + let ident = quote::format_ident!("_{}", i); |
| 203 | + |
| 204 | + let _ = write!(variable_idents, "%{ident} "); |
| 205 | + |
| 206 | + let assert_fn = match format_argument { |
| 207 | + FormatType::Scalar { ty } => { |
| 208 | + quote::quote! { spirv_std::debug_printf_assert_is_type::<#ty> } |
| 209 | + } |
| 210 | + FormatType::Vector { ty, width } => { |
| 211 | + quote::quote! { spirv_std::debug_printf_assert_is_vector::<#ty, _, #width> } |
| 212 | + } |
| 213 | + }; |
| 214 | + |
| 215 | + input_registers.push(quote::quote! { |
| 216 | + #ident = in(reg) &#assert_fn(#variable), |
| 217 | + }); |
| 218 | + |
| 219 | + let op_load = format!("%{ident} = OpLoad _ {{{ident}}}"); |
| 220 | + |
| 221 | + op_loads.push(quote::quote! { |
| 222 | + #op_load, |
| 223 | + }); |
| 224 | + } |
| 225 | + |
| 226 | + let input_registers = input_registers |
| 227 | + .into_iter() |
| 228 | + .collect::<proc_macro2::TokenStream>(); |
| 229 | + let op_loads = op_loads.into_iter().collect::<proc_macro2::TokenStream>(); |
| 230 | + // Escapes the '{' and '}' characters in the format string. |
| 231 | + // Since the `asm!` macro expects '{' '}' to surround its arguments, we have to use '{{' and '}}' instead. |
| 232 | + // The `asm!` macro will then later turn them back into '{' and '}'. |
| 233 | + let format_string = format_string.replace('{', "{{").replace('}', "}}"); |
| 234 | + |
| 235 | + let op_string = format!("%string = OpString {format_string:?}"); |
| 236 | + |
| 237 | + let output = quote::quote! { |
| 238 | + ::core::arch::asm!( |
| 239 | + "%void = OpTypeVoid", |
| 240 | + #op_string, |
| 241 | + "%debug_printf = OpExtInstImport \"NonSemantic.DebugPrintf\"", |
| 242 | + #op_loads |
| 243 | + concat!("%result = OpExtInst %void %debug_printf 1 %string ", #variable_idents), |
| 244 | + #input_registers |
| 245 | + ) |
| 246 | + }; |
| 247 | + |
| 248 | + output.into() |
| 249 | +} |
0 commit comments