Skip to content

Commit 654128f

Browse files
committed
feat: add ZeroCopyMut enum support
1 parent 96c2faa commit 654128f

14 files changed

+575
-147
lines changed

program-libs/zero-copy-derive/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Procedural macros for borsh compatible zero copy serialization.
4848
- **Empty structs**: Not supported - structs must have at least one field for zero-copy serialization
4949
- **Enum support**:
5050
- `ZeroCopy` supports enums with unit variants or single unnamed field variants
51-
- `ZeroCopyMut` does NOT support enums (structs only)
51+
- `ZeroCopyMut` supports enums with unit variants or single unnamed field variants
5252
- `ZeroCopyEq` does NOT support enums (structs only)
5353

5454
### Special Type Handling

program-libs/zero-copy-derive/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
//! - **Empty structs**: Not supported - structs must have at least one field for zero-copy serialization
4949
//! - **Enum support**:
5050
//! - `ZeroCopy` supports enums with unit variants or single unnamed field variants
51-
//! - `ZeroCopyMut` does NOT support enums
51+
//! - `ZeroCopyMut` supports enums with unit variants or single unnamed field variants
5252
//! - `ZeroCopyEq` does NOT support enums
5353
//! - `ZeroCopyEq` does NOT support enums, vectors, arrays)
5454
//!

program-libs/zero-copy-derive/src/shared/z_enum.rs

Lines changed: 103 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@ use quote::{format_ident, quote};
33
use syn::{DataEnum, Fields, Ident};
44

55
/// Generate the zero-copy enum definition with type aliases for pattern matching
6-
pub fn generate_z_enum(z_enum_name: &Ident, enum_data: &DataEnum) -> syn::Result<TokenStream> {
6+
/// The `MUT` parameter controls whether to generate mutable or immutable variants
7+
pub fn generate_z_enum<const MUT: bool>(
8+
z_enum_name: &Ident,
9+
enum_data: &DataEnum,
10+
) -> syn::Result<TokenStream> {
11+
// Add Mut suffix when MUT is true
12+
let z_enum_name = if MUT {
13+
format_ident!("{}Mut", z_enum_name)
14+
} else {
15+
z_enum_name.clone()
16+
};
17+
718
// Collect type aliases for complex variants
819
let mut type_aliases = Vec::new();
920
let mut has_lifetime_dependent_variants = false;
@@ -28,9 +39,21 @@ pub fn generate_z_enum(z_enum_name: &Ident, enum_data: &DataEnum) -> syn::Result
2839
has_lifetime_dependent_variants = true;
2940

3041
// Create a type alias for this variant to enable pattern matching
31-
let alias_name = format_ident!("{}Type", variant_name);
32-
type_aliases.push(quote! {
33-
pub type #alias_name<'a> = <#field_type as ::light_zero_copy::traits::ZeroCopyAt<'a>>::ZeroCopyAt;
42+
let alias_name = if MUT {
43+
format_ident!("{}TypeMut", variant_name)
44+
} else {
45+
format_ident!("{}Type", variant_name)
46+
};
47+
48+
// Generate appropriate type based on MUT
49+
type_aliases.push(if MUT {
50+
quote! {
51+
pub type #alias_name<'a> = <#field_type as ::light_zero_copy::traits::ZeroCopyAtMut<'a>>::ZeroCopyAtMut;
52+
}
53+
} else {
54+
quote! {
55+
pub type #alias_name<'a> = <#field_type as ::light_zero_copy::traits::ZeroCopyAt<'a>>::ZeroCopyAt;
56+
}
3457
});
3558

3659
Ok(quote! { #variant_name(#alias_name<'a>) })
@@ -58,17 +81,24 @@ pub fn generate_z_enum(z_enum_name: &Ident, enum_data: &DataEnum) -> syn::Result
5881
}
5982
}).collect::<Result<Vec<_>, _>>()?;
6083

84+
// For mutable enums, we don't derive Clone (can't clone mutable references)
85+
let derive_attrs = if MUT {
86+
quote! { #[derive(Debug, PartialEq)] }
87+
} else {
88+
quote! { #[derive(Debug, Clone, PartialEq)] }
89+
};
90+
6191
// Conditionally add lifetime parameter only if needed
6292
let enum_declaration = if has_lifetime_dependent_variants {
6393
quote! {
64-
#[derive(Debug, Clone, PartialEq)]
94+
#derive_attrs
6595
pub enum #z_enum_name<'a> {
6696
#(#variants,)*
6797
}
6898
}
6999
} else {
70100
quote! {
71-
#[derive(Debug, Clone, PartialEq)]
101+
#derive_attrs
72102
pub enum #z_enum_name {
73103
#(#variants,)*
74104
}
@@ -84,11 +114,36 @@ pub fn generate_z_enum(z_enum_name: &Ident, enum_data: &DataEnum) -> syn::Result
84114
}
85115

86116
/// Generate the deserialize implementation for the enum
87-
pub fn generate_enum_deserialize_impl(
117+
/// The `MUT` parameter controls whether to generate mutable or immutable deserialization
118+
pub fn generate_enum_deserialize_impl<const MUT: bool>(
88119
original_name: &Ident,
89120
z_enum_name: &Ident,
90121
enum_data: &DataEnum,
91122
) -> syn::Result<TokenStream> {
123+
// Add Mut suffix when MUT is true
124+
let z_enum_name = if MUT {
125+
format_ident!("{}Mut", z_enum_name)
126+
} else {
127+
z_enum_name.clone()
128+
};
129+
130+
// Choose trait and method based on MUT
131+
let (trait_name, mutability, method_name, associated_type) = if MUT {
132+
(
133+
quote!(::light_zero_copy::traits::ZeroCopyAtMut),
134+
quote!(mut),
135+
quote!(zero_copy_at_mut),
136+
quote!(ZeroCopyAtMut),
137+
)
138+
} else {
139+
(
140+
quote!(::light_zero_copy::traits::ZeroCopyAt),
141+
quote!(),
142+
quote!(zero_copy_at),
143+
quote!(ZeroCopyAt),
144+
)
145+
};
146+
92147
// Check if any variants need lifetime parameters
93148
let mut has_lifetime_dependent_variants = false;
94149

@@ -120,10 +175,21 @@ pub fn generate_enum_deserialize_impl(
120175
"Internal error: expected exactly one unnamed field but found none"
121176
))?
122177
.ty;
178+
179+
// Use appropriate trait method based on MUT
180+
let deserialize_call = if MUT {
181+
quote! {
182+
<#field_type as ::light_zero_copy::traits::ZeroCopyAtMut>::zero_copy_at_mut(remaining_data)?
183+
}
184+
} else {
185+
quote! {
186+
<#field_type as ::light_zero_copy::traits::ZeroCopyAt>::zero_copy_at(remaining_data)?
187+
}
188+
};
189+
123190
Ok(quote! {
124191
#discriminant => {
125-
let (value, remaining_bytes) =
126-
<#field_type as ::light_zero_copy::traits::ZeroCopyAt>::zero_copy_at(remaining_data)?;
192+
let (value, remaining_bytes) = #deserialize_call;
127193
Ok((#z_enum_name::#variant_name(value), remaining_bytes))
128194
}
129195
})
@@ -148,13 +214,14 @@ pub fn generate_enum_deserialize_impl(
148214
};
149215

150216
Ok(quote! {
151-
impl<'a> ::light_zero_copy::traits::ZeroCopyAt<'a> for #original_name {
152-
type ZeroCopyAt = #type_annotation;
217+
impl<'a> #trait_name<'a> for #original_name {
218+
type #associated_type = #type_annotation;
153219

154-
fn zero_copy_at(
155-
data: &'a [u8],
156-
) -> Result<(Self::ZeroCopyAt, &'a [u8]), ::light_zero_copy::errors::ZeroCopyError> {
220+
fn #method_name(
221+
data: &'a #mutability [u8],
222+
) -> Result<(Self::#associated_type, &'a #mutability [u8]), ::light_zero_copy::errors::ZeroCopyError> {
157223
// Read discriminant (first 1 byte for borsh enum)
224+
// Note: Discriminant is ALWAYS immutable for safety, even in mutable deserialization
158225
if data.is_empty() {
159226
return Err(::light_zero_copy::errors::ZeroCopyError::ArraySize(
160227
1,
@@ -163,7 +230,7 @@ pub fn generate_enum_deserialize_impl(
163230
}
164231

165232
let discriminant = data[0];
166-
let remaining_data = &data[1..];
233+
let remaining_data = &#mutability data[1..];
167234

168235
match discriminant {
169236
#(#match_arms)*
@@ -175,11 +242,19 @@ pub fn generate_enum_deserialize_impl(
175242
}
176243

177244
/// Generate the ZeroCopyStructInner implementation for the enum
178-
pub fn generate_enum_zero_copy_struct_inner(
245+
/// The `MUT` parameter controls whether to generate mutable or immutable struct inner trait
246+
pub fn generate_enum_zero_copy_struct_inner<const MUT: bool>(
179247
original_name: &Ident,
180248
z_enum_name: &Ident,
181249
enum_data: &DataEnum,
182250
) -> syn::Result<TokenStream> {
251+
// Add Mut suffix when MUT is true
252+
let z_enum_name = if MUT {
253+
format_ident!("{}Mut", z_enum_name)
254+
} else {
255+
z_enum_name.clone()
256+
};
257+
183258
// Check if any variants need lifetime parameters
184259
let has_lifetime_dependent_variants = enum_data.variants.iter().any(
185260
|variant| matches!(&variant.fields, Fields::Unnamed(fields) if fields.unnamed.len() == 1),
@@ -192,9 +267,18 @@ pub fn generate_enum_zero_copy_struct_inner(
192267
quote! { #z_enum_name }
193268
};
194269

195-
Ok(quote! {
196-
impl ::light_zero_copy::traits::ZeroCopyStructInner for #original_name {
197-
type ZeroCopyInner = #type_annotation;
270+
// Generate appropriate trait impl based on MUT
271+
Ok(if MUT {
272+
quote! {
273+
impl ::light_zero_copy::traits::ZeroCopyStructInnerMut for #original_name {
274+
type ZeroCopyInnerMut = #type_annotation;
275+
}
276+
}
277+
} else {
278+
quote! {
279+
impl ::light_zero_copy::traits::ZeroCopyStructInner for #original_name {
280+
type ZeroCopyInner = #type_annotation;
281+
}
198282
}
199283
})
200284
}

program-libs/zero-copy-derive/src/shared/zero_copy_new.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,132 @@ pub fn generate_byte_len_calculation(field_type: &FieldType) -> syn::Result<Toke
477477
};
478478
Ok(result)
479479
}
480+
481+
/// Generate ZeroCopyNew for enums with fixed variant selection
482+
pub fn generate_enum_zero_copy_new(
483+
enum_name: &syn::Ident,
484+
enum_data: &syn::DataEnum,
485+
) -> syn::Result<proc_macro2::TokenStream> {
486+
let config_name = quote::format_ident!("{}Config", enum_name);
487+
let z_enum_mut_name = quote::format_ident!("Z{}Mut", enum_name);
488+
489+
// Generate config enum that mirrors the original enum structure
490+
let config_variants = enum_data.variants.iter().map(|variant| {
491+
let variant_name = &variant.ident;
492+
match &variant.fields {
493+
syn::Fields::Unit => Ok(quote! { #variant_name }),
494+
syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
495+
let field_type = &fields.unnamed.first()
496+
.ok_or_else(|| syn::Error::new_spanned(
497+
fields,
498+
"Internal error: expected exactly one unnamed field but found none"
499+
))?
500+
.ty;
501+
// Get the config type for the inner field
502+
Ok(quote! {
503+
#variant_name(<#field_type as ::light_zero_copy::traits::ZeroCopyNew<'static>>::ZeroCopyConfig)
504+
})
505+
}
506+
_ => Err(syn::Error::new_spanned(variant, "Unsupported enum variant format for ZeroCopyMut")),
507+
}
508+
}).collect::<Result<Vec<_>, _>>()?;
509+
510+
// Generate byte_len match arms
511+
let byte_len_arms = enum_data.variants.iter().map(|variant| {
512+
let variant_name = &variant.ident;
513+
let discriminant_size = 1usize; // Always 1 byte for borsh
514+
515+
match &variant.fields {
516+
syn::Fields::Unit => {
517+
Ok(quote! {
518+
#config_name::#variant_name => Ok(#discriminant_size)
519+
})
520+
}
521+
syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
522+
let field_type = &fields.unnamed.first()
523+
.ok_or_else(|| syn::Error::new_spanned(
524+
fields,
525+
"Internal error: expected exactly one unnamed field but found none"
526+
))?
527+
.ty;
528+
Ok(quote! {
529+
#config_name::#variant_name(ref config) => {
530+
<#field_type as ::light_zero_copy::traits::ZeroCopyNew>::byte_len(config)
531+
.map(|len| #discriminant_size + len)
532+
}
533+
})
534+
}
535+
_ => Err(syn::Error::new_spanned(variant, "Unsupported enum variant format for ZeroCopyMut")),
536+
}
537+
}).collect::<Result<Vec<_>, _>>()?;
538+
539+
// Generate new_zero_copy match arms
540+
let new_arms = enum_data.variants.iter().enumerate().map(|(idx, variant)| {
541+
let variant_name = &variant.ident;
542+
let discriminant = idx as u8;
543+
544+
match &variant.fields {
545+
syn::Fields::Unit => {
546+
Ok(quote! {
547+
#config_name::#variant_name => {
548+
bytes[0] = #discriminant;
549+
let remaining = &mut bytes[1..];
550+
Ok((#z_enum_mut_name::#variant_name, remaining))
551+
}
552+
})
553+
}
554+
syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
555+
let field_type = &fields.unnamed.first()
556+
.ok_or_else(|| syn::Error::new_spanned(
557+
fields,
558+
"Internal error: expected exactly one unnamed field but found none"
559+
))?
560+
.ty;
561+
Ok(quote! {
562+
#config_name::#variant_name(config) => {
563+
bytes[0] = #discriminant;
564+
let remaining = &mut bytes[1..];
565+
let (value, remaining) =
566+
<#field_type as ::light_zero_copy::traits::ZeroCopyNew>::new_zero_copy(
567+
remaining, config
568+
)?;
569+
Ok((#z_enum_mut_name::#variant_name(value), remaining))
570+
}
571+
})
572+
}
573+
_ => Err(syn::Error::new_spanned(variant, "Unsupported enum variant format for ZeroCopyMut")),
574+
}
575+
}).collect::<Result<Vec<_>, _>>()?;
576+
577+
Ok(quote! {
578+
// Config enum that specifies which variant to create
579+
#[derive(Debug, Clone)]
580+
pub enum #config_name {
581+
#(#config_variants,)*
582+
}
583+
584+
impl<'a> ::light_zero_copy::traits::ZeroCopyNew<'a> for #enum_name {
585+
type ZeroCopyConfig = #config_name;
586+
type Output = <Self as ::light_zero_copy::traits::ZeroCopyAtMut<'a>>::ZeroCopyAtMut;
587+
588+
fn byte_len(config: &Self::ZeroCopyConfig) -> Result<usize, ::light_zero_copy::errors::ZeroCopyError> {
589+
match config {
590+
#(#byte_len_arms,)*
591+
}
592+
}
593+
594+
fn new_zero_copy(
595+
bytes: &'a mut [u8],
596+
config: Self::ZeroCopyConfig,
597+
) -> Result<(Self::Output, &'a mut [u8]), ::light_zero_copy::errors::ZeroCopyError> {
598+
if bytes.is_empty() {
599+
return Err(::light_zero_copy::errors::ZeroCopyError::ArraySize(1, bytes.len()));
600+
}
601+
602+
match config {
603+
#(#new_arms,)*
604+
}
605+
}
606+
}
607+
})
608+
}

program-libs/zero-copy-derive/src/zero_copy.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,12 @@ pub fn derive_zero_copy_impl(input: ProcTokenStream) -> syn::Result<proc_macro2:
303303
utils::InputType::Enum(enum_data) => {
304304
let z_enum_name = z_name;
305305

306-
let z_enum_def = generate_z_enum(&z_enum_name, enum_data)?;
307-
let deserialize_impl = generate_enum_deserialize_impl(name, &z_enum_name, enum_data)?;
306+
// Use refactored const generic functions with MUT=false
307+
let z_enum_def = generate_z_enum::<false>(&z_enum_name, enum_data)?;
308+
let deserialize_impl =
309+
generate_enum_deserialize_impl::<false>(name, &z_enum_name, enum_data)?;
308310
let zero_copy_struct_inner_impl =
309-
generate_enum_zero_copy_struct_inner(name, &z_enum_name, enum_data)?;
311+
generate_enum_zero_copy_struct_inner::<false>(name, &z_enum_name, enum_data)?;
310312

311313
Ok(quote! {
312314
#z_enum_def

0 commit comments

Comments
 (0)