Skip to content

Commit eaaa700

Browse files
committed
Allow single-field named structs to be transparent
This more closely matches the criteria for e.g. #[repr(transparent)] and #[serde(transparent)].
1 parent c825bd3 commit eaaa700

File tree

3 files changed

+40
-31
lines changed

3 files changed

+40
-31
lines changed

sqlx-macros-core/src/derives/decode.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,17 @@ use quote::quote;
88
use syn::punctuated::Punctuated;
99
use syn::token::Comma;
1010
use syn::{
11-
parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed,
12-
FieldsUnnamed, Stmt, TypeParamBound, Variant,
11+
parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Stmt,
12+
TypeParamBound, Variant,
1313
};
1414

1515
pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result<TokenStream> {
1616
let attrs = parse_container_attributes(&input.attrs)?;
1717
match &input.data {
18-
Data::Struct(DataStruct {
19-
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
20-
..
21-
}) if unnamed.len() == 1 => {
22-
expand_derive_decode_transparent(input, unnamed.first().unwrap())
18+
Data::Struct(DataStruct { fields, .. })
19+
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || attrs.transparent) =>
20+
{
21+
expand_derive_decode_transparent(input, fields.iter().next().unwrap())
2322
}
2423
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
2524
Some(_) => expand_derive_decode_weak_enum(input, variants),
@@ -72,6 +71,12 @@ fn expand_derive_decode_transparent(
7271
.push(parse_quote!(#ty: ::sqlx::decode::Decode<'r, DB>));
7372
let (impl_generics, _, where_clause) = generics.split_for_impl();
7473

74+
let field_ident = if let Some(ident) = &field.ident {
75+
quote! { #ident }
76+
} else {
77+
quote! { 0 }
78+
};
79+
7580
let tts = quote!(
7681
#[automatically_derived]
7782
impl #impl_generics ::sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause {
@@ -83,7 +88,8 @@ fn expand_derive_decode_transparent(
8388
dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync,
8489
>,
8590
> {
86-
<#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value).map(Self)
91+
<#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value)
92+
.map(|val| Self { #field_ident: val })
8793
}
8894
}
8995
);

sqlx-macros-core/src/derives/encode.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,17 @@ use syn::punctuated::Punctuated;
99
use syn::token::Comma;
1010
use syn::{
1111
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
12-
FieldsUnnamed, Lifetime, LifetimeParam, Stmt, TypeParamBound, Variant,
12+
Lifetime, LifetimeParam, Stmt, TypeParamBound, Variant,
1313
};
1414

1515
pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<TokenStream> {
1616
let args = parse_container_attributes(&input.attrs)?;
1717

1818
match &input.data {
19-
Data::Struct(DataStruct {
20-
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
21-
..
22-
}) if unnamed.len() == 1 => {
23-
expand_derive_encode_transparent(input, unnamed.first().unwrap())
19+
Data::Struct(DataStruct { fields, .. })
20+
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || args.transparent) =>
21+
{
22+
expand_derive_encode_transparent(input, fields.iter().next().unwrap())
2423
}
2524
Data::Enum(DataEnum { variants, .. }) => match args.repr {
2625
Some(_) => expand_derive_encode_weak_enum(input, variants),
@@ -77,6 +76,12 @@ fn expand_derive_encode_transparent(
7776
.push(parse_quote!(#ty: ::sqlx::encode::Encode<#lifetime, DB>));
7877
let (impl_generics, _, where_clause) = generics.split_for_impl();
7978

79+
let field_ident = if let Some(ident) = &field.ident {
80+
quote! { #ident }
81+
} else {
82+
quote! { 0 }
83+
};
84+
8085
Ok(quote!(
8186
#[automatically_derived]
8287
impl #impl_generics ::sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics
@@ -86,15 +91,15 @@ fn expand_derive_encode_transparent(
8691
&self,
8792
buf: &mut <DB as ::sqlx::database::Database>::ArgumentBuffer<#lifetime>,
8893
) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> {
89-
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf)
94+
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.#field_ident, buf)
9095
}
9196

9297
fn produces(&self) -> Option<DB::TypeInfo> {
93-
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0)
98+
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::produces(&self.#field_ident)
9499
}
95100

96101
fn size_hint(&self) -> usize {
97-
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0)
102+
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.#field_ident)
98103
}
99104
}
100105
))

sqlx-macros-core/src/derives/type.rs

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,32 @@ use quote::{quote, quote_spanned};
77
use syn::punctuated::Punctuated;
88
use syn::token::Comma;
99
use syn::{
10-
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed,
11-
FieldsUnnamed, Variant,
10+
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Variant,
1211
};
1312

1413
pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
1514
let attrs = parse_container_attributes(&input.attrs)?;
1615
match &input.data {
1716
// Newtype structs:
1817
// struct Foo(i32);
19-
Data::Struct(DataStruct {
20-
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
21-
..
22-
}) => {
23-
if unnamed.len() == 1 {
24-
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
25-
} else {
26-
Err(syn::Error::new_spanned(
27-
input,
28-
"structs with zero or more than one unnamed field are not supported",
29-
))
30-
}
18+
Data::Struct(DataStruct { fields, .. })
19+
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || attrs.transparent) =>
20+
{
21+
expand_derive_has_sql_type_transparent(input, fields.iter().next().unwrap())
3122
}
3223
// Record types
3324
// struct Foo { foo: i32, bar: String }
3425
Data::Struct(DataStruct {
3526
fields: Fields::Named(FieldsNamed { named, .. }),
3627
..
3728
}) => expand_derive_has_sql_type_struct(input, named),
29+
Data::Struct(DataStruct {
30+
fields: Fields::Unnamed(..),
31+
..
32+
}) => Err(syn::Error::new_spanned(
33+
input,
34+
"structs with zero or more than one unnamed field are not supported",
35+
)),
3836
Data::Struct(DataStruct {
3937
fields: Fields::Unit,
4038
..

0 commit comments

Comments
 (0)