diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 33166f6444f2a..ffdcade4a8404 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -22,9 +22,10 @@ use super::log::LogFunc; use crate::utils::{calculate_binary_decimal_math, calculate_binary_math}; use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::i256; use arrow::datatypes::{ - ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, - Decimal256Type, Float64Type, Int64Type, + ArrowNativeType, ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, + Decimal128Type, Decimal256Type, Float64Type, Int64Type, }; use arrow::error::ArrowError; use datafusion_common::types::{NativeType, logical_float64, logical_int64}; @@ -37,6 +38,7 @@ use datafusion_expr::{ ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit, }; use datafusion_macros::user_doc; +use num_traits::{NumCast, ToPrimitive}; #[user_doc( doc_section(label = "Math Functions"), @@ -112,12 +114,15 @@ impl PowerFunc { /// 2.5 is represented as 25 with scale 1 /// The unscaled result is 25^4 = 390625 /// Scale it back to 1: 390625 / 10^4 = 39 -/// -/// Returns error if base is invalid fn pow_decimal_int(base: T, scale: i8, exp: i64) -> Result where - T: From + ArrowNativeTypeOp, + T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy, { + // Negative exponent: fall back to float computation + if exp < 0 { + return pow_decimal_float(base, scale, exp as f64); + } + let exp: u32 = exp.try_into().map_err(|_| { ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}")) })?; @@ -125,13 +130,13 @@ where // If scale < 0, 10^scale (e.g., 10^-2 = 0.01) becomes 0 in integer arithmetic. if exp == 0 { return if scale >= 0 { - T::from(10).pow_checked(scale as u32).map_err(|_| { + T::usize_as(10).pow_checked(scale as u32).map_err(|_| { ArrowError::ArithmeticOverflow(format!( "Cannot make unscale factor for {scale} and {exp}" )) }) } else { - Ok(T::from(0)) + Ok(T::ZERO) }; } let powered: T = base.pow_checked(exp).map_err(|_| { @@ -149,11 +154,12 @@ where // If mul_exp is positive, we divide (standard case). // If mul_exp is negative, we multiply (negative scale case). if mul_exp > 0 { - let div_factor: T = T::from(10).pow_checked(mul_exp as u32).map_err(|_| { - ArrowError::ArithmeticOverflow(format!( - "Cannot make div factor for {scale} and {exp}" - )) - })?; + let div_factor: T = + T::usize_as(10).pow_checked(mul_exp as u32).map_err(|_| { + ArrowError::ArithmeticOverflow(format!( + "Cannot make div factor for {scale} and {exp}" + )) + })?; powered.div_checked(div_factor) } else { // mul_exp is negative, so we multiply by 10^(-mul_exp) @@ -162,33 +168,227 @@ where "Overflow while negating scale exponent".to_string(), ) })?; - let mul_factor: T = T::from(10).pow_checked(abs_exp as u32).map_err(|_| { - ArrowError::ArithmeticOverflow(format!( - "Cannot make mul factor for {scale} and {exp}" - )) - })?; + let mul_factor: T = + T::usize_as(10).pow_checked(abs_exp as u32).map_err(|_| { + ArrowError::ArithmeticOverflow(format!( + "Cannot make mul factor for {scale} and {exp}" + )) + })?; powered.mul_checked(mul_factor) } } /// Binary function to calculate a math power to float exponent /// for scaled integer types. -/// Returns error if exponent is negative or non-integer, or base invalid fn pow_decimal_float(base: T, scale: i8, exp: f64) -> Result where - T: From + ArrowNativeTypeOp, + T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy, { - if !exp.is_finite() || exp.trunc() != exp { + if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 { + return pow_decimal_int(base, scale, exp as i64); + } + + if !exp.is_finite() { return Err(ArrowError::ComputeError(format!( - "Cannot use non-integer exp: {exp}" + "Cannot use non-finite exp: {exp}" ))); } - if exp < 0f64 || exp >= u32::MAX as f64 { + + pow_decimal_float_fallback(base, scale, exp) +} + +/// Compute the f64 power result and scale it back. +/// Returns the rounded i128 result for conversion to target type. +#[inline] +fn compute_pow_f64_result( + base_f64: f64, + scale: i8, + exp: f64, +) -> Result { + let result_f64 = base_f64.powf(exp); + + if !result_f64.is_finite() { return Err(ArrowError::ArithmeticOverflow(format!( - "Unsupported exp value: {exp}" + "Result of {base_f64}^{exp} is not finite" + ))); + } + + let scale_factor = 10f64.powi(scale as i32); + let result_scaled = result_f64 * scale_factor; + let result_rounded = result_scaled.round(); + + if result_rounded.abs() > i128::MAX as f64 { + return Err(ArrowError::ArithmeticOverflow(format!( + "Result {result_rounded} is too large for the target decimal type" + ))); + } + + Ok(result_rounded as i128) +} + +/// Convert i128 result to target decimal native type using NumCast. +/// Returns error if value overflows the target type. +#[inline] +fn decimal_from_i128(value: i128) -> Result +where + T: NumCast, +{ + NumCast::from(value).ok_or_else(|| { + ArrowError::ArithmeticOverflow(format!( + "Value {value} is too large for the target decimal type" + )) + }) +} + +/// Fallback implementation using f64 for negative or non-integer exponents. +/// This handles cases that cannot be computed using integer arithmetic. +fn pow_decimal_float_fallback(base: T, scale: i8, exp: f64) -> Result +where + T: ToPrimitive + NumCast + Copy, +{ + if scale < 0 { + return Err(ArrowError::NotYetImplemented(format!( + "Negative scale is not yet supported: {scale}" ))); } - pow_decimal_int(base, scale, exp as i64) + + let scale_factor = 10f64.powi(scale as i32); + let base_f64 = base.to_f64().ok_or_else(|| { + ArrowError::ComputeError("Cannot convert base to f64".to_string()) + })? / scale_factor; + + let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?; + + decimal_from_i128(result_i128) +} + +/// Decimal256 specialized float exponent version. +fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result { + if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 { + return pow_decimal256_int(base, scale, exp as i64); + } + + if !exp.is_finite() { + return Err(ArrowError::ComputeError(format!( + "Cannot use non-finite exp: {exp}" + ))); + } + + pow_decimal256_float_fallback(base, scale, exp) +} + +/// Decimal256 specialized integer exponent version. +fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result { + if exp < 0 { + return pow_decimal256_float(base, scale, exp as f64); + } + + let exp: u32 = exp.try_into().map_err(|_| { + ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}")) + })?; + + if exp == 0 { + return if scale >= 0 { + i256::from_i128(10).pow_checked(scale as u32).map_err(|_| { + ArrowError::ArithmeticOverflow(format!( + "Cannot make unscale factor for {scale} and {exp}" + )) + }) + } else { + Ok(i256::from_i128(0)) + }; + } + + let powered: i256 = base.pow_checked(exp).map_err(|_| { + ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}")) + })?; + + let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1); + + if mul_exp == 0 { + return Ok(powered); + } + + if mul_exp > 0 { + let div_factor: i256 = + i256::from_i128(10) + .pow_checked(mul_exp as u32) + .map_err(|_| { + ArrowError::ArithmeticOverflow(format!( + "Cannot make div factor for {scale} and {exp}" + )) + })?; + powered.div_checked(div_factor) + } else { + let abs_exp = mul_exp.checked_neg().ok_or_else(|| { + ArrowError::ArithmeticOverflow( + "Overflow while negating scale exponent".to_string(), + ) + })?; + let mul_factor: i256 = + i256::from_i128(10) + .pow_checked(abs_exp as u32) + .map_err(|_| { + ArrowError::ArithmeticOverflow(format!( + "Cannot make mul factor for {scale} and {exp}" + )) + })?; + powered.mul_checked(mul_factor) + } +} + +/// Fallback implementation for Decimal256. +fn pow_decimal256_float_fallback( + base: i256, + scale: i8, + exp: f64, +) -> Result { + if scale < 0 { + return Err(ArrowError::NotYetImplemented(format!( + "Negative scale is not yet supported: {scale}" + ))); + } + + let scale_factor = 10f64.powi(scale as i32); + let base_f64 = base.to_f64().ok_or_else(|| { + ArrowError::ComputeError("Cannot convert base to f64".to_string()) + })? / scale_factor; + + let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?; + + // i256 can be constructed from i128 directly + Ok(i256::from_i128(result_i128)) +} + +/// Fallback implementation for decimal power when exponent is an array. +/// Casts decimal to float64, computes power, and casts back to original decimal type. +/// This is used for performance when exponent varies per-row. +fn pow_decimal_with_float_fallback( + base: &ArrayRef, + exponent: &ColumnarValue, + num_rows: usize, +) -> Result { + use arrow::compute::cast; + + let original_type = base.data_type().clone(); + let base_f64 = cast(base.as_ref(), &DataType::Float64)?; + + let exp_f64 = match exponent { + ColumnarValue::Array(arr) => cast(arr.as_ref(), &DataType::Float64)?, + ColumnarValue::Scalar(scalar) => { + let scalar_f64 = scalar.cast_to(&DataType::Float64)?; + scalar_f64.to_array_of_size(num_rows)? + } + }; + + let result_f64 = calculate_binary_math::( + &base_f64, + &ColumnarValue::Array(exp_f64), + |b, e| Ok(f64::powf(b, e)), + )?; + + let result = cast(result_f64.as_ref(), &original_type)?; + Ok(ColumnarValue::Array(result)) } impl ScalarUDFImpl for PowerFunc { @@ -218,8 +418,25 @@ impl ScalarUDFImpl for PowerFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [base, exponent] = take_function_args(self.name(), &args.args)?; + + // For decimal types, only use native decimal + // operations when we have a scalar exponent. When the exponent is an array, + // fall back to float computation for better performance. + let use_float_fallback = matches!( + base.data_type(), + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ) && matches!(exponent, ColumnarValue::Array(_)); + let base = base.to_array(args.number_rows)?; + // If decimal with array exponent, cast to float and compute + if use_float_fallback { + return pow_decimal_with_float_fallback(&base, exponent, args.number_rows); + } + let arr: ArrayRef = match (base.data_type(), exponent.data_type()) { (DataType::Float64, DataType::Float64) => { calculate_binary_math::( @@ -311,7 +528,7 @@ impl ScalarUDFImpl for PowerFunc { >( &base, exponent, - |b, e| pow_decimal_int(b, *scale, e), + |b, e| pow_decimal256_int(b, *scale, e), *precision, *scale, )? @@ -325,7 +542,7 @@ impl ScalarUDFImpl for PowerFunc { >( &base, exponent, - |b, e| pow_decimal_float(b, *scale, e), + |b, e| pow_decimal256_float(b, *scale, e), *precision, *scale, )? @@ -398,19 +615,53 @@ mod tests { #[test] fn test_pow_decimal128_helper() { // Expression: 2.5 ^ 4 = 39.0625 - assert_eq!(pow_decimal_int(25, 1, 4).unwrap(), i128::from(390)); - assert_eq!(pow_decimal_int(2500, 3, 4).unwrap(), i128::from(39062)); - assert_eq!(pow_decimal_int(25000, 4, 4).unwrap(), i128::from(390625)); + assert_eq!(pow_decimal_int(25i128, 1, 4).unwrap(), 390i128); + assert_eq!(pow_decimal_int(2500i128, 3, 4).unwrap(), 39062i128); + assert_eq!(pow_decimal_int(25000i128, 4, 4).unwrap(), 390625i128); // Expression: 25 ^ 4 = 390625 - assert_eq!(pow_decimal_int(25, 0, 4).unwrap(), i128::from(390625)); + assert_eq!(pow_decimal_int(25i128, 0, 4).unwrap(), 390625i128); // Expressions for edge cases - assert_eq!(pow_decimal_int(25, 1, 1).unwrap(), i128::from(25)); - assert_eq!(pow_decimal_int(25, 0, 1).unwrap(), i128::from(25)); - assert_eq!(pow_decimal_int(25, 0, 0).unwrap(), i128::from(1)); - assert_eq!(pow_decimal_int(25, 1, 0).unwrap(), i128::from(10)); + assert_eq!(pow_decimal_int(25i128, 1, 1).unwrap(), 25i128); + assert_eq!(pow_decimal_int(25i128, 0, 1).unwrap(), 25i128); + assert_eq!(pow_decimal_int(25i128, 0, 0).unwrap(), 1i128); + assert_eq!(pow_decimal_int(25i128, 1, 0).unwrap(), 10i128); + + assert_eq!(pow_decimal_int(25i128, -1, 4).unwrap(), 390625000i128); + } + + #[test] + fn test_pow_decimal_float_fallback() { + // Test negative exponent: 4^(-1) = 0.25 + // 4 with scale 2 = 400, result should be 25 (0.25 with scale 2) + let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap(); + assert_eq!(result, 25); + + // Test non-integer exponent: 4^0.5 = 2 + // 4 with scale 2 = 400, result should be 200 (2.0 with scale 2) + let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap(); + assert_eq!(result, 200); + + // Test 8^(1/3) = 2 (cube root) + // 8 with scale 1 = 80, result should be 20 (2.0 with scale 1) + let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap(); + assert_eq!(result, 20); + + // Test negative base with integer exponent still works + // (-2)^3 = -8 + // -2 with scale 1 = -20, result should be -80 (-8.0 with scale 1) + let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap(); + assert_eq!(result, -80); + + // Test positive integer exponent goes through fast path + // 2.5^4 = 39.0625 + // 25 with scale 1, result should be 390 (39.0 with scale 1) - truncated + let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap(); + assert_eq!(result, 390); // Uses integer path - assert_eq!(pow_decimal_int(25, -1, 4).unwrap(), i128::from(390625000)); + // Test non-finite exponent returns error + assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err()); + assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err()); } } diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index 9dd31427dcb4a..1e55b5d2b865e 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -1087,8 +1087,17 @@ SELECT power(2, 100000000000) ---- Infinity -query error Arrow error: Arithmetic overflow: Unsupported exp value -SELECT power(2::decimal(38, 0), -5) +# Negative exponent now works (fallback to f64) +query RT +SELECT power(2::decimal(38, 0), -5), arrow_typeof(power(2::decimal(38, 0), -5)); +---- +0 Decimal128(38, 0) + +# Negative exponent with scale preserves decimal places +query RT +SELECT power(4::decimal(38, 5), -1), arrow_typeof(power(4::decimal(38, 5), -1)); +---- +0.25 Decimal128(38, 5) # Expected to have `16 Decimal128(38, 0)` # Due to type coericion, it becomes Float -> Float -> Float @@ -1108,20 +1117,23 @@ SELECT power(2.5, 4.0), arrow_typeof(power(2.5, 4.0)); ---- 39 Decimal128(2, 1) -query error Compute error: Cannot use non-integer exp +# Non-integer exponent now works (fallback to f64) +query RT SELECT power(2.5, 4.2), arrow_typeof(power(2.5, 4.2)); +---- +46.9 Decimal128(2, 1) -query error Compute error: Cannot use non-integer exp: NaN +query error Compute error: Cannot use non-finite exp: NaN SELECT power(2::decimal(38, 0), arrow_cast('NaN','Float64')) -query error Compute error: Cannot use non-integer exp: inf +query error Compute error: Cannot use non-finite exp: inf SELECT power(2::decimal(38, 0), arrow_cast('INF','Float64')) -# Floating above u32::max -query error Compute error: Cannot use non-integer exp +# Floating above u32::max now works (fallback to f64, returns infinity which is an error) +query error Arrow error: Arithmetic overflow: Result of 2\^5000000000.1 is not finite SELECT power(2::decimal(38, 0), 5000000000.1) -# Integer Above u32::max +# Integer Above u32::max - still goes through integer path which fails query error Arrow error: Arithmetic overflow: Unsupported exp value SELECT power(2::decimal(38, 0), 5000000000)