11#![ crate_type = "proc-macro" ]
22#![ allow( unused_imports) ] // Spurious complaints about a required trait import.
3- use syn:: { self , parse, parse_macro_input, spanned:: Spanned , Expr , ExprCall , ItemFn , Path } ;
3+ use syn:: { self , parse, parse2 , parse_macro_input, spanned:: Spanned , AngleBracketedGenericArguments , Expr , ExprCall , ItemFn , Path , PathArguments , Type } ;
44
55use proc_macro:: TokenStream ;
66use quote:: { self , ToTokens } ;
@@ -110,6 +110,40 @@ impl parse::Parse for CacheOptions {
110110 }
111111}
112112
113+
114+ fn check_for_result_type ( outer : proc_macro2:: TokenStream ) -> bool {
115+ // Parse the input as a Rust type
116+ let input_ty = parse2 :: < Type > ( outer) . expect ( "failed to parse outer type" ) ;
117+
118+ if let Type :: Path ( path) = input_ty {
119+ return path. path . segments . last ( ) . expect ( "O length path?" ) . ident == "Result" ;
120+ }
121+ false
122+ }
123+
124+ fn try_unwrap_result_type ( outer : proc_macro2:: TokenStream ) -> proc_macro2:: TokenStream {
125+ let original = outer. clone ( ) ;
126+ // Parse the input as a Rust type
127+ let input_ty = parse2 :: < Type > ( outer) . expect ( "failed to parse outer type" ) ;
128+
129+ // Ensure it’s a path type (e.g., Result<T, E>)
130+ if let Type :: Path ( path) = input_ty {
131+ // Look at the last segment (e.g., "Result")
132+ let last_segment = path. path . segments . last ( ) . expect ( "Expected a Result" ) ;
133+
134+ if last_segment. ident . to_string ( ) . contains ( "Result" ) {
135+ if let PathArguments :: AngleBracketed ( AngleBracketedGenericArguments { args, .. } ) = & last_segment. arguments
136+ {
137+ // The first generic argument of Result<T, E> is the Ok type
138+ if let Some ( syn:: GenericArgument :: Type ( ok_type) ) = args. first ( ) {
139+ return ok_type. to_token_stream ( )
140+ }
141+ }
142+ }
143+ }
144+ original
145+ }
146+
113147// This implementation of the storage backend does not depend on any more crates.
114148#[ cfg( not( feature = "full" ) ) ]
115149mod store {
@@ -122,6 +156,7 @@ mod store {
122156 key_type : proc_macro2:: TokenStream ,
123157 value_type : proc_macro2:: TokenStream ,
124158 ) -> ( proc_macro2:: TokenStream , proc_macro2:: TokenStream ) {
159+ let value_type = crate :: try_unwrap_result_type ( value_type) ;
125160 // This is the unbounded default.
126161 if let Some ( hasher) = & _options. custom_hasher {
127162 return (
@@ -148,8 +183,11 @@ mod store {
148183// This implementation of the storage backend also depends on the `lru` crate.
149184#[ cfg( feature = "full" ) ]
150185mod store {
151- use crate :: CacheOptions ;
186+ use crate :: { try_unwrap_result_type , CacheOptions } ;
152187 use proc_macro:: TokenStream ;
188+ use quote:: quote;
189+ use syn:: { parse2, AngleBracketedGenericArguments , PathArguments , Type , TypePath } ;
190+
153191
154192 /// Returns TokenStreams to be used in quote!{} for parametrizing the memoize store variable,
155193 /// and initializing it.
@@ -161,6 +199,8 @@ mod store {
161199 key_type : proc_macro2:: TokenStream ,
162200 value_type : proc_macro2:: TokenStream ,
163201 ) -> ( proc_macro2:: TokenStream , proc_macro2:: TokenStream ) {
202+ let value_type = try_unwrap_result_type ( value_type) ;
203+
164204 let value_type = match options. time_to_live {
165205 None => quote:: quote! { #value_type} ,
166206 Some ( _) => quote:: quote! { ( std:: time:: Instant , #value_type) } ,
@@ -201,7 +241,6 @@ mod store {
201241 }
202242 }
203243 }
204-
205244 /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a
206245 /// store.
207246 pub ( crate ) fn cache_access_methods (
@@ -379,20 +418,40 @@ pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream {
379418 ) ,
380419 } ;
381420
421+ let get_value = if check_for_result_type ( return_type. clone ( ) ) {
422+ quote:: quote! {
423+ let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple?;
424+ }
425+ } else {
426+ quote:: quote! {
427+ let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple;
428+ }
429+ } ;
430+
431+ let return_value = if check_for_result_type ( return_type. clone ( ) ) {
432+ quote:: quote! {
433+ Ok ( ATTR_MEMOIZE_RETURN__ )
434+ }
435+ } else {
436+ quote:: quote! {
437+ ATTR_MEMOIZE_RETURN__
438+ }
439+ } ;
440+
382441 let memoizer = if options. shared_cache {
383442 quote:: quote! {
384443 {
385444 let mut ATTR_MEMOIZE_HM__ = #store_ident. lock( ) . unwrap( ) ;
386445 if let Some ( ATTR_MEMOIZE_RETURN__ ) = #read_memo {
387- return ATTR_MEMOIZE_RETURN__
446+ return #return_value ;
388447 }
389448 }
390- let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple ;
449+ #get_value
391450
392451 let mut ATTR_MEMOIZE_HM__ = #store_ident. lock( ) . unwrap( ) ;
393452 #memoize
394453
395- ATTR_MEMOIZE_RETURN__
454+ #return_value
396455 }
397456 } else {
398457 quote:: quote! {
@@ -401,17 +460,17 @@ pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream {
401460 #read_memo
402461 } ) ;
403462 if let Some ( ATTR_MEMOIZE_RETURN__ ) = ATTR_MEMOIZE_RETURN__ {
404- return ATTR_MEMOIZE_RETURN__ ;
463+ return #return_value ;
405464 }
406465
407- let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple ;
466+ #get_value
408467
409468 #store_ident. with( |ATTR_MEMOIZE_HM__ | {
410469 let mut ATTR_MEMOIZE_HM__ = ATTR_MEMOIZE_HM__ . borrow_mut( ) ;
411470 #memoize
412471 } ) ;
413472
414- ATTR_MEMOIZE_RETURN__
473+ #return_value
415474 }
416475 } ;
417476
@@ -505,4 +564,47 @@ fn check_signature(
505564}
506565
507566#[ cfg( test) ]
508- mod tests { }
567+ mod tests {
568+ use std:: str:: FromStr ;
569+ use test_case:: test_case;
570+ use proc_macro2:: TokenStream ;
571+ use quote:: quote;
572+ use crate :: { check_for_result_type, try_unwrap_result_type} ;
573+
574+ #[ test_case( "Result<bool>" ) ]
575+ #[ test_case( "anyhow::Result<bool>" ) ]
576+ #[ test_case( "std::io::Result<bool>" ) ]
577+ #[ test_case( "io::Result<bool>" ) ]
578+ fn test_check_for_result_type_success ( typestr : & str ) {
579+ let input = TokenStream :: from_str ( typestr) . unwrap ( ) ;
580+ assert_eq ! ( true , check_for_result_type( input) ) ;
581+ }
582+
583+ #[ test_case( "Option<bool>" ) ]
584+ #[ test_case( "(bool, bool)" ) ]
585+ #[ test_case( "bool" ) ]
586+ fn test_check_for_result_type_fail ( typestr : & str ) {
587+ let input = TokenStream :: from_str ( typestr) . unwrap ( ) ;
588+ assert_eq ! ( false , check_for_result_type( input) ) ;
589+ }
590+
591+ #[ test_case( "Result<bool>" , "bool" ) ]
592+ #[ test_case( "anyhow::Result<u32>" , "u32" ) ]
593+ #[ test_case( "std::io::Result<String>" , "String" ) ]
594+ #[ test_case( "io::Result<(u32, u32)>" , "(u32 , u32)" ) ]
595+ #[ test_case( "CustomResult<CustomStruct, Error>" , "CustomStruct" ) ]
596+ fn test_try_unwrap_result_type_inner ( input_type : & str , output_type : & str ) {
597+ let input = TokenStream :: from_str ( input_type) . unwrap ( ) ;
598+ assert_eq ! ( output_type,
599+ try_unwrap_result_type( input) . to_string( ) ) ;
600+ }
601+
602+ #[ test_case( "Option < bool >" ) ]
603+ #[ test_case( "(bool , bool)" ) ]
604+ #[ test_case( "bool" ) ]
605+ fn test_try_unwrap_result_type_original ( typestr : & str ) {
606+ let input = TokenStream :: from_str ( typestr) . unwrap ( ) ;
607+ assert_eq ! ( typestr. replace( " " , "" ) ,
608+ try_unwrap_result_type( input) . to_string( ) . replace( " " , "" ) ) ;
609+ }
610+ }
0 commit comments