Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 138 additions & 7 deletions asn1_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,25 @@ fn derive_asn1_read_expand(input: syn::DeriveInput) -> syn::Result<proc_macro2::
Ok(expanded)
}

#[proc_macro_derive(Asn1Write, attributes(explicit, implicit, default, defined_by))]
fn get_err_attribute(input: &syn::DeriveInput) -> syn::Result<syn::Type> {
for attr in &input.attrs {
if attr.path().is_ident("error_type") {
let err_type: syn::Type = attr.parse_args()?;
return Ok(err_type);
}
}
Err(syn::Error::new_spanned(
input,
"Error type for asn1::Asn1Writable/Asn1DefinedByWritable implementation was not specified",
))
}

#[proc_macro_derive(
Asn1Write,
attributes(explicit, implicit, default, defined_by, error_type)
)]
pub fn derive_asn1_write(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = syn::parse_macro_input!(input as syn::DeriveInput);

derive_asn1_write_expand(input)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
Expand All @@ -75,11 +90,20 @@ pub fn derive_asn1_write(input: proc_macro::TokenStream) -> proc_macro::TokenStr
fn derive_asn1_write_expand(mut input: syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let name = &input.ident;
let fields = all_field_types(&input.data, false, &input.generics)?;
let err_type = get_err_attribute(&input)?;
add_bounds(
&mut input.generics,
fields.clone(),
syn::parse_quote!(asn1::Asn1Writable),
syn::parse_quote!(asn1::Asn1DefinedByWritable<asn1::ObjectIdentifier>),
true,
);
add_write_error_bounds(
&mut input.generics,
fields,
syn::parse_quote!(asn1::Asn1Writable),
syn::parse_quote!(asn1::Asn1DefinedByWritable<asn1::ObjectIdentifier>),
err_type.clone(),
true,
);
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
Expand All @@ -89,8 +113,9 @@ fn derive_asn1_write_expand(mut input: syn::DeriveInput) -> syn::Result<proc_mac
let (write_block, data_length_block) = generate_struct_write_block(&data)?;
quote::quote! {
impl #impl_generics asn1::SimpleAsn1Writable for #name #ty_generics #where_clause {
type Error = #err_type;
const TAG: asn1::Tag = <asn1::SequenceWriter as asn1::SimpleAsn1Writable>::TAG;
fn write_data(&self, dest: &mut asn1::WriteBuf) -> asn1::WriteResult {
fn write_data(&self, dest: &mut asn1::WriteBuf) -> Result<(), Self::Error> {
#write_block

Ok(())
Expand All @@ -106,7 +131,8 @@ fn derive_asn1_write_expand(mut input: syn::DeriveInput) -> syn::Result<proc_mac
let (write_block, length_block) = generate_enum_write_block(name, &data)?;
quote::quote! {
impl #impl_generics asn1::Asn1Writable for #name #ty_generics #where_clause {
fn write(&self, w: &mut asn1::Writer) -> asn1::WriteResult {
type Error = #err_type;
fn write(&self, w: &mut asn1::Writer) -> Result<(), Self::Error> {
#write_block
}
fn encoded_length(&self) -> Option<usize> {
Expand All @@ -115,7 +141,8 @@ fn derive_asn1_write_expand(mut input: syn::DeriveInput) -> syn::Result<proc_mac
}

impl #impl_generics asn1::Asn1Writable for &#name #ty_generics #where_clause {
fn write(&self, w: &mut asn1::Writer) -> asn1::WriteResult {
type Error = #err_type;
fn write(&self, w: &mut asn1::Writer) -> Result<(), Self::Error> {
(*self).write(w)
}

Expand Down Expand Up @@ -256,7 +283,7 @@ fn derive_asn1_defined_by_read_expand(
})
}

#[proc_macro_derive(Asn1DefinedByWrite, attributes(default, defined_by))]
#[proc_macro_derive(Asn1DefinedByWrite, attributes(default, defined_by, error_type))]
pub fn derive_asn1_defined_by_write(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = syn::parse_macro_input!(input as syn::DeriveInput);

Expand All @@ -270,11 +297,20 @@ fn derive_asn1_defined_by_write_expand(
) -> syn::Result<proc_macro2::TokenStream> {
let name = &input.ident;
let fields = all_field_types(&input.data, true, &input.generics)?;
let err_type = get_err_attribute(&input)?;
add_bounds(
&mut input.generics,
fields.clone(),
syn::parse_quote!(asn1::Asn1Writable),
syn::parse_quote!(asn1::Asn1DefinedByWritable<asn1::ObjectIdentifier>),
true,
);
add_write_error_bounds(
&mut input.generics,
fields,
syn::parse_quote!(asn1::Asn1Writable),
syn::parse_quote!(asn1::Asn1DefinedByWritable<asn1::ObjectIdentifier>),
err_type.clone(),
true,
);
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
Expand Down Expand Up @@ -329,13 +365,14 @@ fn derive_asn1_defined_by_write_expand(

Ok(quote::quote! {
impl #impl_generics asn1::Asn1DefinedByWritable<asn1::ObjectIdentifier> for #name #ty_generics #where_clause {
type Error = #err_type;
fn item(&self) -> &asn1::ObjectIdentifier {
match self {
#(#item_blocks)*
}
}

fn write(&self, w: &mut asn1::Writer) -> asn1::WriteResult {
fn write(&self, w: &mut asn1::Writer) -> Result<(), Self::Error> {
match self {
#(#write_blocks)*
}
Expand Down Expand Up @@ -604,6 +641,100 @@ fn add_bounds(
}
}

fn add_write_error_bounds(
generics: &mut syn::Generics,
field_types: Vec<(syn::Type, OpType, bool)>,
bound: syn::TypeParamBound,
defined_by_bound: syn::TypeParamBound,
write_error_type: syn::Type,
add_ref: bool,
) {
let where_clause = if field_types.is_empty() {
return;
} else {
generics
.where_clause
.get_or_insert_with(|| syn::WhereClause {
where_token: Default::default(),
predicates: syn::punctuated::Punctuated::new(),
})
};
for (f, op_type, has_default) in &field_types {
let (bounded_ty, lifetimes): (syn::Type, Option<syn::BoundLifetimes>) = match (
op_type, add_ref,
) {
(OpType::Regular, _) => (syn::parse_quote!(<#f as #bound>::Error), None),
(OpType::DefinedBy(_), _) => {
(syn::parse_quote!(<#f as #defined_by_bound>::Error), None)
}
(OpType::Implicit(OpTypeArgs { value, required }), false) => {
if *required || *has_default {
(
syn::parse_quote!(<asn1::Implicit::<#f, #value> as #bound>::Error),
None,
)
} else {
(
syn::parse_quote!(<asn1::Implicit::<<#f as asn1::OptionExt>::T, #value> as #bound>::Error),
None,
)
}
}
(OpType::Implicit(OpTypeArgs { value, required }), true) => {
if *required || *has_default {
(
syn::parse_quote!(<asn1::Implicit::<&'asn1_internal #f, #value> as #bound>::Error),
Some(syn::parse_quote!(for<'asn1_internal>)),
)
} else {
(
syn::parse_quote!(<asn1::Implicit::<&'asn1_internal <#f as asn1::OptionExt>::T, #value> as #bound>::Error),
Some(syn::parse_quote!(for<'asn1_internal>)),
)
}
}
(OpType::Explicit(OpTypeArgs { value, required }), false) => {
if *required || *has_default {
(
syn::parse_quote!(<asn1::Explicit::<#f, #value> as #bound>::Error),
None,
)
} else {
(
syn::parse_quote!(<asn1::Explicit::<<#f as asn1::OptionExt>::T, #value> as #bound>::Error),
None,
)
}
}
(OpType::Explicit(OpTypeArgs { value, required }), true) => {
if *required || *has_default {
(
syn::parse_quote!(<asn1::Explicit::<&'asn1_internal #f, #value> as #bound>::Error),
Some(syn::parse_quote!(for<'asn1_internal>)),
)
} else {
(
syn::parse_quote!(<asn1::Explicit::<&'asn1_internal <#f as asn1::OptionExt>::T, #value> as #bound>::Error),
Some(syn::parse_quote!(for<'asn1_internal>)),
)
}
}
};
where_clause
.predicates
.push(syn::WherePredicate::Type(syn::PredicateType {
lifetimes,
bounded_ty,
colon_token: Default::default(),
bounds: {
let mut p = syn::punctuated::Punctuated::new();
p.push(syn::parse_quote!(Into<#write_error_type>));
p
},
}));
}
}

#[derive(Clone)]
enum OpType {
Regular,
Expand Down
12 changes: 7 additions & 5 deletions examples/no_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ fn main() {
};

let computed = asn1::write(|w| {
w.write_element(&asn1::SequenceWriter::new(&|w: &mut asn1::Writer| {
w.write_element(&1i64)?;
w.write_element(&3i64)?;
Ok(())
}))
w.write_element(&asn1::SequenceWriter::<asn1::WriteError>::new(
&|w: &mut asn1::Writer| {
w.write_element(&1i64)?;
w.write_element(&3i64)?;
Ok(())
},
))
})
.unwrap();
unsafe {
Expand Down
Loading
Loading