diff --git a/src/max_min_by.rs b/src/max_min_by.rs index 4c8399d..6c33e09 100644 --- a/src/max_min_by.rs +++ b/src/max_min_by.rs @@ -45,7 +45,9 @@ fn get_min_max_by_result_type( match &input_types[0] { arrow::datatypes::DataType::Dictionary(_, dict_value_type) => { // x add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) + let mut result = vec![dict_value_type.deref().clone()]; + result.extend_from_slice(&input_types[1..]); // Preserve all other argument types + Ok(result) } _ => Ok(input_types.to_vec()), } @@ -207,3 +209,212 @@ impl logical_expr::AggregateUDFImpl for MinByFunction { Some(Box::new(simplify)) } } + +#[cfg(test)] +mod tests { + use datafusion::arrow::array::{ + ArrayRef, Float64Array, Int64Array, RecordBatch, StringArray, UInt64Array, + }; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::datasource::MemTable; + use datafusion::prelude::SessionContext; + use std::any::Any; + use std::sync::Arc; + + #[cfg(test)] + mod tests_max_by { + use crate::max_min_by::max_by_udaf; + use crate::max_min_by::tests::{ + extract_single_float64, extract_single_int64, extract_single_string, test_ctx, + }; + use datafusion::error::Result; + use datafusion::prelude::SessionContext; + + #[tokio::test] + async fn test_max_by_string_int() -> Result<()> { + let df = ctx()? + .sql("SELECT max_by(string, int64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "h"); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_string_float() -> Result<()> { + let df = ctx()? + .sql("SELECT max_by(string, float64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "h"); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_float_string() -> Result<()> { + let df = ctx()? + .sql("SELECT max_by(float64, string) FROM types") + .await?; + assert_eq!(extract_single_float64(df.collect().await?), 8.0); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_int_string() -> Result<()> { + let df = ctx()? + .sql("SELECT max_by(int64, string) FROM types") + .await?; + assert_eq!(extract_single_int64(df.collect().await?), 8); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_dictionary_int() -> Result<()> { + let df = ctx()? + .sql("SELECT max_by(dict_string, int64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "h"); + Ok(()) + } + + fn ctx() -> Result { + let ctx = test_ctx()?; + ctx.register_udaf(max_by_udaf().as_ref().clone()); + Ok(ctx) + } + } + + #[cfg(test)] + mod test_min_by { + use crate::max_min_by::min_by_udaf; + use crate::max_min_by::tests::{ + extract_single_float64, extract_single_int64, extract_single_string, test_ctx, + }; + use datafusion::error::Result; + use datafusion::prelude::SessionContext; + + #[tokio::test] + async fn test_min_by_string_int() -> Result<()> { + let df = ctx()? + .sql("SELECT min_by(string, int64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "a"); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_string_float() -> Result<()> { + let df = ctx()? + .sql("SELECT min_by(string, float64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "a"); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_float_string() -> Result<()> { + let df = ctx()? + .sql("SELECT min_by(float64, string) FROM types") + .await?; + assert_eq!(extract_single_float64(df.collect().await?), 0.5); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_int_string() -> Result<()> { + let df = ctx()? + .sql("SELECT min_by(int64, string) FROM types") + .await?; + assert_eq!(extract_single_int64(df.collect().await?), 1); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_dictionary_int() -> Result<()> { + let df = ctx()? + .sql("SELECT min_by(dict_string, int64) FROM types") + .await?; + assert_eq!(extract_single_string(df.collect().await?), "a"); + Ok(()) + } + + fn ctx() -> Result { + let ctx = test_ctx()?; + ctx.register_udaf(min_by_udaf().as_ref().clone()); + Ok(ctx) + } + } + + pub(super) fn test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("string", DataType::Utf8, false), + Field::new_dictionary("dict_string", DataType::Int32, DataType::Utf8, false), + Field::new("int64", DataType::Int64, false), + Field::new("uint64", DataType::UInt64, false), + Field::new("float64", DataType::Float64, false), + ])) + } + + pub(super) fn test_data(schema: Arc) -> Vec { + use datafusion::arrow::array::DictionaryArray; + use datafusion::arrow::datatypes::Int32Type; + + vec![ + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), + Arc::new( + vec![Some("a"), Some("b"), Some("c"), Some("d")] + .into_iter() + .collect::>(), + ), + Arc::new(Int64Array::from(vec![1, 2, 3, 4])), + Arc::new(UInt64Array::from(vec![1, 2, 3, 4])), + Arc::new(Float64Array::from(vec![0.5, 2.0, 3.0, 4.0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["e", "f", "g", "h"])), + Arc::new( + vec![Some("e"), Some("f"), Some("g"), Some("h")] + .into_iter() + .collect::>(), + ), + Arc::new(Int64Array::from(vec![5, 6, 7, 8])), + Arc::new(UInt64Array::from(vec![5, 6, 7, 8])), + Arc::new(Float64Array::from(vec![5.0, 6.0, 7.0, 8.0])), + ], + ) + .unwrap(), + ] + } + + pub(crate) fn test_ctx() -> datafusion::common::Result { + let schema = test_schema(); + let table = MemTable::try_new(schema.clone(), vec![test_data(schema)])?; + let ctx = SessionContext::new(); + ctx.register_table("types", Arc::new(table))?; + Ok(ctx) + } + + fn downcast(col: &ArrayRef) -> &T { + col.as_any().downcast_ref::().unwrap() + } + + pub(crate) fn extract_single_string(results: Vec) -> String { + let v1 = downcast::(results[0].column(0)); + v1.value(0).to_string() + } + + pub(crate) fn extract_single_int64(results: Vec) -> i64 { + let v1 = downcast::(results[0].column(0)); + v1.value(0) + } + + pub(crate) fn extract_single_float64(results: Vec) -> f64 { + let v1 = downcast::(results[0].column(0)); + v1.value(0) + } +}