diff --git a/diffus-derive-test/src/lib.rs b/diffus-derive-test/src/lib.rs index 67f180e..31e7559 100644 --- a/diffus-derive-test/src/lib.rs +++ b/diffus-derive-test/src/lib.rs @@ -412,4 +412,61 @@ mod test { &vec![string::Edit::Copy('a'), string::Edit::Insert('\''),] ); } + + mod generics { + use diffus::Diffus; + + pub trait Thing { + type Foo; + type Bar; + } + + pub struct ConcreteThing; + + impl Thing for ConcreteThing { + type Foo = String; + type Bar = i64; + } + + #[derive(Diffus)] + pub struct TestNamedStruct where A: Thing { + pub a: A::Foo, + pub inner: i32, + } + + #[derive(Diffus)] + pub enum TestTuple where A: Thing { + Hello { + bar: A::Bar, + }, + UnitVariant, + TupleVariant(A::Bar, A::Bar), + } + + #[derive(Diffus)] + pub struct TestUnnamedStruct(pub A::Foo) where A: Thing; + } + + #[test] + fn test() { + use self::generics::{ConcreteThing, TestNamedStruct}; + use edit::string; + + let a: TestNamedStruct = TestNamedStruct { a: "a".to_string(), inner: 12 }; + let ap = TestNamedStruct { a: "a'".to_string(), inner: 13 }; + + let diff = a.diff(&ap); + let actual_a = diff.change().unwrap().a.change().unwrap(); + let actual_inner = diff.change().unwrap().inner.change().unwrap(); + + assert_eq!( + actual_a, + &vec![string::Edit::Copy('a'), string::Edit::Insert('\''),] + ); + + assert_eq!( + actual_inner, + &(&12, &13), + ); + } } diff --git a/diffus-derive/src/lib.rs b/diffus-derive/src/lib.rs index 8d9be47..a620d15 100644 --- a/diffus-derive/src/lib.rs +++ b/diffus-derive/src/lib.rs @@ -132,18 +132,33 @@ fn input_lifetime(generics: &syn::Generics) -> Option<&syn::Lifetime> { lifetime } +struct Generics { + ty_generic_params: syn::punctuated::Punctuated, + + edited_ty_generic_params: syn::punctuated::Punctuated, + edited_ty_where_clause: syn::WhereClause, + + impl_diffable_generic_params: syn::punctuated::Punctuated, + impl_diffable_where_clause: syn::WhereClause, + + impl_lifetime: syn::Lifetime, +} + #[proc_macro_derive(Diffus)] pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input: syn::DeriveInput = syn::parse2(proc_macro2::TokenStream::from(input)).unwrap(); let ident = &input.ident; let vis = &input.vis; - let where_clause = &input.generics.where_clause; + let edited_ident = syn::parse_str::(&format!("Edited{}", ident)).unwrap(); - let data_lifetime = input_lifetime(&input.generics); - let default_lifetime = syn::parse_str::("'diffus_a").unwrap(); - let impl_lifetime = data_lifetime.unwrap_or(&default_lifetime); + let Generics { + ty_generic_params, + edited_ty_generic_params, edited_ty_where_clause, + impl_diffable_generic_params, impl_diffable_where_clause, + impl_lifetime, + } = Generics::new(&input.generics, &input.data); #[cfg(feature = "serialize-impl")] let derive_serialize = Some(quote! { #[derive(serde::Serialize)] }); @@ -182,8 +197,8 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream } }); - let unit_enum_impl_lifetime = if has_non_unit_variant { - Some(impl_lifetime.clone()) + let unit_enum_generic_params = if has_non_unit_variant { + Some(edited_ty_generic_params.clone()) } else { None }; @@ -262,12 +277,12 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream quote! { #derive_serialize - #vis enum #edited_ident <#unit_enum_impl_lifetime> where #where_clause { + #vis enum #edited_ident <#unit_enum_generic_params> #edited_ty_where_clause { #(#edit_variants),* } - impl<#impl_lifetime> diffus::Diffable<#impl_lifetime> for #ident <#data_lifetime> where #where_clause { - type Diff = diffus::edit::enm::Edit<#impl_lifetime, Self, #edited_ident <#unit_enum_impl_lifetime>>; + impl<#impl_diffable_generic_params> diffus::Diffable<#impl_lifetime> for #ident <#ty_generic_params> #impl_diffable_where_clause { + type Diff = diffus::edit::enm::Edit<#impl_lifetime, Self, #edited_ident <#unit_enum_generic_params>>; fn diff(&#impl_lifetime self, other: &#impl_lifetime Self) -> diffus::edit::Edit<#impl_lifetime, Self> { match (self, other) { @@ -290,12 +305,12 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream syn::Fields::Named(_) => { quote! { #derive_serialize - #vis struct #edited_ident<#impl_lifetime> where #where_clause { + #vis struct #edited_ident<#edited_ty_generic_params> #edited_ty_where_clause { #edit_fields } - impl<#impl_lifetime> diffus::Diffable<#impl_lifetime> for #ident <#data_lifetime> where #where_clause { - type Diff = #edited_ident<#impl_lifetime>; + impl<#impl_diffable_generic_params> diffus::Diffable<#impl_lifetime> for #ident <#ty_generic_params> #impl_diffable_where_clause { + type Diff = #edited_ident<#edited_ty_generic_params>; fn diff(&#impl_lifetime self, other: &#impl_lifetime Self) -> diffus::edit::Edit<#impl_lifetime, Self> { match ( #field_diffs ) { @@ -311,10 +326,10 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream syn::Fields::Unnamed(_) => { quote! { #derive_serialize - #vis struct #edited_ident<#impl_lifetime> ( #edit_fields ) where #where_clause; + #vis struct #edited_ident<#edited_ty_generic_params> ( #edit_fields ) #edited_ty_where_clause; - impl<#impl_lifetime> diffus::Diffable<#impl_lifetime> for #ident <#data_lifetime> where #where_clause { - type Diff = #edited_ident<#impl_lifetime>; + impl<#impl_diffable_generic_params> diffus::Diffable<#impl_lifetime> for #ident <#ty_generic_params> #impl_diffable_where_clause { + type Diff = #edited_ident<#edited_ty_generic_params>; fn diff(&#impl_lifetime self, other: &#impl_lifetime Self) -> diffus::edit::Edit<#impl_lifetime, Self> { match ( #field_diffs ) { @@ -330,9 +345,9 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream syn::Fields::Unit => { quote! { #derive_serialize - #vis struct #edited_ident< > where #where_clause; + #vis struct #edited_ident< > #edited_ty_where_clause; - impl<#impl_lifetime> diffus::Diffable<#impl_lifetime> for #ident< > where #where_clause { + impl<#impl_lifetime> diffus::Diffable<#impl_lifetime> for #ident< > #impl_diffable_where_clause { type Diff = #edited_ident; fn diff(&#impl_lifetime self, other: &#impl_lifetime Self) -> diffus::edit::Edit<#impl_lifetime, Self> { @@ -346,3 +361,112 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream syn::Data::Union(_) => panic!("union type not supported yet"), }) } + +impl Generics { + pub fn new( + input_generics: &syn::Generics, + data: &syn::Data, + ) -> Self { + let input_generic_params = &input_generics.params; + let input_where_clause = &input_generics.where_clause; + let empty_where_clause = input_generics.clone().make_where_clause().clone(); + + let generic_types_used = Self::collect_generic_types_used(input_generics, data); + + let ty_generic_params = input_generic_params.clone(); + + let mut edited_ty_generic_params = ty_generic_params.clone(); + let mut edited_ty_where_clause = input_where_clause.clone().unwrap_or(empty_where_clause.clone()); + + let mut impl_diffable_generic_params = input_generic_params.clone(); + let mut impl_diffable_where_clause = input_where_clause.clone().unwrap_or(empty_where_clause); + + let explicit_data_lifetime = input_lifetime(input_generics); + let impl_lifetime = explicit_data_lifetime.cloned().unwrap_or_else(|| { + let default_lifetime = syn::parse_str::("'diffus_a").unwrap(); + + // Add the lifetime into the generics lists. + impl_diffable_generic_params.insert(0, syn::GenericParam::Lifetime(syn::LifetimeDef::new(default_lifetime.clone()))); + edited_ty_generic_params.insert(0, syn::GenericParam::Lifetime(syn::LifetimeDef::new(default_lifetime.clone()))); + + default_lifetime.clone() + }); + + // Ensure that all generic types that exist live for as long as the diffus lifetime. + impl_diffable_where_clause.predicates.extend(input_generics.type_params().map(|type_param| { + let where_predicate = quote!(#type_param : #impl_lifetime); + let where_predicate: syn::WherePredicate = syn::parse(where_predicate.into()).unwrap(); + where_predicate + })); + + // Ensure that all generic types actually used are diffable and live for as long as the + // diffus lifetime. + for generic_ty_path in generic_types_used { + let where_predicate = quote!(#generic_ty_path : diffus::Diffable<#impl_lifetime> + #impl_lifetime); + let where_predicate = syn::parse::(where_predicate.into()).unwrap(); + + impl_diffable_where_clause.predicates.push(where_predicate.clone()); + edited_ty_where_clause.predicates.push(where_predicate.clone()); + } + + Generics { + ty_generic_params, + edited_ty_generic_params, edited_ty_where_clause, + impl_diffable_generic_params, impl_diffable_where_clause, + impl_lifetime, + } + } + + /// Collects all of the generic types used in a type including associated types. + fn collect_generic_types_used( + input_generics: &syn::Generics, + data: &syn::Data, + ) -> Vec { + let all_possible_fields: Vec<&syn::Fields> = match *data { + syn::Data::Struct(ref s) => vec![&s.fields], + syn::Data::Enum(ref e) => e.variants.iter().map(|v| &v.fields).collect(), + syn::Data::Union(..) => Vec::new(), // unimplemented + }; + + let all_possible_types: Vec<&syn::Type> = all_possible_fields.into_iter().flat_map(|fields| match fields { + syn::Fields::Named(ref fields) => fields.named.iter().map(|f| &f.ty).collect(), + syn::Fields::Unnamed(ref fields) => fields.unnamed.iter().map(|f| &f.ty).collect(), + syn::Fields::Unit => Vec::new(), + }).collect(); + + let mut generic_types_used = Vec::new(); + let mut remaining_types_to_check = all_possible_types.clone(); + + while let Some(type_to_check) = remaining_types_to_check.pop() { + match *type_to_check { + syn::Type::Path(ref path) => { + if let Some(first_segment) = path.path.segments.first().map(|s| &s.ident) { + let first_segment: syn::Ident = first_segment.clone().into(); + + if input_generics.type_params().any(|type_param| type_param.ident == first_segment) { + generic_types_used.push(path.path.clone()); + } + } + }, + + syn::Type::Array(ref array) => remaining_types_to_check.push(&array.elem), + syn::Type::Group(ref group) => remaining_types_to_check.push(&group.elem), + syn::Type::Paren(ref paren) => remaining_types_to_check.push(&paren.elem), + syn::Type::Ptr(ref ptr) => remaining_types_to_check.push(&ptr.elem), + syn::Type::Reference(ref reference) => remaining_types_to_check.push(&reference.elem), + syn::Type::Slice(ref slice) => remaining_types_to_check.push(&slice.elem), + syn::Type::Tuple(ref tuple) => remaining_types_to_check.extend(tuple.elems.iter()), + syn::Type::Verbatim(..) | + syn::Type::ImplTrait(..) | + syn::Type::Infer(..) | + syn::Type::Macro(..) | + syn::Type::Never(..) | + syn::Type::TraitObject(..) | + syn::Type::BareFn(..) => (), + _ => (), // unknown/unsupported type + } + } + + generic_types_used + } +} diff --git a/diffus/src/same.rs b/diffus/src/same.rs index fa635c4..1909391 100644 --- a/diffus/src/same.rs +++ b/diffus/src/same.rs @@ -49,3 +49,11 @@ impl Same for &T { (*self).same(*other) } } + +impl Same for Box + where T: Same { + fn same(&self, other: &Self) -> bool { + let (a, b): (&T, &T) = (&*self, &*other); + a.same(b) + } +}