Skip to content

Commit a301d9a

Browse files
authored
Allow single-field named structs to be transparent (#3971)
* Allow single-field named structs to be transparent This more closely matches the criteria for e.g. #[repr(transparent)] and #[serde(transparent)]. * Add tests, fix error messages
1 parent ff93aa0 commit a301d9a

File tree

7 files changed

+98
-36
lines changed

7 files changed

+98
-36
lines changed

sqlx-macros-core/src/derives/attributes.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ pub fn check_struct_attributes(
281281

282282
assert_attribute!(
283283
!attributes.transparent,
284-
"unexpected #[sqlx(transparent)]",
284+
"#[sqlx(transparent)] is only valid for structs with exactly one field",
285285
input
286286
);
287287

sqlx-macros-core/src/derives/decode.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,17 @@ use quote::quote;
88
use syn::punctuated::Punctuated;
99
use syn::token::Comma;
1010
use syn::{
11-
parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed,
12-
FieldsUnnamed, Stmt, TypeParamBound, Variant,
11+
parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Stmt,
12+
TypeParamBound, Variant,
1313
};
1414

1515
pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result<TokenStream> {
1616
let attrs = parse_container_attributes(&input.attrs)?;
1717
match &input.data {
18-
Data::Struct(DataStruct {
19-
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
20-
..
21-
}) if unnamed.len() == 1 => {
22-
expand_derive_decode_transparent(input, unnamed.first().unwrap())
18+
Data::Struct(DataStruct { fields, .. })
19+
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || attrs.transparent) =>
20+
{
21+
expand_derive_decode_transparent(input, fields.iter().next().unwrap())
2322
}
2423
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
2524
Some(_) => expand_derive_decode_weak_enum(input, variants),
@@ -35,7 +34,7 @@ pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result<TokenStream> {
3534
..
3635
}) => Err(syn::Error::new_spanned(
3736
input,
38-
"structs with zero or more than one unnamed field are not supported",
37+
"tuple structs may only have a single field",
3938
)),
4039
Data::Struct(DataStruct {
4140
fields: Fields::Unit,
@@ -72,6 +71,12 @@ fn expand_derive_decode_transparent(
7271
.push(parse_quote!(#ty: ::sqlx::decode::Decode<'r, DB>));
7372
let (impl_generics, _, where_clause) = generics.split_for_impl();
7473

74+
let field_ident = if let Some(ident) = &field.ident {
75+
quote! { #ident }
76+
} else {
77+
quote! { 0 }
78+
};
79+
7580
let tts = quote!(
7681
#[automatically_derived]
7782
impl #impl_generics ::sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause {
@@ -83,7 +88,8 @@ fn expand_derive_decode_transparent(
8388
dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync,
8489
>,
8590
> {
86-
<#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value).map(Self)
91+
<#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value)
92+
.map(|val| Self { #field_ident: val })
8793
}
8894
}
8995
);

sqlx-macros-core/src/derives/encode.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,17 @@ use syn::punctuated::Punctuated;
99
use syn::token::Comma;
1010
use syn::{
1111
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
12-
FieldsUnnamed, Lifetime, LifetimeParam, Stmt, TypeParamBound, Variant,
12+
Lifetime, LifetimeParam, Stmt, TypeParamBound, Variant,
1313
};
1414

1515
pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<TokenStream> {
1616
let args = parse_container_attributes(&input.attrs)?;
1717

1818
match &input.data {
19-
Data::Struct(DataStruct {
20-
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
21-
..
22-
}) if unnamed.len() == 1 => {
23-
expand_derive_encode_transparent(input, unnamed.first().unwrap())
19+
Data::Struct(DataStruct { fields, .. })
20+
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || args.transparent) =>
21+
{
22+
expand_derive_encode_transparent(input, fields.iter().next().unwrap())
2423
}
2524
Data::Enum(DataEnum { variants, .. }) => match args.repr {
2625
Some(_) => expand_derive_encode_weak_enum(input, variants),
@@ -36,7 +35,7 @@ pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<TokenStream> {
3635
..
3736
}) => Err(syn::Error::new_spanned(
3837
input,
39-
"structs with zero or more than one unnamed field are not supported",
38+
"tuple structs may only have a single field",
4039
)),
4140
Data::Struct(DataStruct {
4241
fields: Fields::Unit,
@@ -77,6 +76,12 @@ fn expand_derive_encode_transparent(
7776
.push(parse_quote!(#ty: ::sqlx::encode::Encode<#lifetime, DB>));
7877
let (impl_generics, _, where_clause) = generics.split_for_impl();
7978

79+
let field_ident = if let Some(ident) = &field.ident {
80+
quote! { #ident }
81+
} else {
82+
quote! { 0 }
83+
};
84+
8085
Ok(quote!(
8186
#[automatically_derived]
8287
impl #impl_generics ::sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics
@@ -86,15 +91,15 @@ fn expand_derive_encode_transparent(
8691
&self,
8792
buf: &mut <DB as ::sqlx::database::Database>::ArgumentBuffer<#lifetime>,
8893
) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> {
89-
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf)
94+
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.#field_ident, buf)
9095
}
9196

9297
fn produces(&self) -> Option<DB::TypeInfo> {
93-
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0)
98+
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::produces(&self.#field_ident)
9499
}
95100

96101
fn size_hint(&self) -> usize {
97-
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0)
102+
<#ty as ::sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.#field_ident)
98103
}
99104
}
100105
))

sqlx-macros-core/src/derives/type.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,33 @@ use quote::{quote, quote_spanned};
77
use syn::punctuated::Punctuated;
88
use syn::token::Comma;
99
use syn::{
10-
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed,
11-
FieldsUnnamed, Variant,
10+
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Variant,
1211
};
1312

1413
pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<TokenStream> {
1514
let attrs = parse_container_attributes(&input.attrs)?;
1615
match &input.data {
1716
// Newtype structs:
1817
// struct Foo(i32);
19-
Data::Struct(DataStruct {
20-
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
21-
..
22-
}) => {
23-
if unnamed.len() == 1 {
24-
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
25-
} else {
26-
Err(syn::Error::new_spanned(
27-
input,
28-
"structs with zero or more than one unnamed field are not supported",
29-
))
30-
}
18+
// struct Foo { field: i32 };
19+
Data::Struct(DataStruct { fields, .. })
20+
if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || attrs.transparent) =>
21+
{
22+
expand_derive_has_sql_type_transparent(input, fields.iter().next().unwrap())
3123
}
3224
// Record types
3325
// struct Foo { foo: i32, bar: String }
3426
Data::Struct(DataStruct {
3527
fields: Fields::Named(FieldsNamed { named, .. }),
3628
..
3729
}) => expand_derive_has_sql_type_struct(input, named),
30+
Data::Struct(DataStruct {
31+
fields: Fields::Unnamed(..),
32+
..
33+
}) => Err(syn::Error::new_spanned(
34+
input,
35+
"tuple structs may only have a single field",
36+
)),
3837
Data::Struct(DataStruct {
3938
fields: Fields::Unit,
4039
..

tests/mysql/derives.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use sqlx_mysql::MySql;
2-
use sqlx_test::new;
2+
use sqlx_test::{new, test_type};
33

44
#[sqlx::test]
55
async fn test_derive_strong_enum() -> anyhow::Result<()> {
@@ -300,3 +300,23 @@ async fn test_derive_weak_enum() -> anyhow::Result<()> {
300300

301301
Ok(())
302302
}
303+
304+
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
305+
#[sqlx(transparent)]
306+
struct TransparentTuple(i64);
307+
308+
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
309+
#[sqlx(transparent)]
310+
struct TransparentNamed {
311+
field: i64,
312+
}
313+
314+
test_type!(transparent_tuple<TransparentTuple>(MySql,
315+
"0" == TransparentTuple(0),
316+
"23523" == TransparentTuple(23523)
317+
));
318+
319+
test_type!(transparent_named<TransparentNamed>(MySql,
320+
"0" == TransparentNamed { field: 0 },
321+
"23523" == TransparentNamed { field: 23523 },
322+
));

tests/postgres/derives.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ use std::ops::Bound;
1212
#[sqlx(transparent)]
1313
struct Transparent(i32);
1414

15+
// Also possible for single-field named structs
16+
#[derive(PartialEq, Debug, sqlx::Type)]
17+
#[sqlx(transparent)]
18+
struct TransparentNamed {
19+
field: i32,
20+
}
21+
1522
#[derive(PartialEq, Debug, sqlx::Type)]
1623
// https://github.com/launchbadge/sqlx/issues/2611
1724
// Previously, the derive would generate a `PgHasArrayType` impl that errored on an
@@ -143,11 +150,16 @@ struct FloatRange(PgRange<f64>);
143150
#[sqlx(type_name = "int4rangeL0pC")]
144151
struct RangeInclusive(PgRange<i32>);
145152

146-
test_type!(transparent<Transparent>(Postgres,
153+
test_type!(transparent_tuple<Transparent>(Postgres,
147154
"0" == Transparent(0),
148155
"23523" == Transparent(23523)
149156
));
150157

158+
test_type!(transparent_named<TransparentNamed>(Postgres,
159+
"0" == TransparentNamed { field: 0 },
160+
"23523" == TransparentNamed { field: 23523 },
161+
));
162+
151163
test_type!(transparent_array<TransparentArray>(Postgres,
152164
"'{}'::int8[]" == TransparentArray(vec![]),
153165
"'{ 23523, 123456, 789 }'::int8[]" == TransparentArray(vec![23523, 123456, 789])

tests/sqlite/derives.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,23 @@ test_type!(origin_enum<Origin>(Sqlite,
1212
"1" == Origin::Foo,
1313
"2" == Origin::Bar,
1414
));
15+
16+
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
17+
#[sqlx(transparent)]
18+
struct TransparentTuple(i64);
19+
20+
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
21+
#[sqlx(transparent)]
22+
struct TransparentNamed {
23+
field: i64,
24+
}
25+
26+
test_type!(transparent_tuple<TransparentTuple>(Sqlite,
27+
"0" == TransparentTuple(0),
28+
"23523" == TransparentTuple(23523)
29+
));
30+
31+
test_type!(transparent_named<TransparentNamed>(Sqlite,
32+
"0" == TransparentNamed { field: 0 },
33+
"23523" == TransparentNamed { field: 23523 },
34+
));

0 commit comments

Comments
 (0)