Skip to content

Commit a247fbd

Browse files
committed
feat: implement Array and NullableArray wrapper types
1 parent 82b4dcd commit a247fbd

File tree

4 files changed

+234
-1
lines changed

4 files changed

+234
-1
lines changed

benzina/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ serde_test = "1"
3737
uuid = { version = ">=0.7.0, <2.0.0", default-features = false, features = ["v4"] }
3838

3939
[features]
40-
default = ["derive"]
40+
default = ["derive", "array"]
4141
derive = ["dep:benzina-derive", "dep:diesel", "dep:indexmap"]
4242
rustc-hash = ["dep:rustc-hash"]
4343

@@ -51,6 +51,7 @@ utoipa = ["dep:utoipa"]
5151
example-generated = ["typed-uuid"]
5252
dangerous-construction = ["typed-uuid"]
5353

54+
array = ["postgres"]
5455
json = ["postgres", "dep:serde_core", "dep:serde_json", "diesel/serde_json"]
5556
ctid = ["postgres", "diesel/i-implement-a-third-party-backend-and-opt-into-breaking-changes"]
5657

benzina/src/array.rs

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
use std::fmt::Debug;
2+
3+
use diesel::{
4+
deserialize::{FromSql, FromSqlRow},
5+
expression::{AppearsOnTable, Expression, SelectableExpression},
6+
pg::{Pg, PgValue},
7+
query_builder::{AstPass, QueryFragment, QueryId},
8+
result::QueryResult,
9+
serialize::ToSql,
10+
sql_types::{
11+
self, BigInt, Bool, Double, Float, Integer, Nullable,
12+
SmallInt, Text,
13+
},
14+
};
15+
16+
use crate::{U15, U31, U63, error::InvalidArray};
17+
18+
/// A diesel [Array](diesel::pg::sql_types::Array) serialization and deserialization wrapper that enforces every item is not null
19+
///
20+
/// See [`array_wrapped`](crate::array_wrapped) for an usage example.
21+
#[derive(Debug, FromSqlRow)]
22+
pub struct Array<T, const N: usize>([T; N]);
23+
impl<T, const N: usize> Array<T, N> {
24+
pub fn new(values: [T; N]) -> Self {
25+
Self(values)
26+
}
27+
28+
pub fn get(self) -> [T; N] {
29+
self.0
30+
}
31+
}
32+
33+
impl<T, const N: usize> From<[T; N]> for Array<T, N> {
34+
fn from(values: [T; N]) -> Self {
35+
Self(values)
36+
}
37+
}
38+
39+
impl<T, const N: usize> From<Array<T, N>> for [T; N] {
40+
fn from(value: Array<T, N>) -> Self {
41+
value.0
42+
}
43+
}
44+
45+
/// A diesel [Array](diesel::pg::sql_types::Array) serialization and deserialization wrapper
46+
///
47+
/// See [`array_wrapped`](crate::array_wrapped) for an usage example.
48+
#[derive(Debug, FromSqlRow)]
49+
pub struct NullableArray<T, const N: usize>([Option<T>; N]);
50+
impl<T, const N: usize> NullableArray<T, N> {
51+
pub fn new(values: [Option<T>; N]) -> Self {
52+
Self(values)
53+
}
54+
55+
pub fn get(self) -> [Option<T>; N] {
56+
self.0
57+
}
58+
}
59+
60+
macro_rules! impl_array {
61+
(
62+
$(
63+
$rust_type:tt $(< $generic:ident >)? => $diesel_type:ident
64+
),*
65+
) => {
66+
$(
67+
impl<$($generic,)? const N: usize> Expression for Array<$rust_type$(<$generic>)?, N> {
68+
type SqlType = sql_types::Array<Nullable<$diesel_type>>;
69+
}
70+
71+
impl<$($generic,)? const N: usize> QueryId for Array<$rust_type$(<$generic>)?, N> {
72+
type QueryId = <sql_types::Array<Nullable<$diesel_type>> as QueryId>::QueryId;
73+
74+
const HAS_STATIC_QUERY_ID: bool = <sql_types::Array<Nullable<$diesel_type>> as QueryId>::HAS_STATIC_QUERY_ID;
75+
}
76+
77+
impl<$($generic: Debug + std::clone::Clone,)? const N: usize> QueryFragment<Pg> for Array<$rust_type$(<$generic>)?, N>
78+
{
79+
fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
80+
pass.push_bind_param(self)?;
81+
Ok(())
82+
}
83+
}
84+
85+
impl<__QS, $($generic,)? const N: usize> AppearsOnTable<__QS> for Array<$rust_type$(<$generic>)?, N> {}
86+
87+
impl<__QS, $($generic,)? const N: usize> SelectableExpression<__QS> for Array<$rust_type$(<$generic>)?, N> {}
88+
89+
impl<$($generic: Debug + std::clone::Clone,)? const N: usize> ToSql<sql_types::Array<Nullable<$diesel_type>>, Pg> for Array<$rust_type$(<$generic>)?, N>
90+
{
91+
fn to_sql<'b>(
92+
&'b self,
93+
out: &mut diesel::serialize::Output<'b, '_, Pg>,
94+
) -> diesel::serialize::Result {
95+
<[$rust_type $(< $generic >)?] as ToSql<sql_types::Array<$diesel_type>, Pg>>::to_sql(&self.0.as_slice(), out)
96+
}
97+
}
98+
99+
impl<$($generic: Debug,)? const N: usize> FromSql<sql_types::Array<Nullable<$diesel_type>>, Pg> for Array<$rust_type$(< $generic >)?, N>
100+
{
101+
fn from_sql(bytes: PgValue<'_>) -> diesel::deserialize::Result<Self> {
102+
let raw = <Vec<Option<$rust_type $(< $generic >)?>> as FromSql<sql_types::Array<Nullable<$diesel_type>>, Pg>>::from_sql(bytes)?;
103+
104+
let res: [$rust_type $(< $generic >)?; N] = raw
105+
.into_iter()
106+
.collect::<Option<Vec<$rust_type $(< $generic >)?>>>()
107+
.ok_or(diesel::result::Error::DeserializationError(Box::new(
108+
InvalidArray::UnexpectedNullValue,
109+
)))?
110+
.try_into()
111+
.map_err(|_| {
112+
diesel::result::Error::DeserializationError(Box::new(
113+
InvalidArray::UnexpectedLength,
114+
))
115+
})?;
116+
117+
Ok(Self(res))
118+
}
119+
}
120+
)*
121+
}
122+
}
123+
124+
macro_rules! impl_nullable_array {
125+
(
126+
$(
127+
$rust_type:tt $(< $generic:ident >)? => $diesel_type:ident
128+
),*
129+
) => {
130+
$(
131+
impl<$($generic,)? const N: usize> Expression for NullableArray<$rust_type$(<$generic>)?, N> {
132+
type SqlType = sql_types::Array<Nullable<$diesel_type>>;
133+
}
134+
135+
impl<$($generic,)? const N: usize> QueryId for NullableArray<$rust_type$(<$generic>)?, N> {
136+
type QueryId = <sql_types::Array<Nullable<$diesel_type>> as QueryId>::QueryId;
137+
138+
const HAS_STATIC_QUERY_ID: bool = <sql_types::Array<Nullable<$diesel_type>> as QueryId>::HAS_STATIC_QUERY_ID;
139+
}
140+
141+
impl<$($generic: Debug + std::clone::Clone,)? const N: usize> QueryFragment<Pg> for NullableArray<$rust_type$(<$generic>)?, N>
142+
{
143+
fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
144+
pass.push_bind_param(self)?;
145+
Ok(())
146+
}
147+
}
148+
149+
impl<__QS, $($generic,)? const N: usize> AppearsOnTable<__QS> for NullableArray<$rust_type$(<$generic>)?, N> {}
150+
impl<__QS, $($generic,)? const N: usize> SelectableExpression<__QS> for NullableArray<$rust_type$(<$generic>)?, N> {}
151+
152+
153+
impl<$($generic: Debug + std::clone::Clone,)? const N: usize> ToSql<sql_types::Array<Nullable<$diesel_type>>, Pg> for NullableArray<$rust_type$(<$generic>)?, N>
154+
{
155+
fn to_sql<'b>(
156+
&'b self,
157+
out: &mut diesel::serialize::Output<'b, '_, Pg>,
158+
) -> diesel::serialize::Result {
159+
<[Option<$rust_type$(< $generic >)?>] as ToSql<sql_types::Array<Nullable<$diesel_type>>, Pg>>::to_sql(self.0.as_slice(), out)
160+
}
161+
}
162+
163+
impl<$($generic: Debug,)? const N: usize> FromSql<sql_types::Array<Nullable<$diesel_type>>, Pg> for NullableArray<$rust_type$(< $generic >)?, N>
164+
{
165+
fn from_sql(bytes: PgValue<'_>) -> diesel::deserialize::Result<Self> {
166+
let raw = <Vec<Option<$rust_type$(< $generic >)?>> as FromSql<sql_types::Array<Nullable<$diesel_type>>, Pg>>::from_sql(bytes)?;
167+
168+
let res: [Option<$rust_type $(< $generic >)?>; N] = raw
169+
.try_into()
170+
.map_err(|_| {
171+
diesel::result::Error::DeserializationError(Box::new(
172+
InvalidArray::UnexpectedLength,
173+
))
174+
})?;
175+
176+
Ok(Self(res))
177+
}
178+
}
179+
)*
180+
};
181+
}
182+
183+
impl_array! {
184+
U15 => SmallInt,
185+
U31 => Integer,
186+
U63 => BigInt,
187+
i16 => SmallInt,
188+
i32 => Integer,
189+
i64 => BigInt,
190+
f32 => Float,
191+
f64 => Double,
192+
bool => Bool,
193+
String => Text
194+
}
195+
196+
impl_nullable_array! {
197+
U15 => SmallInt,
198+
U31 => Integer,
199+
U63 => BigInt,
200+
i16 => SmallInt,
201+
i32 => Integer,
202+
i64 => BigInt,
203+
f32 => Float,
204+
f64 => Double,
205+
bool => Bool,
206+
String => Text
207+
}

benzina/src/error.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,24 @@ impl Error for ParseIntError {
3838
}
3939
}
4040
}
41+
42+
#[derive(Debug, Clone)]
43+
pub enum InvalidArray {
44+
UnexpectedLength,
45+
UnexpectedNullValue,
46+
}
47+
48+
impl Display for InvalidArray {
49+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50+
f.write_str(match self {
51+
Self::UnexpectedLength => "mismatched array length",
52+
Self::UnexpectedNullValue => "the array contains an unexpected null value",
53+
})
54+
}
55+
}
56+
57+
impl Error for InvalidArray {
58+
fn source(&self) -> Option<&(dyn Error + 'static)> {
59+
None
60+
}
61+
}

benzina/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#[cfg(feature = "ctid")]
44
pub use self::ctid::{Ctid, ctid};
5+
#[cfg(feature = "array")]
6+
pub use self::array::{Array, NullableArray};
57
#[cfg(feature = "postgres")]
68
pub use self::int::{U15, U31, U63};
79
#[cfg(feature = "json")]
@@ -17,6 +19,8 @@ pub use benzina_derive::{Enum, join};
1719
pub mod __private;
1820
#[cfg(feature = "ctid")]
1921
mod ctid;
22+
#[cfg(feature = "array")]
23+
mod array;
2024
#[cfg(feature = "postgres")]
2125
pub mod error;
2226
#[cfg(feature = "example-generated")]

0 commit comments

Comments
 (0)