Skip to content

Commit c2747eb

Browse files
Mark1626alamb
andauthored
feat: Support log for Decimal32 and Decimal64 (#18999)
## Which issue does this PR close? - Part of #17555 . ## Rationale for this change ### Analysis Other engines: 1. Clickhouse seems to only consider `"(U)Int*", "Float*", "Decimal*"` as arguments for log https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/log.cpp#L47-L63 Libraries 1. There a C++ library libdecimal which internally uses [Intel Decimal Floating Point Library](https://www.intel.com/content/www/us/en/developer/articles/tool/intel-decimal-floating-point-math-library.html) for it's [decimal32](https://github.com/GaryHughes/stddecimal/blob/main/libdecimal/decimal_cmath.cpp#L150-L159) operations. Intel's library itself converts the decimal32 to double and calls `log`. https://github.com/karlorz/IntelRDFPMathLib20U2/blob/main/LIBRARY/src/bid32_log.c 2. There was another C++ library based on IBM's decimal decNumber library https://github.com/semihc/CppDecimal . This one's implementation of [`log`](https://github.com/semihc/CppDecimal/blob/main/src/decNumber.c#L1384-L1518) is fully using decimal, but I don't think this would be very performant way to do this I'm going to go with an approach similar to the one inside Intel's decimal library. To begin with the `decimal32 -> double` is done by a simple scaling ## What changes are included in this PR? 1. Support Decimal32 for log ## Are these changes tested? Yes, unit tests have been added, and I've tested this from the datafusion cli for Decimal32 ``` > select log(2.0, arrow_cast(12345.67, 'Decimal32(9, 2)')); +-----------------------------------------------------------------------+ | log(Float64(2),arrow_cast(Float64(12345.67),Utf8("Decimal32(9, 2)"))) | +-----------------------------------------------------------------------+ | 13.591717513271785 | +-----------------------------------------------------------------------+ 1 row(s) fetched. Elapsed 0.021 seconds. ``` ## Are there any user-facing changes? 1. The precision of the result for Decimal32 will change, the precision loss in #18524 does not occur in this PR --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent d59ebac commit c2747eb

File tree

3 files changed

+212
-11
lines changed

3 files changed

+212
-11
lines changed

datafusion/functions/src/math/log.rs

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ use std::any::Any;
2121

2222
use super::power::PowerFunc;
2323

24-
use crate::utils::{calculate_binary_math, decimal128_to_i128};
24+
use crate::utils::{
25+
calculate_binary_math, decimal32_to_i32, decimal64_to_i64, decimal128_to_i128,
26+
};
2527
use arrow::array::{Array, ArrayRef};
26-
use arrow::compute::kernels::cast;
2728
use arrow::datatypes::{
28-
DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type,
29+
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type,
30+
Float32Type, Float64Type,
2931
};
3032
use arrow::error::ArrowError;
3133
use arrow_buffer::i256;
@@ -102,6 +104,54 @@ impl LogFunc {
102104
}
103105
}
104106

107+
/// Binary function to calculate logarithm of Decimal32 `value` using `base` base
108+
/// Returns error if base is invalid
109+
fn log_decimal32(value: i32, scale: i8, base: f64) -> Result<f64, ArrowError> {
110+
if !base.is_finite() || base.trunc() != base {
111+
return Err(ArrowError::ComputeError(format!(
112+
"Log cannot use non-integer base: {base}"
113+
)));
114+
}
115+
if (base as u32) < 2 {
116+
return Err(ArrowError::ComputeError(format!(
117+
"Log base must be greater than 1: {base}"
118+
)));
119+
}
120+
121+
let unscaled_value = decimal32_to_i32(value, scale)?;
122+
if unscaled_value > 0 {
123+
let log_value: u32 = unscaled_value.ilog(base as i32);
124+
Ok(log_value as f64)
125+
} else {
126+
// Reflect f64::log behaviour
127+
Ok(f64::NAN)
128+
}
129+
}
130+
131+
/// Binary function to calculate logarithm of Decimal64 `value` using `base` base
132+
/// Returns error if base is invalid
133+
fn log_decimal64(value: i64, scale: i8, base: f64) -> Result<f64, ArrowError> {
134+
if !base.is_finite() || base.trunc() != base {
135+
return Err(ArrowError::ComputeError(format!(
136+
"Log cannot use non-integer base: {base}"
137+
)));
138+
}
139+
if (base as u32) < 2 {
140+
return Err(ArrowError::ComputeError(format!(
141+
"Log base must be greater than 1: {base}"
142+
)));
143+
}
144+
145+
let unscaled_value = decimal64_to_i64(value, scale)?;
146+
if unscaled_value > 0 {
147+
let log_value: u32 = unscaled_value.ilog(base as i64);
148+
Ok(log_value as f64)
149+
} else {
150+
// Reflect f64::log behaviour
151+
Ok(f64::NAN)
152+
}
153+
}
154+
105155
/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
106156
/// Returns error if base is invalid
107157
fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError> {
@@ -223,15 +273,18 @@ impl ScalarUDFImpl for LogFunc {
223273
|value, base| Ok(value.log(base)),
224274
)?
225275
}
226-
// TODO: native log support for decimal 32 & 64; right now upcast
227-
// to decimal128 to calculate
228-
// https://github.com/apache/datafusion/issues/17555
229-
DataType::Decimal32(precision, scale)
230-
| DataType::Decimal64(precision, scale) => {
231-
calculate_binary_math::<Decimal128Type, Float64Type, Float64Type, _>(
232-
&cast(&value, &DataType::Decimal128(*precision, *scale))?,
276+
DataType::Decimal32(_, scale) => {
277+
calculate_binary_math::<Decimal32Type, Float64Type, Float64Type, _>(
278+
&value,
233279
&base,
234-
|value, base| log_decimal128(value, *scale, base),
280+
|value, base| log_decimal32(value, *scale, base),
281+
)?
282+
}
283+
DataType::Decimal64(_, scale) => {
284+
calculate_binary_math::<Decimal64Type, Float64Type, Float64Type, _>(
285+
&value,
286+
&base,
287+
|value, base| log_decimal64(value, *scale, base),
235288
)?
236289
}
237290
DataType::Decimal128(_, scale) => {

datafusion/functions/src/utils.rs

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,40 @@ pub fn decimal128_to_i128(value: i128, scale: i8) -> Result<i128, ArrowError> {
219219
}
220220
}
221221

222+
pub fn decimal32_to_i32(value: i32, scale: i8) -> Result<i32, ArrowError> {
223+
if scale < 0 {
224+
Err(ArrowError::ComputeError(
225+
"Negative scale is not supported".into(),
226+
))
227+
} else if scale == 0 {
228+
Ok(value)
229+
} else {
230+
match 10_i32.checked_pow(scale as u32) {
231+
Some(divisor) => Ok(value / divisor),
232+
None => Err(ArrowError::ComputeError(format!(
233+
"Cannot get a power of {scale}"
234+
))),
235+
}
236+
}
237+
}
238+
239+
pub fn decimal64_to_i64(value: i64, scale: i8) -> Result<i64, ArrowError> {
240+
if scale < 0 {
241+
Err(ArrowError::ComputeError(
242+
"Negative scale is not supported".into(),
243+
))
244+
} else if scale == 0 {
245+
Ok(value)
246+
} else {
247+
match i64::from(10).checked_pow(scale as u32) {
248+
Some(divisor) => Ok(value / divisor),
249+
None => Err(ArrowError::ComputeError(format!(
250+
"Cannot get a power of {scale}"
251+
))),
252+
}
253+
}
254+
}
255+
222256
#[cfg(test)]
223257
pub mod test {
224258
/// $FUNC ScalarUDFImpl to test
@@ -334,6 +368,7 @@ pub mod test {
334368
}
335369

336370
use arrow::datatypes::DataType;
371+
use itertools::Either;
337372
pub(crate) use test_function;
338373

339374
use super::*;
@@ -376,4 +411,106 @@ pub mod test {
376411
}
377412
}
378413
}
414+
415+
#[test]
416+
fn test_decimal32_to_i32() {
417+
let cases: [(i32, i8, Either<i32, String>); _] = [
418+
(123, 0, Either::Left(123)),
419+
(1230, 1, Either::Left(123)),
420+
(123000, 3, Either::Left(123)),
421+
(1234567, 2, Either::Left(12345)),
422+
(-1234567, 2, Either::Left(-12345)),
423+
(1, 0, Either::Left(1)),
424+
(
425+
123,
426+
-3,
427+
Either::Right("Negative scale is not supported".into()),
428+
),
429+
(
430+
123,
431+
i8::MAX,
432+
Either::Right("Cannot get a power of 127".into()),
433+
),
434+
(999999999, 0, Either::Left(999999999)),
435+
(999999999, 3, Either::Left(999999)),
436+
];
437+
438+
for (value, scale, expected) in cases {
439+
match decimal32_to_i32(value, scale) {
440+
Ok(actual) => {
441+
let expected_value =
442+
expected.left().expect("Got value but expected none");
443+
assert_eq!(
444+
actual, expected_value,
445+
"{value} and {scale} vs {expected_value:?}"
446+
);
447+
}
448+
Err(ArrowError::ComputeError(msg)) => {
449+
assert_eq!(
450+
msg,
451+
expected.right().expect("Got error but expected value")
452+
);
453+
}
454+
Err(_) => {
455+
assert!(expected.is_right())
456+
}
457+
}
458+
}
459+
}
460+
461+
#[test]
462+
fn test_decimal64_to_i64() {
463+
let cases: [(i64, i8, Either<i64, String>); _] = [
464+
(123, 0, Either::Left(123)),
465+
(1234567890, 2, Either::Left(12345678)),
466+
(-1234567890, 2, Either::Left(-12345678)),
467+
(
468+
123,
469+
-3,
470+
Either::Right("Negative scale is not supported".into()),
471+
),
472+
(
473+
123,
474+
i8::MAX,
475+
Either::Right("Cannot get a power of 127".into()),
476+
),
477+
(
478+
999999999999999999i64,
479+
0,
480+
Either::Left(999999999999999999i64),
481+
),
482+
(
483+
999999999999999999i64,
484+
3,
485+
Either::Left(999999999999999999i64 / 1000),
486+
),
487+
(
488+
-999999999999999999i64,
489+
3,
490+
Either::Left(-999999999999999999i64 / 1000),
491+
),
492+
];
493+
494+
for (value, scale, expected) in cases {
495+
match decimal64_to_i64(value, scale) {
496+
Ok(actual) => {
497+
let expected_value =
498+
expected.left().expect("Got value but expected none");
499+
assert_eq!(
500+
actual, expected_value,
501+
"{value} and {scale} vs {expected_value:?}"
502+
);
503+
}
504+
Err(ArrowError::ComputeError(msg)) => {
505+
assert_eq!(
506+
msg,
507+
expected.right().expect("Got error but expected value")
508+
);
509+
}
510+
Err(_) => {
511+
assert!(expected.is_right())
512+
}
513+
}
514+
}
515+
}
379516
}

datafusion/sqllogictest/test_files/decimal.slt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,11 @@ select log(arrow_cast(100, 'Decimal32(9, 2)'));
794794
----
795795
2
796796

797+
query R
798+
select log(2.0, arrow_cast(12345.67, 'Decimal32(9, 2)'));
799+
----
800+
13
801+
797802
# log for small decimal64
798803
query R
799804
select log(arrow_cast(100, 'Decimal64(18, 0)'));
@@ -805,6 +810,12 @@ select log(arrow_cast(100, 'Decimal64(18, 2)'));
805810
----
806811
2
807812

813+
query R
814+
select log(2.0, arrow_cast(12345.6789, 'Decimal64(15, 4)'));
815+
----
816+
13
817+
818+
808819
# log for small decimal128
809820
query R
810821
select log(arrow_cast(100, 'Decimal128(38, 0)'));

0 commit comments

Comments
 (0)