1- use std:: { error:: Error , str:: FromStr } ;
1+ use std:: error:: Error as StdError ;
2+ use std:: str:: FromStr ;
23
34use proc_macro:: TokenStream ;
45use proc_macro2:: Span ;
56use quote:: quote;
67use syn:: {
78 fold:: { self , Fold } ,
8- parse_macro_input , parse_quote,
9+ parse_quote,
910 punctuated:: Punctuated ,
1011 token:: Comma ,
11- Expr , Item ,
12+ Error , Expr , Item ,
1213} ;
1314
1415#[ derive( PartialEq , Eq ) ]
@@ -18,12 +19,12 @@ struct Version {
1819}
1920
2021impl FromStr for Version {
21- type Err = Box < dyn Error > ;
22+ type Err = Box < dyn StdError > ;
2223
2324 fn from_str ( s : & str ) -> Result < Self , Self :: Err > {
2425 let ( major_str, minor_str) = s
2526 . split_once ( '.' )
26- . ok_or_else ( || Self :: Err :: from ( "versions must have a `.`" . to_owned ( ) ) ) ?;
27+ . ok_or_else ( || Self :: Err :: from ( "missing `.`" . to_owned ( ) ) ) ?;
2728
2829 Ok ( Self {
2930 major : major_str. parse ( ) ?,
@@ -33,55 +34,56 @@ impl FromStr for Version {
3334}
3435
3536impl Version {
36- fn extract_from_attrs ( attrs : & mut Vec < syn:: Attribute > ) -> Option < Self > {
37- let mut version = None ;
37+ fn as_ident ( & self ) -> syn:: Ident {
38+ syn:: Ident :: new (
39+ & format ! ( "v{}_{}" , self . major, self . minor) ,
40+ Span :: call_site ( ) ,
41+ )
42+ }
43+ }
44+
45+ struct VersionFilter {
46+ version : Version ,
47+ error : Option < Error > ,
48+ }
49+
50+ impl VersionFilter {
51+ fn extract_version ( & mut self , attrs : & mut Vec < syn:: Attribute > ) -> Option < Version > {
52+ let mut opt_version = None ;
3853
3954 attrs. retain ( |attr| {
4055 let path = attr. path ( ) ;
4156
4257 if path. is_ident ( "versioned" ) {
43- version = Some (
44- attr . parse_args :: < syn:: LitStr > ( )
45- . expect ( "expected a string literal" )
46- . value ( )
47- . parse ( )
48- . expect ( "cannot parse version" ) ,
49- ) ;
58+ match attr
59+ . parse_args :: < syn:: LitStr > ( )
60+ . and_then ( |s| s . value ( ) . parse ( ) . map_err ( |err| Error :: new ( s . span ( ) , err ) ) )
61+ {
62+ Ok ( version ) => opt_version = Some ( version ) ,
63+ Err ( err ) => self . error = Some ( err ) ,
64+ }
5065
5166 false
5267 } else {
5368 true
5469 }
5570 } ) ;
5671
57- version
58- }
59-
60- fn as_ident ( & self ) -> syn:: Ident {
61- syn:: Ident :: new (
62- & format ! ( "v{}_{}" , self . major, self . minor) ,
63- Span :: call_site ( ) ,
64- )
72+ opt_version
6573 }
66- }
67-
68- struct VersionFilter {
69- version : Version ,
70- }
7174
72- impl VersionFilter {
7375 fn matches ( & self , found_version : & Version ) -> bool {
7476 & self . version == found_version
7577 }
7678
7779 fn filter_fields (
78- & self ,
80+ & mut self ,
7981 fields : Punctuated < syn:: Field , Comma > ,
8082 ) -> Punctuated < syn:: Field , Comma > {
8183 fields
8284 . into_pairs ( )
8385 . filter_map (
84- |mut pair| match Version :: extract_from_attrs ( & mut pair. value_mut ( ) . attrs ) {
86+ |mut pair| match self . extract_version ( & mut pair. value_mut ( ) . attrs ) {
8587 Some ( version) => self . matches ( & version) . then_some ( pair) ,
8688 None => Some ( pair) ,
8789 } ,
@@ -105,7 +107,7 @@ impl Fold for VersionFilter {
105107 match stmt {
106108 syn:: Stmt :: Local ( syn:: Local { ref mut attrs, .. } )
107109 | syn:: Stmt :: Macro ( syn:: StmtMacro { ref mut attrs, .. } ) => {
108- if let Some ( version) = Version :: extract_from_attrs ( attrs) {
110+ if let Some ( version) = self . extract_version ( attrs) {
109111 if !self . matches ( & version) {
110112 stmt = parse_quote ! ( { } ; ) ;
111113 }
@@ -157,7 +159,7 @@ impl Fold for VersionFilter {
157159 | Expr :: Unsafe ( syn:: ExprUnsafe { ref mut attrs, .. } )
158160 | Expr :: While ( syn:: ExprWhile { ref mut attrs, .. } )
159161 | Expr :: Yield ( syn:: ExprYield { ref mut attrs, .. } ) => {
160- if let Some ( version) = Version :: extract_from_attrs ( attrs) {
162+ if let Some ( version) = self . extract_version ( attrs) {
161163 if !self . matches ( & version) {
162164 expr = parse_quote ! ( { } ) ;
163165 }
@@ -174,7 +176,7 @@ impl Fold for VersionFilter {
174176 . fields
175177 . into_pairs ( )
176178 . filter_map (
177- |mut pair| match Version :: extract_from_attrs ( & mut pair. value_mut ( ) . attrs ) {
179+ |mut pair| match self . extract_version ( & mut pair. value_mut ( ) . attrs ) {
178180 Some ( version) => self . matches ( & version) . then_some ( pair) ,
179181 None => Some ( pair) ,
180182 } ,
@@ -186,7 +188,7 @@ impl Fold for VersionFilter {
186188
187189 fn fold_expr_match ( & mut self , mut expr : syn:: ExprMatch ) -> syn:: ExprMatch {
188190 expr. arms
189- . retain_mut ( |arm| match Version :: extract_from_attrs ( & mut arm. attrs ) {
191+ . retain_mut ( |arm| match self . extract_version ( & mut arm. attrs ) {
190192 Some ( version) => self . matches ( & version) ,
191193 None => true ,
192194 } ) ;
@@ -211,7 +213,7 @@ impl Fold for VersionFilter {
211213 | Item :: Type ( syn:: ItemType { ref mut attrs, .. } )
212214 | Item :: Union ( syn:: ItemUnion { ref mut attrs, .. } )
213215 | Item :: Use ( syn:: ItemUse { ref mut attrs, .. } ) => {
214- if let Some ( version) = Version :: extract_from_attrs ( attrs) {
216+ if let Some ( version) = self . extract_version ( attrs) {
215217 if !self . matches ( & version) {
216218 item = parse_quote ! (
217219 use { } ;
@@ -226,33 +228,46 @@ impl Fold for VersionFilter {
226228 }
227229}
228230
229- #[ proc_macro_attribute]
230- pub fn versioned ( input : TokenStream , annotated_item : TokenStream ) -> TokenStream {
231+ fn helper (
232+ input : TokenStream ,
233+ annotated_item : TokenStream ,
234+ ) -> syn:: Result < proc_macro2:: TokenStream > {
231235 // This parses the module being annotated by the `#[versioned(..)]` attribute.
232- let module = parse_macro_input ! ( annotated_item as syn:: ItemMod ) ;
236+ let module = syn:: parse :: < syn :: ItemMod > ( annotated_item ) ? ;
233237
234238 // This parses the versions passed to the attribute, e.g. the `"1.3"`
235239 // and `"1.4"`in `#[versioned("1.3", "1.4")]
236- let versions: Vec < Version > =
237- parse_macro_input ! ( input with Punctuated :: <syn:: LitStr , Comma >:: parse_terminated)
240+ let versions =
241+ syn :: parse :: Parser :: parse ( Punctuated :: < syn:: LitStr , Comma > :: parse_terminated, input ) ?
238242 . into_iter ( )
239- . map ( |s| s. value ( ) . parse ( ) . expect ( "cannot parse version" ) )
240- . collect ( ) ;
243+ . map ( |s| s. value ( ) . parse ( ) . map_err ( |err| Error :: new ( s. span ( ) , err) ) )
244+ . collect :: < syn:: Result < Vec < Version > > > ( ) ?;
245+
246+ let content = module
247+ . content
248+ . as_ref ( )
249+ . ok_or_else ( || Error :: new ( module. ident . span ( ) , "found module without content" ) ) ?;
241250
242251 let mut tokens = proc_macro2:: TokenStream :: new ( ) ;
243252
244253 for version in versions {
245254 let mod_vis = & module. vis ;
246255 let mod_ident = version. as_ident ( ) ;
247256
248- let ( _ , items) = module . content . clone ( ) . unwrap ( ) ;
257+ let items = content. 1 . clone ( ) ;
249258
250259 let mut folded_items = Vec :: new ( ) ;
251260
252- let mut filter = VersionFilter { version } ;
261+ let mut filter = VersionFilter {
262+ version,
263+ error : None ,
264+ } ;
253265
254266 for item in items {
255267 folded_items. push ( filter. fold_item ( item) ) ;
268+ if let Some ( error) = filter. error {
269+ return Err ( error) ;
270+ }
256271 }
257272
258273 tokens. extend ( quote ! {
@@ -262,5 +277,18 @@ pub fn versioned(input: TokenStream, annotated_item: TokenStream) -> TokenStream
262277 } )
263278 }
264279
265- tokens. into ( )
280+ Ok ( tokens)
281+ }
282+
283+ #[ proc_macro_attribute]
284+ pub fn versioned ( input : TokenStream , annotated_item : TokenStream ) -> TokenStream {
285+ match helper ( input, annotated_item) {
286+ Ok ( tokens) => tokens,
287+ Err ( err) => Error :: new (
288+ err. span ( ) ,
289+ format ! ( "{err} while using the `#[versioned]` macro" ) ,
290+ )
291+ . into_compile_error ( ) ,
292+ }
293+ . into ( )
266294}
0 commit comments