Skip to content

Commit f0b96f7

Browse files
committed
feat: Allow log with non-integer base on decimals
1 parent d8e68a4 commit f0b96f7

File tree

2 files changed

+106
-97
lines changed

2 files changed

+106
-97
lines changed

datafusion/functions/src/math/log.rs

Lines changed: 99 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ use std::any::Any;
2121

2222
use super::power::PowerFunc;
2323

24-
use crate::utils::{
25-
calculate_binary_math, decimal32_to_i32, decimal64_to_i64, decimal128_to_i128,
26-
};
24+
use crate::utils::calculate_binary_math;
2725
use arrow::array::{Array, ArrayRef};
2826
use arrow::datatypes::{
2927
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type,
@@ -44,7 +42,7 @@ use datafusion_expr::{
4442
};
4543
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
4644
use datafusion_macros::user_doc;
47-
use num_traits::Float;
45+
use num_traits::{Float, ToPrimitive};
4846

4947
#[user_doc(
5048
doc_section(label = "Math Functions"),
@@ -104,91 +102,70 @@ impl LogFunc {
104102
}
105103
}
106104

107-
/// Binary function to calculate logarithm of Decimal32 `value` using `base` base
108-
/// Returns error if base is invalid
109-
fn log_decimal32(value: i32, scale: i8, base: f64) -> Result<f64, ArrowError> {
110-
if !base.is_finite() || base.trunc() != base {
111-
return Err(ArrowError::ComputeError(format!(
112-
"Log cannot use non-integer base: {base}"
113-
)));
114-
}
115-
if (base as u32) < 2 {
116-
return Err(ArrowError::ComputeError(format!(
117-
"Log base must be greater than 1: {base}"
118-
)));
119-
}
120-
121-
let unscaled_value = decimal32_to_i32(value, scale)?;
122-
if unscaled_value > 0 {
123-
let log_value: u32 = unscaled_value.ilog(base as i32);
124-
Ok(log_value as f64)
125-
} else {
126-
// Reflect f64::log behaviour
127-
Ok(f64::NAN)
128-
}
105+
/// Checks if the base is valid for the efficient integer logarithm algorithm.
106+
#[inline]
107+
fn is_valid_integer_base(base: f64) -> bool {
108+
base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64
129109
}
130110

131-
/// Binary function to calculate logarithm of Decimal64 `value` using `base` base
132-
/// Returns error if base is invalid
133-
fn log_decimal64(value: i64, scale: i8, base: f64) -> Result<f64, ArrowError> {
134-
if !base.is_finite() || base.trunc() != base {
135-
return Err(ArrowError::ComputeError(format!(
136-
"Log cannot use non-integer base: {base}"
137-
)));
138-
}
139-
if (base as u32) < 2 {
140-
return Err(ArrowError::ComputeError(format!(
141-
"Log base must be greater than 1: {base}"
142-
)));
111+
/// Generic function to calculate logarithm of a decimal value using the given base.
112+
///
113+
/// For integer bases >= 2 with non-negative scale, uses the efficient integer `ilog` algorithm.
114+
/// For all other cases (non-integer bases, negative bases, non-finite bases),
115+
/// falls back to f64 computation which naturally returns NaN for invalid inputs,
116+
/// matching the behavior of `f64::log`.
117+
fn log_decimal<T>(value: T, scale: i8, base: f64) -> Result<f64, ArrowError>
118+
where
119+
T: ToPrimitive + Copy,
120+
{
121+
// For integer bases >= 2 and non-negative scale, try the efficient integer algorithm
122+
if is_valid_integer_base(base)
123+
&& scale >= 0
124+
&& let Some(unscaled) = unscale_decimal_value(value, scale)
125+
{
126+
return if unscaled > 0 {
127+
Ok(unscaled.ilog(base as u128) as f64)
128+
} else {
129+
Ok(f64::NAN)
130+
};
143131
}
144132

145-
let unscaled_value = decimal64_to_i64(value, scale)?;
146-
if unscaled_value > 0 {
147-
let log_value: u32 = unscaled_value.ilog(base as i64);
148-
Ok(log_value as f64)
149-
} else {
150-
// Reflect f64::log behaviour
151-
Ok(f64::NAN)
152-
}
133+
// Fallback to f64 computation for non-integer bases, negative scale, etc.
134+
// This naturally returns NaN for invalid inputs (base <= 1, non-finite, value <= 0)
135+
decimal_to_f64(value, scale).map(|v| v.log(base))
153136
}
154137

155-
/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
156-
/// Returns error if base is invalid
157-
fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError> {
158-
if !base.is_finite() || base.trunc() != base {
159-
return Err(ArrowError::ComputeError(format!(
160-
"Log cannot use non-integer base: {base}"
161-
)));
162-
}
163-
if (base as u32) < 2 {
164-
return Err(ArrowError::ComputeError(format!(
165-
"Log base must be greater than 1: {base}"
166-
)));
167-
}
168-
169-
if value <= 0 {
170-
// Reflect f64::log behaviour
171-
return Ok(f64::NAN);
172-
}
138+
/// Unscale a decimal value by dividing by 10^scale, returning the result as u128.
139+
/// Returns None if the value is negative or the conversion fails.
140+
#[inline]
141+
fn unscale_decimal_value<T: ToPrimitive>(value: T, scale: i8) -> Option<u128> {
142+
let value_u128 = value.to_u128()?;
143+
let divisor = 10u128.checked_pow(scale as u32)?;
144+
Some(value_u128 / divisor)
145+
}
173146

174-
if scale < 0 {
175-
let actual_value = (value as f64) * 10.0_f64.powi(-(scale as i32));
176-
Ok(actual_value.log(base))
177-
} else {
178-
let unscaled_value = decimal128_to_i128(value, scale)?;
179-
let log_value: u32 = unscaled_value.ilog(base as i128);
180-
Ok(log_value as f64)
181-
}
147+
/// Convert a scaled decimal value to f64.
148+
#[inline]
149+
fn decimal_to_f64<T: ToPrimitive>(value: T, scale: i8) -> Result<f64, ArrowError> {
150+
let value_f64 = value
151+
.to_f64()
152+
.ok_or_else(|| ArrowError::ComputeError("Cannot convert value to f64".to_string()))?;
153+
let scale_factor = 10f64.powi(scale as i32);
154+
Ok(value_f64 / scale_factor)
182155
}
183156

184-
/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
185-
/// Returns error if base is invalid or if value is out of bounds of Decimal128
186157
fn log_decimal256(value: i256, scale: i8, base: f64) -> Result<f64, ArrowError> {
158+
// Try to convert to i128 for the optimized path
187159
match value.to_i128() {
188-
Some(value) => log_decimal128(value, scale, base),
189-
None => Err(ArrowError::NotYetImplemented(format!(
190-
"Log of Decimal256 larger than Decimal128 is not yet supported: {value}"
191-
))),
160+
Some(v) => log_decimal(v, scale, base),
161+
None => {
162+
// For very large Decimal256 values, use f64 computation
163+
let value_f64 = value.to_f64().ok_or_else(|| {
164+
ArrowError::ComputeError(format!("Cannot convert {value} to f64"))
165+
})?;
166+
let scale_factor = 10f64.powi(scale as i32);
167+
Ok((value_f64 / scale_factor).log(base))
168+
}
192169
}
193170
}
194171

@@ -282,21 +259,21 @@ impl ScalarUDFImpl for LogFunc {
282259
calculate_binary_math::<Decimal32Type, Float64Type, Float64Type, _>(
283260
&value,
284261
&base,
285-
|value, base| log_decimal32(value, *scale, base),
262+
|value, base| log_decimal(value, *scale, base),
286263
)?
287264
}
288265
DataType::Decimal64(_, scale) => {
289266
calculate_binary_math::<Decimal64Type, Float64Type, Float64Type, _>(
290267
&value,
291268
&base,
292-
|value, base| log_decimal64(value, *scale, base),
269+
|value, base| log_decimal(value, *scale, base),
293270
)?
294271
}
295272
DataType::Decimal128(_, scale) => {
296273
calculate_binary_math::<Decimal128Type, Float64Type, Float64Type, _>(
297274
&value,
298275
&base,
299-
|value, base| log_decimal128(value, *scale, base),
276+
|value, base| log_decimal(value, *scale, base),
300277
)?
301278
}
302279
DataType::Decimal256(_, scale) => {
@@ -433,7 +410,7 @@ mod tests {
433410
let value = 10_i128.pow(35);
434411
assert_eq!((value as f64).log2(), 116.26748332105768);
435412
assert_eq!(
436-
log_decimal128(value, 0, 2.0).unwrap(),
413+
log_decimal(value, 0, 2.0).unwrap(),
437414
// TODO: see we're losing our decimal points compared to above
438415
// https://github.com/apache/datafusion/issues/18524
439416
116.0
@@ -1151,7 +1128,8 @@ mod tests {
11511128
}
11521129

11531130
#[test]
1154-
fn test_log_decimal128_wrong_base() {
1131+
fn test_log_decimal128_invalid_base() {
1132+
// Invalid base (-2.0) should return NaN, matching f64::log behavior
11551133
let arg_fields = vec![
11561134
Field::new("b", DataType::Float64, false).into(),
11571135
Field::new("x", DataType::Decimal128(38, 0), false).into(),
@@ -1166,16 +1144,26 @@ mod tests {
11661144
return_field: Field::new("f", DataType::Float64, true).into(),
11671145
config_options: Arc::new(ConfigOptions::default()),
11681146
};
1169-
let result = LogFunc::new().invoke_with_args(args);
1170-
assert!(result.is_err());
1171-
assert_eq!(
1172-
"Arrow error: Compute error: Log base must be greater than 1: -2",
1173-
result.unwrap_err().to_string().lines().next().unwrap()
1174-
);
1147+
let result = LogFunc::new()
1148+
.invoke_with_args(args)
1149+
.expect("should not error on invalid base");
1150+
1151+
match result {
1152+
ColumnarValue::Array(arr) => {
1153+
let floats = as_float64_array(&arr)
1154+
.expect("failed to convert result to a Float64Array");
1155+
assert_eq!(floats.len(), 1);
1156+
assert!(floats.value(0).is_nan());
1157+
}
1158+
ColumnarValue::Scalar(_) => {
1159+
panic!("Expected an array value")
1160+
}
1161+
}
11751162
}
11761163

11771164
#[test]
1178-
fn test_log_decimal256_error() {
1165+
fn test_log_decimal256_large() {
1166+
// Large Decimal256 values that don't fit in i128 now use f64 fallback
11791167
let arg_field = Field::new("a", DataType::Decimal256(38, 0), false).into();
11801168
let args = ScalarFunctionArgs {
11811169
args: vec![
@@ -1189,11 +1177,26 @@ mod tests {
11891177
return_field: Field::new("f", DataType::Float64, true).into(),
11901178
config_options: Arc::new(ConfigOptions::default()),
11911179
};
1192-
let result = LogFunc::new().invoke_with_args(args);
1193-
assert!(result.is_err());
1194-
assert_eq!(
1195-
result.unwrap_err().to_string().lines().next().unwrap(),
1196-
"Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727"
1197-
);
1180+
let result = LogFunc::new()
1181+
.invoke_with_args(args)
1182+
.expect("should handle large Decimal256 via f64 fallback");
1183+
1184+
match result {
1185+
ColumnarValue::Array(arr) => {
1186+
let floats = as_float64_array(&arr)
1187+
.expect("failed to convert result to a Float64Array");
1188+
assert_eq!(floats.len(), 1);
1189+
// The f64 fallback may lose some precision for very large numbers,
1190+
// but we verify we get a reasonable positive result (not NaN/infinity)
1191+
let log_result = floats.value(0);
1192+
assert!(
1193+
log_result.is_finite() && log_result > 0.0,
1194+
"Expected positive finite log result, got {log_result}"
1195+
);
1196+
}
1197+
ColumnarValue::Scalar(_) => {
1198+
panic!("Expected an array value")
1199+
}
1200+
}
11981201
}
11991202
}

datafusion/sqllogictest/test_files/decimal.slt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,14 +889,20 @@ select log(2, 100000000000000000000000000000000000::decimal(38,0));
889889
----
890890
116
891891

892-
# log(10^35) for decimal128 with another base
892+
# log(10^35) for decimal128 with another base (float base)
893893
# TODO: this should be 116.267483321058, error with native decimal log impl
894894
# https://github.com/apache/datafusion/issues/18524
895895
query R
896896
select log(2.0, 100000000000000000000000000000000000::decimal(38,0));
897897
----
898898
116
899899

900+
# log with non-integer base now works (fallback to f64)
901+
query R
902+
select log(2.5, 100::decimal(38,0));
903+
----
904+
5.025883189464
905+
900906
# null cases
901907
query R
902908
select log(null, 100);

0 commit comments

Comments
 (0)