Skip to content

Commit 1e7cc77

Browse files
authored
Merge pull request #43 from cameroncros/feature/memoize_result
Allow memoizing Results
2 parents b2fe6b5 + 90d3f49 commit 1e7cc77

File tree

3 files changed

+131
-10
lines changed

3 files changed

+131
-10
lines changed

examples/result.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
use memoize::memoize;
2+
3+
#[memoize]
4+
fn hello(arg: String, arg2: usize) -> std::io::Result<bool> {
5+
println!("{} => {}", arg, arg2);
6+
Ok(arg.len() % 2 == arg2)
7+
}
8+
9+
fn main() {
10+
// `hello` is only called once here.
11+
assert!(hello("World2".to_string(), 0).unwrap());
12+
assert!(hello("World2".to_string(), 0).unwrap());
13+
// Sometimes one might need the original function.
14+
assert!(memoized_original_hello("World2".to_string(), 0).unwrap());
15+
memoized_flush_hello();
16+
}

inner/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ proc-macro2 = "1.0"
1919
quote = "1.0"
2020
syn = { version = "1.0", features = ["full"] }
2121

22+
[dev-dependencies]
23+
test-case = "3.3.1"
24+
2225
[features]
2326
default = []
2427
full = []

inner/src/lib.rs

Lines changed: 112 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#![crate_type = "proc-macro"]
22
#![allow(unused_imports)] // Spurious complaints about a required trait import.
3-
use syn::{self, parse, parse_macro_input, spanned::Spanned, Expr, ExprCall, ItemFn, Path};
3+
use syn::{self, parse, parse2, parse_macro_input, spanned::Spanned, AngleBracketedGenericArguments, Expr, ExprCall, ItemFn, Path, PathArguments, Type};
44

55
use proc_macro::TokenStream;
66
use quote::{self, ToTokens};
@@ -110,6 +110,40 @@ impl parse::Parse for CacheOptions {
110110
}
111111
}
112112

113+
114+
fn check_for_result_type(outer: proc_macro2::TokenStream) -> bool {
115+
// Parse the input as a Rust type
116+
let input_ty = parse2::<Type>(outer).expect("failed to parse outer type");
117+
118+
if let Type::Path(path) = input_ty {
119+
return path.path.segments.last().expect("O length path?").ident == "Result";
120+
}
121+
false
122+
}
123+
124+
fn try_unwrap_result_type(outer: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
125+
let original = outer.clone();
126+
// Parse the input as a Rust type
127+
let input_ty = parse2::<Type>(outer).expect("failed to parse outer type");
128+
129+
// Ensure it’s a path type (e.g., Result<T, E>)
130+
if let Type::Path(path) = input_ty {
131+
// Look at the last segment (e.g., "Result")
132+
let last_segment = path.path.segments.last().expect("Expected a Result");
133+
134+
if last_segment.ident.to_string().contains("Result") {
135+
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) = &last_segment.arguments
136+
{
137+
// The first generic argument of Result<T, E> is the Ok type
138+
if let Some(syn::GenericArgument::Type(ok_type)) = args.first() {
139+
return ok_type.to_token_stream()
140+
}
141+
}
142+
}
143+
}
144+
original
145+
}
146+
113147
// This implementation of the storage backend does not depend on any more crates.
114148
#[cfg(not(feature = "full"))]
115149
mod store {
@@ -122,6 +156,7 @@ mod store {
122156
key_type: proc_macro2::TokenStream,
123157
value_type: proc_macro2::TokenStream,
124158
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
159+
let value_type = crate::try_unwrap_result_type(value_type);
125160
// This is the unbounded default.
126161
if let Some(hasher) = &_options.custom_hasher {
127162
return (
@@ -148,8 +183,11 @@ mod store {
148183
// This implementation of the storage backend also depends on the `lru` crate.
149184
#[cfg(feature = "full")]
150185
mod store {
151-
use crate::CacheOptions;
186+
use crate::{try_unwrap_result_type, CacheOptions};
152187
use proc_macro::TokenStream;
188+
use quote::quote;
189+
use syn::{parse2, AngleBracketedGenericArguments, PathArguments, Type, TypePath};
190+
153191

154192
/// Returns TokenStreams to be used in quote!{} for parametrizing the memoize store variable,
155193
/// and initializing it.
@@ -161,6 +199,8 @@ mod store {
161199
key_type: proc_macro2::TokenStream,
162200
value_type: proc_macro2::TokenStream,
163201
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
202+
let value_type = try_unwrap_result_type(value_type);
203+
164204
let value_type = match options.time_to_live {
165205
None => quote::quote! {#value_type},
166206
Some(_) => quote::quote! {(std::time::Instant, #value_type)},
@@ -201,7 +241,6 @@ mod store {
201241
}
202242
}
203243
}
204-
205244
/// Returns names of methods as TokenStreams to insert and get (respectively) elements from a
206245
/// store.
207246
pub(crate) fn cache_access_methods(
@@ -379,20 +418,40 @@ pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream {
379418
),
380419
};
381420

421+
let get_value = if check_for_result_type(return_type.clone()) {
422+
quote::quote! {
423+
let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple?;
424+
}
425+
} else {
426+
quote::quote! {
427+
let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple;
428+
}
429+
};
430+
431+
let return_value = if check_for_result_type(return_type.clone()) {
432+
quote::quote! {
433+
Ok(ATTR_MEMOIZE_RETURN__)
434+
}
435+
} else {
436+
quote::quote! {
437+
ATTR_MEMOIZE_RETURN__
438+
}
439+
};
440+
382441
let memoizer = if options.shared_cache {
383442
quote::quote! {
384443
{
385444
let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap();
386445
if let Some(ATTR_MEMOIZE_RETURN__) = #read_memo {
387-
return ATTR_MEMOIZE_RETURN__
446+
return #return_value;
388447
}
389448
}
390-
let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple;
449+
#get_value
391450

392451
let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap();
393452
#memoize
394453

395-
ATTR_MEMOIZE_RETURN__
454+
#return_value
396455
}
397456
} else {
398457
quote::quote! {
@@ -401,17 +460,17 @@ pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream {
401460
#read_memo
402461
});
403462
if let Some(ATTR_MEMOIZE_RETURN__) = ATTR_MEMOIZE_RETURN__ {
404-
return ATTR_MEMOIZE_RETURN__;
463+
return #return_value;
405464
}
406465

407-
let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple;
466+
#get_value
408467

409468
#store_ident.with(|ATTR_MEMOIZE_HM__| {
410469
let mut ATTR_MEMOIZE_HM__ = ATTR_MEMOIZE_HM__.borrow_mut();
411470
#memoize
412471
});
413472

414-
ATTR_MEMOIZE_RETURN__
473+
#return_value
415474
}
416475
};
417476

@@ -505,4 +564,47 @@ fn check_signature(
505564
}
506565

507566
#[cfg(test)]
508-
mod tests {}
567+
mod tests {
568+
use std::str::FromStr;
569+
use test_case::test_case;
570+
use proc_macro2::TokenStream;
571+
use quote::quote;
572+
use crate::{check_for_result_type, try_unwrap_result_type};
573+
574+
#[test_case("Result<bool>")]
575+
#[test_case("anyhow::Result<bool>")]
576+
#[test_case("std::io::Result<bool>")]
577+
#[test_case("io::Result<bool>")]
578+
fn test_check_for_result_type_success(typestr: &str) {
579+
let input = TokenStream::from_str(typestr).unwrap();
580+
assert_eq!(true, check_for_result_type(input));
581+
}
582+
583+
#[test_case("Option<bool>")]
584+
#[test_case("(bool, bool)")]
585+
#[test_case("bool")]
586+
fn test_check_for_result_type_fail(typestr: &str) {
587+
let input = TokenStream::from_str(typestr).unwrap();
588+
assert_eq!(false, check_for_result_type(input));
589+
}
590+
591+
#[test_case("Result<bool>", "bool")]
592+
#[test_case("anyhow::Result<u32>", "u32")]
593+
#[test_case("std::io::Result<String>", "String")]
594+
#[test_case("io::Result<(u32, u32)>", "(u32 , u32)")]
595+
#[test_case("CustomResult<CustomStruct, Error>", "CustomStruct")]
596+
fn test_try_unwrap_result_type_inner(input_type: &str, output_type: &str) {
597+
let input = TokenStream::from_str(input_type).unwrap();
598+
assert_eq!(output_type,
599+
try_unwrap_result_type(input).to_string());
600+
}
601+
602+
#[test_case("Option < bool >")]
603+
#[test_case("(bool , bool)")]
604+
#[test_case("bool")]
605+
fn test_try_unwrap_result_type_original(typestr: &str) {
606+
let input = TokenStream::from_str(typestr).unwrap();
607+
assert_eq!(typestr.replace(" ", ""),
608+
try_unwrap_result_type(input).to_string().replace(" ", ""));
609+
}
610+
}

0 commit comments

Comments
 (0)