From 1198b30458da908b46dc9ba300dbe08d700b5b9d Mon Sep 17 00:00:00 2001 From: Ahmed Mezghani Date: Fri, 17 Oct 2025 10:16:50 +0200 Subject: [PATCH 1/2] fix: preserve argument types in max_by/min_by with dictionary inputs --- src/max_min_by.rs | 213 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 212 insertions(+), 1 deletion(-) diff --git a/src/max_min_by.rs b/src/max_min_by.rs index 4c8399d..6c33e09 100644 --- a/src/max_min_by.rs +++ b/src/max_min_by.rs @@ -45,7 +45,9 @@ fn get_min_max_by_result_type( match &input_types[0] { arrow::datatypes::DataType::Dictionary(_, dict_value_type) => { // x add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) + let mut result = vec![dict_value_type.deref().clone()]; + result.extend_from_slice(&input_types[1..]); // Preserve all other argument types + Ok(result) } _ => Ok(input_types.to_vec()), } @@ -207,3 +209,212 @@ impl logical_expr::AggregateUDFImpl for MinByFunction { Some(Box::new(simplify)) } } + +#[cfg(test)] +mod tests { + use datafusion::arrow::array::{ + ArrayRef, Float64Array, Int64Array, RecordBatch, StringArray, UInt64Array, + }; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::datasource::MemTable; + use datafusion::prelude::SessionContext; + use std::any::Any; + use std::sync::Arc; + + #[cfg(test)] + mod tests_max_by { + use crate::max_min_by::max_by_udaf; + use crate::max_min_by::tests::{ + extract_single_float64, extract_single_int64, extract_single_string, test_ctx, + }; + use datafusion::error::Result; + use datafusion::prelude::SessionContext; + + #[tokio::test] + async fn test_max_by_string_int() -> Result<()> { + let df = ctx()? + .sql("SELECT max_by(string, int64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "h"); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_string_float() -> Result<()> { + let df = ctx()? + .sql("SELECT max_by(string, float64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "h"); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_float_string() -> Result<()> { + let df = ctx()? + .sql("SELECT max_by(float64, string) FROM types") + .await?; + assert_eq!(extract_single_float64(df.collect().await?), 8.0); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_int_string() -> Result<()> { + let df = ctx()? + .sql("SELECT max_by(int64, string) FROM types") + .await?; + assert_eq!(extract_single_int64(df.collect().await?), 8); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_dictionary_int() -> Result<()> { + let df = ctx()? + .sql("SELECT max_by(dict_string, int64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "h"); + Ok(()) + } + + fn ctx() -> Result { + let ctx = test_ctx()?; + ctx.register_udaf(max_by_udaf().as_ref().clone()); + Ok(ctx) + } + } + + #[cfg(test)] + mod test_min_by { + use crate::max_min_by::min_by_udaf; + use crate::max_min_by::tests::{ + extract_single_float64, extract_single_int64, extract_single_string, test_ctx, + }; + use datafusion::error::Result; + use datafusion::prelude::SessionContext; + + #[tokio::test] + async fn test_min_by_string_int() -> Result<()> { + let df = ctx()? + .sql("SELECT min_by(string, int64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "a"); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_string_float() -> Result<()> { + let df = ctx()? + .sql("SELECT min_by(string, float64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "a"); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_float_string() -> Result<()> { + let df = ctx()? + .sql("SELECT min_by(float64, string) FROM types") + .await?; + assert_eq!(extract_single_float64(df.collect().await?), 0.5); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_int_string() -> Result<()> { + let df = ctx()? + .sql("SELECT min_by(int64, string) FROM types") + .await?; + assert_eq!(extract_single_int64(df.collect().await?), 1); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_dictionary_int() -> Result<()> { + let df = ctx()? + .sql("SELECT min_by(dict_string, int64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "a"); + Ok(()) + } + + fn ctx() -> Result { + let ctx = test_ctx()?; + ctx.register_udaf(min_by_udaf().as_ref().clone()); + Ok(ctx) + } + } + + pub(super) fn test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("string", DataType::Utf8, false), + Field::new_dictionary("dict_string", DataType::Int32, DataType::Utf8, false), + Field::new("int64", DataType::Int64, false), + Field::new("uint64", DataType::UInt64, false), + Field::new("float64", DataType::Float64, false), + ])) + } + + pub(super) fn test_data(schema: Arc) -> Vec { + use datafusion::arrow::array::DictionaryArray; + use datafusion::arrow::datatypes::Int32Type; + + vec![ + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), + Arc::new( + vec![Some("a"), Some("b"), Some("c"), Some("d")] + .into_iter() + .collect::>(), + ), + Arc::new(Int64Array::from(vec![1, 2, 3, 4])), + Arc::new(UInt64Array::from(vec![1, 2, 3, 4])), + Arc::new(Float64Array::from(vec![0.5, 2.0, 3.0, 4.0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["e", "f", "g", "h"])), + Arc::new( + vec![Some("e"), Some("f"), Some("g"), Some("h")] + .into_iter() + .collect::>(), + ), + Arc::new(Int64Array::from(vec![5, 6, 7, 8])), + Arc::new(UInt64Array::from(vec![5, 6, 7, 8])), + Arc::new(Float64Array::from(vec![5.0, 6.0, 7.0, 8.0])), + ], + ) + .unwrap(), + ] + } + + pub(crate) fn test_ctx() -> datafusion::common::Result { + let schema = test_schema(); + let table = MemTable::try_new(schema.clone(), vec![test_data(schema)])?; + let ctx = SessionContext::new(); + ctx.register_table("types", Arc::new(table))?; + Ok(ctx) + } + + fn downcast(col: &ArrayRef) -> &T { + col.as_any().downcast_ref::().unwrap() + } + + pub(crate) fn extract_single_string(results: Vec) -> String { + let v1 = downcast::(results[0].column(0)); + v1.value(0).to_string() + } + + pub(crate) fn extract_single_int64(results: Vec) -> i64 { + let v1 = downcast::(results[0].column(0)); + v1.value(0) + } + + pub(crate) fn extract_single_float64(results: Vec) -> f64 { + let v1 = downcast::(results[0].column(0)); + v1.value(0) + } +} From df669316e77285b3e513fc3d5a6eb04d44423e41 Mon Sep 17 00:00:00 2001 From: dario curreri Date: Tue, 21 Oct 2025 15:28:52 +0200 Subject: [PATCH 2/2] tests: make tests more readable --- src/max_min_by.rs | 354 ++++++++++++++++++++++++++-------------------- 1 file changed, 202 insertions(+), 152 deletions(-) diff --git a/src/max_min_by.rs b/src/max_min_by.rs index 6c33e09..bfa3754 100644 --- a/src/max_min_by.rs +++ b/src/max_min_by.rs @@ -46,7 +46,8 @@ fn get_min_max_by_result_type( arrow::datatypes::DataType::Dictionary(_, dict_value_type) => { // x add checker, if the value type is complex data type let mut result = vec![dict_value_type.deref().clone()]; - result.extend_from_slice(&input_types[1..]); // Preserve all other argument types + // Preserve all other argument types + result.extend_from_slice(&input_types[1..]); Ok(result) } _ => Ok(input_types.to_vec()), @@ -212,209 +213,258 @@ impl logical_expr::AggregateUDFImpl for MinByFunction { #[cfg(test)] mod tests { - use datafusion::arrow::array::{ - ArrayRef, Float64Array, Int64Array, RecordBatch, StringArray, UInt64Array, - }; - use datafusion::arrow::datatypes::{DataType, Field, Schema}; - use datafusion::datasource::MemTable; - use datafusion::prelude::SessionContext; - use std::any::Any; - use std::sync::Arc; + use super::*; + + use datafusion::arrow::array::ArrayAccessor; + use datafusion::{arrow, datasource, error, prelude}; + use std::sync; + + const TEST_TABLE_NAME: &str = "types"; + const STRING_COLUMN_NAME: &str = "string"; + const DICTIONARY_COLUMN_NAME: &str = "dict_string"; + const INT64_COLUMN_NAME: &str = "int64"; + const FLOAT64_COLUMN_NAME: &str = "float64"; + + const MIN_STRING_VALUE: &str = "a"; + const MID_STRING_VALUE: &str = "b"; + const MAX_STRING_VALUE: &str = "c"; + const MIN_FLOAT_VALUE: f64 = 0.25; + const MID_FLOAT_VALUE: f64 = 0.5; + const MAX_FLOAT_VALUE: f64 = 0.75; + const MIN_INT_VALUE: i64 = -1; + const MID_INT_VALUE: i64 = 0; + const MAX_INT_VALUE: i64 = 1; + const MIN_DICTIONARY_VALUE: &str = "a"; + const MID_DICTIONARY_VALUE: &str = "b"; + const MAX_DICTIONARY_VALUE: &str = "c"; + + fn test_schema() -> sync::Arc { + sync::Arc::new(arrow::datatypes::Schema::new(vec![ + arrow::datatypes::Field::new( + STRING_COLUMN_NAME, + arrow::datatypes::DataType::Utf8, + false, + ), + arrow::datatypes::Field::new_dictionary( + DICTIONARY_COLUMN_NAME, + arrow::datatypes::DataType::Int32, + arrow::datatypes::DataType::Utf8, + false, + ), + arrow::datatypes::Field::new( + INT64_COLUMN_NAME, + arrow::datatypes::DataType::Int64, + false, + ), + arrow::datatypes::Field::new( + FLOAT64_COLUMN_NAME, + arrow::datatypes::DataType::Float64, + false, + ), + ])) + } + + fn test_data( + schema: sync::Arc, + ) -> Vec { + vec![ + arrow::record_batch::RecordBatch::try_new( + schema, + vec![ + sync::Arc::new(arrow::array::StringArray::from(vec![ + MID_STRING_VALUE, + MIN_STRING_VALUE, + MAX_STRING_VALUE, + ])), + sync::Arc::new( + vec![ + Some(MID_DICTIONARY_VALUE), + Some(MIN_DICTIONARY_VALUE), + Some(MAX_DICTIONARY_VALUE), + ] + .into_iter() + .collect::>(), + ), + sync::Arc::new(arrow::array::Int64Array::from(vec![ + MID_INT_VALUE, + MIN_INT_VALUE, + MAX_INT_VALUE, + ])), + sync::Arc::new(arrow::array::Float64Array::from(vec![ + MID_FLOAT_VALUE, + MIN_FLOAT_VALUE, + MAX_FLOAT_VALUE, + ])), + ], + ) + .unwrap(), + ] + } + + fn test_ctx() -> datafusion::common::Result { + let schema = test_schema(); + let data = test_data(schema.clone()); + let table = datasource::MemTable::try_new(schema, vec![data])?; + let ctx = prelude::SessionContext::new(); + ctx.register_table(TEST_TABLE_NAME, sync::Arc::new(table))?; + Ok(ctx) + } + + async fn extract_single_value(df: prelude::DataFrame) -> error::Result + where + A: arrow::array::Array + 'static, + for<'a> &'a A: arrow::array::ArrayAccessor, + for<'a> <&'a A as arrow::array::ArrayAccessor>::Item: Into, + { + let results = df.collect().await?; + let col = results[0].column(0); + let v1 = col.as_any().downcast_ref::().unwrap(); + let value = v1.value(0).into(); + Ok(value) + } #[cfg(test)] - mod tests_max_by { - use crate::max_min_by::max_by_udaf; - use crate::max_min_by::tests::{ - extract_single_float64, extract_single_int64, extract_single_string, test_ctx, - }; - use datafusion::error::Result; - use datafusion::prelude::SessionContext; + mod max_by { + use super::*; #[tokio::test] - async fn test_max_by_string_int() -> Result<()> { - let df = ctx()? - .sql("SELECT max_by(string, int64) FROM types") - .await?; - assert_eq!(extract_single_string(df.collect().await?), "h"); + async fn test_max_by_string_int() -> error::Result<()> { + let query = format!( + "SELECT max_by({}, {}) FROM {}", + STRING_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MAX_STRING_VALUE); Ok(()) } #[tokio::test] - async fn test_max_by_string_float() -> Result<()> { - let df = ctx()? - .sql("SELECT max_by(string, float64) FROM types") - .await?; - assert_eq!(extract_single_string(df.collect().await?), "h"); + async fn test_max_by_string_float() -> error::Result<()> { + let query = format!( + "SELECT max_by({}, {}) FROM {}", + STRING_COLUMN_NAME, FLOAT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MAX_STRING_VALUE); Ok(()) } #[tokio::test] - async fn test_max_by_float_string() -> Result<()> { - let df = ctx()? - .sql("SELECT max_by(float64, string) FROM types") - .await?; - assert_eq!(extract_single_float64(df.collect().await?), 8.0); + async fn test_max_by_float_string() -> error::Result<()> { + let query = format!( + "SELECT max_by({}, {}) FROM {}", + FLOAT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MAX_FLOAT_VALUE); Ok(()) } #[tokio::test] - async fn test_max_by_int_string() -> Result<()> { - let df = ctx()? - .sql("SELECT max_by(int64, string) FROM types") - .await?; - assert_eq!(extract_single_int64(df.collect().await?), 8); + async fn test_max_by_int_string() -> error::Result<()> { + let query = format!( + "SELECT max_by({}, {}) FROM {}", + INT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MAX_INT_VALUE); Ok(()) } #[tokio::test] - async fn test_max_by_dictionary_int() -> Result<()> { - let df = ctx()? - .sql("SELECT max_by(dict_string, int64) FROM types") - .await?; - assert_eq!(extract_single_string(df.collect().await?), "h"); + async fn test_max_by_dictionary_int() -> error::Result<()> { + let query = format!( + "SELECT max_by({}, {}) FROM {}", + DICTIONARY_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MAX_DICTIONARY_VALUE); Ok(()) } - fn ctx() -> Result { + fn ctx() -> error::Result { let ctx = test_ctx()?; - ctx.register_udaf(max_by_udaf().as_ref().clone()); + let max_by_udaf = MaxByFunction::new(); + ctx.register_udaf(max_by_udaf.into()); Ok(ctx) } } #[cfg(test)] - mod test_min_by { - use crate::max_min_by::min_by_udaf; - use crate::max_min_by::tests::{ - extract_single_float64, extract_single_int64, extract_single_string, test_ctx, - }; - use datafusion::error::Result; - use datafusion::prelude::SessionContext; + mod min_by { + + use super::*; #[tokio::test] - async fn test_min_by_string_int() -> Result<()> { - let df = ctx()? - .sql("SELECT min_by(string, int64) FROM types") - .await?; - assert_eq!(extract_single_string(df.collect().await?), "a"); + async fn test_min_by_string_int() -> error::Result<()> { + let query = format!( + "SELECT min_by({}, {}) FROM {}", + STRING_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MIN_STRING_VALUE); Ok(()) } #[tokio::test] - async fn test_min_by_string_float() -> Result<()> { - let df = ctx()? - .sql("SELECT min_by(string, float64) FROM types") - .await?; - assert_eq!(extract_single_string(df.collect().await?), "a"); + async fn test_min_by_string_float() -> error::Result<()> { + let query = format!( + "SELECT min_by({}, {}) FROM {}", + STRING_COLUMN_NAME, FLOAT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MIN_STRING_VALUE); Ok(()) } #[tokio::test] - async fn test_min_by_float_string() -> Result<()> { - let df = ctx()? - .sql("SELECT min_by(float64, string) FROM types") - .await?; - assert_eq!(extract_single_float64(df.collect().await?), 0.5); + async fn test_min_by_float_string() -> error::Result<()> { + let query = format!( + "SELECT min_by({}, {}) FROM {}", + FLOAT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MIN_FLOAT_VALUE); Ok(()) } #[tokio::test] - async fn test_min_by_int_string() -> Result<()> { - let df = ctx()? - .sql("SELECT min_by(int64, string) FROM types") - .await?; - assert_eq!(extract_single_int64(df.collect().await?), 1); + async fn test_min_by_int_string() -> error::Result<()> { + let query = format!( + "SELECT min_by({}, {}) FROM {}", + INT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MIN_INT_VALUE); Ok(()) } #[tokio::test] - async fn test_min_by_dictionary_int() -> Result<()> { - let df = ctx()? - .sql("SELECT min_by(dict_string, int64) FROM types") - .await?; - assert_eq!(extract_single_string(df.collect().await?), "a"); + async fn test_min_by_dictionary_int() -> error::Result<()> { + let query = format!( + "SELECT min_by({}, {}) FROM {}", + DICTIONARY_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MIN_DICTIONARY_VALUE); Ok(()) } - fn ctx() -> Result { + fn ctx() -> error::Result { let ctx = test_ctx()?; - ctx.register_udaf(min_by_udaf().as_ref().clone()); + let min_by_udaf = MinByFunction::new(); + ctx.register_udaf(min_by_udaf.into()); Ok(ctx) } } - - pub(super) fn test_schema() -> Arc { - Arc::new(Schema::new(vec![ - Field::new("string", DataType::Utf8, false), - Field::new_dictionary("dict_string", DataType::Int32, DataType::Utf8, false), - Field::new("int64", DataType::Int64, false), - Field::new("uint64", DataType::UInt64, false), - Field::new("float64", DataType::Float64, false), - ])) - } - - pub(super) fn test_data(schema: Arc) -> Vec { - use datafusion::arrow::array::DictionaryArray; - use datafusion::arrow::datatypes::Int32Type; - - vec![ - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), - Arc::new( - vec![Some("a"), Some("b"), Some("c"), Some("d")] - .into_iter() - .collect::>(), - ), - Arc::new(Int64Array::from(vec![1, 2, 3, 4])), - Arc::new(UInt64Array::from(vec![1, 2, 3, 4])), - Arc::new(Float64Array::from(vec![0.5, 2.0, 3.0, 4.0])), - ], - ) - .unwrap(), - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(StringArray::from(vec!["e", "f", "g", "h"])), - Arc::new( - vec![Some("e"), Some("f"), Some("g"), Some("h")] - .into_iter() - .collect::>(), - ), - Arc::new(Int64Array::from(vec![5, 6, 7, 8])), - Arc::new(UInt64Array::from(vec![5, 6, 7, 8])), - Arc::new(Float64Array::from(vec![5.0, 6.0, 7.0, 8.0])), - ], - ) - .unwrap(), - ] - } - - pub(crate) fn test_ctx() -> datafusion::common::Result { - let schema = test_schema(); - let table = MemTable::try_new(schema.clone(), vec![test_data(schema)])?; - let ctx = SessionContext::new(); - ctx.register_table("types", Arc::new(table))?; - Ok(ctx) - } - - fn downcast(col: &ArrayRef) -> &T { - col.as_any().downcast_ref::().unwrap() - } - - pub(crate) fn extract_single_string(results: Vec) -> String { - let v1 = downcast::(results[0].column(0)); - v1.value(0).to_string() - } - - pub(crate) fn extract_single_int64(results: Vec) -> i64 { - let v1 = downcast::(results[0].column(0)); - v1.value(0) - } - - pub(crate) fn extract_single_float64(results: Vec) -> f64 { - let v1 = downcast::(results[0].column(0)); - v1.value(0) - } }