Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion program-libs/zero-copy-derive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Procedural macros for borsh compatible zero copy serialization.
- **Empty structs**: Not supported - structs must have at least one field for zero-copy serialization
- **Enum support**:
- `ZeroCopy` supports enums with unit variants or single unnamed field variants
- `ZeroCopyMut` does NOT support enums (structs only)
- `ZeroCopyMut` supports enums with unit variants or single unnamed field variants
- `ZeroCopyEq` does NOT support enums (structs only)

### Special Type Handling
Expand Down
2 changes: 1 addition & 1 deletion program-libs/zero-copy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
//! - **Empty structs**: Not supported - structs must have at least one field for zero-copy serialization
//! - **Enum support**:
//! - `ZeroCopy` supports enums with unit variants or single unnamed field variants
//! - `ZeroCopyMut` does NOT support enums
//! - `ZeroCopyMut` supports enums with unit variants or single unnamed field variants
//! - `ZeroCopyEq` does NOT support enums
//! - `ZeroCopyEq` does NOT support enums, vectors, arrays)
//!
Expand Down
122 changes: 103 additions & 19 deletions program-libs/zero-copy-derive/src/shared/z_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@ use quote::{format_ident, quote};
use syn::{DataEnum, Fields, Ident};

/// Generate the zero-copy enum definition with type aliases for pattern matching
pub fn generate_z_enum(z_enum_name: &Ident, enum_data: &DataEnum) -> syn::Result<TokenStream> {
/// The `MUT` parameter controls whether to generate mutable or immutable variants
pub fn generate_z_enum<const MUT: bool>(
z_enum_name: &Ident,
enum_data: &DataEnum,
) -> syn::Result<TokenStream> {
// Add Mut suffix when MUT is true
let z_enum_name = if MUT {
format_ident!("{}Mut", z_enum_name)
} else {
z_enum_name.clone()
};

// Collect type aliases for complex variants
let mut type_aliases = Vec::new();
let mut has_lifetime_dependent_variants = false;
Expand All @@ -28,9 +39,21 @@ pub fn generate_z_enum(z_enum_name: &Ident, enum_data: &DataEnum) -> syn::Result
has_lifetime_dependent_variants = true;

// Create a type alias for this variant to enable pattern matching
let alias_name = format_ident!("{}Type", variant_name);
type_aliases.push(quote! {
pub type #alias_name<'a> = <#field_type as ::light_zero_copy::traits::ZeroCopyAt<'a>>::ZeroCopyAt;
let alias_name = if MUT {
format_ident!("{}TypeMut", variant_name)
} else {
format_ident!("{}Type", variant_name)
};

// Generate appropriate type based on MUT
type_aliases.push(if MUT {
quote! {
pub type #alias_name<'a> = <#field_type as ::light_zero_copy::traits::ZeroCopyAtMut<'a>>::ZeroCopyAtMut;
}
} else {
quote! {
pub type #alias_name<'a> = <#field_type as ::light_zero_copy::traits::ZeroCopyAt<'a>>::ZeroCopyAt;
}
});

Ok(quote! { #variant_name(#alias_name<'a>) })
Expand Down Expand Up @@ -58,17 +81,24 @@ pub fn generate_z_enum(z_enum_name: &Ident, enum_data: &DataEnum) -> syn::Result
}
}).collect::<Result<Vec<_>, _>>()?;

// For mutable enums, we don't derive Clone (can't clone mutable references)
let derive_attrs = if MUT {
quote! { #[derive(Debug, PartialEq)] }
} else {
quote! { #[derive(Debug, Clone, PartialEq)] }
};

// Conditionally add lifetime parameter only if needed
let enum_declaration = if has_lifetime_dependent_variants {
quote! {
#[derive(Debug, Clone, PartialEq)]
#derive_attrs
pub enum #z_enum_name<'a> {
#(#variants,)*
}
}
} else {
quote! {
#[derive(Debug, Clone, PartialEq)]
#derive_attrs
pub enum #z_enum_name {
#(#variants,)*
}
Expand All @@ -84,11 +114,36 @@ pub fn generate_z_enum(z_enum_name: &Ident, enum_data: &DataEnum) -> syn::Result
}

/// Generate the deserialize implementation for the enum
pub fn generate_enum_deserialize_impl(
/// The `MUT` parameter controls whether to generate mutable or immutable deserialization
pub fn generate_enum_deserialize_impl<const MUT: bool>(
original_name: &Ident,
z_enum_name: &Ident,
enum_data: &DataEnum,
) -> syn::Result<TokenStream> {
// Add Mut suffix when MUT is true
let z_enum_name = if MUT {
format_ident!("{}Mut", z_enum_name)
} else {
z_enum_name.clone()
};

// Choose trait and method based on MUT
let (trait_name, mutability, method_name, associated_type) = if MUT {
(
quote!(::light_zero_copy::traits::ZeroCopyAtMut),
quote!(mut),
quote!(zero_copy_at_mut),
quote!(ZeroCopyAtMut),
)
} else {
(
quote!(::light_zero_copy::traits::ZeroCopyAt),
quote!(),
quote!(zero_copy_at),
quote!(ZeroCopyAt),
)
};

// Check if any variants need lifetime parameters
let mut has_lifetime_dependent_variants = false;

Expand Down Expand Up @@ -120,10 +175,21 @@ pub fn generate_enum_deserialize_impl(
"Internal error: expected exactly one unnamed field but found none"
))?
.ty;

// Use appropriate trait method based on MUT
let deserialize_call = if MUT {
quote! {
<#field_type as ::light_zero_copy::traits::ZeroCopyAtMut>::zero_copy_at_mut(remaining_data)?
}
} else {
quote! {
<#field_type as ::light_zero_copy::traits::ZeroCopyAt>::zero_copy_at(remaining_data)?
}
};

Ok(quote! {
#discriminant => {
let (value, remaining_bytes) =
<#field_type as ::light_zero_copy::traits::ZeroCopyAt>::zero_copy_at(remaining_data)?;
let (value, remaining_bytes) = #deserialize_call;
Ok((#z_enum_name::#variant_name(value), remaining_bytes))
}
})
Expand All @@ -148,13 +214,14 @@ pub fn generate_enum_deserialize_impl(
};

Ok(quote! {
impl<'a> ::light_zero_copy::traits::ZeroCopyAt<'a> for #original_name {
type ZeroCopyAt = #type_annotation;
impl<'a> #trait_name<'a> for #original_name {
type #associated_type = #type_annotation;

fn zero_copy_at(
data: &'a [u8],
) -> Result<(Self::ZeroCopyAt, &'a [u8]), ::light_zero_copy::errors::ZeroCopyError> {
fn #method_name(
data: &'a #mutability [u8],
) -> Result<(Self::#associated_type, &'a #mutability [u8]), ::light_zero_copy::errors::ZeroCopyError> {
// Read discriminant (first 1 byte for borsh enum)
// Note: Discriminant is ALWAYS immutable for safety, even in mutable deserialization
if data.is_empty() {
return Err(::light_zero_copy::errors::ZeroCopyError::ArraySize(
1,
Expand All @@ -163,7 +230,7 @@ pub fn generate_enum_deserialize_impl(
}

let discriminant = data[0];
let remaining_data = &data[1..];
let remaining_data = &#mutability data[1..];

match discriminant {
#(#match_arms)*
Expand All @@ -175,11 +242,19 @@ pub fn generate_enum_deserialize_impl(
}

/// Generate the ZeroCopyStructInner implementation for the enum
pub fn generate_enum_zero_copy_struct_inner(
/// The `MUT` parameter controls whether to generate mutable or immutable struct inner trait
pub fn generate_enum_zero_copy_struct_inner<const MUT: bool>(
original_name: &Ident,
z_enum_name: &Ident,
enum_data: &DataEnum,
) -> syn::Result<TokenStream> {
// Add Mut suffix when MUT is true
let z_enum_name = if MUT {
format_ident!("{}Mut", z_enum_name)
} else {
z_enum_name.clone()
};

// Check if any variants need lifetime parameters
let has_lifetime_dependent_variants = enum_data.variants.iter().any(
|variant| matches!(&variant.fields, Fields::Unnamed(fields) if fields.unnamed.len() == 1),
Expand All @@ -192,9 +267,18 @@ pub fn generate_enum_zero_copy_struct_inner(
quote! { #z_enum_name }
};

Ok(quote! {
impl ::light_zero_copy::traits::ZeroCopyStructInner for #original_name {
type ZeroCopyInner = #type_annotation;
// Generate appropriate trait impl based on MUT
Ok(if MUT {
quote! {
impl ::light_zero_copy::traits::ZeroCopyStructInnerMut for #original_name {
type ZeroCopyInnerMut = #type_annotation;
}
}
} else {
quote! {
impl ::light_zero_copy::traits::ZeroCopyStructInner for #original_name {
type ZeroCopyInner = #type_annotation;
}
}
})
}
129 changes: 129 additions & 0 deletions program-libs/zero-copy-derive/src/shared/zero_copy_new.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,132 @@ pub fn generate_byte_len_calculation(field_type: &FieldType) -> syn::Result<Toke
};
Ok(result)
}

/// Generate ZeroCopyNew for enums with fixed variant selection
pub fn generate_enum_zero_copy_new(
enum_name: &syn::Ident,
enum_data: &syn::DataEnum,
) -> syn::Result<proc_macro2::TokenStream> {
let config_name = quote::format_ident!("{}Config", enum_name);
let z_enum_mut_name = quote::format_ident!("Z{}Mut", enum_name);

Comment on lines +481 to +488
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard against enums with > 256 variants (u8 discriminant overflow).

The code writes the variant tag as a u8. If an enum has more than 256 variants, the tag would silently truncate. Add a compile-time (macro-time) check to fail derivation early.

Apply this diff:

 pub fn generate_enum_zero_copy_new(
     enum_name: &syn::Ident,
     enum_data: &syn::DataEnum,
 ) -> syn::Result<proc_macro2::TokenStream> {
+    // Borsh encodes enum variant index as a single u8; enforce this invariant at macro time.
+    let variant_count = enum_data.variants.len();
+    if variant_count > 256 {
+        return Err(syn::Error::new_spanned(
+            enum_name,
+            format!(
+                "ZeroCopyNew for enums supports at most 256 variants (found {}). Borsh encodes the tag as u8.",
+                variant_count
+            ),
+        ));
+    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/// Generate ZeroCopyNew for enums with fixed variant selection
pub fn generate_enum_zero_copy_new(
enum_name: &syn::Ident,
enum_data: &syn::DataEnum,
) -> syn::Result<proc_macro2::TokenStream> {
let config_name = quote::format_ident!("{}Config", enum_name);
let z_enum_mut_name = quote::format_ident!("Z{}Mut", enum_name);
/// Generate ZeroCopyNew for enums with fixed variant selection
pub fn generate_enum_zero_copy_new(
enum_name: &syn::Ident,
enum_data: &syn::DataEnum,
) -> syn::Result<proc_macro2::TokenStream> {
// Borsh encodes enum variant index as a single u8; enforce this invariant at macro time.
let variant_count = enum_data.variants.len();
if variant_count > 256 {
return Err(syn::Error::new_spanned(
enum_name,
format!(
"ZeroCopyNew for enums supports at most 256 variants (found {}). \
Borsh encodes the tag as u8.",
variant_count
),
));
}
let config_name = quote::format_ident!("{}Config", enum_name);
let z_enum_mut_name = quote::format_ident!("Z{}Mut", enum_name);
// …rest of the implementation…
}
🤖 Prompt for AI Agents
program-libs/zero-copy-derive/src/shared/zero_copy_new.rs around lines 481 to
488: The function generates a u8 discriminant for enum variants but doesn't
guard against enums with more than 256 variants which would silently overflow;
add a macro-time check at the start of generate_enum_zero_copy_new that inspects
enum_data.variants.len() and if it is greater than 256 returns
Err(syn::Error::new_spanned(enum_name, "enum has more than 256 variants; u8
discriminant would overflow")); this ensures derivation fails early with a clear
compile error.

// Generate config enum that mirrors the original enum structure
let config_variants = enum_data.variants.iter().map(|variant| {
let variant_name = &variant.ident;
match &variant.fields {
syn::Fields::Unit => Ok(quote! { #variant_name }),
syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
let field_type = &fields.unnamed.first()
.ok_or_else(|| syn::Error::new_spanned(
fields,
"Internal error: expected exactly one unnamed field but found none"
))?
.ty;
// Get the config type for the inner field
Ok(quote! {
#variant_name(<#field_type as ::light_zero_copy::traits::ZeroCopyNew<'static>>::ZeroCopyConfig)
})
}
_ => Err(syn::Error::new_spanned(variant, "Unsupported enum variant format for ZeroCopyMut")),
}
}).collect::<Result<Vec<_>, _>>()?;

// Generate byte_len match arms
let byte_len_arms = enum_data.variants.iter().map(|variant| {
let variant_name = &variant.ident;
let discriminant_size = 1usize; // Always 1 byte for borsh

match &variant.fields {
syn::Fields::Unit => {
Ok(quote! {
#config_name::#variant_name => Ok(#discriminant_size)
})
}
syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
let field_type = &fields.unnamed.first()
.ok_or_else(|| syn::Error::new_spanned(
fields,
"Internal error: expected exactly one unnamed field but found none"
))?
.ty;
Ok(quote! {
#config_name::#variant_name(ref config) => {
<#field_type as ::light_zero_copy::traits::ZeroCopyNew>::byte_len(config)
.map(|len| #discriminant_size + len)
}
})
}
_ => Err(syn::Error::new_spanned(variant, "Unsupported enum variant format for ZeroCopyMut")),
}
}).collect::<Result<Vec<_>, _>>()?;

// Generate new_zero_copy match arms
let new_arms = enum_data.variants.iter().enumerate().map(|(idx, variant)| {
let variant_name = &variant.ident;
let discriminant = idx as u8;

match &variant.fields {
syn::Fields::Unit => {
Ok(quote! {
#config_name::#variant_name => {
bytes[0] = #discriminant;
let remaining = &mut bytes[1..];
Ok((#z_enum_mut_name::#variant_name, remaining))
}
})
}
syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
let field_type = &fields.unnamed.first()
.ok_or_else(|| syn::Error::new_spanned(
fields,
"Internal error: expected exactly one unnamed field but found none"
))?
.ty;
Ok(quote! {
#config_name::#variant_name(config) => {
bytes[0] = #discriminant;
let remaining = &mut bytes[1..];
let (value, remaining) =
<#field_type as ::light_zero_copy::traits::ZeroCopyNew>::new_zero_copy(
remaining, config
)?;
Ok((#z_enum_mut_name::#variant_name(value), remaining))
}
})
}
_ => Err(syn::Error::new_spanned(variant, "Unsupported enum variant format for ZeroCopyMut")),
}
}).collect::<Result<Vec<_>, _>>()?;

Ok(quote! {
// Config enum that specifies which variant to create
#[derive(Debug, Clone)]
pub enum #config_name {
#(#config_variants,)*
}

impl<'a> ::light_zero_copy::traits::ZeroCopyNew<'a> for #enum_name {
type ZeroCopyConfig = #config_name;
type Output = <Self as ::light_zero_copy::traits::ZeroCopyAtMut<'a>>::ZeroCopyAtMut;

fn byte_len(config: &Self::ZeroCopyConfig) -> Result<usize, ::light_zero_copy::errors::ZeroCopyError> {
match config {
#(#byte_len_arms,)*
}
}

fn new_zero_copy(
bytes: &'a mut [u8],
config: Self::ZeroCopyConfig,
) -> Result<(Self::Output, &'a mut [u8]), ::light_zero_copy::errors::ZeroCopyError> {
if bytes.is_empty() {
return Err(::light_zero_copy::errors::ZeroCopyError::ArraySize(1, bytes.len()));
}

match config {
#(#new_arms,)*
}
}
}
})
}
8 changes: 5 additions & 3 deletions program-libs/zero-copy-derive/src/zero_copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,12 @@ pub fn derive_zero_copy_impl(input: ProcTokenStream) -> syn::Result<proc_macro2:
utils::InputType::Enum(enum_data) => {
let z_enum_name = z_name;

let z_enum_def = generate_z_enum(&z_enum_name, enum_data)?;
let deserialize_impl = generate_enum_deserialize_impl(name, &z_enum_name, enum_data)?;
// Use refactored const generic functions with MUT=false
let z_enum_def = generate_z_enum::<false>(&z_enum_name, enum_data)?;
let deserialize_impl =
generate_enum_deserialize_impl::<false>(name, &z_enum_name, enum_data)?;
let zero_copy_struct_inner_impl =
generate_enum_zero_copy_struct_inner(name, &z_enum_name, enum_data)?;
generate_enum_zero_copy_struct_inner::<false>(name, &z_enum_name, enum_data)?;

Ok(quote! {
#z_enum_def
Expand Down
Loading
Loading