diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 0c50afa2dffd3..fa691f946d166 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -21,9 +21,7 @@ use std::any::Any; use super::power::PowerFunc; -use crate::utils::{ - calculate_binary_math, decimal32_to_i32, decimal64_to_i64, decimal128_to_i128, -}; +use crate::utils::calculate_binary_math; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{ DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, @@ -44,7 +42,7 @@ use datafusion_expr::{ }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use num_traits::Float; +use num_traits::{Float, ToPrimitive}; #[user_doc( doc_section(label = "Math Functions"), @@ -104,109 +102,109 @@ impl LogFunc { } } -/// Binary function to calculate logarithm of Decimal32 `value` using `base` base -/// Returns error if base is invalid -fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { - if !base.is_finite() || base.trunc() != base { - return Err(ArrowError::ComputeError(format!( - "Log cannot use non-integer base: {base}" - ))); - } - if (base as u32) < 2 { - return Err(ArrowError::ComputeError(format!( - "Log base must be greater than 1: {base}" - ))); - } - - // Match f64::log behaviour - if value <= 0 { - return Ok(f64::NAN); - } +/// Checks if the base is valid for the efficient integer logarithm algorithm. +#[inline] +fn is_valid_integer_base(base: f64) -> bool { + base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64 +} - if scale < 0 { - let actual_value = (value as f64) * 10.0_f64.powi(-(scale as i32)); - Ok(actual_value.log(base)) - } else { - let unscaled_value = decimal32_to_i32(value, scale)?; - if unscaled_value <= 0 { - return Ok(f64::NAN); - } - let log_value: u32 = unscaled_value.ilog(base as i32); - Ok(log_value as f64) +/// Calculate logarithm for Decimal32 values. +/// For integer bases >= 2 with non-negative scale, uses the efficient u32 ilog algorithm. +/// Otherwise falls back to f64 computation. +fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { + if is_valid_integer_base(base) + && scale >= 0 + && let Some(unscaled) = unscale_to_u32(value, scale) + { + return if unscaled > 0 { + Ok(unscaled.ilog(base as u32) as f64) + } else { + Ok(f64::NAN) + }; } + decimal_to_f64(value, scale).map(|v| v.log(base)) } -/// Binary function to calculate logarithm of Decimal64 `value` using `base` base -/// Returns error if base is invalid +/// Calculate logarithm for Decimal64 values. +/// For integer bases >= 2 with non-negative scale, uses the efficient u64 ilog algorithm. +/// Otherwise falls back to f64 computation. fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { - if !base.is_finite() || base.trunc() != base { - return Err(ArrowError::ComputeError(format!( - "Log cannot use non-integer base: {base}" - ))); - } - if (base as u32) < 2 { - return Err(ArrowError::ComputeError(format!( - "Log base must be greater than 1: {base}" - ))); + if is_valid_integer_base(base) + && scale >= 0 + && let Some(unscaled) = unscale_to_u64(value, scale) + { + return if unscaled > 0 { + Ok(unscaled.ilog(base as u64) as f64) + } else { + Ok(f64::NAN) + }; } + decimal_to_f64(value, scale).map(|v| v.log(base)) +} - if value <= 0 { - return Ok(f64::NAN); +/// Calculate logarithm for Decimal128 values. +/// For integer bases >= 2 with non-negative scale, uses the efficient u128 ilog algorithm. +/// Otherwise falls back to f64 computation. +fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { + if is_valid_integer_base(base) + && scale >= 0 + && let Some(unscaled) = unscale_to_u128(value, scale) + { + return if unscaled > 0 { + Ok(unscaled.ilog(base as u128) as f64) + } else { + Ok(f64::NAN) + }; } + decimal_to_f64(value, scale).map(|v| v.log(base)) +} - if scale < 0 { - let actual_value = (value as f64) * 10.0_f64.powi(-(scale as i32)); - Ok(actual_value.log(base)) - } else { - let unscaled_value = decimal64_to_i64(value, scale)?; - if unscaled_value <= 0 { - return Ok(f64::NAN); - } - let log_value: u32 = unscaled_value.ilog(base as i64); - Ok(log_value as f64) - } +/// Unscale a Decimal32 value to u32. +#[inline] +fn unscale_to_u32(value: i32, scale: i8) -> Option { + let value_u32 = u32::try_from(value).ok()?; + let divisor = 10u32.checked_pow(scale as u32)?; + Some(value_u32 / divisor) } -/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base -/// Returns error if base is invalid -fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { - if !base.is_finite() || base.trunc() != base { - return Err(ArrowError::ComputeError(format!( - "Log cannot use non-integer base: {base}" - ))); - } - if (base as u32) < 2 { - return Err(ArrowError::ComputeError(format!( - "Log base must be greater than 1: {base}" - ))); - } +/// Unscale a Decimal64 value to u64. +#[inline] +fn unscale_to_u64(value: i64, scale: i8) -> Option { + let value_u64 = u64::try_from(value).ok()?; + let divisor = 10u64.checked_pow(scale as u32)?; + Some(value_u64 / divisor) +} - if value <= 0 { - // Reflect f64::log behaviour - return Ok(f64::NAN); - } +/// Unscale a Decimal128 value to u128. +#[inline] +fn unscale_to_u128(value: i128, scale: i8) -> Option { + let value_u128 = u128::try_from(value).ok()?; + let divisor = 10u128.checked_pow(scale as u32)?; + Some(value_u128 / divisor) +} - if scale < 0 { - let actual_value = (value as f64) * 10.0_f64.powi(-(scale as i32)); - Ok(actual_value.log(base)) - } else { - let unscaled_value = decimal128_to_i128(value, scale)?; - if unscaled_value <= 0 { - return Ok(f64::NAN); - } - let log_value: u32 = unscaled_value.ilog(base as i128); - Ok(log_value as f64) - } +/// Convert a scaled decimal value to f64. +#[inline] +fn decimal_to_f64(value: T, scale: i8) -> Result { + let value_f64 = value.to_f64().ok_or_else(|| { + ArrowError::ComputeError("Cannot convert value to f64".to_string()) + })?; + let scale_factor = 10f64.powi(scale as i32); + Ok(value_f64 / scale_factor) } -/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base -/// Returns error if base is invalid or if value is out of bounds of Decimal128 fn log_decimal256(value: i256, scale: i8, base: f64) -> Result { + // Try to convert to i128 for the optimized path match value.to_i128() { - Some(value) => log_decimal128(value, scale, base), - None => Err(ArrowError::NotYetImplemented(format!( - "Log of Decimal256 larger than Decimal128 is not yet supported: {value}" - ))), + Some(v) => log_decimal128(v, scale, base), + None => { + // For very large Decimal256 values, use f64 computation + let value_f64 = value.to_f64().ok_or_else(|| { + ArrowError::ComputeError(format!("Cannot convert {value} to f64")) + })?; + let scale_factor = 10f64.powi(scale as i32); + Ok((value_f64 / scale_factor).log(base)) + } } } @@ -1169,7 +1167,8 @@ mod tests { } #[test] - fn test_log_decimal128_wrong_base() { + fn test_log_decimal128_invalid_base() { + // Invalid base (-2.0) should return NaN, matching f64::log behavior let arg_fields = vec![ Field::new("b", DataType::Float64, false).into(), Field::new("x", DataType::Decimal128(38, 0), false).into(), @@ -1184,16 +1183,26 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - let result = LogFunc::new().invoke_with_args(args); - assert!(result.is_err()); - assert_eq!( - "Arrow error: Compute error: Log base must be greater than 1: -2", - result.unwrap_err().to_string().lines().next().unwrap() - ); + let result = LogFunc::new() + .invoke_with_args(args) + .expect("should not error on invalid base"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + assert_eq!(floats.len(), 1); + assert!(floats.value(0).is_nan()); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } #[test] - fn test_log_decimal256_error() { + fn test_log_decimal256_large() { + // Large Decimal256 values that don't fit in i128 now use f64 fallback let arg_field = Field::new("a", DataType::Decimal256(38, 0), false).into(); let args = ScalarFunctionArgs { args: vec![ @@ -1207,11 +1216,26 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - let result = LogFunc::new().invoke_with_args(args); - assert!(result.is_err()); - assert_eq!( - result.unwrap_err().to_string().lines().next().unwrap(), - "Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727" - ); + let result = LogFunc::new() + .invoke_with_args(args) + .expect("should handle large Decimal256 via f64 fallback"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + assert_eq!(floats.len(), 1); + // The f64 fallback may lose some precision for very large numbers, + // but we verify we get a reasonable positive result (not NaN/infinity) + let log_result = floats.value(0); + assert!( + log_result.is_finite() && log_result > 0.0, + "Expected positive finite log result, got {log_result}" + ); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } } diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index 9dd31427dcb4a..85f2559f583dc 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -868,9 +868,11 @@ select log(100000000000000000000000000000000000::decimal(76,0)); ---- 35 -# log(10^50) for decimal256 for a value larger than i128 -query error Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported +# log(10^50) for decimal256 for a value larger than i128 (uses f64 fallback) +query R select log(100000000000000000000000000000000000000000000000000::decimal(76,0)); +---- +50 # log(10^35) for decimal128 with explicit base query R @@ -904,6 +906,12 @@ select log(2.0, 100000000000000000000000000000000000::decimal(38,0)); ---- 116 +# log with non-integer base (fallback to f64) +query R +select log(2.5, 100::decimal(38,0)); +---- +5.025883189464 + # null cases query R select log(null, 100);