Skip to content

Commit 0e5771b

Browse files
wip
1 parent ac297e2 commit 0e5771b

File tree

5 files changed

+146
-35
lines changed

5 files changed

+146
-35
lines changed

benzina-derive/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ diesel = { version = "2", default-features = false, features = ["postgres", "mys
2727
[features]
2828
postgres = []
2929
mysql = []
30+
json = []
3031

3132
[lints]
3233
workspace = true

benzina-derive/src/enum_derive.rs

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,20 @@ struct EnumVariant {
3434
original_name: String,
3535
original_name_span: Span,
3636
rename: Option<String>,
37+
#[cfg(all(feature = "postgres", feature = "json"))]
38+
payload: Option<EnumVariantPayload>,
3739

3840
crate_name: Option<Path>,
3941
}
4042

43+
#[cfg(all(feature = "postgres", feature = "json"))]
44+
struct EnumVariantPayload {
45+
type_: Type,
46+
span: Span,
47+
}
48+
4149
impl Enum {
50+
#[expect(clippy::too_many_lines)]
4251
pub(crate) fn parse(input: DeriveInput) -> Result<Self, syn::Error> {
4352
let Data::Enum(e) = input.data else {
4453
fail!(input, "`benzina::Enum` macro available only for enums");
@@ -97,9 +106,29 @@ impl Enum {
97106
.variants
98107
.into_iter()
99108
.map(|variant| {
100-
if !matches!(variant.fields, Fields::Unit) {
101-
fail!(variant, "only unit variants are supported");
102-
}
109+
let payload = match &variant.fields {
110+
Fields::Unit => None,
111+
#[cfg(all(feature = "postgres", feature = "json"))]
112+
Fields::Unnamed(fields) => {
113+
let mut fields = fields.unnamed.iter();
114+
let (Some(field), None) = (fields.next(), fields.next()) else {
115+
fail!(variant, "only single-item variants are supported");
116+
};
117+
118+
let span = field.span();
119+
Some(EnumVariantPayload {
120+
type_: field.ty.clone(),
121+
span,
122+
})
123+
}
124+
#[cfg(not(all(feature = "postgres", feature = "json")))]
125+
Fields::Unnamed(_fields) => {
126+
fail!(variant, "fields require both the `postgres` and the `json` feature to be enabled");
127+
}
128+
Fields::Named(_fields) => {
129+
fail!(variant, "only unit an unnamed variants are supported");
130+
}
131+
};
103132

104133
let name = variant.ident.to_string();
105134
let mut rename = None;
@@ -120,11 +149,18 @@ impl Enum {
120149
})?;
121150
}
122151

152+
// Suppress build breakage when building without the
153+
// PostgreSQL JSON feature.
154+
#[cfg(not(all(feature = "postgres", feature = "json")))]
155+
let _: Option<()> = payload;
156+
123157
let original_name_span = variant.span();
124158
Ok(EnumVariant {
125159
original_name: name,
126160
original_name_span,
127161
rename,
162+
#[cfg(all(feature = "postgres", feature = "json"))]
163+
payload,
128164

129165
crate_name: crate_name.clone(),
130166
})
@@ -139,6 +175,22 @@ impl Enum {
139175
crate_name,
140176
})
141177
}
178+
179+
#[cfg(all(feature = "postgres", feature = "json"))]
180+
fn has_json_fields(&self) -> bool {
181+
self.variants
182+
.iter()
183+
.any(|variant| variant.payload.is_some())
184+
}
185+
186+
#[cfg(not(all(feature = "postgres", feature = "json")))]
187+
#[expect(
188+
clippy::unused_self,
189+
reason = "kept for compatibility with the above implementation"
190+
)]
191+
fn has_json_fields(&self) -> bool {
192+
false
193+
}
142194
}
143195

144196
impl ToTokens for Enum {
@@ -153,16 +205,23 @@ impl ToTokens for Enum {
153205
} = &self;
154206
let crate_name = crate::crate_name(crate_name);
155207

208+
let has_json_fields = self.has_json_fields();
156209
let from_bytes_arms = variants
157210
.iter()
158-
.map(|variant| variant.gen_from_bytes(*rename_all))
211+
.map(|variant| variant.gen_from_bytes(has_json_fields, *rename_all))
159212
.collect::<Vec<_>>();
160213
let to_byte_str_arms = variants
161214
.iter()
162-
.map(|variant| variant.gen_to_byte_str(*rename_all))
215+
.map(|variant| variant.gen_to_byte_str(has_json_fields, *rename_all))
163216
.collect::<Vec<_>>();
164217

165-
#[cfg(feature = "postgres")]
218+
let pg_sql_type = if self.has_json_fields() {
219+
#[cfg(all(feature = "postgres", feature = "json"))]
220+
quote! { (#sql_type, #crate_name::__private::diesel::pg::sql_types::Jsonb) }
221+
} else {
222+
quote! { #sql_type }
223+
};
224+
166225
tokens.append_all(quote! {
167226
impl diesel::deserialize::Queryable<#sql_type, #crate_name::__private::diesel::pg::Pg> for #ident {
168227
type Row = Self;
@@ -173,7 +232,7 @@ impl ToTokens for Enum {
173232
}
174233

175234
#[automatically_derived]
176-
impl #crate_name::__private::diesel::deserialize::FromSql<#sql_type, #crate_name::__private::diesel::pg::Pg> for #ident {
235+
impl #crate_name::__private::diesel::deserialize::FromSql<#pg_sql_type, #crate_name::__private::diesel::pg::Pg> for #ident {
177236
fn from_sql(bytes: #crate_name::__private::diesel::pg::PgValue<'_>) -> #crate_name::__private::diesel::deserialize::Result<Self> {
178237
match bytes.as_bytes() {
179238
#(#from_bytes_arms)*
@@ -189,7 +248,7 @@ impl ToTokens for Enum {
189248
}
190249

191250
#[automatically_derived]
192-
impl #crate_name::__private::diesel::serialize::ToSql<#sql_type, #crate_name::__private::diesel::pg::Pg> for #ident {
251+
impl #crate_name::__private::diesel::serialize::ToSql<#pg_sql_type, #crate_name::__private::diesel::pg::Pg> for #ident {
193252
fn to_sql<'b>(&'b self, out: &mut #crate_name::__private::diesel::serialize::Output<'b, '_, #crate_name::__private::diesel::pg::Pg>) -> #crate_name::__private::diesel::serialize::Result {
194253
let s = match self {
195254
#(#to_byte_str_arms)*
@@ -204,44 +263,48 @@ impl ToTokens for Enum {
204263
});
205264

206265
#[cfg(feature = "mysql")]
207-
tokens.append_all(quote! {
208-
#[automatically_derived]
209-
impl #crate_name::__private::diesel::deserialize::FromSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident {
210-
fn from_sql(bytes: #crate_name::__private::diesel::mysql::MysqlValue<'_>) -> #crate_name::__private::diesel::deserialize::Result<Self> {
211-
match bytes.as_bytes() {
212-
#(#from_bytes_arms)*
213-
_ => {
214-
#crate_name::__private::std::result::Result::Err(
215-
#crate_name::__private::std::convert::Into::into(
216-
"Unrecognized enum variant"
266+
if !self.has_json_fields() {
267+
tokens.append_all(quote! {
268+
#[automatically_derived]
269+
impl #crate_name::"mysql"__private::diesel::deserialize::FromSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident {
270+
fn from_sql(bytes: #crate_name::__private::diesel::mysql::MysqlValue<'_>) -> #crate_name::__private::diesel::deserialize::Result<Self> {
271+
match bytes.as_bytes() {
272+
#(#from_bytes_arms)*
273+
_ => {
274+
#crate_name::__private::std::result::Result::Err(
275+
#crate_name::__private::std::convert::Into::into(
276+
"Unrecognized enum variant"
277+
)
217278
)
218-
)
219-
},
279+
},
280+
}
220281
}
221282
}
222-
}
223283

224-
#[automatically_derived]
225-
impl #crate_name::__private::diesel::serialize::ToSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident {
226-
fn to_sql<'b>(&'b self, out: &mut #crate_name::__private::diesel::serialize::Output<'b, '_, #crate_name::__private::diesel::mysql::Mysql>) -> #crate_name::__private::diesel::serialize::Result {
227-
let s = match self {
228-
#(#to_byte_str_arms)*
229-
};
230-
#crate_name::__private::std::io::Write::write_all(out, s)?;
284+
#[automatically_derived]
285+
impl #crate_name::__private::diesel::serialize::ToSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident {
286+
fn to_sql<'b>(&'b self, out: &mut #crate_name::__private::diesel::serialize::Output<'b, '_, #crate_name::__private::diesel::mysql::Mysql>) -> #crate_name::__private::diesel::serialize::Result {
287+
let s = match self {
288+
#(#to_byte_str_arms)*
289+
};
290+
#crate_name::__private::std::io::Write::write_all(out, s)?;
231291

232-
#crate_name::__private::std::result::Result::Ok(#crate_name::__private::diesel::serialize::IsNull::No)
292+
#crate_name::__private::std::result::Result::Ok(#crate_name::__private::diesel::serialize::IsNull::No)
293+
}
233294
}
234-
}
235-
});
295+
});
296+
}
236297
}
237298
}
238299

239300
impl EnumVariant {
240-
fn gen_from_bytes(&self, rename_rule: RenameRule) -> impl ToTokens {
301+
fn gen_from_bytes(&self, has_fields: bool, rename_rule: RenameRule) -> impl ToTokens {
241302
let Self {
242303
original_name,
243304
original_name_span,
244305
rename,
306+
#[cfg(all(feature = "postgres", feature = "json"))]
307+
payload,
245308

246309
crate_name,
247310
} = self;
@@ -258,11 +321,13 @@ impl EnumVariant {
258321
}
259322
}
260323

261-
fn gen_to_byte_str(&self, rename_rule: RenameRule) -> impl ToTokens {
324+
fn gen_to_byte_str(&self, has_fields: bool, rename_rule: RenameRule) -> impl ToTokens {
262325
let Self {
263326
original_name,
264327
original_name_span,
265328
rename,
329+
#[cfg(all(feature = "postgres", feature = "json"))]
330+
payload,
266331

267332
crate_name: _,
268333
} = self;

benzina/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ utoipa = ["dep:utoipa"]
5151
example-generated = ["typed-uuid"]
5252
dangerous-construction = ["typed-uuid"]
5353

54-
json = ["postgres", "dep:serde_core", "dep:serde_json", "diesel/serde_json"]
54+
json = ["postgres", "benzina-derive?/json", "dep:serde_core", "dep:serde_json", "diesel/serde_json"]
5555
ctid = ["postgres", "diesel/i-implement-a-third-party-backend-and-opt-into-breaking-changes"]
5656

5757
[lints]

benzina/src/__private.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,35 @@ pub fn new_indexmap<K, V>() -> IndexMap<K, V> {
2323
IndexMap::with_hasher(Hasher::default())
2424
}
2525

26+
#[cfg(all(feature = "postgres", feature = "json"))]
27+
pub mod json {
28+
use diesel::{
29+
deserialize::{FromSql, FromSqlRow},
30+
expression::AsExpression,
31+
pg::{Pg, PgValue},
32+
serialize::ToSql,
33+
sql_types,
34+
};
35+
36+
use crate::json::convert::{sql_deserialize_binary_raw, sql_serialize_binary_raw};
37+
38+
#[derive(Debug, FromSqlRow, AsExpression)]
39+
#[diesel(sql_type = sql_types::Jsonb)]
40+
pub struct RawJsonb(Box<[u8]>);
41+
42+
impl FromSql<sql_types::Jsonb, Pg> for RawJsonb {
43+
fn from_sql(value: PgValue) -> diesel::deserialize::Result<Self> {
44+
sql_deserialize_binary_raw(&value).map(Box::from).map(Self)
45+
}
46+
}
47+
48+
impl ToSql<sql_types::Jsonb, Pg> for RawJsonb {
49+
fn to_sql(&self, out: &mut diesel::serialize::Output<Pg>) -> diesel::serialize::Result {
50+
sql_serialize_binary_raw(&self.0, out)
51+
}
52+
}
53+
}
54+
2655
pub mod deep_clone {
2756
pub trait DeepClone {
2857
type Output;

benzina/src/json/convert.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ where
5555
sql_serialize(value, out)
5656
}
5757

58+
pub(crate) fn sql_serialize_binary_raw(
59+
value: &[u8],
60+
out: &mut diesel::serialize::Output<'_, '_, Pg>,
61+
) -> diesel::serialize::Result {
62+
out.write_all(&[1])?;
63+
out.write_all(value)?;
64+
Ok(IsNull::No)
65+
}
66+
5867
pub(super) fn sql_deserialize<T>(value: PgValue<'_>) -> diesel::deserialize::Result<T>
5968
where
6069
T: DeserializeOwned,
@@ -66,6 +75,13 @@ pub(super) fn sql_deserialize_binary<T>(value: PgValue<'_>) -> diesel::deseriali
6675
where
6776
T: DeserializeOwned,
6877
{
78+
let bytes = sql_deserialize_binary_raw(&value)?;
79+
serde_json::from_slice(bytes).map_err(Into::into)
80+
}
81+
82+
pub(crate) fn sql_deserialize_binary_raw<'a>(
83+
value: &'a PgValue<'_>,
84+
) -> diesel::deserialize::Result<&'a [u8]> {
6985
let (version, bytes) = value
7086
.as_bytes()
7187
.split_first()
@@ -75,5 +91,5 @@ where
7591
return Err("Unsupported JSONB encoding version".into());
7692
}
7793

78-
serde_json::from_slice(bytes).map_err(Into::into)
94+
Ok(bytes)
7995
}

0 commit comments

Comments
 (0)