diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index db297f297..1d413358c 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use datafusion::{common::Result, logical_expr::ScalarFunctionImplementation}; use datafusion_ext_commons::df_unimplemented_err; @@ -39,51 +39,57 @@ pub fn create_auron_ext_function( name: &str, spark_partition_id: usize, ) -> Result { + macro_rules! cache { + ($func:path) => {{ + static CELL: OnceLock = OnceLock::new(); + CELL.get_or_init(|| Arc::new($func)).clone() + }}; + } // auron ext functions, if used for spark should be start with 'Spark_', // if used for flink should be start with 'Flink_', // same to other engines. Ok(match name { "Placeholder" => Arc::new(|_| panic!("placeholder() should never be called")), - "Spark_NullIf" => Arc::new(spark_null_if::spark_null_if), - "Spark_NullIfZero" => Arc::new(spark_null_if::spark_null_if_zero), - "Spark_UnscaledValue" => Arc::new(spark_unscaled_value::spark_unscaled_value), - "Spark_MakeDecimal" => Arc::new(spark_make_decimal::spark_make_decimal), - "Spark_CheckOverflow" => Arc::new(spark_check_overflow::spark_check_overflow), - "Spark_Murmur3Hash" => Arc::new(spark_hash::spark_murmur3_hash), - "Spark_XxHash64" => Arc::new(spark_hash::spark_xxhash64), - "Spark_Sha224" => Arc::new(spark_crypto::spark_sha224), - "Spark_Sha256" => Arc::new(spark_crypto::spark_sha256), - "Spark_Sha384" => Arc::new(spark_crypto::spark_sha384), - "Spark_Sha512" => Arc::new(spark_crypto::spark_sha512), - "Spark_MD5" => Arc::new(spark_crypto::spark_md5), - "Spark_GetJsonObject" => Arc::new(spark_get_json_object::spark_get_json_object), + "Spark_NullIf" => cache!(spark_null_if::spark_null_if), + "Spark_NullIfZero" => cache!(spark_null_if::spark_null_if_zero), + "Spark_UnscaledValue" => cache!(spark_unscaled_value::spark_unscaled_value), + "Spark_MakeDecimal" => cache!(spark_make_decimal::spark_make_decimal), + "Spark_CheckOverflow" => cache!(spark_check_overflow::spark_check_overflow), + "Spark_Murmur3Hash" => cache!(spark_hash::spark_murmur3_hash), + "Spark_XxHash64" => cache!(spark_hash::spark_xxhash64), + "Spark_Sha224" => cache!(spark_crypto::spark_sha224), + "Spark_Sha256" => cache!(spark_crypto::spark_sha256), + "Spark_Sha384" => cache!(spark_crypto::spark_sha384), + "Spark_Sha512" => cache!(spark_crypto::spark_sha512), + "Spark_MD5" => cache!(spark_crypto::spark_md5), + "Spark_GetJsonObject" => cache!(spark_get_json_object::spark_get_json_object), "Spark_GetParsedJsonObject" => { - Arc::new(spark_get_json_object::spark_get_parsed_json_object) + cache!(spark_get_json_object::spark_get_parsed_json_object) } - "Spark_ParseJson" => Arc::new(spark_get_json_object::spark_parse_json), - "Spark_MakeArray" => Arc::new(spark_make_array::array), - "Spark_StringSpace" => Arc::new(spark_strings::string_space), - "Spark_StringRepeat" => Arc::new(spark_strings::string_repeat), - "Spark_StringSplit" => Arc::new(spark_strings::string_split), - "Spark_StringConcat" => Arc::new(spark_strings::string_concat), - "Spark_StringConcatWs" => Arc::new(spark_strings::string_concat_ws), - "Spark_StringLower" => Arc::new(spark_strings::string_lower), - "Spark_StringUpper" => Arc::new(spark_strings::string_upper), - "Spark_InitCap" => Arc::new(spark_initcap::string_initcap), - "Spark_Year" => Arc::new(spark_dates::spark_year), - "Spark_Month" => Arc::new(spark_dates::spark_month), - "Spark_Day" => Arc::new(spark_dates::spark_day), - "Spark_Quarter" => Arc::new(spark_dates::spark_quarter), - "Spark_Hour" => Arc::new(spark_dates::spark_hour), - "Spark_Minute" => Arc::new(spark_dates::spark_minute), - "Spark_Second" => Arc::new(spark_dates::spark_second), - "Spark_BrickhouseArrayUnion" => Arc::new(brickhouse::array_union::array_union), - "Spark_Round" => Arc::new(spark_round::spark_round), - "Spark_BRound" => Arc::new(spark_bround::spark_bround), + "Spark_ParseJson" => cache!(spark_get_json_object::spark_parse_json), + "Spark_MakeArray" => cache!(spark_make_array::array), + "Spark_StringSpace" => cache!(spark_strings::string_space), + "Spark_StringRepeat" => cache!(spark_strings::string_repeat), + "Spark_StringSplit" => cache!(spark_strings::string_split), + "Spark_StringConcat" => cache!(spark_strings::string_concat), + "Spark_StringConcatWs" => cache!(spark_strings::string_concat_ws), + "Spark_StringLower" => cache!(spark_strings::string_lower), + "Spark_StringUpper" => cache!(spark_strings::string_upper), + "Spark_InitCap" => cache!(spark_initcap::string_initcap), + "Spark_Year" => cache!(spark_dates::spark_year), + "Spark_Month" => cache!(spark_dates::spark_month), + "Spark_Day" => cache!(spark_dates::spark_day), + "Spark_Quarter" => cache!(spark_dates::spark_quarter), + "Spark_Hour" => cache!(spark_dates::spark_hour), + "Spark_Minute" => cache!(spark_dates::spark_minute), + "Spark_Second" => cache!(spark_dates::spark_second), + "Spark_BrickhouseArrayUnion" => cache!(brickhouse::array_union::array_union), + "Spark_Round" => cache!(spark_round::spark_round), + "Spark_BRound" => cache!(spark_bround::spark_bround), "Spark_NormalizeNanAndZero" => { - Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero) + cache!(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero) } - "Spark_IsNaN" => Arc::new(spark_isnan::spark_isnan), + "Spark_IsNaN" => cache!(spark_isnan::spark_isnan), _ => df_unimplemented_err!("spark ext function not implemented: {name}")?, }) }