diff --git a/impl/src/try_into.rs b/impl/src/try_into.rs index c76baa03..2ef37f06 100644 --- a/impl/src/try_into.rs +++ b/impl/src/try_into.rs @@ -2,21 +2,50 @@ use crate::utils::{ add_extra_generic_param, numbered_vars, replace_self::DeriveInputExt as _, AttrParams, DeriveType, MultiFieldData, State, }; -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens}; -use syn::{DeriveInput, Result}; +use syn::{Attribute, DeriveInput, Ident, Result}; -use crate::utils::HashMap; +use crate::utils::{ + attr::{self, ParseMultiple as _}, + HashMap, Spanning, +}; /// Provides the hook to expand `#[derive(TryInto)]` into an implementation of `TryInto` #[allow(clippy::cognitive_complexity)] pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result { - let input = &input.replace_self_type(); + let input = &mut input.replace_self_type(); + + let trait_attr = "try_into"; + + let error_attrs: Vec<(usize, Attribute)> = input + .attrs + .iter() + .enumerate() + .filter(|(_, attr)| attr.path().is_ident(trait_attr)) + .filter(|(_, attr)| attr.parse_args_with(detect_error_attr).is_ok()) + .map(|(i, attr)| (i, attr.clone())) + .collect(); + + for (i, _) in &error_attrs { + let _ = &mut input.attrs.remove(*i); + } + + let error_attrs = error_attrs + .into_iter() + .map(|(_, attr)| attr) + .collect::>(); + + let custom_error = attr::Error::parse_attrs( + error_attrs, + &Ident::new(trait_attr, Span::call_site()), + )? + .map(Spanning::into_inner); let state = State::with_attr_params( input, trait_name, - "try_into".into(), + trait_attr.into(), AttrParams { enum_: vec!["ignore", "owned", "ref", "ref_mut"], variant: vec!["ignore", "owned", "ref", "ref_mut"], @@ -101,26 +130,36 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result }; + let mut error_conv = quote! {}; + + if let Some(custom_error) = custom_error.as_ref() { + error_ty = custom_error.ty.to_token_stream(); + error_conv = custom_error.conv.as_ref().map_or_else( + || quote! { .map_err(derive_more::core::convert::Into::into) }, + |conv| quote! { .map_err(#conv) }, + ); + } + let try_from = quote! { #[automatically_derived] impl #impl_generics derive_more::core::convert::TryFrom< #reference_with_lifetime #input_type #ty_generics > for (#(#reference_with_lifetime #original_types),*) #where_clause { - type Error = #error; + type Error = #error_ty; #[inline] fn try_from( value: #reference_with_lifetime #input_type #ty_generics, - ) -> derive_more::core::result::Result { + ) -> derive_more::core::result::Result { match value { #(#matchers)|* => derive_more::core::result::Result::Ok(#vars), _ => derive_more::core::result::Result::Err( derive_more::TryIntoError::new(value, #variant_names, #output_type), - ), + )#error_conv, } } } @@ -129,3 +168,17 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result Result<()> { + mod ident { + syn::custom_keyword!(error); + } + + let ahead = input.lookahead1(); + if ahead.peek(ident::error) { + let _ = input.parse::(); + Ok(()) + } else { + Err(ahead.error()) + } +} diff --git a/impl/src/utils.rs b/impl/src/utils.rs index 6da96012..705e03d0 100644 --- a/impl/src/utils.rs +++ b/impl/src/utils.rs @@ -21,6 +21,7 @@ use syn::{ feature = "from_str", feature = "into", feature = "try_from", + feature = "try_into", ))] pub(crate) use self::either::Either; #[cfg(any(feature = "from", feature = "into"))] @@ -36,6 +37,7 @@ pub(crate) use self::generics_search::GenericsSearch; feature = "from_str", feature = "into", feature = "try_from", + feature = "try_into", ))] pub(crate) use self::spanning::Spanning; @@ -1324,6 +1326,7 @@ pub fn is_type_parameter_used_in_type( feature = "from_str", feature = "into", feature = "try_from", + feature = "try_into", ))] mod either { use proc_macro2::TokenStream; @@ -1397,6 +1400,7 @@ mod either { feature = "from_str", feature = "into", feature = "try_from", + feature = "try_into", ))] mod spanning { use std::ops::{Deref, DerefMut}; @@ -1494,6 +1498,7 @@ mod spanning { feature = "from_str", feature = "into", feature = "try_from", + feature = "try_into", ))] pub(crate) mod attr { use std::any::Any; @@ -1512,7 +1517,7 @@ pub(crate) mod attr { feature = "try_from" ))] pub(crate) use self::empty::Empty; - #[cfg(feature = "from_str")] + #[cfg(any(feature = "from_str", feature = "try_into"))] pub(crate) use self::error::Error; #[cfg(any(feature = "display", feature = "from_str"))] pub(crate) use self::rename_all::RenameAll; @@ -2117,7 +2122,7 @@ pub(crate) mod attr { } } - #[cfg(feature = "from_str")] + #[cfg(any(feature = "from_str", feature = "try_into"))] pub(crate) mod error { use syn::parse::{Parse, ParseStream}; diff --git a/tests/try_into.rs b/tests/try_into.rs index b1d9390b..7540fd2e 100644 --- a/tests/try_into.rs +++ b/tests/try_into.rs @@ -260,3 +260,73 @@ fn test_try_into() { ); assert!(matches!(i.try_into().unwrap(), ())); } + +mod error { + use core::fmt; + + #[cfg(not(feature = "std"))] + use alloc::string::String; + + use derive_more::TryIntoError; + + use super::*; + + struct CustomError(String); + + impl From> for CustomError { + fn from(value: TryIntoError) -> Self { + Self(value.to_string()) + } + } + + impl fmt::Display for CustomError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } + } + + #[derive(TryInto, Clone, Copy, Debug, Eq, PartialEq)] + #[try_into(ref)] + #[try_into(error(CustomError))] + enum MixedInts { + SmallInt(i32), + NamedBigInt { + int: i64, + }, + UnsignedWithIgnoredField(#[try_into(ignore)] bool, i64), + NamedUnsignedWithIgnoredField { + #[try_into(ignore)] + useless: bool, + x: i64, + }, + TwoSmallInts(i32, i32), + NamedBigInts { + x: i64, + y: i64, + }, + Unsigned(u32), + NamedUnsigned { + x: u32, + }, + Unit, + #[try_into(ignore)] + Unit2, + } + + #[test] + fn test() { + let i = MixedInts::Unsigned(42); + assert_eq!( + i32::try_from(i).unwrap_err().to_string(), + "Only SmallInt can be converted to i32" + ); + assert_eq!( + i64::try_from(i).unwrap_err().to_string(), + "Only NamedBigInt, UnsignedWithIgnoredField, NamedUnsignedWithIgnoredField can be converted to i64" + ); + assert_eq!( + <(i32, i32)>::try_from(i).unwrap_err().to_string(), + "Only TwoSmallInts can be converted to (i32, i32)" + ); + } +}