@@ -3,7 +3,7 @@ use std::collections::HashSet;
33use alloc:: format;
44use alloc:: string:: String ;
55
6- use crate :: parse:: { Category , ConstValType , Generic , Struct , Type } ;
6+ use crate :: parse:: { Category , ConstValType , Generic , Struct , Type , Enum } ;
77use crate :: shared:: { attrs_collection_type, attrs_recurse, attrs_skip} ;
88
99use proc_macro:: TokenStream ;
@@ -521,3 +521,283 @@ pub(crate) fn derive_struct_diff_struct(struct_: &Struct) -> TokenStream {
521521 . parse ( )
522522 . unwrap ( )
523523}
524+
525+ pub ( crate ) fn derive_struct_diff_enum ( enum_ : & Enum ) -> TokenStream {
526+ let derives: String = vec ! [
527+ #[ cfg( feature = "debug_diffs" ) ]
528+ "core::fmt::Debug" ,
529+ "Clone" ,
530+ #[ cfg( feature = "nanoserde" ) ]
531+ "nanoserde::SerBin" ,
532+ #[ cfg( feature = "nanoserde" ) ]
533+ "nanoserde::DeBin" ,
534+ #[ cfg( feature = "serde" ) ]
535+ "serde::Serialize" ,
536+ #[ cfg( feature = "serde" ) ]
537+ "serde::Deserialize" ,
538+ ]
539+ . join ( ", " ) ;
540+
541+ let mut replace_enum_body = String :: new ( ) ;
542+ #[ cfg( unused) ]
543+ let mut diff_enum_body = String :: new ( ) ;
544+ let mut diff_body = String :: new ( ) ;
545+ let mut apply_single_body = String :: new ( ) ;
546+ #[ allow( unused_mut) ]
547+ let mut type_aliases = String :: new ( ) ;
548+ let mut used_generics: Vec < & Generic > = Vec :: new ( ) ;
549+
550+ let enum_name = String :: from ( "__" . to_owned ( ) + enum_. name . as_str ( ) + "StructDiffEnum" ) ;
551+ let struct_generics_names_hash: HashSet < String > =
552+ enum_. generics . iter ( ) . map ( |x| x. full ( ) ) . collect ( ) ;
553+
554+ if enum_
555+ . variants
556+ . iter ( )
557+ . any ( |x| attrs_skip ( & x. attributes ) ) {
558+ panic ! ( "Enum variants may not be skipped" ) ;
559+ } ;
560+
561+ enum_
562+ . variants
563+ . iter ( )
564+ . enumerate ( )
565+ . for_each ( |( _, field) | {
566+ let field_name = & field. name ;
567+ if let Some ( ty) = & field. ty {
568+ used_generics. extend ( enum_. generics . iter ( ) . filter ( |x| x. full ( ) == ty. ident . path ( & ty, false ) ) ) ;
569+
570+
571+ let to_add = enum_. generics . iter ( ) . filter ( |x| ty. wraps ( ) . iter ( ) . find ( |& wrapped_type| & x. full ( ) == wrapped_type ) . is_some ( ) ) ;
572+ used_generics. extend ( to_add) ;
573+
574+ used_generics. extend ( get_used_lifetimes ( & ty) . into_iter ( ) . filter_map ( |x| match struct_generics_names_hash. contains ( & x) {
575+ true => Some ( enum_. generics . iter ( ) . find ( |generic| generic. full ( ) == x ) . unwrap ( ) ) ,
576+ false => None ,
577+ } ) ) ;
578+
579+ for val in get_array_lens ( & ty) {
580+ if let Some ( const_gen) = enum_. generics . iter ( ) . find ( |x| x. full ( ) == val) {
581+ used_generics. push ( const_gen)
582+ }
583+ }
584+ }
585+ if let Some ( ty) = & field. ty {
586+ match ( attrs_recurse ( & field. attributes ) , attrs_collection_type ( & field. attributes ) , ty. base ( ) == "Option" ) {
587+
588+ ( false , None , false ) => { // The default case
589+ l ! ( replace_enum_body, " {}({})," , field_name, ty. full( ) ) ;
590+
591+ if matches ! ( ty. ident, Category :: UnNamed ) {
592+ l ! (
593+ apply_single_body,
594+ "variant @ Self::{}{{..}} => *self = variant," ,
595+ field_name
596+ ) ;
597+
598+ l ! (
599+ diff_body,
600+ "variant @ Self::{}{{..}} => Self::Diff::Replace(variant)," ,
601+ field_name
602+ ) ;
603+ } else {
604+ l ! (
605+ apply_single_body,
606+ "variant @ Self::{}(..) => *self = variant," ,
607+ field_name
608+ ) ;
609+
610+ l ! (
611+ diff_body,
612+ "variant @ Self::{}(..) => Self::Diff::Replace(variant)," ,
613+ field_name
614+ ) ;
615+ }
616+
617+
618+ } ,
619+ #[ allow( unreachable_patterns) ]
620+ _ => panic ! ( "this combination of options is not yet supported, please file an issue" )
621+ }
622+ } else { //empty variant
623+ l ! ( replace_enum_body, " {}," , field_name) ;
624+
625+ l ! (
626+ apply_single_body,
627+ "variant @ Self::{}(..) => *self = variant," ,
628+ field_name
629+ ) ;
630+
631+ l ! (
632+ diff_body,
633+ "variant @ Self::{}(..) => Self::Diff::Replace(variant)," ,
634+ field_name
635+ ) ;
636+ } ;
637+
638+
639+ } ) ;
640+
641+ #[ allow( unused) ]
642+ let nanoserde_hack = String :: new ( ) ;
643+ #[ cfg( feature = "nanoserde" ) ]
644+ let nanoserde_hack = String :: from ( "\n use nanoserde::*;" ) ;
645+
646+ #[ cfg( unused) ]
647+ let used_generics = {
648+ let mut added: HashSet < String > = HashSet :: new ( ) ;
649+ let mut ret = Vec :: new ( ) ;
650+ for maybe_used in enum_. generics . iter ( ) {
651+ if added. insert ( maybe_used. full ( ) ) && used_generics. contains ( & maybe_used) {
652+ ret. push ( maybe_used. clone ( ) )
653+ }
654+ }
655+
656+ ret
657+ } ;
658+
659+ #[ inline( always) ]
660+ fn get_used_generic_bounds ( ) -> & ' static [ & ' static str ] {
661+ BOUNDS
662+ }
663+
664+ #[ cfg( feature = "serde" ) ]
665+ let serde_bound = {
666+ let start = "\n #[serde(bound = \" " ;
667+ let mid = used_generics
668+ . iter ( )
669+ . filter ( |gen| !matches ! ( gen , Generic :: Lifetime { .. } | Generic :: ConstGeneric { .. } ) )
670+ . map ( |x| {
671+ format ! (
672+ "{}: serde::Serialize + serde::de::DeserializeOwned" ,
673+ x. ident_only( )
674+ )
675+ } )
676+ . collect :: < Vec < _ > > ( )
677+ . join ( ", " ) ;
678+ let end = "\" )]" ;
679+ vec ! [ start, & mid, end] . join ( "" )
680+ } ;
681+ #[ cfg( not( feature = "serde" ) ) ]
682+ let serde_bound = "" ;
683+
684+ format ! (
685+ "const _: () = {{
686+ use structdiff::collections::*;
687+ {type_aliases}
688+ {nanoserde_hack}
689+
690+ /// Generated type from StructDiff
691+ #[derive({derives})]{serde_bounds}
692+ pub enum {enum_name}{enum_def_generics}
693+ where
694+ {enum_where_bounds}
695+ {{
696+ Replace({struct_name}{struct_generics})
697+ }}
698+
699+ impl{impl_generics} StructDiff for {struct_name}{struct_generics}
700+ where
701+ {struct_where_bounds}
702+ {{
703+ type Diff = {enum_name}{enum_impl_generics};
704+
705+ fn diff(&self, updated: &Self) -> Vec<Self::Diff> {{
706+ if self == updated {{
707+ vec![]
708+ }} else {{
709+ vec![match updated.clone() {{
710+ {diff_body}
711+ }}]
712+ }}
713+ }}
714+
715+ #[inline(always)]
716+ fn apply_single(&mut self, diff: Self::Diff) {{
717+ match diff {{
718+ Self::Diff::Replace(diff) => match diff {{
719+ {apply_single_body}
720+ }}
721+ }}
722+ }}
723+ }}
724+ }};" ,
725+ type_aliases = type_aliases,
726+ nanoserde_hack = nanoserde_hack,
727+ derives = derives,
728+ struct_name = enum_. name,
729+ diff_body = diff_body,
730+ enum_name = enum_name,
731+ apply_single_body = apply_single_body,
732+ enum_def_generics = format!(
733+ "<{}>" ,
734+ enum_
735+ . generics
736+ . iter( )
737+ . filter( |gen | !matches!( gen , Generic :: WhereBounded { .. } ) )
738+ . map( |gen | Generic :: ident_with_const( gen ) )
739+ . collect:: <Vec <_>>( )
740+ . join( ", " )
741+ ) ,
742+ enum_where_bounds = format!(
743+ "{}" ,
744+ enum_
745+ . generics
746+ . iter( )
747+ . filter( |gen | !matches!(
748+ gen ,
749+ Generic :: ConstGeneric { .. }
750+ ) )
751+ . map( |gen | Generic :: full_with_const( gen , get_used_generic_bounds( ) , true ) )
752+ . collect:: <Vec <_>>( )
753+ . join( ",\n " )
754+ ) ,
755+ impl_generics = format!(
756+ "<{}>" ,
757+ enum_
758+ . generics
759+ . iter( )
760+ . filter( |gen | !matches!( gen , Generic :: WhereBounded { .. } ) )
761+ . map( |gen | Generic :: ident_with_const( gen ) )
762+ . collect:: <Vec <_>>( )
763+ . join( ", " )
764+ ) ,
765+ struct_generics = format!(
766+ "<{}>" ,
767+ enum_
768+ . generics
769+ . iter( )
770+ . filter( |gen | !matches!( gen , Generic :: WhereBounded { .. } ) )
771+ . map( |gen | Generic :: ident_only( gen ) )
772+ . collect:: <Vec <_>>( )
773+ . join( ", " )
774+ ) ,
775+ struct_where_bounds = format!(
776+ "{}" ,
777+ enum_
778+ . generics
779+ . iter( )
780+ . filter( |gen | !matches!( gen , Generic :: ConstGeneric { .. } | Generic :: WhereBounded { .. } ) )
781+ . map( |gen | Generic :: full_with_const( gen , get_used_generic_bounds( ) , true ) )
782+ . collect:: <Vec <_>>( ) . into_iter( ) . chain( enum_
783+ . generics
784+ . iter( )
785+ . filter( |gen | matches!( gen , Generic :: WhereBounded { .. } ) )
786+ . map( |gen | Generic :: full_with_const( gen , & [ ] , true ) ) . collect:: <Vec <_>>( ) . into_iter( ) ) . collect:: <Vec <_>>( )
787+ . join( ",\n " )
788+ ) ,
789+ enum_impl_generics = format!(
790+ "<{}>" ,
791+ enum_
792+ . generics
793+ . iter( )
794+ . filter( |gen | !matches!( gen , Generic :: WhereBounded { .. } ) )
795+ . map( |gen | Generic :: ident_only( gen ) )
796+ . collect:: <Vec <_>>( )
797+ . join( ", " )
798+ ) ,
799+ serde_bounds = serde_bound
800+ )
801+ . parse ( )
802+ . unwrap ( )
803+ }
0 commit comments