Skip to content

Commit cb29663

Browse files
committed
fix: removed ilog and added tests
1 parent 024a1cb commit cb29663

File tree

2 files changed

+38
-89
lines changed

2 files changed

+38
-89
lines changed

datafusion/functions/src/math/log.rs

Lines changed: 12 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -102,48 +102,17 @@ impl LogFunc {
102102
}
103103
}
104104

105-
/// Checks if the base is valid for the efficient integer logarithm algorithm.
106-
#[inline]
107-
fn is_valid_integer_base(base: f64) -> bool {
108-
base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64
109-
}
110-
111105
/// Generic function to calculate logarithm of a decimal value using the given base.
112106
///
113-
/// For integer bases >= 2 with non-negative scale, uses the efficient integer `ilog` algorithm.
114-
/// For all other cases (non-integer bases, negative bases, non-finite bases),
115-
/// falls back to f64 computation which naturally returns NaN for invalid inputs,
116-
/// matching the behavior of `f64::log`.
107+
/// Uses f64 computation which naturally returns NaN for invalid inputs
108+
/// (base <= 1, non-finite, value <= 0), matching the behavior of `f64::log`.
117109
fn log_decimal<T>(value: T, scale: i8, base: f64) -> Result<f64, ArrowError>
118110
where
119111
T: ToPrimitive + Copy,
120112
{
121-
// For integer bases >= 2 and non-negative scale, try the efficient integer algorithm
122-
if is_valid_integer_base(base)
123-
&& scale >= 0
124-
&& let Some(unscaled) = unscale_decimal_value(&value, scale)
125-
{
126-
return if unscaled > 0 {
127-
Ok(unscaled.ilog(base as u128) as f64)
128-
} else {
129-
Ok(f64::NAN)
130-
};
131-
}
132-
133-
// Fallback to f64 computation for non-integer bases, negative scale, etc.
134-
// This naturally returns NaN for invalid inputs (base <= 1, non-finite, value <= 0)
135113
decimal_to_f64(&value, scale).map(|v| v.log(base))
136114
}
137115

138-
/// Unscale a decimal value by dividing by 10^scale, returning the result as u128.
139-
/// Returns None if the value is negative or the conversion fails.
140-
#[inline]
141-
fn unscale_decimal_value<T: ToPrimitive>(value: &T, scale: i8) -> Option<u128> {
142-
let value_u128 = value.to_u128()?;
143-
let divisor = 10u128.checked_pow(scale as u32)?;
144-
Some(value_u128 / divisor)
145-
}
146-
147116
/// Convert a scaled decimal value to f64.
148117
#[inline]
149118
fn decimal_to_f64<T: ToPrimitive>(value: &T, scale: i8) -> Result<f64, ArrowError> {
@@ -408,13 +377,10 @@ mod tests {
408377
#[test]
409378
fn test_log_decimal_native() {
410379
let value = 10_i128.pow(35);
411-
assert_eq!((value as f64).log2(), 116.26748332105768);
412-
assert_eq!(
413-
log_decimal(value, 0, 2.0).unwrap(),
414-
// TODO: see we're losing our decimal points compared to above
415-
// https://github.com/apache/datafusion/issues/18524
416-
116.0
417-
);
380+
let expected = (value as f64).log2();
381+
assert_eq!(expected, 116.26748332105768);
382+
// Now using f64 computation, we get the precise value
383+
assert!((log_decimal(value, 0, 2.0).unwrap() - expected).abs() < 1e-10);
418384
}
419385

420386
#[test]
@@ -982,7 +948,8 @@ mod tests {
982948
assert!((floats.value(1) - 2.0).abs() < 1e-10);
983949
assert!((floats.value(2) - 3.0).abs() < 1e-10);
984950
assert!((floats.value(3) - 4.0).abs() < 1e-10);
985-
assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding
951+
// log10(12600) ≈ 4.1003 (not truncated to 4)
952+
assert!((floats.value(4) - 12600f64.log10()).abs() < 1e-10);
986953
assert!(floats.value(5).is_nan());
987954
}
988955
ColumnarValue::Scalar(_) => {
@@ -1117,8 +1084,10 @@ mod tests {
11171084
assert!((floats.value(1) - 2.0).abs() < 1e-10);
11181085
assert!((floats.value(2) - 3.0).abs() < 1e-10);
11191086
assert!((floats.value(3) - 4.0).abs() < 1e-10);
1120-
assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding for float log
1121-
assert!((floats.value(5) - 38.0).abs() < 1e-10);
1087+
// log10(12600) ≈ 4.1003 (not truncated to 4)
1088+
assert!((floats.value(4) - 12600f64.log10()).abs() < 1e-10);
1089+
// log10(i128::MAX - 1000) ≈ 38.23 (not truncated to 38)
1090+
assert!((floats.value(5) - ((i128::MAX - 1000) as f64).log10()).abs() < 1e-10);
11221091
assert!(floats.value(6).is_nan());
11231092
}
11241093
ColumnarValue::Scalar(_) => {
@@ -1127,40 +1096,6 @@ mod tests {
11271096
}
11281097
}
11291098

1130-
#[test]
1131-
fn test_log_decimal128_invalid_base() {
1132-
// Invalid base (-2.0) should return NaN, matching f64::log behavior
1133-
let arg_fields = vec![
1134-
Field::new("b", DataType::Float64, false).into(),
1135-
Field::new("x", DataType::Decimal128(38, 0), false).into(),
1136-
];
1137-
let args = ScalarFunctionArgs {
1138-
args: vec![
1139-
ColumnarValue::Scalar(ScalarValue::Float64(Some(-2.0))), // base
1140-
ColumnarValue::Scalar(ScalarValue::Decimal128(Some(64), 38, 0)), // num
1141-
],
1142-
arg_fields,
1143-
number_rows: 1,
1144-
return_field: Field::new("f", DataType::Float64, true).into(),
1145-
config_options: Arc::new(ConfigOptions::default()),
1146-
};
1147-
let result = LogFunc::new()
1148-
.invoke_with_args(args)
1149-
.expect("should not error on invalid base");
1150-
1151-
match result {
1152-
ColumnarValue::Array(arr) => {
1153-
let floats = as_float64_array(&arr)
1154-
.expect("failed to convert result to a Float64Array");
1155-
assert_eq!(floats.len(), 1);
1156-
assert!(floats.value(0).is_nan());
1157-
}
1158-
ColumnarValue::Scalar(_) => {
1159-
panic!("Expected an array value")
1160-
}
1161-
}
1162-
}
1163-
11641099
#[test]
11651100
fn test_log_decimal256_large() {
11661101
// Large Decimal256 values that don't fit in i128 now use f64 fallback

datafusion/sqllogictest/test_files/decimal.slt

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ select log(arrow_cast(100, 'Decimal32(9, 2)'));
804804
query R
805805
select log(2.0, arrow_cast(12345.67, 'Decimal32(9, 2)'));
806806
----
807-
13
807+
13.591717513272
808808

809809
# log for small decimal64
810810
query R
@@ -820,7 +820,7 @@ select log(arrow_cast(100, 'Decimal64(18, 2)'));
820820
query R
821821
select log(2.0, arrow_cast(12345.6789, 'Decimal64(15, 4)'));
822822
----
823-
13
823+
13.591718553311
824824

825825

826826
# log for small decimal128
@@ -896,15 +896,13 @@ select log(10::decimal(38, 0), 100000000000000000000000000000000000::decimal(38,
896896
query R
897897
select log(2, 100000000000000000000000000000000000::decimal(38,0));
898898
----
899-
116
899+
116.267483321058
900900

901901
# log(10^35) for decimal128 with another base (float base)
902-
# TODO: this should be 116.267483321058, error with native decimal log impl
903-
# https://github.com/apache/datafusion/issues/18524
904902
query R
905903
select log(2.0, 100000000000000000000000000000000000::decimal(38,0));
906904
----
907-
116
905+
116.267483321058
908906

909907
# log with non-integer base now works (fallback to f64)
910908
query R
@@ -1036,13 +1034,31 @@ from (values (10.0), (2.0), (3.0)) as t(base);
10361034
query R
10371035
SELECT log(10, arrow_cast(0.5, 'Decimal32(5, 1)'))
10381036
----
1039-
NaN
1037+
-0.301029995664
10401038

10411039
query R
10421040
SELECT log(10, arrow_cast(1 , 'Decimal32(5, 1)'))
10431041
----
10441042
0
10451043

1044+
# Test log with invalid base (-2.0) returns NaN, matching f64::log behavior
1045+
query R
1046+
SELECT log(-2.0, 64::decimal(38, 0))
1047+
----
1048+
NaN
1049+
1050+
# Test log with base 0 returns 0 (log(x)/log(0) = log(x)/-inf = -0 ≈ 0)
1051+
query R
1052+
SELECT log(0.0, 64::decimal(38, 0))
1053+
----
1054+
0
1055+
1056+
# Test log with base 1 returns Infinity (log base 1 is division by zero: log(x)/log(1) = log(x)/0)
1057+
query R
1058+
SELECT log(1.0, 64::decimal(38, 0))
1059+
----
1060+
Infinity
1061+
10461062
# power with decimals
10471063

10481064
query RT
@@ -1183,18 +1199,16 @@ select 100000000000000000000000000000000000::decimal(38,0)
11831199
99999999999999996863366107917975552
11841200

11851201
# log(10^35) for decimal128 with explicit decimal base
1186-
# Float parsing is rounding down
11871202
query R
11881203
select log(10, 100000000000000000000000000000000000::decimal(38,0));
11891204
----
1190-
34
1205+
35
11911206

1192-
# log(10^35) for large decimal128 if parsed as float
1193-
# Float parsing is rounding down
1207+
# log(10^35) for large decimal128
11941208
query R
11951209
select log(100000000000000000000000000000000000::decimal(38,0))
11961210
----
1197-
34
1211+
35
11981212

11991213
# Result is decimal since argument is decimal regardless decimals-as-floats parsing
12001214
query R

0 commit comments

Comments
 (0)