Skip to content

Commit c6220a5

Browse files
committed
feat: implement Array and NullableArray wrapper types
1 parent 82b4dcd commit c6220a5

File tree

4 files changed

+231
-1
lines changed

4 files changed

+231
-1
lines changed

benzina/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ serde_test = "1"
3737
uuid = { version = ">=0.7.0, <2.0.0", default-features = false, features = ["v4"] }
3838

3939
[features]
40-
default = ["derive"]
40+
default = ["derive", "array"]
4141
derive = ["dep:benzina-derive", "dep:diesel", "dep:indexmap"]
4242
rustc-hash = ["dep:rustc-hash"]
4343

@@ -51,6 +51,7 @@ utoipa = ["dep:utoipa"]
5151
example-generated = ["typed-uuid"]
5252
dangerous-construction = ["typed-uuid"]
5353

54+
array = ["postgres"]
5455
json = ["postgres", "dep:serde_core", "dep:serde_json", "diesel/serde_json"]
5556
ctid = ["postgres", "diesel/i-implement-a-third-party-backend-and-opt-into-breaking-changes"]
5657

benzina/src/array.rs

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
use std::fmt::Debug;
2+
3+
use diesel::{
4+
deserialize::{FromSql, FromSqlRow},
5+
expression::{AppearsOnTable, Expression, SelectableExpression},
6+
pg::{Pg, PgValue},
7+
query_builder::{AstPass, QueryFragment, QueryId},
8+
result::QueryResult,
9+
serialize::ToSql,
10+
sql_types::{self, BigInt, Bool, Double, Float, Integer, Nullable, SmallInt, Text},
11+
};
12+
13+
use crate::{U15, U31, U63, error::InvalidArray};
14+
15+
/// A diesel [Array](diesel::pg::sql_types::Array) serialization and deserialization wrapper that enforces every item is not null
16+
///
17+
/// See [`array_wrapped`](crate::array_wrapped) for an usage example.
18+
#[derive(Debug, FromSqlRow)]
19+
pub struct Array<T, const N: usize>([T; N]);
20+
impl<T, const N: usize> Array<T, N> {
21+
pub fn new(values: [T; N]) -> Self {
22+
Self(values)
23+
}
24+
25+
pub fn get(self) -> [T; N] {
26+
self.0
27+
}
28+
}
29+
30+
impl<T, const N: usize> From<[T; N]> for Array<T, N> {
31+
fn from(values: [T; N]) -> Self {
32+
Self(values)
33+
}
34+
}
35+
36+
impl<T, const N: usize> From<Array<T, N>> for [T; N] {
37+
fn from(value: Array<T, N>) -> Self {
38+
value.0
39+
}
40+
}
41+
42+
/// A diesel [Array](diesel::pg::sql_types::Array) serialization and deserialization wrapper
43+
///
44+
/// See [`array_wrapped`](crate::array_wrapped) for an usage example.
45+
#[derive(Debug, FromSqlRow)]
46+
pub struct NullableArray<T, const N: usize>([Option<T>; N]);
47+
impl<T, const N: usize> NullableArray<T, N> {
48+
pub fn new(values: [Option<T>; N]) -> Self {
49+
Self(values)
50+
}
51+
52+
pub fn get(self) -> [Option<T>; N] {
53+
self.0
54+
}
55+
}
56+
57+
macro_rules! impl_array {
58+
(
59+
$(
60+
$rust_type:tt $(< $generic:ident >)? => $diesel_type:ident
61+
),*
62+
) => {
63+
$(
64+
impl<$($generic,)? const N: usize> Expression for Array<$rust_type$(<$generic>)?, N> {
65+
type SqlType = sql_types::Array<Nullable<$diesel_type>>;
66+
}
67+
68+
impl<$($generic,)? const N: usize> QueryId for Array<$rust_type$(<$generic>)?, N> {
69+
type QueryId = <sql_types::Array<Nullable<$diesel_type>> as QueryId>::QueryId;
70+
71+
const HAS_STATIC_QUERY_ID: bool = <sql_types::Array<Nullable<$diesel_type>> as QueryId>::HAS_STATIC_QUERY_ID;
72+
}
73+
74+
impl<$($generic: Debug + std::clone::Clone,)? const N: usize> QueryFragment<Pg> for Array<$rust_type$(<$generic>)?, N>
75+
{
76+
fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
77+
pass.push_bind_param(self)?;
78+
Ok(())
79+
}
80+
}
81+
82+
impl<__QS, $($generic,)? const N: usize> AppearsOnTable<__QS> for Array<$rust_type$(<$generic>)?, N> {}
83+
84+
impl<__QS, $($generic,)? const N: usize> SelectableExpression<__QS> for Array<$rust_type$(<$generic>)?, N> {}
85+
86+
impl<$($generic: Debug + std::clone::Clone,)? const N: usize> ToSql<sql_types::Array<Nullable<$diesel_type>>, Pg> for Array<$rust_type$(<$generic>)?, N>
87+
{
88+
fn to_sql<'b>(
89+
&'b self,
90+
out: &mut diesel::serialize::Output<'b, '_, Pg>,
91+
) -> diesel::serialize::Result {
92+
<[$rust_type $(< $generic >)?] as ToSql<sql_types::Array<$diesel_type>, Pg>>::to_sql(&self.0.as_slice(), out)
93+
}
94+
}
95+
96+
impl<$($generic: Debug,)? const N: usize> FromSql<sql_types::Array<Nullable<$diesel_type>>, Pg> for Array<$rust_type$(< $generic >)?, N>
97+
{
98+
fn from_sql(bytes: PgValue<'_>) -> diesel::deserialize::Result<Self> {
99+
let raw = <Vec<Option<$rust_type $(< $generic >)?>> as FromSql<sql_types::Array<Nullable<$diesel_type>>, Pg>>::from_sql(bytes)?;
100+
101+
let res: [$rust_type $(< $generic >)?; N] = raw
102+
.into_iter()
103+
.collect::<Option<Vec<$rust_type $(< $generic >)?>>>()
104+
.ok_or(diesel::result::Error::DeserializationError(Box::new(
105+
InvalidArray::UnexpectedNullValue,
106+
)))?
107+
.try_into()
108+
.map_err(|_| {
109+
diesel::result::Error::DeserializationError(Box::new(
110+
InvalidArray::UnexpectedLength,
111+
))
112+
})?;
113+
114+
Ok(Self(res))
115+
}
116+
}
117+
)*
118+
}
119+
}
120+
121+
macro_rules! impl_nullable_array {
122+
(
123+
$(
124+
$rust_type:tt $(< $generic:ident >)? => $diesel_type:ident
125+
),*
126+
) => {
127+
$(
128+
impl<$($generic,)? const N: usize> Expression for NullableArray<$rust_type$(<$generic>)?, N> {
129+
type SqlType = sql_types::Array<Nullable<$diesel_type>>;
130+
}
131+
132+
impl<$($generic,)? const N: usize> QueryId for NullableArray<$rust_type$(<$generic>)?, N> {
133+
type QueryId = <sql_types::Array<Nullable<$diesel_type>> as QueryId>::QueryId;
134+
135+
const HAS_STATIC_QUERY_ID: bool = <sql_types::Array<Nullable<$diesel_type>> as QueryId>::HAS_STATIC_QUERY_ID;
136+
}
137+
138+
impl<$($generic: Debug + std::clone::Clone,)? const N: usize> QueryFragment<Pg> for NullableArray<$rust_type$(<$generic>)?, N>
139+
{
140+
fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
141+
pass.push_bind_param(self)?;
142+
Ok(())
143+
}
144+
}
145+
146+
impl<__QS, $($generic,)? const N: usize> AppearsOnTable<__QS> for NullableArray<$rust_type$(<$generic>)?, N> {}
147+
impl<__QS, $($generic,)? const N: usize> SelectableExpression<__QS> for NullableArray<$rust_type$(<$generic>)?, N> {}
148+
149+
150+
impl<$($generic: Debug + std::clone::Clone,)? const N: usize> ToSql<sql_types::Array<Nullable<$diesel_type>>, Pg> for NullableArray<$rust_type$(<$generic>)?, N>
151+
{
152+
fn to_sql<'b>(
153+
&'b self,
154+
out: &mut diesel::serialize::Output<'b, '_, Pg>,
155+
) -> diesel::serialize::Result {
156+
<[Option<$rust_type$(< $generic >)?>] as ToSql<sql_types::Array<Nullable<$diesel_type>>, Pg>>::to_sql(self.0.as_slice(), out)
157+
}
158+
}
159+
160+
impl<$($generic: Debug,)? const N: usize> FromSql<sql_types::Array<Nullable<$diesel_type>>, Pg> for NullableArray<$rust_type$(< $generic >)?, N>
161+
{
162+
fn from_sql(bytes: PgValue<'_>) -> diesel::deserialize::Result<Self> {
163+
let raw = <Vec<Option<$rust_type$(< $generic >)?>> as FromSql<sql_types::Array<Nullable<$diesel_type>>, Pg>>::from_sql(bytes)?;
164+
165+
let res: [Option<$rust_type $(< $generic >)?>; N] = raw
166+
.try_into()
167+
.map_err(|_| {
168+
diesel::result::Error::DeserializationError(Box::new(
169+
InvalidArray::UnexpectedLength,
170+
))
171+
})?;
172+
173+
Ok(Self(res))
174+
}
175+
}
176+
)*
177+
};
178+
}
179+
180+
impl_array! {
181+
U15 => SmallInt,
182+
U31 => Integer,
183+
U63 => BigInt,
184+
i16 => SmallInt,
185+
i32 => Integer,
186+
i64 => BigInt,
187+
f32 => Float,
188+
f64 => Double,
189+
bool => Bool,
190+
String => Text
191+
}
192+
193+
impl_nullable_array! {
194+
U15 => SmallInt,
195+
U31 => Integer,
196+
U63 => BigInt,
197+
i16 => SmallInt,
198+
i32 => Integer,
199+
i64 => BigInt,
200+
f32 => Float,
201+
f64 => Double,
202+
bool => Bool,
203+
String => Text
204+
}

benzina/src/error.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,24 @@ impl Error for ParseIntError {
3838
}
3939
}
4040
}
41+
42+
#[derive(Debug, Clone)]
43+
pub enum InvalidArray {
44+
UnexpectedLength,
45+
UnexpectedNullValue,
46+
}
47+
48+
impl Display for InvalidArray {
49+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50+
f.write_str(match self {
51+
Self::UnexpectedLength => "mismatched array length",
52+
Self::UnexpectedNullValue => "the array contains an unexpected null value",
53+
})
54+
}
55+
}
56+
57+
impl Error for InvalidArray {
58+
fn source(&self) -> Option<&(dyn Error + 'static)> {
59+
None
60+
}
61+
}

benzina/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
22

3+
#[cfg(feature = "array")]
4+
pub use self::array::{Array, NullableArray};
35
#[cfg(feature = "ctid")]
46
pub use self::ctid::{Ctid, ctid};
57
#[cfg(feature = "postgres")]
@@ -15,6 +17,8 @@ pub use benzina_derive::{Enum, join};
1517

1618
#[doc(hidden)]
1719
pub mod __private;
20+
#[cfg(feature = "array")]
21+
mod array;
1822
#[cfg(feature = "ctid")]
1923
mod ctid;
2024
#[cfg(feature = "postgres")]

0 commit comments

Comments
 (0)