diff --git a/src/max_min_by.rs b/src/max_min_by.rs index 4c8399d..bfa3754 100644 --- a/src/max_min_by.rs +++ b/src/max_min_by.rs @@ -45,7 +45,10 @@ fn get_min_max_by_result_type( match &input_types[0] { arrow::datatypes::DataType::Dictionary(_, dict_value_type) => { // x add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) + let mut result = vec![dict_value_type.deref().clone()]; + // Preserve all other argument types + result.extend_from_slice(&input_types[1..]); + Ok(result) } _ => Ok(input_types.to_vec()), } @@ -207,3 +210,261 @@ impl logical_expr::AggregateUDFImpl for MinByFunction { Some(Box::new(simplify)) } } + +#[cfg(test)] +mod tests { + use super::*; + + use datafusion::arrow::array::ArrayAccessor; + use datafusion::{arrow, datasource, error, prelude}; + use std::sync; + + const TEST_TABLE_NAME: &str = "types"; + const STRING_COLUMN_NAME: &str = "string"; + const DICTIONARY_COLUMN_NAME: &str = "dict_string"; + const INT64_COLUMN_NAME: &str = "int64"; + const FLOAT64_COLUMN_NAME: &str = "float64"; + + const MIN_STRING_VALUE: &str = "a"; + const MID_STRING_VALUE: &str = "b"; + const MAX_STRING_VALUE: &str = "c"; + const MIN_FLOAT_VALUE: f64 = 0.25; + const MID_FLOAT_VALUE: f64 = 0.5; + const MAX_FLOAT_VALUE: f64 = 0.75; + const MIN_INT_VALUE: i64 = -1; + const MID_INT_VALUE: i64 = 0; + const MAX_INT_VALUE: i64 = 1; + const MIN_DICTIONARY_VALUE: &str = "a"; + const MID_DICTIONARY_VALUE: &str = "b"; + const MAX_DICTIONARY_VALUE: &str = "c"; + + fn test_schema() -> sync::Arc { + sync::Arc::new(arrow::datatypes::Schema::new(vec![ + arrow::datatypes::Field::new( + STRING_COLUMN_NAME, + arrow::datatypes::DataType::Utf8, + false, + ), + arrow::datatypes::Field::new_dictionary( + DICTIONARY_COLUMN_NAME, + arrow::datatypes::DataType::Int32, + arrow::datatypes::DataType::Utf8, + false, + ), + arrow::datatypes::Field::new( + INT64_COLUMN_NAME, + arrow::datatypes::DataType::Int64, + false, + ), + arrow::datatypes::Field::new( + FLOAT64_COLUMN_NAME, + arrow::datatypes::DataType::Float64, + false, + ), + ])) + } + + fn test_data( + schema: sync::Arc, + ) -> Vec { + vec![ + arrow::record_batch::RecordBatch::try_new( + schema, + vec![ + sync::Arc::new(arrow::array::StringArray::from(vec![ + MID_STRING_VALUE, + MIN_STRING_VALUE, + MAX_STRING_VALUE, + ])), + sync::Arc::new( + vec![ + Some(MID_DICTIONARY_VALUE), + Some(MIN_DICTIONARY_VALUE), + Some(MAX_DICTIONARY_VALUE), + ] + .into_iter() + .collect::>(), + ), + sync::Arc::new(arrow::array::Int64Array::from(vec![ + MID_INT_VALUE, + MIN_INT_VALUE, + MAX_INT_VALUE, + ])), + sync::Arc::new(arrow::array::Float64Array::from(vec![ + MID_FLOAT_VALUE, + MIN_FLOAT_VALUE, + MAX_FLOAT_VALUE, + ])), + ], + ) + .unwrap(), + ] + } + + fn test_ctx() -> datafusion::common::Result { + let schema = test_schema(); + let data = test_data(schema.clone()); + let table = datasource::MemTable::try_new(schema, vec![data])?; + let ctx = prelude::SessionContext::new(); + ctx.register_table(TEST_TABLE_NAME, sync::Arc::new(table))?; + Ok(ctx) + } + + async fn extract_single_value(df: prelude::DataFrame) -> error::Result + where + A: arrow::array::Array + 'static, + for<'a> &'a A: arrow::array::ArrayAccessor, + for<'a> <&'a A as arrow::array::ArrayAccessor>::Item: Into, + { + let results = df.collect().await?; + let col = results[0].column(0); + let v1 = col.as_any().downcast_ref::().unwrap(); + let value = v1.value(0).into(); + Ok(value) + } + + #[cfg(test)] + mod max_by { + use super::*; + + #[tokio::test] + async fn test_max_by_string_int() -> error::Result<()> { + let query = format!( + "SELECT max_by({}, {}) FROM {}", + STRING_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MAX_STRING_VALUE); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_string_float() -> error::Result<()> { + let query = format!( + "SELECT max_by({}, {}) FROM {}", + STRING_COLUMN_NAME, FLOAT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MAX_STRING_VALUE); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_float_string() -> error::Result<()> { + let query = format!( + "SELECT max_by({}, {}) FROM {}", + FLOAT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MAX_FLOAT_VALUE); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_int_string() -> error::Result<()> { + let query = format!( + "SELECT max_by({}, {}) FROM {}", + INT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MAX_INT_VALUE); + Ok(()) + } + + #[tokio::test] + async fn test_max_by_dictionary_int() -> error::Result<()> { + let query = format!( + "SELECT max_by({}, {}) FROM {}", + DICTIONARY_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MAX_DICTIONARY_VALUE); + Ok(()) + } + + fn ctx() -> error::Result { + let ctx = test_ctx()?; + let max_by_udaf = MaxByFunction::new(); + ctx.register_udaf(max_by_udaf.into()); + Ok(ctx) + } + } + + #[cfg(test)] + mod min_by { + + use super::*; + + #[tokio::test] + async fn test_min_by_string_int() -> error::Result<()> { + let query = format!( + "SELECT min_by({}, {}) FROM {}", + STRING_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MIN_STRING_VALUE); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_string_float() -> error::Result<()> { + let query = format!( + "SELECT min_by({}, {}) FROM {}", + STRING_COLUMN_NAME, FLOAT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MIN_STRING_VALUE); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_float_string() -> error::Result<()> { + let query = format!( + "SELECT min_by({}, {}) FROM {}", + FLOAT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MIN_FLOAT_VALUE); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_int_string() -> error::Result<()> { + let query = format!( + "SELECT min_by({}, {}) FROM {}", + INT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MIN_INT_VALUE); + Ok(()) + } + + #[tokio::test] + async fn test_min_by_dictionary_int() -> error::Result<()> { + let query = format!( + "SELECT min_by({}, {}) FROM {}", + DICTIONARY_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME + ); + let df = ctx()?.sql(&query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, MIN_DICTIONARY_VALUE); + Ok(()) + } + + fn ctx() -> error::Result { + let ctx = test_ctx()?; + let min_by_udaf = MinByFunction::new(); + ctx.register_udaf(min_by_udaf.into()); + Ok(ctx) + } + } +}