Skip to content

Commit aa112b4

Browse files
feat: #[derive(Decode)] changes to make generics more robust (#515)
* Derive Decode changes to make generics more robust This allows `#[derive(Decode)]` to be used in the following cases: ```rust enum MyType<T> { // has no T Variant1(String), // uses T directly Variant2(T), // uses T indirectly Variant3(Vec<T>), // indirect T on struct variant VariantStruct4 { val: Vec<T>, } } ``` Before, the `<T>` generic would be applied and expected to exist on ALL enum variants, causing compile errors when deriving. Tests updated. * Fixing formatting * Lintfix
1 parent e7b60d9 commit aa112b4

File tree

4 files changed

+252
-7
lines changed

4 files changed

+252
-7
lines changed

macros/macros_impl/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ keywords = ["asn1", "der", "ber", "cer", "per"]
99
repository.workspace = true
1010

1111
[dependencies]
12-
syn = { version = "2.0.79", features = ["extra-traits"] }
12+
syn = { version = "2.0.79", features = ["extra-traits", "visit"] }
1313
quote = "1.0.37"
1414
proc-macro2 = "1.0.88"
1515
itertools = "0.13"

macros/macros_impl/src/decode.rs

Lines changed: 236 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
use std::collections::HashSet;
2+
13
use quote::ToTokens;
2-
use syn::Fields;
4+
use syn::{
5+
visit::{self, Visit},
6+
Fields,
7+
};
38

49
use crate::{
510
config::{map_to_inner_type, Config, FieldConfig},
@@ -305,6 +310,7 @@ pub fn map_from_inner_type(
305310
let inner_name = quote::format_ident!("Inner{}", name);
306311
let crate_root = &config.crate_root;
307312
let outer_name = outer_name.unwrap_or(quote!(Self));
313+
let inner_generics = filter_generics_for_fields(generics, fields.raw());
308314

309315
let map_from_inner = fields.iter().enumerate().map(|(i, field)| {
310316
let name = field
@@ -314,7 +320,7 @@ pub fn map_from_inner_type(
314320
quote!(#name : inner.#name)
315321
});
316322

317-
let (_, ty_generics, _) = generics.split_for_impl();
323+
let (_, ty_generics, _) = inner_generics.split_for_impl();
318324
let decode_op = if is_explicit {
319325
quote!(decoder.decode_explicit_prefix::<#inner_name #ty_generics>(#tag)?)
320326
} else {
@@ -324,7 +330,7 @@ pub fn map_from_inner_type(
324330

325331
quote! {
326332
#[derive(#crate_root::AsnType, #crate_root::Decode, #crate_root::Encode)]
327-
struct #inner_name #generics #sanitized_fields #semi
333+
struct #inner_name #inner_generics #sanitized_fields #semi
328334

329335
let inner = #decode_op;
330336

@@ -345,6 +351,10 @@ impl UnsanitizedFields<'_> {
345351
self.0.iter()
346352
}
347353

354+
fn raw(&self) -> &syn::Fields {
355+
self.0
356+
}
357+
348358
fn sanitize(&self) -> proc_macro2::TokenStream {
349359
let fields = self.0.iter().map(|field| {
350360
let syn::Field {
@@ -372,3 +382,226 @@ impl UnsanitizedFields<'_> {
372382
}
373383
}
374384
}
385+
386+
#[derive(Default)]
387+
struct GenericUsage {
388+
type_params: HashSet<String>,
389+
lifetime_params: HashSet<String>,
390+
const_params: HashSet<String>,
391+
used_type_params: HashSet<String>,
392+
used_lifetime_params: HashSet<String>,
393+
used_const_params: HashSet<String>,
394+
}
395+
396+
impl GenericUsage {
397+
fn from_generics(generics: &syn::Generics) -> Self {
398+
Self {
399+
type_params: generics
400+
.type_params()
401+
.map(|param| param.ident.to_string())
402+
.collect(),
403+
lifetime_params: generics
404+
.lifetimes()
405+
.map(|param| param.lifetime.ident.to_string())
406+
.collect(),
407+
const_params: generics
408+
.const_params()
409+
.map(|param| param.ident.to_string())
410+
.collect(),
411+
..Self::default()
412+
}
413+
}
414+
415+
fn is_empty(&self) -> bool {
416+
self.used_type_params.is_empty()
417+
&& self.used_lifetime_params.is_empty()
418+
&& self.used_const_params.is_empty()
419+
}
420+
421+
fn intersects(&self, other: &Self) -> bool {
422+
self.used_type_params
423+
.iter()
424+
.any(|name| other.used_type_params.contains(name))
425+
|| self
426+
.used_lifetime_params
427+
.iter()
428+
.any(|name| other.used_lifetime_params.contains(name))
429+
|| self
430+
.used_const_params
431+
.iter()
432+
.any(|name| other.used_const_params.contains(name))
433+
}
434+
435+
fn is_subset_of(&self, other: &Self) -> bool {
436+
self.used_type_params
437+
.iter()
438+
.all(|name| other.used_type_params.contains(name))
439+
&& self
440+
.used_lifetime_params
441+
.iter()
442+
.all(|name| other.used_lifetime_params.contains(name))
443+
&& self
444+
.used_const_params
445+
.iter()
446+
.all(|name| other.used_const_params.contains(name))
447+
}
448+
449+
fn merge_from(&mut self, other: Self) -> bool {
450+
let before = self.used_type_params.len()
451+
+ self.used_lifetime_params.len()
452+
+ self.used_const_params.len();
453+
self.used_type_params.extend(other.used_type_params);
454+
self.used_lifetime_params.extend(other.used_lifetime_params);
455+
self.used_const_params.extend(other.used_const_params);
456+
let after = self.used_type_params.len()
457+
+ self.used_lifetime_params.len()
458+
+ self.used_const_params.len();
459+
before != after
460+
}
461+
}
462+
463+
impl<'ast> Visit<'ast> for GenericUsage {
464+
fn visit_type_path(&mut self, ty: &'ast syn::TypePath) {
465+
if ty.qself.is_none() {
466+
if let Some(first) = ty.path.segments.first() {
467+
let name = first.ident.to_string();
468+
if self.type_params.contains(&name) {
469+
self.used_type_params.insert(name.clone());
470+
}
471+
if self.const_params.contains(&name) {
472+
self.used_const_params.insert(name);
473+
}
474+
}
475+
}
476+
visit::visit_type_path(self, ty);
477+
}
478+
479+
fn visit_expr_path(&mut self, expr: &'ast syn::ExprPath) {
480+
if expr.qself.is_none() {
481+
if let Some(first) = expr.path.segments.first() {
482+
let name = first.ident.to_string();
483+
if self.const_params.contains(&name) {
484+
self.used_const_params.insert(name);
485+
}
486+
}
487+
}
488+
visit::visit_expr_path(self, expr);
489+
}
490+
491+
fn visit_lifetime(&mut self, lifetime: &'ast syn::Lifetime) {
492+
let name = lifetime.ident.to_string();
493+
if self.lifetime_params.contains(&name) {
494+
self.used_lifetime_params.insert(name);
495+
}
496+
visit::visit_lifetime(self, lifetime);
497+
}
498+
}
499+
500+
pub(crate) fn filter_generics_for_fields(
501+
generics: &syn::Generics,
502+
fields: &syn::Fields,
503+
) -> syn::Generics {
504+
let mut usage = GenericUsage::from_generics(generics);
505+
usage.visit_fields(fields);
506+
507+
let mut changed = true;
508+
while changed {
509+
changed = false;
510+
511+
for param in &generics.params {
512+
match param {
513+
syn::GenericParam::Type(type_param) => {
514+
if !usage
515+
.used_type_params
516+
.contains(&type_param.ident.to_string())
517+
{
518+
continue;
519+
}
520+
let mut bound_usage = GenericUsage::from_generics(generics);
521+
for bound in &type_param.bounds {
522+
bound_usage.visit_type_param_bound(bound);
523+
}
524+
if usage.merge_from(bound_usage) {
525+
changed = true;
526+
}
527+
}
528+
syn::GenericParam::Lifetime(lifetime_param) => {
529+
if !usage
530+
.used_lifetime_params
531+
.contains(&lifetime_param.lifetime.ident.to_string())
532+
{
533+
continue;
534+
}
535+
let mut bound_usage = GenericUsage::from_generics(generics);
536+
for bound in &lifetime_param.bounds {
537+
bound_usage.visit_lifetime(bound);
538+
}
539+
if usage.merge_from(bound_usage) {
540+
changed = true;
541+
}
542+
}
543+
syn::GenericParam::Const(const_param) => {
544+
if !usage
545+
.used_const_params
546+
.contains(&const_param.ident.to_string())
547+
{
548+
continue;
549+
}
550+
let mut type_usage = GenericUsage::from_generics(generics);
551+
type_usage.visit_type(&const_param.ty);
552+
if usage.merge_from(type_usage) {
553+
changed = true;
554+
}
555+
}
556+
}
557+
}
558+
559+
if let Some(where_clause) = &generics.where_clause {
560+
for predicate in &where_clause.predicates {
561+
let mut predicate_usage = GenericUsage::from_generics(generics);
562+
predicate_usage.visit_where_predicate(predicate);
563+
if predicate_usage.is_empty() || !predicate_usage.intersects(&usage) {
564+
continue;
565+
}
566+
if usage.merge_from(predicate_usage) {
567+
changed = true;
568+
}
569+
}
570+
}
571+
}
572+
573+
let mut inner_generics = generics.clone();
574+
inner_generics.params = inner_generics
575+
.params
576+
.into_iter()
577+
.filter(|param| match param {
578+
syn::GenericParam::Type(type_param) => usage
579+
.used_type_params
580+
.contains(&type_param.ident.to_string()),
581+
syn::GenericParam::Lifetime(lifetime_param) => usage
582+
.used_lifetime_params
583+
.contains(&lifetime_param.lifetime.ident.to_string()),
584+
syn::GenericParam::Const(const_param) => usage
585+
.used_const_params
586+
.contains(&const_param.ident.to_string()),
587+
})
588+
.collect();
589+
590+
if let Some(where_clause) = &mut inner_generics.where_clause {
591+
where_clause.predicates = where_clause
592+
.predicates
593+
.clone()
594+
.into_iter()
595+
.filter(|predicate| {
596+
let mut predicate_usage = GenericUsage::from_generics(generics);
597+
predicate_usage.visit_where_predicate(predicate);
598+
predicate_usage.is_empty() || predicate_usage.is_subset_of(&usage)
599+
})
600+
.collect();
601+
if where_clause.predicates.is_empty() {
602+
inner_generics.where_clause = None;
603+
}
604+
}
605+
606+
inner_generics
607+
}

macros/macros_impl/src/encode.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ pub fn map_to_inner_type(
135135
|id| quote!(#crate_root::types::Identifier(Some(#id))),
136136
);
137137

138-
let mut inner_generics = generics.clone();
138+
let mut inner_generics = crate::decode::filter_generics_for_fields(generics, fields);
139139
let lifetime = syn::Lifetime::new(
140140
&format!("'inner{}", uuid::Uuid::new_v4().as_u128()),
141141
proc_macro2::Span::call_site(),

tests/derive.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,19 +328,31 @@ fn decode_enum_with_generics() {
328328
#[rasn(choice)]
329329
#[allow(dead_code)]
330330
enum MyContainer<M> {
331-
SomeVal {
331+
Struct {
332332
#[rasn(tag(0))]
333333
inner: Vec<M>,
334334
},
335+
StructNoGeneric {
336+
#[rasn(tag(0))]
337+
name: String,
338+
},
339+
Newtype(Vec<M>),
340+
NewtypeNoGeneric(String),
335341
}
336342

337343
#[derive(AsnType, Encode, Decode)]
338344
#[rasn(choice)]
339345
#[allow(dead_code)]
340346
enum MyContainerExplicit<M> {
341-
SomeVal {
347+
Struct {
342348
#[rasn(tag(explicit(0)))]
343349
inner: Vec<M>,
344350
},
351+
StructNoGeneric {
352+
#[rasn(tag(explicit(0)))]
353+
name: String,
354+
},
355+
Newtype(Vec<M>),
356+
NewtypeNoGeneric(String),
345357
}
346358
}

0 commit comments

Comments
 (0)