diff --git a/crates/duckdb/src/types/from_sql.rs b/crates/duckdb/src/types/from_sql.rs index 4065243d..e3116e77 100644 --- a/crates/duckdb/src/types/from_sql.rs +++ b/crates/duckdb/src/types/from_sql.rs @@ -1,6 +1,7 @@ use std::{error::Error, fmt}; use cast; +use rust_decimal::RoundingStrategy::MidpointAwayFromZero; use super::{TimeUnit, Value, ValueRef}; @@ -76,46 +77,36 @@ macro_rules! from_sql_integral( #[inline] fn column_result(value: ValueRef<'_>) -> FromSqlResult { match value { - // TODO: Update all cast operation same to HugeInt - ValueRef::TinyInt(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - ValueRef::SmallInt(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - ValueRef::Int(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - ValueRef::BigInt(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - ValueRef::HugeInt(i) => { - let v = <$t as cast::From>::cast(i); - if v.is_ok() { - Ok(v.unwrap()) - } else { - Err(FromSqlError::OutOfRange(i)) - } - }, - - ValueRef::UTinyInt(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - ValueRef::USmallInt(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - ValueRef::UInt(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - ValueRef::UBigInt(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - - ValueRef::Float(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - ValueRef::Double(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - - // TODO: more efficient way? - ValueRef::Decimal(i) => Ok(i.to_string().parse::<$t>().unwrap()), - - ValueRef::Timestamp(_, i) => Ok(<$t as cast::From>::cast(i).unwrap()), - ValueRef::Date32(i) => Ok(<$t as cast::From>::cast(i).unwrap()), - ValueRef::Time64(TimeUnit::Microsecond, i) => Ok(<$t as cast::From>::cast(i).unwrap()), + ValueRef::TinyInt(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + ValueRef::SmallInt(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + ValueRef::Int(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + ValueRef::BigInt(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + ValueRef::HugeInt(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i)), + + ValueRef::UTinyInt(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + ValueRef::USmallInt(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + ValueRef::UInt(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + ValueRef::UBigInt(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + + ValueRef::Float(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + ValueRef::Double(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + + ValueRef::Decimal(d) => { + // DuckDB rounds DECIMAL to INTEGER (following PostgreSQL behavior) + let rounded = d.round_dp_with_strategy(0, MidpointAwayFromZero); + <$t as cast::From>::cast(rounded.mantissa()).into_result(FromSqlError::OutOfRange(d.mantissa())) + } + + ValueRef::Timestamp(_, i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + ValueRef::Date32(i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), + ValueRef::Time64(TimeUnit::Microsecond, i) => <$t as cast::From>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)), ValueRef::Text(_) => { - let v = value.as_str()?.parse::<$t>(); - match v { - Ok(i) => Ok(i), - Err(_) => { - let v = value.as_str()?.parse::(); - match v { - Ok(i) => Err(FromSqlError::OutOfRange(i)), - _ => Err(FromSqlError::InvalidType), - } - }, - } + let s = value.as_str()?; + s.parse::<$t>().or_else(|_| { + s.parse::() + .map_err(|_| FromSqlError::InvalidType) + .and_then(|i| Err(FromSqlError::OutOfRange(i))) + }) } _ => Err(FromSqlError::InvalidType), } @@ -124,45 +115,49 @@ macro_rules! from_sql_integral( ) ); -/// A trait for to implement unwrap method for primitive types -/// cast::From trait returns Result or the primitive, and for -/// Result we need to unwrap() for the column_result function -/// We implement unwrap() for all the primitive types so -/// We can always call unwrap() for the cast() function. -trait Unwrap { - fn unwrap(self) -> Self; - fn is_ok(&self) -> bool; +/// A trait to provide ok_or method for both Result and primitive types +/// cast::From trait returns Result or the primitive, depending on the types +trait IntoResult { + type Value; + fn into_result(self, err: E) -> Result; } -macro_rules! unwrap_integral( - ($t:ident) => ( - impl Unwrap for $t { - #[inline] - fn unwrap(self) -> Self { - self - } +/// A macro to implement the IntoResult trait for all integral types +macro_rules! into_result_integral( + ($type_name:ident) => ( + impl IntoResult for $type_name { + type Value = $type_name; #[inline] - fn is_ok(&self) -> bool { - true + fn into_result(self, _err: E) -> Result { + Ok(self) } } ) ); -unwrap_integral!(i8); -unwrap_integral!(i16); -unwrap_integral!(i32); -unwrap_integral!(i64); -unwrap_integral!(i128); -unwrap_integral!(isize); -unwrap_integral!(u8); -unwrap_integral!(u16); -unwrap_integral!(u32); -unwrap_integral!(u64); -unwrap_integral!(usize); -unwrap_integral!(f32); -unwrap_integral!(f64); +into_result_integral!(i8); +into_result_integral!(i16); +into_result_integral!(i32); +into_result_integral!(i64); +into_result_integral!(i128); +into_result_integral!(isize); +into_result_integral!(u8); +into_result_integral!(u16); +into_result_integral!(u32); +into_result_integral!(u64); +into_result_integral!(usize); +into_result_integral!(f32); +into_result_integral!(f64); + +impl IntoResult for Result { + type Value = T; + + #[inline] + fn into_result(self, err: E2) -> Result { + self.map_err(|_| err) + } +} from_sql_integral!(i8); from_sql_integral!(i16); @@ -418,4 +413,167 @@ mod test { assert_eq!(v.0.to_string(), "47183823-2574-4bfd-b411-99ed177d3e43"); Ok(()) } + + #[test] + fn test_decimal_to_integer() -> Result<()> { + let db = Connection::open_in_memory()?; + + assert_eq!( + db.query_row("SELECT 0.1::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 0 + ); + assert_eq!( + db.query_row("SELECT 0.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 0 + ); + assert_eq!( + db.query_row("SELECT 0.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 1 + ); + assert_eq!( + db.query_row("SELECT 0.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 1 + ); + assert_eq!( + db.query_row("SELECT 0.9::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 1 + ); + + assert_eq!( + db.query_row("SELECT 1.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 2 + ); + assert_eq!( + db.query_row("SELECT 2.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 3 + ); + assert_eq!( + db.query_row("SELECT 3.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 4 + ); + assert_eq!( + db.query_row("SELECT 4.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 5 + ); + assert_eq!( + db.query_row("SELECT 5.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 6 + ); + assert_eq!( + db.query_row("SELECT 10.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 11 + ); + assert_eq!( + db.query_row("SELECT 99.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 100 + ); + + assert_eq!( + db.query_row("SELECT -0.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + -1 + ); + assert_eq!( + db.query_row("SELECT -1.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + -2 + ); + assert_eq!( + db.query_row("SELECT -2.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + -3 + ); + assert_eq!( + db.query_row("SELECT -3.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + -4 + ); + assert_eq!( + db.query_row("SELECT -4.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + -5 + ); + + assert_eq!( + db.query_row("SELECT -0.1::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 0 + ); + assert_eq!( + db.query_row("SELECT -0.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 0 + ); + assert_eq!( + db.query_row("SELECT -0.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + -1 + ); + assert_eq!( + db.query_row("SELECT -0.9::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + -1 + ); + + assert_eq!( + db.query_row("SELECT 999.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 999 + ); + assert_eq!( + db.query_row("SELECT 999.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 1000 + ); + assert_eq!( + db.query_row("SELECT 999.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 1000 + ); + + assert_eq!( + db.query_row("SELECT 123456.49::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?, + 123456 + ); + assert_eq!( + db.query_row("SELECT 123456.50::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?, + 123457 + ); + assert_eq!( + db.query_row("SELECT 123456.51::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?, + 123457 + ); + + assert_eq!( + db.query_row("SELECT 0.49::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 0 + ); + assert_eq!( + db.query_row("SELECT 0.50::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 1 + ); + assert_eq!( + db.query_row("SELECT 0.51::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 1 + ); + assert_eq!( + db.query_row("SELECT -0.49::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + 0 + ); + assert_eq!( + db.query_row("SELECT -0.50::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + -1 + ); + assert_eq!( + db.query_row("SELECT -0.51::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?, + -1 + ); + + assert_eq!( + db.query_row("SELECT 126.4::DECIMAL(5,1)", [], |row| row.get::<_, i8>(0))?, + 126 + ); + assert_eq!( + db.query_row("SELECT 126.6::DECIMAL(5,1)", [], |row| row.get::<_, i8>(0))?, + 127 + ); + + let err = db + .query_row("SELECT 999::DECIMAL(10,0)", [], |row| row.get::<_, i8>(0)) + .unwrap_err(); + match err { + Error::IntegralValueOutOfRange(_, _) => {} // Expected + _ => panic!("Expected IntegralValueOutOfRange error, got: {err}"), + } + + Ok(()) + } }