Skip to content

Commit 18a26aa

Browse files
committed
add parsing of enums and derive macro
1 parent aa380a3 commit 18a26aa

File tree

4 files changed

+339
-26
lines changed

4 files changed

+339
-26
lines changed

derive/src/difference.rs

Lines changed: 281 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::collections::HashSet;
33
use alloc::format;
44
use alloc::string::String;
55

6-
use crate::parse::{Category, ConstValType, Generic, Struct, Type};
6+
use crate::parse::{Category, ConstValType, Generic, Struct, Type, Enum};
77
use crate::shared::{attrs_collection_type, attrs_recurse, attrs_skip};
88

99
use proc_macro::TokenStream;
@@ -521,3 +521,283 @@ pub(crate) fn derive_struct_diff_struct(struct_: &Struct) -> TokenStream {
521521
.parse()
522522
.unwrap()
523523
}
524+
525+
pub(crate) fn derive_struct_diff_enum(enum_: &Enum) -> TokenStream {
526+
let derives: String = vec![
527+
#[cfg(feature = "debug_diffs")]
528+
"core::fmt::Debug",
529+
"Clone",
530+
#[cfg(feature = "nanoserde")]
531+
"nanoserde::SerBin",
532+
#[cfg(feature = "nanoserde")]
533+
"nanoserde::DeBin",
534+
#[cfg(feature = "serde")]
535+
"serde::Serialize",
536+
#[cfg(feature = "serde")]
537+
"serde::Deserialize",
538+
]
539+
.join(", ");
540+
541+
let mut replace_enum_body = String::new();
542+
#[cfg(unused)]
543+
let mut diff_enum_body = String::new();
544+
let mut diff_body = String::new();
545+
let mut apply_single_body = String::new();
546+
#[allow(unused_mut)]
547+
let mut type_aliases = String::new();
548+
let mut used_generics: Vec<&Generic> = Vec::new();
549+
550+
let enum_name = String::from("__".to_owned() + enum_.name.as_str() + "StructDiffEnum");
551+
let struct_generics_names_hash: HashSet<String> =
552+
enum_.generics.iter().map(|x| x.full()).collect();
553+
554+
if enum_
555+
.variants
556+
.iter()
557+
.any(|x| attrs_skip(&x.attributes)) {
558+
panic!("Enum variants may not be skipped");
559+
};
560+
561+
enum_
562+
.variants
563+
.iter()
564+
.enumerate()
565+
.for_each(|(_, field)| {
566+
let field_name = &field.name;
567+
if let Some(ty) = &field.ty {
568+
used_generics.extend(enum_.generics.iter().filter(|x| x.full() == ty.ident.path(&ty, false)));
569+
570+
571+
let to_add = enum_.generics.iter().filter(|x| ty.wraps().iter().find(|&wrapped_type| &x.full() == wrapped_type ).is_some());
572+
used_generics.extend(to_add);
573+
574+
used_generics.extend(get_used_lifetimes(&ty).into_iter().filter_map(|x| match struct_generics_names_hash.contains(&x) {
575+
true => Some(enum_.generics.iter().find(|generic| generic.full() == x ).unwrap()),
576+
false => None,
577+
}));
578+
579+
for val in get_array_lens(&ty) {
580+
if let Some(const_gen) = enum_.generics.iter().find(|x| x.full() == val) {
581+
used_generics.push(const_gen)
582+
}
583+
}
584+
}
585+
if let Some(ty) = &field.ty {
586+
match (attrs_recurse(&field.attributes), attrs_collection_type(&field.attributes), ty.base() == "Option") {
587+
588+
(false, None, false) => { // The default case
589+
l!(replace_enum_body, " {}({}),", field_name, ty.full());
590+
591+
if matches!(ty.ident, Category::UnNamed) {
592+
l!(
593+
apply_single_body,
594+
"variant @ Self::{}{{..}} => *self = variant,",
595+
field_name
596+
);
597+
598+
l!(
599+
diff_body,
600+
"variant @ Self::{}{{..}} => Self::Diff::Replace(variant),",
601+
field_name
602+
);
603+
} else {
604+
l!(
605+
apply_single_body,
606+
"variant @ Self::{}(..) => *self = variant,",
607+
field_name
608+
);
609+
610+
l!(
611+
diff_body,
612+
"variant @ Self::{}(..) => Self::Diff::Replace(variant),",
613+
field_name
614+
);
615+
}
616+
617+
618+
},
619+
#[allow(unreachable_patterns)]
620+
_ => panic!("this combination of options is not yet supported, please file an issue")
621+
}
622+
} else { //empty variant
623+
l!(replace_enum_body, " {},", field_name);
624+
625+
l!(
626+
apply_single_body,
627+
"variant @ Self::{}(..) => *self = variant,",
628+
field_name
629+
);
630+
631+
l!(
632+
diff_body,
633+
"variant @ Self::{}(..) => Self::Diff::Replace(variant),",
634+
field_name
635+
);
636+
};
637+
638+
639+
});
640+
641+
#[allow(unused)]
642+
let nanoserde_hack = String::new();
643+
#[cfg(feature = "nanoserde")]
644+
let nanoserde_hack = String::from("\nuse nanoserde::*;");
645+
646+
#[cfg(unused)]
647+
let used_generics = {
648+
let mut added: HashSet<String> = HashSet::new();
649+
let mut ret = Vec::new();
650+
for maybe_used in enum_.generics.iter() {
651+
if added.insert(maybe_used.full()) && used_generics.contains(&maybe_used) {
652+
ret.push(maybe_used.clone())
653+
}
654+
}
655+
656+
ret
657+
};
658+
659+
#[inline(always)]
660+
fn get_used_generic_bounds() -> &'static [&'static str] {
661+
BOUNDS
662+
}
663+
664+
#[cfg(feature = "serde")]
665+
let serde_bound = {
666+
let start = "\n#[serde(bound = \"";
667+
let mid = used_generics
668+
.iter()
669+
.filter(|gen| !matches!(gen, Generic::Lifetime { .. } | Generic::ConstGeneric { .. }))
670+
.map(|x| {
671+
format!(
672+
"{}: serde::Serialize + serde::de::DeserializeOwned",
673+
x.ident_only()
674+
)
675+
})
676+
.collect::<Vec<_>>()
677+
.join(", ");
678+
let end = "\")]";
679+
vec![start, &mid, end].join("")
680+
};
681+
#[cfg(not(feature = "serde"))]
682+
let serde_bound = "";
683+
684+
format!(
685+
"const _: () = {{
686+
use structdiff::collections::*;
687+
{type_aliases}
688+
{nanoserde_hack}
689+
690+
/// Generated type from StructDiff
691+
#[derive({derives})]{serde_bounds}
692+
pub enum {enum_name}{enum_def_generics}
693+
where
694+
{enum_where_bounds}
695+
{{
696+
Replace({struct_name}{struct_generics})
697+
}}
698+
699+
impl{impl_generics} StructDiff for {struct_name}{struct_generics}
700+
where
701+
{struct_where_bounds}
702+
{{
703+
type Diff = {enum_name}{enum_impl_generics};
704+
705+
fn diff(&self, updated: &Self) -> Vec<Self::Diff> {{
706+
if self == updated {{
707+
vec![]
708+
}} else {{
709+
vec![match updated.clone() {{
710+
{diff_body}
711+
}}]
712+
}}
713+
}}
714+
715+
#[inline(always)]
716+
fn apply_single(&mut self, diff: Self::Diff) {{
717+
match diff {{
718+
Self::Diff::Replace(diff) => match diff {{
719+
{apply_single_body}
720+
}}
721+
}}
722+
}}
723+
}}
724+
}};",
725+
type_aliases = type_aliases,
726+
nanoserde_hack = nanoserde_hack,
727+
derives = derives,
728+
struct_name = enum_.name,
729+
diff_body = diff_body,
730+
enum_name = enum_name,
731+
apply_single_body = apply_single_body,
732+
enum_def_generics = format!(
733+
"<{}>",
734+
enum_
735+
.generics
736+
.iter()
737+
.filter(|gen| !matches!(gen, Generic::WhereBounded { .. }))
738+
.map(|gen| Generic::ident_with_const(gen))
739+
.collect::<Vec<_>>()
740+
.join(", ")
741+
),
742+
enum_where_bounds = format!(
743+
"{}",
744+
enum_
745+
.generics
746+
.iter()
747+
.filter(|gen| !matches!(
748+
gen,
749+
Generic::ConstGeneric { .. }
750+
))
751+
.map(|gen| Generic::full_with_const(gen, get_used_generic_bounds(), true))
752+
.collect::<Vec<_>>()
753+
.join(",\n")
754+
),
755+
impl_generics = format!(
756+
"<{}>",
757+
enum_
758+
.generics
759+
.iter()
760+
.filter(|gen| !matches!(gen, Generic::WhereBounded { .. }))
761+
.map(|gen| Generic::ident_with_const(gen))
762+
.collect::<Vec<_>>()
763+
.join(", ")
764+
),
765+
struct_generics = format!(
766+
"<{}>",
767+
enum_
768+
.generics
769+
.iter()
770+
.filter(|gen| !matches!(gen, Generic::WhereBounded { .. }))
771+
.map(|gen| Generic::ident_only(gen))
772+
.collect::<Vec<_>>()
773+
.join(", ")
774+
),
775+
struct_where_bounds = format!(
776+
"{}",
777+
enum_
778+
.generics
779+
.iter()
780+
.filter(|gen| !matches!(gen, Generic::ConstGeneric { .. } | Generic::WhereBounded { .. }))
781+
.map(|gen| Generic::full_with_const(gen, get_used_generic_bounds(), true))
782+
.collect::<Vec<_>>().into_iter().chain(enum_
783+
.generics
784+
.iter()
785+
.filter(|gen| matches!(gen, Generic::WhereBounded { .. }))
786+
.map(|gen| Generic::full_with_const(gen, &[], true)).collect::<Vec<_>>().into_iter()).collect::<Vec<_>>()
787+
.join(",\n")
788+
),
789+
enum_impl_generics = format!(
790+
"<{}>",
791+
enum_
792+
.generics
793+
.iter()
794+
.filter(|gen| !matches!(gen, Generic::WhereBounded { .. }))
795+
.map(|gen| Generic::ident_only(gen))
796+
.collect::<Vec<_>>()
797+
.join(", ")
798+
),
799+
serde_bounds = serde_bound
800+
)
801+
.parse()
802+
.unwrap()
803+
}

derive/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ extern crate proc_macro;
55
mod shared;
66

77
mod difference;
8+
use difference::derive_struct_diff_enum;
9+
810
use crate::difference::derive_struct_diff_struct;
911

1012
mod parse;
@@ -17,6 +19,7 @@ pub fn derive_struct_diff(input: proc_macro::TokenStream) -> proc_macro::TokenSt
1719
// ok we have an ident, hopefully it's a struct
1820
let ts = match &input {
1921
parse::Data::Struct(struct_) if struct_.named => derive_struct_diff_struct(struct_),
22+
parse::Data::Enum(enum_) => derive_struct_diff_enum(enum_),
2023
_ => unimplemented!("Only structs are supported"),
2124
};
2225

derive/src/parse.rs

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ pub enum Category {
7575
args: Option<Box<Type>>,
7676
return_type: Option<Box<Type>>,
7777
},
78-
#[allow(unused)]
7978
UnNamed,
8079
Object {
8180
is_dyn: bool,
@@ -135,8 +134,7 @@ pub struct Struct {
135134
#[derive(Debug)]
136135
pub struct EnumVariant {
137136
pub name: String,
138-
pub named: bool,
139-
pub fields: Vec<Field>,
137+
pub ty: Option<Type>,
140138
pub attributes: Vec<Attribute>,
141139
}
142140

@@ -944,12 +942,21 @@ fn next_type<T: Iterator<Item = TokenTree> + Clone>(mut source: &mut Peekable<T>
944942
})
945943
} else {
946944
let as_other = as_other_type(source).map(Box::new);
947-
Some(Type {
948-
ident: Category::Named { path: ty },
949-
wraps: None,
950-
ref_type,
951-
as_other,
952-
})
945+
if ty.is_empty() {
946+
Some(Type {
947+
ident: Category::UnNamed,
948+
wraps: None,
949+
ref_type,
950+
as_other,
951+
})
952+
} else {
953+
Some(Type {
954+
ident: Category::Named { path: ty },
955+
wraps: None,
956+
ref_type,
957+
as_other,
958+
})
959+
}
953960
}
954961
}
955962

@@ -1154,32 +1161,21 @@ fn next_enum<T: Iterator<Item = TokenTree> + Clone>(mut source: &mut Peekable<T>
11541161
let attributes = next_attributes_list(&mut body);
11551162

11561163
let variant_name = next_ident(&mut body).expect("Unnamed variants are not supported");
1157-
let group = next_group(&mut body);
1158-
if group.is_none() {
1164+
let ty = next_type(&mut body);
1165+
let Some(ty) = ty else {
11591166
variants.push(EnumVariant {
11601167
name: variant_name,
1161-
named: false,
1162-
fields: vec![],
1168+
ty: None,
11631169
attributes,
11641170
});
11651171
let _maybe_comma = next_exact_punct(&mut body, ",");
11661172
continue;
1167-
}
1168-
let group = group.unwrap();
1169-
let delimiter = group.delimiter();
1170-
let named = match delimiter {
1171-
Delimiter::Parenthesis => false,
1172-
Delimiter::Brace => true,
1173-
1174-
_ => panic!("Enum with unsupported delimiter"),
11751173
};
1174+
11761175
{
1177-
let mut body = group.stream().into_iter().peekable();
1178-
let fields = next_fields(&mut body, named);
11791176
variants.push(EnumVariant {
11801177
name: variant_name,
1181-
named,
1182-
fields,
1178+
ty: Some(ty),
11831179
attributes,
11841180
});
11851181
}

0 commit comments

Comments
 (0)