Skip to content

Commit 9f79025

Browse files
committed
feat: Allow pow with negative & non-integer exponent on decimals
1 parent 33ac70d commit 9f79025

File tree

2 files changed

+264
-45
lines changed

2 files changed

+264
-45
lines changed

datafusion/functions/src/math/power.rs

Lines changed: 244 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use super::log::LogFunc;
2222

2323
use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
2424
use arrow::array::{Array, ArrayRef};
25+
use arrow::datatypes::i256;
2526
use arrow::datatypes::{
2627
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
2728
Decimal256Type, Float64Type, Int64Type,
@@ -37,6 +38,7 @@ use datafusion_expr::{
3738
ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit,
3839
};
3940
use datafusion_macros::user_doc;
41+
use num_traits::{NumCast, ToPrimitive};
4042

4143
#[user_doc(
4244
doc_section(label = "Math Functions"),
@@ -112,26 +114,31 @@ impl PowerFunc {
112114
/// 2.5 is represented as 25 with scale 1
113115
/// The unscaled result is 25^4 = 390625
114116
/// Scale it back to 1: 390625 / 10^4 = 39
115-
///
116-
/// Returns error if base is invalid
117117
fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
118118
where
119-
T: From<i32> + ArrowNativeTypeOp,
119+
T: From<i32> + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
120120
{
121+
// Negative exponent: fall back to float computation
122+
if exp < 0 {
123+
return pow_decimal_float(base, scale, exp as f64);
124+
}
125+
121126
let exp: u32 = exp.try_into().map_err(|_| {
122127
ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
123128
})?;
124129
// Handle edge case for exp == 0
125130
// If scale < 0, 10^scale (e.g., 10^-2 = 0.01) becomes 0 in integer arithmetic.
126131
if exp == 0 {
127132
return if scale >= 0 {
128-
T::from(10).pow_checked(scale as u32).map_err(|_| {
129-
ArrowError::ArithmeticOverflow(format!(
130-
"Cannot make unscale factor for {scale} and {exp}"
131-
))
132-
})
133+
<T as From<i32>>::from(10)
134+
.pow_checked(scale as u32)
135+
.map_err(|_| {
136+
ArrowError::ArithmeticOverflow(format!(
137+
"Cannot make unscale factor for {scale} and {exp}"
138+
))
139+
})
133140
} else {
134-
Ok(T::from(0))
141+
Ok(<T as From<i32>>::from(0))
135142
};
136143
}
137144
let powered: T = base.pow_checked(exp).map_err(|_| {
@@ -149,11 +156,13 @@ where
149156
// If mul_exp is positive, we divide (standard case).
150157
// If mul_exp is negative, we multiply (negative scale case).
151158
if mul_exp > 0 {
152-
let div_factor: T = T::from(10).pow_checked(mul_exp as u32).map_err(|_| {
153-
ArrowError::ArithmeticOverflow(format!(
154-
"Cannot make div factor for {scale} and {exp}"
155-
))
156-
})?;
159+
let div_factor: T = <T as From<i32>>::from(10)
160+
.pow_checked(mul_exp as u32)
161+
.map_err(|_| {
162+
ArrowError::ArithmeticOverflow(format!(
163+
"Cannot make div factor for {scale} and {exp}"
164+
))
165+
})?;
157166
powered.div_checked(div_factor)
158167
} else {
159168
// mul_exp is negative, so we multiply by 10^(-mul_exp)
@@ -162,33 +171,197 @@ where
162171
"Overflow while negating scale exponent".to_string(),
163172
)
164173
})?;
165-
let mul_factor: T = T::from(10).pow_checked(abs_exp as u32).map_err(|_| {
166-
ArrowError::ArithmeticOverflow(format!(
167-
"Cannot make mul factor for {scale} and {exp}"
168-
))
169-
})?;
174+
let mul_factor: T = <T as From<i32>>::from(10)
175+
.pow_checked(abs_exp as u32)
176+
.map_err(|_| {
177+
ArrowError::ArithmeticOverflow(format!(
178+
"Cannot make mul factor for {scale} and {exp}"
179+
))
180+
})?;
170181
powered.mul_checked(mul_factor)
171182
}
172183
}
173184

174185
/// Binary function to calculate a math power to float exponent
175186
/// for scaled integer types.
176-
/// Returns error if exponent is negative or non-integer, or base invalid
177187
fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
178188
where
179-
T: From<i32> + ArrowNativeTypeOp,
189+
T: From<i32> + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
180190
{
181-
if !exp.is_finite() || exp.trunc() != exp {
191+
if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
192+
return pow_decimal_int(base, scale, exp as i64);
193+
}
194+
195+
if !exp.is_finite() {
182196
return Err(ArrowError::ComputeError(format!(
183-
"Cannot use non-integer exp: {exp}"
197+
"Cannot use non-finite exp: {exp}"
198+
)));
199+
}
200+
201+
pow_decimal_float_fallback(base, scale, exp)
202+
}
203+
204+
/// Compute the f64 power result and scale it back.
205+
/// Returns the rounded i128 result for conversion to target type.
206+
#[inline]
207+
fn compute_pow_f64_result(
208+
base_f64: f64,
209+
scale: i8,
210+
exp: f64,
211+
) -> Result<i128, ArrowError> {
212+
let result_f64 = base_f64.powf(exp);
213+
214+
if !result_f64.is_finite() {
215+
return Err(ArrowError::ArithmeticOverflow(format!(
216+
"Result of {base_f64}^{exp} is not finite"
184217
)));
185218
}
186-
if exp < 0f64 || exp >= u32::MAX as f64 {
219+
220+
let scale_factor = 10f64.powi(scale as i32);
221+
let result_scaled = result_f64 * scale_factor;
222+
let result_rounded = result_scaled.round();
223+
224+
if result_rounded.abs() > i128::MAX as f64 {
187225
return Err(ArrowError::ArithmeticOverflow(format!(
188-
"Unsupported exp value: {exp}"
226+
"Result {result_rounded} is too large for the target decimal type"
227+
)));
228+
}
229+
230+
Ok(result_rounded as i128)
231+
}
232+
233+
/// Convert i128 result to target decimal native type using NumCast.
234+
/// Returns error if value overflows the target type.
235+
#[inline]
236+
fn decimal_from_i128<T>(value: i128) -> Result<T, ArrowError>
237+
where
238+
T: NumCast,
239+
{
240+
NumCast::from(value).ok_or_else(|| {
241+
ArrowError::ArithmeticOverflow(format!(
242+
"Value {value} is too large for the target decimal type"
243+
))
244+
})
245+
}
246+
247+
/// Fallback implementation using f64 for negative or non-integer exponents.
248+
/// This handles cases that cannot be computed using integer arithmetic.
249+
fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
250+
where
251+
T: ToPrimitive + NumCast + Copy,
252+
{
253+
if scale < 0 {
254+
return Err(ArrowError::NotYetImplemented(format!(
255+
"Negative scale is not yet supported: {scale}"
256+
)));
257+
}
258+
259+
let scale_factor = 10f64.powi(scale as i32);
260+
let base_f64 = base.to_f64().ok_or_else(|| {
261+
ArrowError::ComputeError("Cannot convert base to f64".to_string())
262+
})? / scale_factor;
263+
264+
let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
265+
266+
decimal_from_i128(result_i128)
267+
}
268+
269+
/// Decimal256 specialized float exponent version.
270+
fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result<i256, ArrowError> {
271+
if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
272+
return pow_decimal256_int(base, scale, exp as i64);
273+
}
274+
275+
if !exp.is_finite() {
276+
return Err(ArrowError::ComputeError(format!(
277+
"Cannot use non-finite exp: {exp}"
278+
)));
279+
}
280+
281+
pow_decimal256_float_fallback(base, scale, exp)
282+
}
283+
284+
/// Decimal256 specialized integer exponent version.
285+
fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result<i256, ArrowError> {
286+
if exp < 0 {
287+
return pow_decimal256_float(base, scale, exp as f64);
288+
}
289+
290+
let exp: u32 = exp.try_into().map_err(|_| {
291+
ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
292+
})?;
293+
294+
if exp == 0 {
295+
return if scale >= 0 {
296+
i256::from_i128(10).pow_checked(scale as u32).map_err(|_| {
297+
ArrowError::ArithmeticOverflow(format!(
298+
"Cannot make unscale factor for {scale} and {exp}"
299+
))
300+
})
301+
} else {
302+
Ok(i256::from_i128(0))
303+
};
304+
}
305+
306+
let powered: i256 = base.pow_checked(exp).map_err(|_| {
307+
ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
308+
})?;
309+
310+
let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
311+
312+
if mul_exp == 0 {
313+
return Ok(powered);
314+
}
315+
316+
if mul_exp > 0 {
317+
let div_factor: i256 =
318+
i256::from_i128(10)
319+
.pow_checked(mul_exp as u32)
320+
.map_err(|_| {
321+
ArrowError::ArithmeticOverflow(format!(
322+
"Cannot make div factor for {scale} and {exp}"
323+
))
324+
})?;
325+
powered.div_checked(div_factor)
326+
} else {
327+
let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
328+
ArrowError::ArithmeticOverflow(
329+
"Overflow while negating scale exponent".to_string(),
330+
)
331+
})?;
332+
let mul_factor: i256 =
333+
i256::from_i128(10)
334+
.pow_checked(abs_exp as u32)
335+
.map_err(|_| {
336+
ArrowError::ArithmeticOverflow(format!(
337+
"Cannot make mul factor for {scale} and {exp}"
338+
))
339+
})?;
340+
powered.mul_checked(mul_factor)
341+
}
342+
}
343+
344+
/// Fallback implementation for Decimal256.
345+
fn pow_decimal256_float_fallback(
346+
base: i256,
347+
scale: i8,
348+
exp: f64,
349+
) -> Result<i256, ArrowError> {
350+
if scale < 0 {
351+
return Err(ArrowError::NotYetImplemented(format!(
352+
"Negative scale is not yet supported: {scale}"
189353
)));
190354
}
191-
pow_decimal_int(base, scale, exp as i64)
355+
356+
let scale_factor = 10f64.powi(scale as i32);
357+
let base_f64 = base.to_f64().ok_or_else(|| {
358+
ArrowError::ComputeError("Cannot convert base to f64".to_string())
359+
})? / scale_factor;
360+
361+
let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
362+
363+
// i256 can be constructed from i128 directly
364+
Ok(i256::from_i128(result_i128))
192365
}
193366

194367
impl ScalarUDFImpl for PowerFunc {
@@ -311,7 +484,7 @@ impl ScalarUDFImpl for PowerFunc {
311484
>(
312485
&base,
313486
exponent,
314-
|b, e| pow_decimal_int(b, *scale, e),
487+
|b, e| pow_decimal256_int(b, *scale, e),
315488
*precision,
316489
*scale,
317490
)?
@@ -325,7 +498,7 @@ impl ScalarUDFImpl for PowerFunc {
325498
>(
326499
&base,
327500
exponent,
328-
|b, e| pow_decimal_float(b, *scale, e),
501+
|b, e| pow_decimal256_float(b, *scale, e),
329502
*precision,
330503
*scale,
331504
)?
@@ -398,19 +571,53 @@ mod tests {
398571
#[test]
399572
fn test_pow_decimal128_helper() {
400573
// Expression: 2.5 ^ 4 = 39.0625
401-
assert_eq!(pow_decimal_int(25, 1, 4).unwrap(), i128::from(390));
402-
assert_eq!(pow_decimal_int(2500, 3, 4).unwrap(), i128::from(39062));
403-
assert_eq!(pow_decimal_int(25000, 4, 4).unwrap(), i128::from(390625));
574+
assert_eq!(pow_decimal_int(25i128, 1, 4).unwrap(), 390i128);
575+
assert_eq!(pow_decimal_int(2500i128, 3, 4).unwrap(), 39062i128);
576+
assert_eq!(pow_decimal_int(25000i128, 4, 4).unwrap(), 390625i128);
404577

405578
// Expression: 25 ^ 4 = 390625
406-
assert_eq!(pow_decimal_int(25, 0, 4).unwrap(), i128::from(390625));
579+
assert_eq!(pow_decimal_int(25i128, 0, 4).unwrap(), 390625i128);
407580

408581
// Expressions for edge cases
409-
assert_eq!(pow_decimal_int(25, 1, 1).unwrap(), i128::from(25));
410-
assert_eq!(pow_decimal_int(25, 0, 1).unwrap(), i128::from(25));
411-
assert_eq!(pow_decimal_int(25, 0, 0).unwrap(), i128::from(1));
412-
assert_eq!(pow_decimal_int(25, 1, 0).unwrap(), i128::from(10));
582+
assert_eq!(pow_decimal_int(25i128, 1, 1).unwrap(), 25i128);
583+
assert_eq!(pow_decimal_int(25i128, 0, 1).unwrap(), 25i128);
584+
assert_eq!(pow_decimal_int(25i128, 0, 0).unwrap(), 1i128);
585+
assert_eq!(pow_decimal_int(25i128, 1, 0).unwrap(), 10i128);
586+
587+
assert_eq!(pow_decimal_int(25i128, -1, 4).unwrap(), 390625000i128);
588+
}
589+
590+
#[test]
591+
fn test_pow_decimal_float_fallback() {
592+
// Test negative exponent: 4^(-1) = 0.25
593+
// 4 with scale 2 = 400, result should be 25 (0.25 with scale 2)
594+
let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap();
595+
assert_eq!(result, 25);
596+
597+
// Test non-integer exponent: 4^0.5 = 2
598+
// 4 with scale 2 = 400, result should be 200 (2.0 with scale 2)
599+
let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap();
600+
assert_eq!(result, 200);
601+
602+
// Test 8^(1/3) = 2 (cube root)
603+
// 8 with scale 1 = 80, result should be 20 (2.0 with scale 1)
604+
let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap();
605+
assert_eq!(result, 20);
606+
607+
// Test negative base with integer exponent still works
608+
// (-2)^3 = -8
609+
// -2 with scale 1 = -20, result should be -80 (-8.0 with scale 1)
610+
let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap();
611+
assert_eq!(result, -80);
612+
613+
// Test positive integer exponent goes through fast path
614+
// 2.5^4 = 39.0625
615+
// 25 with scale 1, result should be 390 (39.0 with scale 1) - truncated
616+
let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap();
617+
assert_eq!(result, 390); // Uses integer path
413618

414-
assert_eq!(pow_decimal_int(25, -1, 4).unwrap(), i128::from(390625000));
619+
// Test non-finite exponent returns error
620+
assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err());
621+
assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err());
415622
}
416623
}

0 commit comments

Comments
 (0)