Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 227 additions & 69 deletions crates/duckdb/src/types/from_sql.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{error::Error, fmt};

use cast;
use rust_decimal::RoundingStrategy::MidpointAwayFromZero;

use super::{TimeUnit, Value, ValueRef};

Expand Down Expand Up @@ -76,46 +77,36 @@ macro_rules! from_sql_integral(
#[inline]
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
match value {
// TODO: Update all cast operation same to HugeInt
ValueRef::TinyInt(i) => Ok(<$t as cast::From<i8>>::cast(i).unwrap()),
ValueRef::SmallInt(i) => Ok(<$t as cast::From<i16>>::cast(i).unwrap()),
ValueRef::Int(i) => Ok(<$t as cast::From<i32>>::cast(i).unwrap()),
ValueRef::BigInt(i) => Ok(<$t as cast::From<i64>>::cast(i).unwrap()),
ValueRef::HugeInt(i) => {
let v = <$t as cast::From<i128>>::cast(i);
if v.is_ok() {
Ok(v.unwrap())
} else {
Err(FromSqlError::OutOfRange(i))
}
},

ValueRef::UTinyInt(i) => Ok(<$t as cast::From<u8>>::cast(i).unwrap()),
ValueRef::USmallInt(i) => Ok(<$t as cast::From<u16>>::cast(i).unwrap()),
ValueRef::UInt(i) => Ok(<$t as cast::From<u32>>::cast(i).unwrap()),
ValueRef::UBigInt(i) => Ok(<$t as cast::From<u64>>::cast(i).unwrap()),

ValueRef::Float(i) => Ok(<$t as cast::From<f32>>::cast(i).unwrap()),
ValueRef::Double(i) => Ok(<$t as cast::From<f64>>::cast(i).unwrap()),

// TODO: more efficient way?
ValueRef::Decimal(i) => Ok(i.to_string().parse::<$t>().unwrap()),

ValueRef::Timestamp(_, i) => Ok(<$t as cast::From<i64>>::cast(i).unwrap()),
ValueRef::Date32(i) => Ok(<$t as cast::From<i32>>::cast(i).unwrap()),
ValueRef::Time64(TimeUnit::Microsecond, i) => Ok(<$t as cast::From<i64>>::cast(i).unwrap()),
ValueRef::TinyInt(i) => <$t as cast::From<i8>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::SmallInt(i) => <$t as cast::From<i16>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::Int(i) => <$t as cast::From<i32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::BigInt(i) => <$t as cast::From<i64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::HugeInt(i) => <$t as cast::From<i128>>::cast(i).into_result(FromSqlError::OutOfRange(i)),

ValueRef::UTinyInt(i) => <$t as cast::From<u8>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::USmallInt(i) => <$t as cast::From<u16>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::UInt(i) => <$t as cast::From<u32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::UBigInt(i) => <$t as cast::From<u64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),

ValueRef::Float(i) => <$t as cast::From<f32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::Double(i) => <$t as cast::From<f64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),

ValueRef::Decimal(d) => {
// DuckDB rounds DECIMAL to INTEGER (following PostgreSQL behavior)
let rounded = d.round_dp_with_strategy(0, MidpointAwayFromZero);
<$t as cast::From<i128>>::cast(rounded.mantissa()).into_result(FromSqlError::OutOfRange(d.mantissa()))
}

ValueRef::Timestamp(_, i) => <$t as cast::From<i64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::Date32(i) => <$t as cast::From<i32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::Time64(TimeUnit::Microsecond, i) => <$t as cast::From<i64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
ValueRef::Text(_) => {
let v = value.as_str()?.parse::<$t>();
match v {
Ok(i) => Ok(i),
Err(_) => {
let v = value.as_str()?.parse::<i128>();
match v {
Ok(i) => Err(FromSqlError::OutOfRange(i)),
_ => Err(FromSqlError::InvalidType),
}
},
}
let s = value.as_str()?;
s.parse::<$t>().or_else(|_| {
s.parse::<i128>()
.map_err(|_| FromSqlError::InvalidType)
.and_then(|i| Err(FromSqlError::OutOfRange(i)))
})
}
_ => Err(FromSqlError::InvalidType),
}
Expand All @@ -124,45 +115,49 @@ macro_rules! from_sql_integral(
)
);

/// A trait for to implement unwrap method for primitive types
/// cast::From trait returns Result or the primitive, and for
/// Result we need to unwrap() for the column_result function
/// We implement unwrap() for all the primitive types so
/// We can always call unwrap() for the cast() function.
trait Unwrap {
fn unwrap(self) -> Self;
fn is_ok(&self) -> bool;
/// A trait to provide ok_or method for both Result and primitive types
/// cast::From trait returns Result or the primitive, depending on the types
trait IntoResult {
type Value;
fn into_result<E>(self, err: E) -> Result<Self::Value, E>;
}

macro_rules! unwrap_integral(
($t:ident) => (
impl Unwrap for $t {
#[inline]
fn unwrap(self) -> Self {
self
}
/// A macro to implement the IntoResult trait for all integral types
macro_rules! into_result_integral(
($type_name:ident) => (
impl IntoResult for $type_name {
type Value = $type_name;

#[inline]
fn is_ok(&self) -> bool {
true
fn into_result<E>(self, _err: E) -> Result<Self::Value, E> {
Ok(self)
}
}
)
);

unwrap_integral!(i8);
unwrap_integral!(i16);
unwrap_integral!(i32);
unwrap_integral!(i64);
unwrap_integral!(i128);
unwrap_integral!(isize);
unwrap_integral!(u8);
unwrap_integral!(u16);
unwrap_integral!(u32);
unwrap_integral!(u64);
unwrap_integral!(usize);
unwrap_integral!(f32);
unwrap_integral!(f64);
into_result_integral!(i8);
into_result_integral!(i16);
into_result_integral!(i32);
into_result_integral!(i64);
into_result_integral!(i128);
into_result_integral!(isize);
into_result_integral!(u8);
into_result_integral!(u16);
into_result_integral!(u32);
into_result_integral!(u64);
into_result_integral!(usize);
into_result_integral!(f32);
into_result_integral!(f64);

impl<T, E> IntoResult for Result<T, E> {
type Value = T;

#[inline]
fn into_result<E2>(self, err: E2) -> Result<Self::Value, E2> {
self.map_err(|_| err)
}
}

from_sql_integral!(i8);
from_sql_integral!(i16);
Expand Down Expand Up @@ -418,4 +413,167 @@ mod test {
assert_eq!(v.0.to_string(), "47183823-2574-4bfd-b411-99ed177d3e43");
Ok(())
}

#[test]
fn test_decimal_to_integer() -> Result<()> {
let db = Connection::open_in_memory()?;

assert_eq!(
db.query_row("SELECT 0.1::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
0
);
assert_eq!(
db.query_row("SELECT 0.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
0
);
assert_eq!(
db.query_row("SELECT 0.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
1
);
assert_eq!(
db.query_row("SELECT 0.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
1
);
assert_eq!(
db.query_row("SELECT 0.9::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
1
);

assert_eq!(
db.query_row("SELECT 1.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
2
);
assert_eq!(
db.query_row("SELECT 2.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
3
);
assert_eq!(
db.query_row("SELECT 3.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
4
);
assert_eq!(
db.query_row("SELECT 4.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
5
);
assert_eq!(
db.query_row("SELECT 5.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
6
);
assert_eq!(
db.query_row("SELECT 10.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
11
);
assert_eq!(
db.query_row("SELECT 99.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
100
);

assert_eq!(
db.query_row("SELECT -0.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
-1
);
assert_eq!(
db.query_row("SELECT -1.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
-2
);
assert_eq!(
db.query_row("SELECT -2.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
-3
);
assert_eq!(
db.query_row("SELECT -3.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
-4
);
assert_eq!(
db.query_row("SELECT -4.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
-5
);

assert_eq!(
db.query_row("SELECT -0.1::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
0
);
assert_eq!(
db.query_row("SELECT -0.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
0
);
assert_eq!(
db.query_row("SELECT -0.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
-1
);
assert_eq!(
db.query_row("SELECT -0.9::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
-1
);

assert_eq!(
db.query_row("SELECT 999.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
999
);
assert_eq!(
db.query_row("SELECT 999.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
1000
);
assert_eq!(
db.query_row("SELECT 999.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
1000
);

assert_eq!(
db.query_row("SELECT 123456.49::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
123456
);
assert_eq!(
db.query_row("SELECT 123456.50::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
123457
);
assert_eq!(
db.query_row("SELECT 123456.51::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
123457
);

assert_eq!(
db.query_row("SELECT 0.49::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
0
);
assert_eq!(
db.query_row("SELECT 0.50::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
1
);
assert_eq!(
db.query_row("SELECT 0.51::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
1
);
assert_eq!(
db.query_row("SELECT -0.49::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
0
);
assert_eq!(
db.query_row("SELECT -0.50::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
-1
);
assert_eq!(
db.query_row("SELECT -0.51::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
-1
);

assert_eq!(
db.query_row("SELECT 126.4::DECIMAL(5,1)", [], |row| row.get::<_, i8>(0))?,
126
);
assert_eq!(
db.query_row("SELECT 126.6::DECIMAL(5,1)", [], |row| row.get::<_, i8>(0))?,
127
);

let err = db
.query_row("SELECT 999::DECIMAL(10,0)", [], |row| row.get::<_, i8>(0))
.unwrap_err();
match err {
Error::IntegralValueOutOfRange(_, _) => {} // Expected
_ => panic!("Expected IntegralValueOutOfRange error, got: {err}"),
}

Ok(())
}
}