@@ -21,9 +21,7 @@ use std::any::Any;
2121
2222use super :: power:: PowerFunc ;
2323
24- use crate :: utils:: {
25- calculate_binary_math, decimal32_to_i32, decimal64_to_i64, decimal128_to_i128,
26- } ;
24+ use crate :: utils:: calculate_binary_math;
2725use arrow:: array:: { Array , ArrayRef } ;
2826use arrow:: datatypes:: {
2927 DataType , Decimal32Type , Decimal64Type , Decimal128Type , Decimal256Type , Float16Type ,
@@ -44,7 +42,7 @@ use datafusion_expr::{
4442} ;
4543use datafusion_expr:: { ScalarUDFImpl , Signature , Volatility } ;
4644use datafusion_macros:: user_doc;
47- use num_traits:: Float ;
45+ use num_traits:: { Float , ToPrimitive } ;
4846
4947#[ user_doc(
5048 doc_section( label = "Math Functions" ) ,
@@ -104,91 +102,70 @@ impl LogFunc {
104102 }
105103}
106104
107- /// Binary function to calculate logarithm of Decimal32 `value` using `base` base
108- /// Returns error if base is invalid
109- fn log_decimal32 ( value : i32 , scale : i8 , base : f64 ) -> Result < f64 , ArrowError > {
110- if !base. is_finite ( ) || base. trunc ( ) != base {
111- return Err ( ArrowError :: ComputeError ( format ! (
112- "Log cannot use non-integer base: {base}"
113- ) ) ) ;
114- }
115- if ( base as u32 ) < 2 {
116- return Err ( ArrowError :: ComputeError ( format ! (
117- "Log base must be greater than 1: {base}"
118- ) ) ) ;
119- }
120-
121- let unscaled_value = decimal32_to_i32 ( value, scale) ?;
122- if unscaled_value > 0 {
123- let log_value: u32 = unscaled_value. ilog ( base as i32 ) ;
124- Ok ( log_value as f64 )
125- } else {
126- // Reflect f64::log behaviour
127- Ok ( f64:: NAN )
128- }
105+ /// Checks if the base is valid for the efficient integer logarithm algorithm.
106+ #[ inline]
107+ fn is_valid_integer_base ( base : f64 ) -> bool {
108+ base. trunc ( ) == base && base >= 2.0 && base <= u32:: MAX as f64
129109}
130110
131- /// Binary function to calculate logarithm of Decimal64 `value` using `base` base
132- /// Returns error if base is invalid
133- fn log_decimal64 ( value : i64 , scale : i8 , base : f64 ) -> Result < f64 , ArrowError > {
134- if !base. is_finite ( ) || base. trunc ( ) != base {
135- return Err ( ArrowError :: ComputeError ( format ! (
136- "Log cannot use non-integer base: {base}"
137- ) ) ) ;
138- }
139- if ( base as u32 ) < 2 {
140- return Err ( ArrowError :: ComputeError ( format ! (
141- "Log base must be greater than 1: {base}"
142- ) ) ) ;
111+ /// Generic function to calculate logarithm of a decimal value using the given base.
112+ ///
113+ /// For integer bases >= 2 with non-negative scale, uses the efficient integer `ilog` algorithm.
114+ /// For all other cases (non-integer bases, negative bases, non-finite bases),
115+ /// falls back to f64 computation which naturally returns NaN for invalid inputs,
116+ /// matching the behavior of `f64::log`.
117+ fn log_decimal < T > ( value : T , scale : i8 , base : f64 ) -> Result < f64 , ArrowError >
118+ where
119+ T : ToPrimitive + Copy ,
120+ {
121+ // For integer bases >= 2 and non-negative scale, try the efficient integer algorithm
122+ if is_valid_integer_base ( base)
123+ && scale >= 0
124+ && let Some ( unscaled) = unscale_decimal_value ( value, scale)
125+ {
126+ return if unscaled > 0 {
127+ Ok ( unscaled. ilog ( base as u128 ) as f64 )
128+ } else {
129+ Ok ( f64:: NAN )
130+ } ;
143131 }
144132
145- let unscaled_value = decimal64_to_i64 ( value, scale) ?;
146- if unscaled_value > 0 {
147- let log_value: u32 = unscaled_value. ilog ( base as i64 ) ;
148- Ok ( log_value as f64 )
149- } else {
150- // Reflect f64::log behaviour
151- Ok ( f64:: NAN )
152- }
133+ // Fallback to f64 computation for non-integer bases, negative scale, etc.
134+ // This naturally returns NaN for invalid inputs (base <= 1, non-finite, value <= 0)
135+ decimal_to_f64 ( value, scale) . map ( |v| v. log ( base) )
153136}
154137
155- /// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
156- /// Returns error if base is invalid
157- fn log_decimal128 ( value : i128 , scale : i8 , base : f64 ) -> Result < f64 , ArrowError > {
158- if !base. is_finite ( ) || base. trunc ( ) != base {
159- return Err ( ArrowError :: ComputeError ( format ! (
160- "Log cannot use non-integer base: {base}"
161- ) ) ) ;
162- }
163- if ( base as u32 ) < 2 {
164- return Err ( ArrowError :: ComputeError ( format ! (
165- "Log base must be greater than 1: {base}"
166- ) ) ) ;
167- }
168-
169- if value <= 0 {
170- // Reflect f64::log behaviour
171- return Ok ( f64:: NAN ) ;
172- }
138+ /// Unscale a decimal value by dividing by 10^scale, returning the result as u128.
139+ /// Returns None if the value is negative or the conversion fails.
140+ #[ inline]
141+ fn unscale_decimal_value < T : ToPrimitive > ( value : T , scale : i8 ) -> Option < u128 > {
142+ let value_u128 = value. to_u128 ( ) ?;
143+ let divisor = 10u128 . checked_pow ( scale as u32 ) ?;
144+ Some ( value_u128 / divisor)
145+ }
173146
174- if scale < 0 {
175- let actual_value = ( value as f64 ) * 10.0_f64 . powi ( - ( scale as i32 ) ) ;
176- Ok ( actual_value . log ( base ) )
177- } else {
178- let unscaled_value = decimal128_to_i128 ( value , scale ) ? ;
179- let log_value : u32 = unscaled_value . ilog ( base as i128 ) ;
180- Ok ( log_value as f64 )
181- }
147+ /// Convert a scaled decimal value to f64.
148+ # [ inline ]
149+ fn decimal_to_f64 < T : ToPrimitive > ( value : T , scale : i8 ) -> Result < f64 , ArrowError > {
150+ let value_f64 = value
151+ . to_f64 ( )
152+ . ok_or_else ( || ArrowError :: ComputeError ( "Cannot convert value to f64" . to_string ( ) ) ) ? ;
153+ let scale_factor = 10f64 . powi ( scale as i32 ) ;
154+ Ok ( value_f64 / scale_factor )
182155}
183156
184- /// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
185- /// Returns error if base is invalid or if value is out of bounds of Decimal128
186157fn log_decimal256 ( value : i256 , scale : i8 , base : f64 ) -> Result < f64 , ArrowError > {
158+ // Try to convert to i128 for the optimized path
187159 match value. to_i128 ( ) {
188- Some ( value) => log_decimal128 ( value, scale, base) ,
189- None => Err ( ArrowError :: NotYetImplemented ( format ! (
190- "Log of Decimal256 larger than Decimal128 is not yet supported: {value}"
191- ) ) ) ,
160+ Some ( v) => log_decimal ( v, scale, base) ,
161+ None => {
162+ // For very large Decimal256 values, use f64 computation
163+ let value_f64 = value. to_f64 ( ) . ok_or_else ( || {
164+ ArrowError :: ComputeError ( format ! ( "Cannot convert {value} to f64" ) )
165+ } ) ?;
166+ let scale_factor = 10f64 . powi ( scale as i32 ) ;
167+ Ok ( ( value_f64 / scale_factor) . log ( base) )
168+ }
192169 }
193170}
194171
@@ -282,21 +259,21 @@ impl ScalarUDFImpl for LogFunc {
282259 calculate_binary_math :: < Decimal32Type , Float64Type , Float64Type , _ > (
283260 & value,
284261 & base,
285- |value, base| log_decimal32 ( value, * scale, base) ,
262+ |value, base| log_decimal ( value, * scale, base) ,
286263 ) ?
287264 }
288265 DataType :: Decimal64 ( _, scale) => {
289266 calculate_binary_math :: < Decimal64Type , Float64Type , Float64Type , _ > (
290267 & value,
291268 & base,
292- |value, base| log_decimal64 ( value, * scale, base) ,
269+ |value, base| log_decimal ( value, * scale, base) ,
293270 ) ?
294271 }
295272 DataType :: Decimal128 ( _, scale) => {
296273 calculate_binary_math :: < Decimal128Type , Float64Type , Float64Type , _ > (
297274 & value,
298275 & base,
299- |value, base| log_decimal128 ( value, * scale, base) ,
276+ |value, base| log_decimal ( value, * scale, base) ,
300277 ) ?
301278 }
302279 DataType :: Decimal256 ( _, scale) => {
@@ -433,7 +410,7 @@ mod tests {
433410 let value = 10_i128 . pow ( 35 ) ;
434411 assert_eq ! ( ( value as f64 ) . log2( ) , 116.26748332105768 ) ;
435412 assert_eq ! (
436- log_decimal128 ( value, 0 , 2.0 ) . unwrap( ) ,
413+ log_decimal ( value, 0 , 2.0 ) . unwrap( ) ,
437414 // TODO: see we're losing our decimal points compared to above
438415 // https://github.com/apache/datafusion/issues/18524
439416 116.0
@@ -1151,7 +1128,8 @@ mod tests {
11511128 }
11521129
11531130 #[ test]
1154- fn test_log_decimal128_wrong_base ( ) {
1131+ fn test_log_decimal128_invalid_base ( ) {
1132+ // Invalid base (-2.0) should return NaN, matching f64::log behavior
11551133 let arg_fields = vec ! [
11561134 Field :: new( "b" , DataType :: Float64 , false ) . into( ) ,
11571135 Field :: new( "x" , DataType :: Decimal128 ( 38 , 0 ) , false ) . into( ) ,
@@ -1166,16 +1144,26 @@ mod tests {
11661144 return_field : Field :: new ( "f" , DataType :: Float64 , true ) . into ( ) ,
11671145 config_options : Arc :: new ( ConfigOptions :: default ( ) ) ,
11681146 } ;
1169- let result = LogFunc :: new ( ) . invoke_with_args ( args) ;
1170- assert ! ( result. is_err( ) ) ;
1171- assert_eq ! (
1172- "Arrow error: Compute error: Log base must be greater than 1: -2" ,
1173- result. unwrap_err( ) . to_string( ) . lines( ) . next( ) . unwrap( )
1174- ) ;
1147+ let result = LogFunc :: new ( )
1148+ . invoke_with_args ( args)
1149+ . expect ( "should not error on invalid base" ) ;
1150+
1151+ match result {
1152+ ColumnarValue :: Array ( arr) => {
1153+ let floats = as_float64_array ( & arr)
1154+ . expect ( "failed to convert result to a Float64Array" ) ;
1155+ assert_eq ! ( floats. len( ) , 1 ) ;
1156+ assert ! ( floats. value( 0 ) . is_nan( ) ) ;
1157+ }
1158+ ColumnarValue :: Scalar ( _) => {
1159+ panic ! ( "Expected an array value" )
1160+ }
1161+ }
11751162 }
11761163
11771164 #[ test]
1178- fn test_log_decimal256_error ( ) {
1165+ fn test_log_decimal256_large ( ) {
1166+ // Large Decimal256 values that don't fit in i128 now use f64 fallback
11791167 let arg_field = Field :: new ( "a" , DataType :: Decimal256 ( 38 , 0 ) , false ) . into ( ) ;
11801168 let args = ScalarFunctionArgs {
11811169 args : vec ! [
@@ -1189,11 +1177,26 @@ mod tests {
11891177 return_field : Field :: new ( "f" , DataType :: Float64 , true ) . into ( ) ,
11901178 config_options : Arc :: new ( ConfigOptions :: default ( ) ) ,
11911179 } ;
1192- let result = LogFunc :: new ( ) . invoke_with_args ( args) ;
1193- assert ! ( result. is_err( ) ) ;
1194- assert_eq ! (
1195- result. unwrap_err( ) . to_string( ) . lines( ) . next( ) . unwrap( ) ,
1196- "Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727"
1197- ) ;
1180+ let result = LogFunc :: new ( )
1181+ . invoke_with_args ( args)
1182+ . expect ( "should handle large Decimal256 via f64 fallback" ) ;
1183+
1184+ match result {
1185+ ColumnarValue :: Array ( arr) => {
1186+ let floats = as_float64_array ( & arr)
1187+ . expect ( "failed to convert result to a Float64Array" ) ;
1188+ assert_eq ! ( floats. len( ) , 1 ) ;
1189+ // The f64 fallback may lose some precision for very large numbers,
1190+ // but we verify we get a reasonable positive result (not NaN/infinity)
1191+ let log_result = floats. value ( 0 ) ;
1192+ assert ! (
1193+ log_result. is_finite( ) && log_result > 0.0 ,
1194+ "Expected positive finite log result, got {log_result}"
1195+ ) ;
1196+ }
1197+ ColumnarValue :: Scalar ( _) => {
1198+ panic ! ( "Expected an array value" )
1199+ }
1200+ }
11981201 }
11991202}
0 commit comments