Skip to content

Commit 09ecc07

Browse files
committed
refactor(types): improve the Decimal cast by leveraging rust_decimal to avoid string allocation
Following PostgreSQL, the round_dp_with_strategy(..., MidpointAwayFromZero) gives identical semantics to DuckDB’s Decimal to Integer cast. Signed-off-by: Florian Valeye <[email protected]>
1 parent 2a19420 commit 09ecc07

File tree

1 file changed

+169
-2
lines changed

1 file changed

+169
-2
lines changed

crates/duckdb/src/types/from_sql.rs

Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::{error::Error, fmt};
22

33
use cast;
4+
use rust_decimal::RoundingStrategy::MidpointAwayFromZero;
45

56
use super::{TimeUnit, Value, ValueRef};
67

@@ -90,8 +91,11 @@ macro_rules! from_sql_integral(
9091
ValueRef::Float(i) => <$t as cast::From<f32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
9192
ValueRef::Double(i) => <$t as cast::From<f64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
9293

93-
// TODO: more efficient way?
94-
ValueRef::Decimal(i) => Ok(i.to_string().parse::<$t>().unwrap()),
94+
ValueRef::Decimal(d) => {
95+
// DuckDB rounds DECIMAL to INTEGER (following PostgreSQL behavior)
96+
let rounded = d.round_dp_with_strategy(0, MidpointAwayFromZero);
97+
<$t as cast::From<i128>>::cast(rounded.mantissa()).into_result(FromSqlError::OutOfRange(d.mantissa()))
98+
}
9599

96100
ValueRef::Timestamp(_, i) => <$t as cast::From<i64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
97101
ValueRef::Date32(i) => <$t as cast::From<i32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
@@ -414,4 +418,167 @@ mod test {
414418
assert_eq!(v.0.to_string(), "47183823-2574-4bfd-b411-99ed177d3e43");
415419
Ok(())
416420
}
421+
422+
#[test]
423+
fn test_decimal_to_integer() -> Result<()> {
424+
let db = Connection::open_in_memory()?;
425+
426+
assert_eq!(
427+
db.query_row("SELECT 0.1::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
428+
0
429+
);
430+
assert_eq!(
431+
db.query_row("SELECT 0.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
432+
0
433+
);
434+
assert_eq!(
435+
db.query_row("SELECT 0.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
436+
1
437+
);
438+
assert_eq!(
439+
db.query_row("SELECT 0.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
440+
1
441+
);
442+
assert_eq!(
443+
db.query_row("SELECT 0.9::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
444+
1
445+
);
446+
447+
assert_eq!(
448+
db.query_row("SELECT 1.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
449+
2
450+
);
451+
assert_eq!(
452+
db.query_row("SELECT 2.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
453+
3
454+
);
455+
assert_eq!(
456+
db.query_row("SELECT 3.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
457+
4
458+
);
459+
assert_eq!(
460+
db.query_row("SELECT 4.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
461+
5
462+
);
463+
assert_eq!(
464+
db.query_row("SELECT 5.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
465+
6
466+
);
467+
assert_eq!(
468+
db.query_row("SELECT 10.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
469+
11
470+
);
471+
assert_eq!(
472+
db.query_row("SELECT 99.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
473+
100
474+
);
475+
476+
assert_eq!(
477+
db.query_row("SELECT -0.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
478+
-1
479+
);
480+
assert_eq!(
481+
db.query_row("SELECT -1.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
482+
-2
483+
);
484+
assert_eq!(
485+
db.query_row("SELECT -2.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
486+
-3
487+
);
488+
assert_eq!(
489+
db.query_row("SELECT -3.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
490+
-4
491+
);
492+
assert_eq!(
493+
db.query_row("SELECT -4.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
494+
-5
495+
);
496+
497+
assert_eq!(
498+
db.query_row("SELECT -0.1::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
499+
0
500+
);
501+
assert_eq!(
502+
db.query_row("SELECT -0.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
503+
0
504+
);
505+
assert_eq!(
506+
db.query_row("SELECT -0.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
507+
-1
508+
);
509+
assert_eq!(
510+
db.query_row("SELECT -0.9::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
511+
-1
512+
);
513+
514+
assert_eq!(
515+
db.query_row("SELECT 999.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
516+
999
517+
);
518+
assert_eq!(
519+
db.query_row("SELECT 999.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
520+
1000
521+
);
522+
assert_eq!(
523+
db.query_row("SELECT 999.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
524+
1000
525+
);
526+
527+
assert_eq!(
528+
db.query_row("SELECT 123456.49::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
529+
123456
530+
);
531+
assert_eq!(
532+
db.query_row("SELECT 123456.50::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
533+
123457
534+
);
535+
assert_eq!(
536+
db.query_row("SELECT 123456.51::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
537+
123457
538+
);
539+
540+
assert_eq!(
541+
db.query_row("SELECT 0.49::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
542+
0
543+
);
544+
assert_eq!(
545+
db.query_row("SELECT 0.50::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
546+
1
547+
);
548+
assert_eq!(
549+
db.query_row("SELECT 0.51::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
550+
1
551+
);
552+
assert_eq!(
553+
db.query_row("SELECT -0.49::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
554+
0
555+
);
556+
assert_eq!(
557+
db.query_row("SELECT -0.50::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
558+
-1
559+
);
560+
assert_eq!(
561+
db.query_row("SELECT -0.51::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
562+
-1
563+
);
564+
565+
assert_eq!(
566+
db.query_row("SELECT 126.4::DECIMAL(5,1)", [], |row| row.get::<_, i8>(0))?,
567+
126
568+
);
569+
assert_eq!(
570+
db.query_row("SELECT 126.6::DECIMAL(5,1)", [], |row| row.get::<_, i8>(0))?,
571+
127
572+
);
573+
574+
let err = db
575+
.query_row("SELECT 999::DECIMAL(10,0)", [], |row| row.get::<_, i8>(0))
576+
.unwrap_err();
577+
match err {
578+
Error::IntegralValueOutOfRange(_, _) => {} // Expected
579+
_ => panic!("Expected IntegralValueOutOfRange error, got: {err}"),
580+
}
581+
582+
Ok(())
583+
}
417584
}

0 commit comments

Comments
 (0)