Skip to content

Commit 1198b30

Browse files
committed
fix: preserve argument types in max_by/min_by with dictionary inputs
1 parent f969b49 commit 1198b30

File tree

1 file changed

+212
-1
lines changed

1 file changed

+212
-1
lines changed

src/max_min_by.rs

Lines changed: 212 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ fn get_min_max_by_result_type(
4545
match &input_types[0] {
4646
arrow::datatypes::DataType::Dictionary(_, dict_value_type) => {
4747
// x add checker, if the value type is complex data type
48-
Ok(vec![dict_value_type.deref().clone()])
48+
let mut result = vec![dict_value_type.deref().clone()];
49+
result.extend_from_slice(&input_types[1..]); // Preserve all other argument types
50+
Ok(result)
4951
}
5052
_ => Ok(input_types.to_vec()),
5153
}
@@ -207,3 +209,212 @@ impl logical_expr::AggregateUDFImpl for MinByFunction {
207209
Some(Box::new(simplify))
208210
}
209211
}
212+
213+
#[cfg(test)]
214+
mod tests {
215+
use datafusion::arrow::array::{
216+
ArrayRef, Float64Array, Int64Array, RecordBatch, StringArray, UInt64Array,
217+
};
218+
use datafusion::arrow::datatypes::{DataType, Field, Schema};
219+
use datafusion::datasource::MemTable;
220+
use datafusion::prelude::SessionContext;
221+
use std::any::Any;
222+
use std::sync::Arc;
223+
224+
#[cfg(test)]
225+
mod tests_max_by {
226+
use crate::max_min_by::max_by_udaf;
227+
use crate::max_min_by::tests::{
228+
extract_single_float64, extract_single_int64, extract_single_string, test_ctx,
229+
};
230+
use datafusion::error::Result;
231+
use datafusion::prelude::SessionContext;
232+
233+
#[tokio::test]
234+
async fn test_max_by_string_int() -> Result<()> {
235+
let df = ctx()?
236+
.sql("SELECT max_by(string, int64) FROM types")
237+
.await?;
238+
assert_eq!(extract_single_string(df.collect().await?), "h");
239+
Ok(())
240+
}
241+
242+
#[tokio::test]
243+
async fn test_max_by_string_float() -> Result<()> {
244+
let df = ctx()?
245+
.sql("SELECT max_by(string, float64) FROM types")
246+
.await?;
247+
assert_eq!(extract_single_string(df.collect().await?), "h");
248+
Ok(())
249+
}
250+
251+
#[tokio::test]
252+
async fn test_max_by_float_string() -> Result<()> {
253+
let df = ctx()?
254+
.sql("SELECT max_by(float64, string) FROM types")
255+
.await?;
256+
assert_eq!(extract_single_float64(df.collect().await?), 8.0);
257+
Ok(())
258+
}
259+
260+
#[tokio::test]
261+
async fn test_max_by_int_string() -> Result<()> {
262+
let df = ctx()?
263+
.sql("SELECT max_by(int64, string) FROM types")
264+
.await?;
265+
assert_eq!(extract_single_int64(df.collect().await?), 8);
266+
Ok(())
267+
}
268+
269+
#[tokio::test]
270+
async fn test_max_by_dictionary_int() -> Result<()> {
271+
let df = ctx()?
272+
.sql("SELECT max_by(dict_string, int64) FROM types")
273+
.await?;
274+
assert_eq!(extract_single_string(df.collect().await?), "h");
275+
Ok(())
276+
}
277+
278+
fn ctx() -> Result<SessionContext> {
279+
let ctx = test_ctx()?;
280+
ctx.register_udaf(max_by_udaf().as_ref().clone());
281+
Ok(ctx)
282+
}
283+
}
284+
285+
#[cfg(test)]
286+
mod test_min_by {
287+
use crate::max_min_by::min_by_udaf;
288+
use crate::max_min_by::tests::{
289+
extract_single_float64, extract_single_int64, extract_single_string, test_ctx,
290+
};
291+
use datafusion::error::Result;
292+
use datafusion::prelude::SessionContext;
293+
294+
#[tokio::test]
295+
async fn test_min_by_string_int() -> Result<()> {
296+
let df = ctx()?
297+
.sql("SELECT min_by(string, int64) FROM types")
298+
.await?;
299+
assert_eq!(extract_single_string(df.collect().await?), "a");
300+
Ok(())
301+
}
302+
303+
#[tokio::test]
304+
async fn test_min_by_string_float() -> Result<()> {
305+
let df = ctx()?
306+
.sql("SELECT min_by(string, float64) FROM types")
307+
.await?;
308+
assert_eq!(extract_single_string(df.collect().await?), "a");
309+
Ok(())
310+
}
311+
312+
#[tokio::test]
313+
async fn test_min_by_float_string() -> Result<()> {
314+
let df = ctx()?
315+
.sql("SELECT min_by(float64, string) FROM types")
316+
.await?;
317+
assert_eq!(extract_single_float64(df.collect().await?), 0.5);
318+
Ok(())
319+
}
320+
321+
#[tokio::test]
322+
async fn test_min_by_int_string() -> Result<()> {
323+
let df = ctx()?
324+
.sql("SELECT min_by(int64, string) FROM types")
325+
.await?;
326+
assert_eq!(extract_single_int64(df.collect().await?), 1);
327+
Ok(())
328+
}
329+
330+
#[tokio::test]
331+
async fn test_min_by_dictionary_int() -> Result<()> {
332+
let df = ctx()?
333+
.sql("SELECT min_by(dict_string, int64) FROM types")
334+
.await?;
335+
assert_eq!(extract_single_string(df.collect().await?), "a");
336+
Ok(())
337+
}
338+
339+
fn ctx() -> Result<SessionContext> {
340+
let ctx = test_ctx()?;
341+
ctx.register_udaf(min_by_udaf().as_ref().clone());
342+
Ok(ctx)
343+
}
344+
}
345+
346+
pub(super) fn test_schema() -> Arc<Schema> {
347+
Arc::new(Schema::new(vec![
348+
Field::new("string", DataType::Utf8, false),
349+
Field::new_dictionary("dict_string", DataType::Int32, DataType::Utf8, false),
350+
Field::new("int64", DataType::Int64, false),
351+
Field::new("uint64", DataType::UInt64, false),
352+
Field::new("float64", DataType::Float64, false),
353+
]))
354+
}
355+
356+
pub(super) fn test_data(schema: Arc<Schema>) -> Vec<RecordBatch> {
357+
use datafusion::arrow::array::DictionaryArray;
358+
use datafusion::arrow::datatypes::Int32Type;
359+
360+
vec![
361+
RecordBatch::try_new(
362+
schema.clone(),
363+
vec![
364+
Arc::new(StringArray::from(vec!["a", "b", "c", "d"])),
365+
Arc::new(
366+
vec![Some("a"), Some("b"), Some("c"), Some("d")]
367+
.into_iter()
368+
.collect::<DictionaryArray<Int32Type>>(),
369+
),
370+
Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
371+
Arc::new(UInt64Array::from(vec![1, 2, 3, 4])),
372+
Arc::new(Float64Array::from(vec![0.5, 2.0, 3.0, 4.0])),
373+
],
374+
)
375+
.unwrap(),
376+
RecordBatch::try_new(
377+
schema.clone(),
378+
vec![
379+
Arc::new(StringArray::from(vec!["e", "f", "g", "h"])),
380+
Arc::new(
381+
vec![Some("e"), Some("f"), Some("g"), Some("h")]
382+
.into_iter()
383+
.collect::<DictionaryArray<Int32Type>>(),
384+
),
385+
Arc::new(Int64Array::from(vec![5, 6, 7, 8])),
386+
Arc::new(UInt64Array::from(vec![5, 6, 7, 8])),
387+
Arc::new(Float64Array::from(vec![5.0, 6.0, 7.0, 8.0])),
388+
],
389+
)
390+
.unwrap(),
391+
]
392+
}
393+
394+
pub(crate) fn test_ctx() -> datafusion::common::Result<SessionContext> {
395+
let schema = test_schema();
396+
let table = MemTable::try_new(schema.clone(), vec![test_data(schema)])?;
397+
let ctx = SessionContext::new();
398+
ctx.register_table("types", Arc::new(table))?;
399+
Ok(ctx)
400+
}
401+
402+
fn downcast<T: Any>(col: &ArrayRef) -> &T {
403+
col.as_any().downcast_ref::<T>().unwrap()
404+
}
405+
406+
pub(crate) fn extract_single_string(results: Vec<RecordBatch>) -> String {
407+
let v1 = downcast::<StringArray>(results[0].column(0));
408+
v1.value(0).to_string()
409+
}
410+
411+
pub(crate) fn extract_single_int64(results: Vec<RecordBatch>) -> i64 {
412+
let v1 = downcast::<Int64Array>(results[0].column(0));
413+
v1.value(0)
414+
}
415+
416+
pub(crate) fn extract_single_float64(results: Vec<RecordBatch>) -> f64 {
417+
let v1 = downcast::<Float64Array>(results[0].column(0));
418+
v1.value(0)
419+
}
420+
}

0 commit comments

Comments
 (0)