Skip to content

Commit 26697de

Browse files
wip
1 parent 8d807e2 commit 26697de

File tree

4 files changed

+101
-36
lines changed

4 files changed

+101
-36
lines changed

benzina-derive/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ diesel = { version = "2", default-features = false, features = ["postgres", "mys
2727
[features]
2828
postgres = []
2929
mysql = []
30+
json = []
3031

3132
[lints]
3233
workspace = true

benzina-derive/src/enum_derive.rs

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,20 @@ struct EnumVariant {
3434
original_name: String,
3535
original_name_span: Span,
3636
rename: Option<String>,
37+
#[cfg(all(feature = "postgres", feature = "json"))]
38+
payload: Option<EnumVariantPayload>,
3739

3840
crate_name: Option<Path>,
3941
}
4042

43+
#[cfg(all(feature = "postgres", feature = "json"))]
44+
struct EnumVariantPayload {
45+
type_: Type,
46+
span: Span,
47+
}
48+
4149
impl Enum {
50+
#[expect(clippy::too_many_lines)]
4251
pub(crate) fn parse(input: DeriveInput) -> Result<Self, syn::Error> {
4352
let Data::Enum(e) = input.data else {
4453
fail!(input, "`benzina::Enum` macro available only for enums");
@@ -97,9 +106,29 @@ impl Enum {
97106
.variants
98107
.into_iter()
99108
.map(|variant| {
100-
if !matches!(variant.fields, Fields::Unit) {
101-
fail!(variant, "only unit variants are supported");
102-
}
109+
let payload = match &variant.fields {
110+
Fields::Unit => None,
111+
#[cfg(all(feature = "postgres", feature = "json"))]
112+
Fields::Unnamed(fields) => {
113+
let mut fields = fields.unnamed.iter();
114+
let (Some(field), None) = (fields.next(), fields.next()) else {
115+
fail!(variant, "only single-item variants are supported");
116+
};
117+
118+
let span = field.span();
119+
Some(EnumVariantPayload {
120+
type_: field.ty.clone(),
121+
span,
122+
})
123+
}
124+
#[cfg(not(all(feature = "postgres", feature = "json")))]
125+
Fields::Unnamed(_fields) => {
126+
fail!(variant, "fields require both the `postgres` and the `json` feature to be enabled");
127+
}
128+
Fields::Named(_fields) => {
129+
fail!(variant, "only unit an unnamed variants are supported");
130+
}
131+
};
103132

104133
let name = variant.ident.to_string();
105134
let mut rename = None;
@@ -120,11 +149,18 @@ impl Enum {
120149
})?;
121150
}
122151

152+
// Suppress build breakage when building without the
153+
// PostgreSQL JSON feature.
154+
#[cfg(not(all(feature = "postgres", feature = "json")))]
155+
let _: Option<()> = payload;
156+
123157
let original_name_span = variant.span();
124158
Ok(EnumVariant {
125159
original_name: name,
126160
original_name_span,
127161
rename,
162+
#[cfg(all(feature = "postgres", feature = "json"))]
163+
payload,
128164

129165
crate_name: crate_name.clone(),
130166
})
@@ -139,6 +175,22 @@ impl Enum {
139175
crate_name,
140176
})
141177
}
178+
179+
#[cfg(all(feature = "postgres", feature = "json"))]
180+
fn has_json_fields(&self) -> bool {
181+
self.variants
182+
.iter()
183+
.any(|variant| variant.payload.is_some())
184+
}
185+
186+
#[cfg(not(all(feature = "postgres", feature = "json")))]
187+
#[expect(
188+
clippy::unused_self,
189+
reason = "kept for compatibility with the above implementation"
190+
)]
191+
fn has_json_fields(&self) -> bool {
192+
false
193+
}
142194
}
143195

144196
impl ToTokens for Enum {
@@ -153,19 +205,26 @@ impl ToTokens for Enum {
153205
} = &self;
154206
let crate_name = crate::crate_name(crate_name);
155207

208+
let has_json_fields = self.has_json_fields();
156209
let from_bytes_arms = variants
157210
.iter()
158-
.map(|variant| variant.gen_from_bytes(*rename_all))
211+
.map(|variant| variant.gen_from_bytes(has_json_fields, *rename_all))
159212
.collect::<Vec<_>>();
160213
let to_byte_str_arms = variants
161214
.iter()
162-
.map(|variant| variant.gen_to_byte_str(*rename_all))
215+
.map(|variant| variant.gen_to_byte_str(has_json_fields, *rename_all))
163216
.collect::<Vec<_>>();
164217

165-
#[cfg(feature = "postgres")]
218+
let pg_sql_type = if self.has_json_fields() {
219+
#[cfg(all(feature = "postgres", feature = "json"))]
220+
quote! { (#sql_type, #crate_name::__private::diesel::pg::sql_types::Jsonb) }
221+
} else {
222+
quote! { #sql_type }
223+
};
224+
166225
tokens.append_all(quote! {
167226
#[automatically_derived]
168-
impl #crate_name::__private::diesel::deserialize::FromSql<#sql_type, #crate_name::__private::diesel::pg::Pg> for #ident {
227+
impl #crate_name::__private::diesel::deserialize::FromSql<#pg_sql_type, #crate_name::__private::diesel::pg::Pg> for #ident {
169228
fn from_sql(bytes: #crate_name::__private::diesel::pg::PgValue<'_>) -> #crate_name::__private::diesel::deserialize::Result<Self> {
170229
match bytes.as_bytes() {
171230
#(#from_bytes_arms)*
@@ -181,7 +240,7 @@ impl ToTokens for Enum {
181240
}
182241

183242
#[automatically_derived]
184-
impl #crate_name::__private::diesel::serialize::ToSql<#sql_type, #crate_name::__private::diesel::pg::Pg> for #ident {
243+
impl #crate_name::__private::diesel::serialize::ToSql<#pg_sql_type, #crate_name::__private::diesel::pg::Pg> for #ident {
185244
fn to_sql<'b>(&'b self, out: &mut #crate_name::__private::diesel::serialize::Output<'b, '_, #crate_name::__private::diesel::pg::Pg>) -> #crate_name::__private::diesel::serialize::Result {
186245
let s = match self {
187246
#(#to_byte_str_arms)*
@@ -196,44 +255,48 @@ impl ToTokens for Enum {
196255
});
197256

198257
#[cfg(feature = "mysql")]
199-
tokens.append_all(quote! {
200-
#[automatically_derived]
201-
impl #crate_name::__private::diesel::deserialize::FromSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident {
202-
fn from_sql(bytes: #crate_name::__private::diesel::mysql::MysqlValue<'_>) -> #crate_name::__private::diesel::deserialize::Result<Self> {
203-
match bytes.as_bytes() {
204-
#(#from_bytes_arms)*
205-
_ => {
206-
#crate_name::__private::std::result::Result::Err(
207-
#crate_name::__private::std::convert::Into::into(
208-
"Unrecognized enum variant"
258+
if !self.has_json_fields() {
259+
tokens.append_all(quote! {
260+
#[automatically_derived]
261+
impl #crate_name::"mysql"__private::diesel::deserialize::FromSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident {
262+
fn from_sql(bytes: #crate_name::__private::diesel::mysql::MysqlValue<'_>) -> #crate_name::__private::diesel::deserialize::Result<Self> {
263+
match bytes.as_bytes() {
264+
#(#from_bytes_arms)*
265+
_ => {
266+
#crate_name::__private::std::result::Result::Err(
267+
#crate_name::__private::std::convert::Into::into(
268+
"Unrecognized enum variant"
269+
)
209270
)
210-
)
211-
},
271+
},
272+
}
212273
}
213274
}
214-
}
215275

216-
#[automatically_derived]
217-
impl #crate_name::__private::diesel::serialize::ToSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident {
218-
fn to_sql<'b>(&'b self, out: &mut #crate_name::__private::diesel::serialize::Output<'b, '_, #crate_name::__private::diesel::mysql::Mysql>) -> #crate_name::__private::diesel::serialize::Result {
219-
let s = match self {
220-
#(#to_byte_str_arms)*
221-
};
222-
#crate_name::__private::std::io::Write::write_all(out, s)?;
276+
#[automatically_derived]
277+
impl #crate_name::__private::diesel::serialize::ToSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident {
278+
fn to_sql<'b>(&'b self, out: &mut #crate_name::__private::diesel::serialize::Output<'b, '_, #crate_name::__private::diesel::mysql::Mysql>) -> #crate_name::__private::diesel::serialize::Result {
279+
let s = match self {
280+
#(#to_byte_str_arms)*
281+
};
282+
#crate_name::__private::std::io::Write::write_all(out, s)?;
223283

224-
#crate_name::__private::std::result::Result::Ok(#crate_name::__private::diesel::serialize::IsNull::No)
284+
#crate_name::__private::std::result::Result::Ok(#crate_name::__private::diesel::serialize::IsNull::No)
285+
}
225286
}
226-
}
227-
});
287+
});
288+
}
228289
}
229290
}
230291

231292
impl EnumVariant {
232-
fn gen_from_bytes(&self, rename_rule: RenameRule) -> impl ToTokens {
293+
fn gen_from_bytes(&self, has_fields: bool, rename_rule: RenameRule) -> impl ToTokens {
233294
let Self {
234295
original_name,
235296
original_name_span,
236297
rename,
298+
#[cfg(all(feature = "postgres", feature = "json"))]
299+
payload,
237300

238301
crate_name,
239302
} = self;
@@ -250,11 +313,13 @@ impl EnumVariant {
250313
}
251314
}
252315

253-
fn gen_to_byte_str(&self, rename_rule: RenameRule) -> impl ToTokens {
316+
fn gen_to_byte_str(&self, has_fields: bool, rename_rule: RenameRule) -> impl ToTokens {
254317
let Self {
255318
original_name,
256319
original_name_span,
257320
rename,
321+
#[cfg(all(feature = "postgres", feature = "json"))]
322+
payload,
258323

259324
crate_name: _,
260325
} = self;

benzina-derive/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ mod rename_rule;
4040
/// expression::AsExpression,
4141
/// };
4242
///
43-
/// #[derive(Debug, Copy, Clone, AsExpression, FromSqlRow, benzina::Enum)]
44-
/// #[diesel(sql_type = crate::schema::sql_types::Animal)]
43+
/// #[derive(Debug, Copy, Clone, benzina::Enum)]
4544
/// #[benzina(
4645
/// sql_type = crate::schema::sql_types::Animal,
4746
/// rename_all = "snake_case"

benzina/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ utoipa = ["dep:utoipa"]
5151
example-generated = ["typed-uuid"]
5252
dangerous-construction = ["typed-uuid"]
5353

54-
json = ["postgres", "dep:serde_core", "dep:serde_json", "diesel/serde_json"]
54+
json = ["postgres", "benzina-derive?/json", "dep:serde_core", "dep:serde_json", "diesel/serde_json"]
5555
ctid = ["postgres", "diesel/i-implement-a-third-party-backend-and-opt-into-breaking-changes"]
5656

5757
[lints]

0 commit comments

Comments
 (0)