diff --git a/CHANGELOG.md b/CHANGELOG.md index 86919568..a652c26a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Note:** `Clone` is excluded from blanket skipping and can only be used with selective skipping to avoid this being a breaking change. +### Fixed +- Support skipping only some variants with `Zeroize`. + ## [1.3.0] - 2025-04-21 ### Added diff --git a/src/test/zeroize.rs b/src/test/zeroize.rs index b2efd7f9..8ba1afd2 100644 --- a/src/test/zeroize.rs +++ b/src/test/zeroize.rs @@ -233,6 +233,35 @@ fn fqs() -> Result<()> { #[test] fn enum_skip() -> Result<()> { + test_derive( + quote! { + #[derive_where(Zeroize)] + enum Test { + A(std::marker::PhantomData), + #[derive_where(skip_inner(Zeroize))] + B(std::marker::PhantomData), + } + }, + quote! { + #[automatically_derived] + impl ::zeroize::Zeroize for Test { + fn zeroize(&mut self) { + use ::zeroize::Zeroize; + + match self { + Test::A(ref mut __field_0) => { + __field_0.zeroize(); + } + Test::B(ref mut __field_0) => { } + } + } + } + }, + ) +} + +#[test] +fn enum_skip_drop() -> Result<()> { test_derive( quote! { #[derive_where(ZeroizeOnDrop)] diff --git a/src/trait_/zeroize.rs b/src/trait_/zeroize.rs index dec02c31..eaa63eee 100644 --- a/src/trait_/zeroize.rs +++ b/src/trait_/zeroize.rs @@ -112,34 +112,30 @@ impl TraitImpl for Zeroize { trait_: &DeriveTrait, data: &Data, ) -> TokenStream { - if data.is_empty(**trait_) { - TokenStream::new() - } else { - match data.simple_type() { - SimpleType::Struct(fields) | SimpleType::Tuple(fields) => { - let trait_path = trait_.path(); - let self_pattern = fields.self_pattern_mut(); - - let body = data - .iter_fields(**trait_) - .zip(data.iter_self_ident(**trait_)) - .map(|(field, self_ident)| { - if field.attr.zeroize_fqs.0 { - quote! { #trait_path::zeroize(#self_ident); } - } else { - quote! { #self_ident.zeroize(); } - } - }); - - quote! { - #self_pattern => { - #(#body)* + match data.simple_type() { + SimpleType::Struct(fields) | SimpleType::Tuple(fields) => { + let trait_path = trait_.path(); + let self_pattern = fields.self_pattern_mut(); + + let body = data + .iter_fields(**trait_) + .zip(data.iter_self_ident(**trait_)) + .map(|(field, self_ident)| { + if field.attr.zeroize_fqs.0 { + quote! { #trait_path::zeroize(#self_ident); } + } else { + quote! { #self_ident.zeroize(); } } + }); + + quote! { + #self_pattern => { + #(#body)* } } - SimpleType::Unit(_) => TokenStream::new(), - SimpleType::Union => unreachable!("unexpected trait for union"), } + SimpleType::Unit(_) => TokenStream::new(), + SimpleType::Union => unreachable!("unexpected trait for union"), } } }