@@ -22,6 +22,7 @@ use super::log::LogFunc;
2222
2323use crate :: utils:: { calculate_binary_decimal_math, calculate_binary_math} ;
2424use arrow:: array:: { Array , ArrayRef } ;
25+ use arrow:: datatypes:: i256;
2526use arrow:: datatypes:: {
2627 ArrowNativeTypeOp , DataType , Decimal32Type , Decimal64Type , Decimal128Type ,
2728 Decimal256Type , Float64Type , Int64Type ,
@@ -37,6 +38,7 @@ use datafusion_expr::{
3738 ScalarUDFImpl , Signature , TypeSignature , TypeSignatureClass , Volatility , lit,
3839} ;
3940use datafusion_macros:: user_doc;
41+ use num_traits:: { NumCast , ToPrimitive } ;
4042
4143#[ user_doc(
4244 doc_section( label = "Math Functions" ) ,
@@ -112,26 +114,31 @@ impl PowerFunc {
112114/// 2.5 is represented as 25 with scale 1
113115/// The unscaled result is 25^4 = 390625
114116/// Scale it back to 1: 390625 / 10^4 = 39
115- ///
116- /// Returns error if base is invalid
117117fn pow_decimal_int < T > ( base : T , scale : i8 , exp : i64 ) -> Result < T , ArrowError >
118118where
119- T : From < i32 > + ArrowNativeTypeOp ,
119+ T : From < i32 > + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy ,
120120{
121+ // Negative exponent: fall back to float computation
122+ if exp < 0 {
123+ return pow_decimal_float ( base, scale, exp as f64 ) ;
124+ }
125+
121126 let exp: u32 = exp. try_into ( ) . map_err ( |_| {
122127 ArrowError :: ArithmeticOverflow ( format ! ( "Unsupported exp value: {exp}" ) )
123128 } ) ?;
124129 // Handle edge case for exp == 0
125130 // If scale < 0, 10^scale (e.g., 10^-2 = 0.01) becomes 0 in integer arithmetic.
126131 if exp == 0 {
127132 return if scale >= 0 {
128- T :: from ( 10 ) . pow_checked ( scale as u32 ) . map_err ( |_| {
129- ArrowError :: ArithmeticOverflow ( format ! (
130- "Cannot make unscale factor for {scale} and {exp}"
131- ) )
132- } )
133+ <T as From < i32 > >:: from ( 10 )
134+ . pow_checked ( scale as u32 )
135+ . map_err ( |_| {
136+ ArrowError :: ArithmeticOverflow ( format ! (
137+ "Cannot make unscale factor for {scale} and {exp}"
138+ ) )
139+ } )
133140 } else {
134- Ok ( T :: from ( 0 ) )
141+ Ok ( < T as From < i32 > > :: from ( 0 ) )
135142 } ;
136143 }
137144 let powered: T = base. pow_checked ( exp) . map_err ( |_| {
@@ -149,11 +156,13 @@ where
149156 // If mul_exp is positive, we divide (standard case).
150157 // If mul_exp is negative, we multiply (negative scale case).
151158 if mul_exp > 0 {
152- let div_factor: T = T :: from ( 10 ) . pow_checked ( mul_exp as u32 ) . map_err ( |_| {
153- ArrowError :: ArithmeticOverflow ( format ! (
154- "Cannot make div factor for {scale} and {exp}"
155- ) )
156- } ) ?;
159+ let div_factor: T = <T as From < i32 > >:: from ( 10 )
160+ . pow_checked ( mul_exp as u32 )
161+ . map_err ( |_| {
162+ ArrowError :: ArithmeticOverflow ( format ! (
163+ "Cannot make div factor for {scale} and {exp}"
164+ ) )
165+ } ) ?;
157166 powered. div_checked ( div_factor)
158167 } else {
159168 // mul_exp is negative, so we multiply by 10^(-mul_exp)
@@ -162,33 +171,197 @@ where
162171 "Overflow while negating scale exponent" . to_string ( ) ,
163172 )
164173 } ) ?;
165- let mul_factor: T = T :: from ( 10 ) . pow_checked ( abs_exp as u32 ) . map_err ( |_| {
166- ArrowError :: ArithmeticOverflow ( format ! (
167- "Cannot make mul factor for {scale} and {exp}"
168- ) )
169- } ) ?;
174+ let mul_factor: T = <T as From < i32 > >:: from ( 10 )
175+ . pow_checked ( abs_exp as u32 )
176+ . map_err ( |_| {
177+ ArrowError :: ArithmeticOverflow ( format ! (
178+ "Cannot make mul factor for {scale} and {exp}"
179+ ) )
180+ } ) ?;
170181 powered. mul_checked ( mul_factor)
171182 }
172183}
173184
174185/// Binary function to calculate a math power to float exponent
175186/// for scaled integer types.
176- /// Returns error if exponent is negative or non-integer, or base invalid
177187fn pow_decimal_float < T > ( base : T , scale : i8 , exp : f64 ) -> Result < T , ArrowError >
178188where
179- T : From < i32 > + ArrowNativeTypeOp ,
189+ T : From < i32 > + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy ,
180190{
181- if !exp. is_finite ( ) || exp. trunc ( ) != exp {
191+ if exp. is_finite ( ) && exp. trunc ( ) == exp && exp >= 0f64 && exp < u32:: MAX as f64 {
192+ return pow_decimal_int ( base, scale, exp as i64 ) ;
193+ }
194+
195+ if !exp. is_finite ( ) {
182196 return Err ( ArrowError :: ComputeError ( format ! (
183- "Cannot use non-integer exp: {exp}"
197+ "Cannot use non-finite exp: {exp}"
198+ ) ) ) ;
199+ }
200+
201+ pow_decimal_float_fallback ( base, scale, exp)
202+ }
203+
204+ /// Compute the f64 power result and scale it back.
205+ /// Returns the rounded i128 result for conversion to target type.
206+ #[ inline]
207+ fn compute_pow_f64_result (
208+ base_f64 : f64 ,
209+ scale : i8 ,
210+ exp : f64 ,
211+ ) -> Result < i128 , ArrowError > {
212+ let result_f64 = base_f64. powf ( exp) ;
213+
214+ if !result_f64. is_finite ( ) {
215+ return Err ( ArrowError :: ArithmeticOverflow ( format ! (
216+ "Result of {base_f64}^{exp} is not finite"
184217 ) ) ) ;
185218 }
186- if exp < 0f64 || exp >= u32:: MAX as f64 {
219+
220+ let scale_factor = 10f64 . powi ( scale as i32 ) ;
221+ let result_scaled = result_f64 * scale_factor;
222+ let result_rounded = result_scaled. round ( ) ;
223+
224+ if result_rounded. abs ( ) > i128:: MAX as f64 {
187225 return Err ( ArrowError :: ArithmeticOverflow ( format ! (
188- "Unsupported exp value: {exp}"
226+ "Result {result_rounded} is too large for the target decimal type"
227+ ) ) ) ;
228+ }
229+
230+ Ok ( result_rounded as i128 )
231+ }
232+
233+ /// Convert i128 result to target decimal native type using NumCast.
234+ /// Returns error if value overflows the target type.
235+ #[ inline]
236+ fn decimal_from_i128 < T > ( value : i128 ) -> Result < T , ArrowError >
237+ where
238+ T : NumCast ,
239+ {
240+ NumCast :: from ( value) . ok_or_else ( || {
241+ ArrowError :: ArithmeticOverflow ( format ! (
242+ "Value {value} is too large for the target decimal type"
243+ ) )
244+ } )
245+ }
246+
247+ /// Fallback implementation using f64 for negative or non-integer exponents.
248+ /// This handles cases that cannot be computed using integer arithmetic.
249+ fn pow_decimal_float_fallback < T > ( base : T , scale : i8 , exp : f64 ) -> Result < T , ArrowError >
250+ where
251+ T : ToPrimitive + NumCast + Copy ,
252+ {
253+ if scale < 0 {
254+ return Err ( ArrowError :: NotYetImplemented ( format ! (
255+ "Negative scale is not yet supported: {scale}"
256+ ) ) ) ;
257+ }
258+
259+ let scale_factor = 10f64 . powi ( scale as i32 ) ;
260+ let base_f64 = base. to_f64 ( ) . ok_or_else ( || {
261+ ArrowError :: ComputeError ( "Cannot convert base to f64" . to_string ( ) )
262+ } ) ? / scale_factor;
263+
264+ let result_i128 = compute_pow_f64_result ( base_f64, scale, exp) ?;
265+
266+ decimal_from_i128 ( result_i128)
267+ }
268+
269+ /// Decimal256 specialized float exponent version.
270+ fn pow_decimal256_float ( base : i256 , scale : i8 , exp : f64 ) -> Result < i256 , ArrowError > {
271+ if exp. is_finite ( ) && exp. trunc ( ) == exp && exp >= 0f64 && exp < u32:: MAX as f64 {
272+ return pow_decimal256_int ( base, scale, exp as i64 ) ;
273+ }
274+
275+ if !exp. is_finite ( ) {
276+ return Err ( ArrowError :: ComputeError ( format ! (
277+ "Cannot use non-finite exp: {exp}"
278+ ) ) ) ;
279+ }
280+
281+ pow_decimal256_float_fallback ( base, scale, exp)
282+ }
283+
284+ /// Decimal256 specialized integer exponent version.
285+ fn pow_decimal256_int ( base : i256 , scale : i8 , exp : i64 ) -> Result < i256 , ArrowError > {
286+ if exp < 0 {
287+ return pow_decimal256_float ( base, scale, exp as f64 ) ;
288+ }
289+
290+ let exp: u32 = exp. try_into ( ) . map_err ( |_| {
291+ ArrowError :: ArithmeticOverflow ( format ! ( "Unsupported exp value: {exp}" ) )
292+ } ) ?;
293+
294+ if exp == 0 {
295+ return if scale >= 0 {
296+ i256:: from_i128 ( 10 ) . pow_checked ( scale as u32 ) . map_err ( |_| {
297+ ArrowError :: ArithmeticOverflow ( format ! (
298+ "Cannot make unscale factor for {scale} and {exp}"
299+ ) )
300+ } )
301+ } else {
302+ Ok ( i256:: from_i128 ( 0 ) )
303+ } ;
304+ }
305+
306+ let powered: i256 = base. pow_checked ( exp) . map_err ( |_| {
307+ ArrowError :: ArithmeticOverflow ( format ! ( "Cannot raise base {base:?} to exp {exp}" ) )
308+ } ) ?;
309+
310+ let mul_exp = ( scale as i64 ) . wrapping_mul ( exp as i64 - 1 ) ;
311+
312+ if mul_exp == 0 {
313+ return Ok ( powered) ;
314+ }
315+
316+ if mul_exp > 0 {
317+ let div_factor: i256 =
318+ i256:: from_i128 ( 10 )
319+ . pow_checked ( mul_exp as u32 )
320+ . map_err ( |_| {
321+ ArrowError :: ArithmeticOverflow ( format ! (
322+ "Cannot make div factor for {scale} and {exp}"
323+ ) )
324+ } ) ?;
325+ powered. div_checked ( div_factor)
326+ } else {
327+ let abs_exp = mul_exp. checked_neg ( ) . ok_or_else ( || {
328+ ArrowError :: ArithmeticOverflow (
329+ "Overflow while negating scale exponent" . to_string ( ) ,
330+ )
331+ } ) ?;
332+ let mul_factor: i256 =
333+ i256:: from_i128 ( 10 )
334+ . pow_checked ( abs_exp as u32 )
335+ . map_err ( |_| {
336+ ArrowError :: ArithmeticOverflow ( format ! (
337+ "Cannot make mul factor for {scale} and {exp}"
338+ ) )
339+ } ) ?;
340+ powered. mul_checked ( mul_factor)
341+ }
342+ }
343+
344+ /// Fallback implementation for Decimal256.
345+ fn pow_decimal256_float_fallback (
346+ base : i256 ,
347+ scale : i8 ,
348+ exp : f64 ,
349+ ) -> Result < i256 , ArrowError > {
350+ if scale < 0 {
351+ return Err ( ArrowError :: NotYetImplemented ( format ! (
352+ "Negative scale is not yet supported: {scale}"
189353 ) ) ) ;
190354 }
191- pow_decimal_int ( base, scale, exp as i64 )
355+
356+ let scale_factor = 10f64 . powi ( scale as i32 ) ;
357+ let base_f64 = base. to_f64 ( ) . ok_or_else ( || {
358+ ArrowError :: ComputeError ( "Cannot convert base to f64" . to_string ( ) )
359+ } ) ? / scale_factor;
360+
361+ let result_i128 = compute_pow_f64_result ( base_f64, scale, exp) ?;
362+
363+ // i256 can be constructed from i128 directly
364+ Ok ( i256:: from_i128 ( result_i128) )
192365}
193366
194367impl ScalarUDFImpl for PowerFunc {
@@ -311,7 +484,7 @@ impl ScalarUDFImpl for PowerFunc {
311484 > (
312485 & base,
313486 exponent,
314- |b, e| pow_decimal_int ( b, * scale, e) ,
487+ |b, e| pow_decimal256_int ( b, * scale, e) ,
315488 * precision,
316489 * scale,
317490 ) ?
@@ -325,7 +498,7 @@ impl ScalarUDFImpl for PowerFunc {
325498 > (
326499 & base,
327500 exponent,
328- |b, e| pow_decimal_float ( b, * scale, e) ,
501+ |b, e| pow_decimal256_float ( b, * scale, e) ,
329502 * precision,
330503 * scale,
331504 ) ?
@@ -398,19 +571,53 @@ mod tests {
398571 #[ test]
399572 fn test_pow_decimal128_helper ( ) {
400573 // Expression: 2.5 ^ 4 = 39.0625
401- assert_eq ! ( pow_decimal_int( 25 , 1 , 4 ) . unwrap( ) , i128 :: from ( 390 ) ) ;
402- assert_eq ! ( pow_decimal_int( 2500 , 3 , 4 ) . unwrap( ) , i128 :: from ( 39062 ) ) ;
403- assert_eq ! ( pow_decimal_int( 25000 , 4 , 4 ) . unwrap( ) , i128 :: from ( 390625 ) ) ;
574+ assert_eq ! ( pow_decimal_int( 25i128 , 1 , 4 ) . unwrap( ) , 390i128 ) ;
575+ assert_eq ! ( pow_decimal_int( 2500i128 , 3 , 4 ) . unwrap( ) , 39062i128 ) ;
576+ assert_eq ! ( pow_decimal_int( 25000i128 , 4 , 4 ) . unwrap( ) , 390625i128 ) ;
404577
405578 // Expression: 25 ^ 4 = 390625
406- assert_eq ! ( pow_decimal_int( 25 , 0 , 4 ) . unwrap( ) , i128 :: from ( 390625 ) ) ;
579+ assert_eq ! ( pow_decimal_int( 25i128 , 0 , 4 ) . unwrap( ) , 390625i128 ) ;
407580
408581 // Expressions for edge cases
409- assert_eq ! ( pow_decimal_int( 25 , 1 , 1 ) . unwrap( ) , i128 :: from( 25 ) ) ;
410- assert_eq ! ( pow_decimal_int( 25 , 0 , 1 ) . unwrap( ) , i128 :: from( 25 ) ) ;
411- assert_eq ! ( pow_decimal_int( 25 , 0 , 0 ) . unwrap( ) , i128 :: from( 1 ) ) ;
412- assert_eq ! ( pow_decimal_int( 25 , 1 , 0 ) . unwrap( ) , i128 :: from( 10 ) ) ;
582+ assert_eq ! ( pow_decimal_int( 25i128 , 1 , 1 ) . unwrap( ) , 25i128 ) ;
583+ assert_eq ! ( pow_decimal_int( 25i128 , 0 , 1 ) . unwrap( ) , 25i128 ) ;
584+ assert_eq ! ( pow_decimal_int( 25i128 , 0 , 0 ) . unwrap( ) , 1i128 ) ;
585+ assert_eq ! ( pow_decimal_int( 25i128 , 1 , 0 ) . unwrap( ) , 10i128 ) ;
586+
587+ assert_eq ! ( pow_decimal_int( 25i128 , -1 , 4 ) . unwrap( ) , 390625000i128 ) ;
588+ }
589+
590+ #[ test]
591+ fn test_pow_decimal_float_fallback ( ) {
592+ // Test negative exponent: 4^(-1) = 0.25
593+ // 4 with scale 2 = 400, result should be 25 (0.25 with scale 2)
594+ let result: i128 = pow_decimal_float ( 400i128 , 2 , -1.0 ) . unwrap ( ) ;
595+ assert_eq ! ( result, 25 ) ;
596+
597+ // Test non-integer exponent: 4^0.5 = 2
598+ // 4 with scale 2 = 400, result should be 200 (2.0 with scale 2)
599+ let result: i128 = pow_decimal_float ( 400i128 , 2 , 0.5 ) . unwrap ( ) ;
600+ assert_eq ! ( result, 200 ) ;
601+
602+ // Test 8^(1/3) = 2 (cube root)
603+ // 8 with scale 1 = 80, result should be 20 (2.0 with scale 1)
604+ let result: i128 = pow_decimal_float ( 80i128 , 1 , 1.0 / 3.0 ) . unwrap ( ) ;
605+ assert_eq ! ( result, 20 ) ;
606+
607+ // Test negative base with integer exponent still works
608+ // (-2)^3 = -8
609+ // -2 with scale 1 = -20, result should be -80 (-8.0 with scale 1)
610+ let result: i128 = pow_decimal_float ( -20i128 , 1 , 3.0 ) . unwrap ( ) ;
611+ assert_eq ! ( result, -80 ) ;
612+
613+ // Test positive integer exponent goes through fast path
614+ // 2.5^4 = 39.0625
615+ // 25 with scale 1, result should be 390 (39.0 with scale 1) - truncated
616+ let result: i128 = pow_decimal_float ( 25i128 , 1 , 4.0 ) . unwrap ( ) ;
617+ assert_eq ! ( result, 390 ) ; // Uses integer path
413618
414- assert_eq ! ( pow_decimal_int( 25 , -1 , 4 ) . unwrap( ) , i128 :: from( 390625000 ) ) ;
619+ // Test non-finite exponent returns error
620+ assert ! ( pow_decimal_float( 100i128 , 2 , f64 :: NAN ) . is_err( ) ) ;
621+ assert ! ( pow_decimal_float( 100i128 , 2 , f64 :: INFINITY ) . is_err( ) ) ;
415622 }
416623}
0 commit comments