|
| 1 | +use proc_macro::TokenStream; |
| 2 | +use proc_macro2::Span; |
| 3 | +use quote::quote; |
| 4 | +use syn::{ |
| 5 | + fold::{self, Fold}, |
| 6 | + parse_macro_input, parse_quote, |
| 7 | + punctuated::Punctuated, |
| 8 | + token::Comma, |
| 9 | + Expr, Item, |
| 10 | +}; |
| 11 | + |
| 12 | +struct VersionFilter { |
| 13 | + version: String, |
| 14 | +} |
| 15 | + |
| 16 | +fn extract_version(attrs: &mut Vec<syn::Attribute>) -> Option<String> { |
| 17 | + let mut version = None; |
| 18 | + |
| 19 | + attrs.retain(|attr| { |
| 20 | + let path = attr.path(); |
| 21 | + |
| 22 | + if path.is_ident("versioned") { |
| 23 | + version = Some( |
| 24 | + attr.parse_args::<syn::LitStr>() |
| 25 | + .expect("expected a string literal with a version number") |
| 26 | + .value(), |
| 27 | + ); |
| 28 | + |
| 29 | + false |
| 30 | + } else { |
| 31 | + true |
| 32 | + } |
| 33 | + }); |
| 34 | + |
| 35 | + version |
| 36 | +} |
| 37 | + |
| 38 | +impl VersionFilter { |
| 39 | + fn matches(&self, found_version: &str) -> bool { |
| 40 | + self.version == found_version |
| 41 | + } |
| 42 | + |
| 43 | + fn filter_fields( |
| 44 | + &self, |
| 45 | + fields: Punctuated<syn::Field, Comma>, |
| 46 | + ) -> Punctuated<syn::Field, Comma> { |
| 47 | + fields |
| 48 | + .into_pairs() |
| 49 | + .filter_map( |
| 50 | + |mut pair| match extract_version(&mut pair.value_mut().attrs) { |
| 51 | + Some(version) => self.matches(&version).then_some(pair), |
| 52 | + None => Some(pair), |
| 53 | + }, |
| 54 | + ) |
| 55 | + .collect() |
| 56 | + } |
| 57 | +} |
| 58 | + |
| 59 | +impl Fold for VersionFilter { |
| 60 | + fn fold_fields_named(&mut self, mut fields: syn::FieldsNamed) -> syn::FieldsNamed { |
| 61 | + fields.named = self.filter_fields(fields.named); |
| 62 | + fields |
| 63 | + } |
| 64 | + |
| 65 | + fn fold_fields_unnamed(&mut self, mut fields: syn::FieldsUnnamed) -> syn::FieldsUnnamed { |
| 66 | + fields.unnamed = self.filter_fields(fields.unnamed); |
| 67 | + fields |
| 68 | + } |
| 69 | + |
| 70 | + fn fold_stmt(&mut self, mut stmt: syn::Stmt) -> syn::Stmt { |
| 71 | + match stmt { |
| 72 | + syn::Stmt::Local(syn::Local { ref mut attrs, .. }) |
| 73 | + | syn::Stmt::Macro(syn::StmtMacro { ref mut attrs, .. }) => { |
| 74 | + if let Some(version) = extract_version(attrs) { |
| 75 | + if !self.matches(&version) { |
| 76 | + stmt = parse_quote!({};); |
| 77 | + } |
| 78 | + } |
| 79 | + } |
| 80 | + _ => {} |
| 81 | + } |
| 82 | + |
| 83 | + fold::fold_stmt(self, stmt) |
| 84 | + } |
| 85 | + |
| 86 | + fn fold_expr(&mut self, mut expr: Expr) -> Expr { |
| 87 | + match &mut expr { |
| 88 | + Expr::Array(syn::ExprArray { ref mut attrs, .. }) |
| 89 | + | Expr::Assign(syn::ExprAssign { ref mut attrs, .. }) |
| 90 | + | Expr::Async(syn::ExprAsync { ref mut attrs, .. }) |
| 91 | + | Expr::Await(syn::ExprAwait { ref mut attrs, .. }) |
| 92 | + | Expr::Binary(syn::ExprBinary { ref mut attrs, .. }) |
| 93 | + | Expr::Block(syn::ExprBlock { ref mut attrs, .. }) |
| 94 | + | Expr::Break(syn::ExprBreak { ref mut attrs, .. }) |
| 95 | + | Expr::Call(syn::ExprCall { ref mut attrs, .. }) |
| 96 | + | Expr::Cast(syn::ExprCast { ref mut attrs, .. }) |
| 97 | + | Expr::Closure(syn::ExprClosure { ref mut attrs, .. }) |
| 98 | + | Expr::Const(syn::ExprConst { ref mut attrs, .. }) |
| 99 | + | Expr::Continue(syn::ExprContinue { ref mut attrs, .. }) |
| 100 | + | Expr::Field(syn::ExprField { ref mut attrs, .. }) |
| 101 | + | Expr::ForLoop(syn::ExprForLoop { ref mut attrs, .. }) |
| 102 | + | Expr::Group(syn::ExprGroup { ref mut attrs, .. }) |
| 103 | + | Expr::If(syn::ExprIf { ref mut attrs, .. }) |
| 104 | + | Expr::Index(syn::ExprIndex { ref mut attrs, .. }) |
| 105 | + | Expr::Infer(syn::ExprInfer { ref mut attrs, .. }) |
| 106 | + | Expr::Let(syn::ExprLet { ref mut attrs, .. }) |
| 107 | + | Expr::Lit(syn::ExprLit { ref mut attrs, .. }) |
| 108 | + | Expr::Loop(syn::ExprLoop { ref mut attrs, .. }) |
| 109 | + | Expr::Macro(syn::ExprMacro { ref mut attrs, .. }) |
| 110 | + | Expr::Match(syn::ExprMatch { ref mut attrs, .. }) |
| 111 | + | Expr::MethodCall(syn::ExprMethodCall { ref mut attrs, .. }) |
| 112 | + | Expr::Paren(syn::ExprParen { ref mut attrs, .. }) |
| 113 | + | Expr::Path(syn::ExprPath { ref mut attrs, .. }) |
| 114 | + | Expr::Range(syn::ExprRange { ref mut attrs, .. }) |
| 115 | + | Expr::Reference(syn::ExprReference { ref mut attrs, .. }) |
| 116 | + | Expr::Repeat(syn::ExprRepeat { ref mut attrs, .. }) |
| 117 | + | Expr::Return(syn::ExprReturn { ref mut attrs, .. }) |
| 118 | + | Expr::Struct(syn::ExprStruct { ref mut attrs, .. }) |
| 119 | + | Expr::Try(syn::ExprTry { ref mut attrs, .. }) |
| 120 | + | Expr::TryBlock(syn::ExprTryBlock { ref mut attrs, .. }) |
| 121 | + | Expr::Tuple(syn::ExprTuple { ref mut attrs, .. }) |
| 122 | + | Expr::Unary(syn::ExprUnary { ref mut attrs, .. }) |
| 123 | + | Expr::Unsafe(syn::ExprUnsafe { ref mut attrs, .. }) |
| 124 | + | Expr::While(syn::ExprWhile { ref mut attrs, .. }) |
| 125 | + | Expr::Yield(syn::ExprYield { ref mut attrs, .. }) => { |
| 126 | + if let Some(version) = extract_version(attrs) { |
| 127 | + if !self.matches(&version) { |
| 128 | + expr = parse_quote!({}); |
| 129 | + } |
| 130 | + } |
| 131 | + } |
| 132 | + _ => {} |
| 133 | + } |
| 134 | + |
| 135 | + fold::fold_expr(self, expr) |
| 136 | + } |
| 137 | + |
| 138 | + fn fold_expr_struct(&mut self, mut expr: syn::ExprStruct) -> syn::ExprStruct { |
| 139 | + expr.fields = expr |
| 140 | + .fields |
| 141 | + .into_pairs() |
| 142 | + .filter_map( |
| 143 | + |mut pair| match extract_version(&mut pair.value_mut().attrs) { |
| 144 | + Some(version) => self.matches(&version).then_some(pair), |
| 145 | + None => Some(pair), |
| 146 | + }, |
| 147 | + ) |
| 148 | + .collect(); |
| 149 | + |
| 150 | + fold::fold_expr_struct(self, expr) |
| 151 | + } |
| 152 | + |
| 153 | + fn fold_expr_match(&mut self, mut expr: syn::ExprMatch) -> syn::ExprMatch { |
| 154 | + expr.arms |
| 155 | + .retain_mut(|arm| match extract_version(&mut arm.attrs) { |
| 156 | + Some(version) => self.matches(&version), |
| 157 | + None => true, |
| 158 | + }); |
| 159 | + |
| 160 | + fold::fold_expr_match(self, expr) |
| 161 | + } |
| 162 | + |
| 163 | + fn fold_item(&mut self, mut item: Item) -> Item { |
| 164 | + match item { |
| 165 | + Item::Const(syn::ItemConst { ref mut attrs, .. }) |
| 166 | + | Item::Enum(syn::ItemEnum { ref mut attrs, .. }) |
| 167 | + | Item::ExternCrate(syn::ItemExternCrate { ref mut attrs, .. }) |
| 168 | + | Item::Fn(syn::ItemFn { ref mut attrs, .. }) |
| 169 | + | Item::ForeignMod(syn::ItemForeignMod { ref mut attrs, .. }) |
| 170 | + | Item::Impl(syn::ItemImpl { ref mut attrs, .. }) |
| 171 | + | Item::Macro(syn::ItemMacro { ref mut attrs, .. }) |
| 172 | + | Item::Mod(syn::ItemMod { ref mut attrs, .. }) |
| 173 | + | Item::Static(syn::ItemStatic { ref mut attrs, .. }) |
| 174 | + | Item::Struct(syn::ItemStruct { ref mut attrs, .. }) |
| 175 | + | Item::Trait(syn::ItemTrait { ref mut attrs, .. }) |
| 176 | + | Item::TraitAlias(syn::ItemTraitAlias { ref mut attrs, .. }) |
| 177 | + | Item::Type(syn::ItemType { ref mut attrs, .. }) |
| 178 | + | Item::Union(syn::ItemUnion { ref mut attrs, .. }) |
| 179 | + | Item::Use(syn::ItemUse { ref mut attrs, .. }) => { |
| 180 | + if let Some(version) = extract_version(attrs) { |
| 181 | + if !self.matches(&version) { |
| 182 | + item = parse_quote!( |
| 183 | + use {}; |
| 184 | + ); |
| 185 | + } |
| 186 | + } |
| 187 | + } |
| 188 | + _ => {} |
| 189 | + } |
| 190 | + |
| 191 | + fold::fold_item(self, item) |
| 192 | + } |
| 193 | +} |
| 194 | + |
| 195 | +#[proc_macro_attribute] |
| 196 | +pub fn versioned(input: TokenStream, annotated_item: TokenStream) -> TokenStream { |
| 197 | + // This parses the module being annotated by the `#[versioned(..)]` attribute. |
| 198 | + let module = parse_macro_input!(annotated_item as syn::ItemMod); |
| 199 | + |
| 200 | + // This parses the versions passed to the attribute, e.g. the `"1.3"` |
| 201 | + // and `"1.4"`in `#[versioned("1.3", "1.4")] |
| 202 | + // FIXME: we should do extra validations for the version numbers themselves. |
| 203 | + let versions: Vec<String> = |
| 204 | + parse_macro_input!(input with Punctuated::<syn::LitStr, Comma>::parse_terminated) |
| 205 | + .into_iter() |
| 206 | + .map(|s| s.value()) |
| 207 | + .collect(); |
| 208 | + |
| 209 | + let mut tokens = proc_macro2::TokenStream::new(); |
| 210 | + |
| 211 | + for version in versions { |
| 212 | + let mod_vis = &module.vis; |
| 213 | + let mod_ident = syn::Ident::new( |
| 214 | + &format!("v{}", version.replace('.', "_")), |
| 215 | + Span::call_site(), |
| 216 | + ); |
| 217 | + |
| 218 | + let (_, items) = module.content.clone().unwrap(); |
| 219 | + |
| 220 | + let mut folded_items = Vec::new(); |
| 221 | + |
| 222 | + let mut filter = VersionFilter { version }; |
| 223 | + |
| 224 | + for item in items { |
| 225 | + folded_items.push(filter.fold_item(item)); |
| 226 | + } |
| 227 | + |
| 228 | + tokens.extend(quote! { |
| 229 | + #mod_vis mod #mod_ident { |
| 230 | + #(#folded_items)* |
| 231 | + } |
| 232 | + }) |
| 233 | + } |
| 234 | + |
| 235 | + tokens.into() |
| 236 | +} |
0 commit comments