Skip to content

Commit 029e2bd

Browse files
committed
feat: Allow pow with negative & non-integer exponent on decimals
1 parent 7900cd6 commit 029e2bd

File tree

2 files changed

+154
-16
lines changed

2 files changed

+154
-16
lines changed

datafusion/functions/src/math/power.rs

Lines changed: 134 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,14 @@ impl PowerFunc {
112112
/// 2.5 is represented as 25 with scale 1
113113
/// The unscaled result is 25^4 = 390625
114114
/// Scale it back to 1: 390625 / 10^4 = 39
115-
///
116-
/// Returns error if base is invalid
117115
fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
118116
where
119117
T: From<i32> + ArrowNativeTypeOp,
120118
{
119+
if exp < 0 {
120+
return pow_decimal_float(base, scale, exp as f64);
121+
}
122+
121123
let scale: u32 = scale.try_into().map_err(|_| {
122124
ArrowError::NotYetImplemented(format!(
123125
"Negative scale is not yet supported value: {scale}"
@@ -149,22 +151,112 @@ where
149151

150152
/// Binary function to calculate a math power to float exponent
151153
/// for scaled integer types.
152-
/// Returns error if exponent is negative or non-integer, or base invalid
153154
fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
154155
where
155156
T: From<i32> + ArrowNativeTypeOp,
156157
{
157-
if !exp.is_finite() || exp.trunc() != exp {
158+
if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
159+
return pow_decimal_int(base, scale, exp as i64);
160+
}
161+
162+
if !exp.is_finite() {
158163
return Err(ArrowError::ComputeError(format!(
159-
"Cannot use non-integer exp: {exp}"
164+
"Cannot use non-finite exp: {exp}"
165+
)));
166+
}
167+
168+
pow_decimal_float_fallback(base, scale, exp)
169+
}
170+
171+
/// Fallback implementation using f64 for negative or non-integer exponents.
172+
/// This handles cases that cannot be computed using integer arithmetic.
173+
fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
174+
where
175+
T: From<i32> + ArrowNativeTypeOp,
176+
{
177+
let scale_factor = 10f64.powi(scale as i32);
178+
let base_f64 = format!("{base:?}")
179+
.parse::<f64>()
180+
.map(|v| v / scale_factor)
181+
.map_err(|_| {
182+
ArrowError::ComputeError(format!("Cannot convert base {base:?} to f64"))
183+
})?;
184+
185+
let result_f64 = base_f64.powf(exp);
186+
187+
if !result_f64.is_finite() {
188+
return Err(ArrowError::ArithmeticOverflow(format!(
189+
"Result of {base_f64}^{exp} is not finite"
160190
)));
161191
}
162-
if exp < 0f64 || exp >= u32::MAX as f64 {
192+
193+
let result_scaled = result_f64 * scale_factor;
194+
let result_rounded = result_scaled.round();
195+
196+
if result_rounded.abs() > i128::MAX as f64 {
163197
return Err(ArrowError::ArithmeticOverflow(format!(
164-
"Unsupported exp value: {exp}"
198+
"Result {result_rounded} is too large for the target decimal type"
165199
)));
166200
}
167-
pow_decimal_int(base, scale, exp as i64)
201+
202+
decimal_from_i128::<T>(result_rounded as i128)
203+
}
204+
205+
fn decimal_from_i128<T>(value: i128) -> Result<T, ArrowError>
206+
where
207+
T: From<i32> + ArrowNativeTypeOp,
208+
{
209+
if value == 0 {
210+
return Ok(T::from(0));
211+
}
212+
213+
if value >= i32::MIN as i128 && value <= i32::MAX as i128 {
214+
return Ok(T::from(value as i32));
215+
}
216+
217+
let is_negative = value < 0;
218+
let abs_value = value.unsigned_abs();
219+
220+
let billion = 1_000_000_000u128;
221+
let mut result = T::from(0);
222+
let mut multiplier = T::from(1);
223+
let billion_t = T::from(1_000_000_000);
224+
225+
let mut remaining = abs_value;
226+
while remaining > 0 {
227+
let chunk = (remaining % billion) as i32;
228+
remaining /= billion;
229+
230+
let chunk_value = T::from(chunk).mul_checked(multiplier).map_err(|_| {
231+
ArrowError::ArithmeticOverflow(format!(
232+
"Overflow while converting {value} to decimal type"
233+
))
234+
})?;
235+
236+
result = result.add_checked(chunk_value).map_err(|_| {
237+
ArrowError::ArithmeticOverflow(format!(
238+
"Overflow while converting {value} to decimal type"
239+
))
240+
})?;
241+
242+
if remaining > 0 {
243+
multiplier = multiplier.mul_checked(billion_t).map_err(|_| {
244+
ArrowError::ArithmeticOverflow(format!(
245+
"Overflow while converting {value} to decimal type"
246+
))
247+
})?;
248+
}
249+
}
250+
251+
if is_negative {
252+
result = T::from(0).sub_checked(result).map_err(|_| {
253+
ArrowError::ArithmeticOverflow(format!(
254+
"Overflow while negating {value} in decimal type"
255+
))
256+
})?;
257+
}
258+
259+
Ok(result)
168260
}
169261

170262
impl ScalarUDFImpl for PowerFunc {
@@ -392,4 +484,38 @@ mod tests {
392484
"Not yet implemented: Negative scale is not yet supported value: -1"
393485
);
394486
}
487+
488+
#[test]
489+
fn test_pow_decimal_float_fallback() {
490+
// Test negative exponent: 4^(-1) = 0.25
491+
// 4 with scale 2 = 400, result should be 25 (0.25 with scale 2)
492+
let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap();
493+
assert_eq!(result, 25);
494+
495+
// Test non-integer exponent: 4^0.5 = 2
496+
// 4 with scale 2 = 400, result should be 200 (2.0 with scale 2)
497+
let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap();
498+
assert_eq!(result, 200);
499+
500+
// Test 8^(1/3) = 2 (cube root)
501+
// 8 with scale 1 = 80, result should be 20 (2.0 with scale 1)
502+
let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap();
503+
assert_eq!(result, 20);
504+
505+
// Test negative base with integer exponent still works
506+
// (-2)^3 = -8
507+
// -2 with scale 1 = -20, result should be -80 (-8.0 with scale 1)
508+
let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap();
509+
assert_eq!(result, -80);
510+
511+
// Test positive integer exponent goes through fast path
512+
// 2.5^4 = 39.0625
513+
// 25 with scale 1, result should be 390 (39.0 with scale 1) - truncated
514+
let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap();
515+
assert_eq!(result, 390); // Uses integer path
516+
517+
// Test non-finite exponent returns error
518+
assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err());
519+
assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err());
520+
}
395521
}

datafusion/sqllogictest/test_files/decimal.slt

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -954,8 +954,17 @@ SELECT power(2, 100000000000)
954954
----
955955
Infinity
956956

957-
query error Arrow error: Arithmetic overflow: Unsupported exp value
958-
SELECT power(2::decimal(38, 0), -5)
957+
# Negative exponent now works (fallback to f64)
958+
query RT
959+
SELECT power(2::decimal(38, 0), -5), arrow_typeof(power(2::decimal(38, 0), -5));
960+
----
961+
0 Decimal128(38, 0)
962+
963+
# Negative exponent with scale preserves decimal places
964+
query RT
965+
SELECT power(4::decimal(38, 5), -1), arrow_typeof(power(4::decimal(38, 5), -1));
966+
----
967+
0.25 Decimal128(38, 5)
959968

960969
# Expected to have `16 Decimal128(38, 0)`
961970
# Due to type coericion, it becomes Float -> Float -> Float
@@ -975,20 +984,23 @@ SELECT power(2.5, 4.0), arrow_typeof(power(2.5, 4.0));
975984
----
976985
39 Decimal128(2, 1)
977986

978-
query error Compute error: Cannot use non-integer exp
987+
# Non-integer exponent now works (fallback to f64)
988+
query RT
979989
SELECT power(2.5, 4.2), arrow_typeof(power(2.5, 4.2));
990+
----
991+
46.9 Decimal128(2, 1)
980992

981-
query error Compute error: Cannot use non-integer exp: NaN
993+
query error Compute error: Cannot use non-finite exp: NaN
982994
SELECT power(2::decimal(38, 0), arrow_cast('NaN','Float64'))
983995

984-
query error Compute error: Cannot use non-integer exp: inf
996+
query error Compute error: Cannot use non-finite exp: inf
985997
SELECT power(2::decimal(38, 0), arrow_cast('INF','Float64'))
986998

987-
# Floating above u32::max
988-
query error Compute error: Cannot use non-integer exp
999+
# Floating above u32::max now works (fallback to f64, returns infinity which is an error)
1000+
query error Arrow error: Arithmetic overflow: Result of 2\^5000000000.1 is not finite
9891001
SELECT power(2::decimal(38, 0), 5000000000.1)
9901002

991-
# Integer Above u32::max
1003+
# Integer Above u32::max - still goes through integer path which fails
9921004
query error Arrow error: Arithmetic overflow: Unsupported exp value
9931005
SELECT power(2::decimal(38, 0), 5000000000)
9941006

0 commit comments

Comments
 (0)