Skip to content

Commit dd338ce

Browse files
authored
Merge pull request #19 from datafusion-contrib/min-max-dict
Min max dict
2 parents 32028ea + df66931 commit dd338ce

File tree

1 file changed

+262
-1
lines changed

1 file changed

+262
-1
lines changed

src/max_min_by.rs

Lines changed: 262 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ fn get_min_max_by_result_type(
4545
match &input_types[0] {
4646
arrow::datatypes::DataType::Dictionary(_, dict_value_type) => {
4747
// x add checker, if the value type is complex data type
48-
Ok(vec![dict_value_type.deref().clone()])
48+
let mut result = vec![dict_value_type.deref().clone()];
49+
// Preserve all other argument types
50+
result.extend_from_slice(&input_types[1..]);
51+
Ok(result)
4952
}
5053
_ => Ok(input_types.to_vec()),
5154
}
@@ -207,3 +210,261 @@ impl logical_expr::AggregateUDFImpl for MinByFunction {
207210
Some(Box::new(simplify))
208211
}
209212
}
213+
214+
#[cfg(test)]
215+
mod tests {
216+
use super::*;
217+
218+
use datafusion::arrow::array::ArrayAccessor;
219+
use datafusion::{arrow, datasource, error, prelude};
220+
use std::sync;
221+
222+
const TEST_TABLE_NAME: &str = "types";
223+
const STRING_COLUMN_NAME: &str = "string";
224+
const DICTIONARY_COLUMN_NAME: &str = "dict_string";
225+
const INT64_COLUMN_NAME: &str = "int64";
226+
const FLOAT64_COLUMN_NAME: &str = "float64";
227+
228+
const MIN_STRING_VALUE: &str = "a";
229+
const MID_STRING_VALUE: &str = "b";
230+
const MAX_STRING_VALUE: &str = "c";
231+
const MIN_FLOAT_VALUE: f64 = 0.25;
232+
const MID_FLOAT_VALUE: f64 = 0.5;
233+
const MAX_FLOAT_VALUE: f64 = 0.75;
234+
const MIN_INT_VALUE: i64 = -1;
235+
const MID_INT_VALUE: i64 = 0;
236+
const MAX_INT_VALUE: i64 = 1;
237+
const MIN_DICTIONARY_VALUE: &str = "a";
238+
const MID_DICTIONARY_VALUE: &str = "b";
239+
const MAX_DICTIONARY_VALUE: &str = "c";
240+
241+
fn test_schema() -> sync::Arc<arrow::datatypes::Schema> {
242+
sync::Arc::new(arrow::datatypes::Schema::new(vec![
243+
arrow::datatypes::Field::new(
244+
STRING_COLUMN_NAME,
245+
arrow::datatypes::DataType::Utf8,
246+
false,
247+
),
248+
arrow::datatypes::Field::new_dictionary(
249+
DICTIONARY_COLUMN_NAME,
250+
arrow::datatypes::DataType::Int32,
251+
arrow::datatypes::DataType::Utf8,
252+
false,
253+
),
254+
arrow::datatypes::Field::new(
255+
INT64_COLUMN_NAME,
256+
arrow::datatypes::DataType::Int64,
257+
false,
258+
),
259+
arrow::datatypes::Field::new(
260+
FLOAT64_COLUMN_NAME,
261+
arrow::datatypes::DataType::Float64,
262+
false,
263+
),
264+
]))
265+
}
266+
267+
fn test_data(
268+
schema: sync::Arc<arrow::datatypes::Schema>,
269+
) -> Vec<arrow::record_batch::RecordBatch> {
270+
vec![
271+
arrow::record_batch::RecordBatch::try_new(
272+
schema,
273+
vec![
274+
sync::Arc::new(arrow::array::StringArray::from(vec![
275+
MID_STRING_VALUE,
276+
MIN_STRING_VALUE,
277+
MAX_STRING_VALUE,
278+
])),
279+
sync::Arc::new(
280+
vec![
281+
Some(MID_DICTIONARY_VALUE),
282+
Some(MIN_DICTIONARY_VALUE),
283+
Some(MAX_DICTIONARY_VALUE),
284+
]
285+
.into_iter()
286+
.collect::<arrow::array::DictionaryArray<arrow::datatypes::Int32Type>>(),
287+
),
288+
sync::Arc::new(arrow::array::Int64Array::from(vec![
289+
MID_INT_VALUE,
290+
MIN_INT_VALUE,
291+
MAX_INT_VALUE,
292+
])),
293+
sync::Arc::new(arrow::array::Float64Array::from(vec![
294+
MID_FLOAT_VALUE,
295+
MIN_FLOAT_VALUE,
296+
MAX_FLOAT_VALUE,
297+
])),
298+
],
299+
)
300+
.unwrap(),
301+
]
302+
}
303+
304+
fn test_ctx() -> datafusion::common::Result<prelude::SessionContext> {
305+
let schema = test_schema();
306+
let data = test_data(schema.clone());
307+
let table = datasource::MemTable::try_new(schema, vec![data])?;
308+
let ctx = prelude::SessionContext::new();
309+
ctx.register_table(TEST_TABLE_NAME, sync::Arc::new(table))?;
310+
Ok(ctx)
311+
}
312+
313+
async fn extract_single_value<T, A>(df: prelude::DataFrame) -> error::Result<T>
314+
where
315+
A: arrow::array::Array + 'static,
316+
for<'a> &'a A: arrow::array::ArrayAccessor,
317+
for<'a> <&'a A as arrow::array::ArrayAccessor>::Item: Into<T>,
318+
{
319+
let results = df.collect().await?;
320+
let col = results[0].column(0);
321+
let v1 = col.as_any().downcast_ref::<A>().unwrap();
322+
let value = v1.value(0).into();
323+
Ok(value)
324+
}
325+
326+
#[cfg(test)]
327+
mod max_by {
328+
use super::*;
329+
330+
#[tokio::test]
331+
async fn test_max_by_string_int() -> error::Result<()> {
332+
let query = format!(
333+
"SELECT max_by({}, {}) FROM {}",
334+
STRING_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME
335+
);
336+
let df = ctx()?.sql(&query).await?;
337+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
338+
assert_eq!(result, MAX_STRING_VALUE);
339+
Ok(())
340+
}
341+
342+
#[tokio::test]
343+
async fn test_max_by_string_float() -> error::Result<()> {
344+
let query = format!(
345+
"SELECT max_by({}, {}) FROM {}",
346+
STRING_COLUMN_NAME, FLOAT64_COLUMN_NAME, TEST_TABLE_NAME
347+
);
348+
let df = ctx()?.sql(&query).await?;
349+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
350+
assert_eq!(result, MAX_STRING_VALUE);
351+
Ok(())
352+
}
353+
354+
#[tokio::test]
355+
async fn test_max_by_float_string() -> error::Result<()> {
356+
let query = format!(
357+
"SELECT max_by({}, {}) FROM {}",
358+
FLOAT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME
359+
);
360+
let df = ctx()?.sql(&query).await?;
361+
let result = extract_single_value::<f64, arrow::array::Float64Array>(df).await?;
362+
assert_eq!(result, MAX_FLOAT_VALUE);
363+
Ok(())
364+
}
365+
366+
#[tokio::test]
367+
async fn test_max_by_int_string() -> error::Result<()> {
368+
let query = format!(
369+
"SELECT max_by({}, {}) FROM {}",
370+
INT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME
371+
);
372+
let df = ctx()?.sql(&query).await?;
373+
let result = extract_single_value::<i64, arrow::array::Int64Array>(df).await?;
374+
assert_eq!(result, MAX_INT_VALUE);
375+
Ok(())
376+
}
377+
378+
#[tokio::test]
379+
async fn test_max_by_dictionary_int() -> error::Result<()> {
380+
let query = format!(
381+
"SELECT max_by({}, {}) FROM {}",
382+
DICTIONARY_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME
383+
);
384+
let df = ctx()?.sql(&query).await?;
385+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
386+
assert_eq!(result, MAX_DICTIONARY_VALUE);
387+
Ok(())
388+
}
389+
390+
fn ctx() -> error::Result<prelude::SessionContext> {
391+
let ctx = test_ctx()?;
392+
let max_by_udaf = MaxByFunction::new();
393+
ctx.register_udaf(max_by_udaf.into());
394+
Ok(ctx)
395+
}
396+
}
397+
398+
#[cfg(test)]
399+
mod min_by {
400+
401+
use super::*;
402+
403+
#[tokio::test]
404+
async fn test_min_by_string_int() -> error::Result<()> {
405+
let query = format!(
406+
"SELECT min_by({}, {}) FROM {}",
407+
STRING_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME
408+
);
409+
let df = ctx()?.sql(&query).await?;
410+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
411+
assert_eq!(result, MIN_STRING_VALUE);
412+
Ok(())
413+
}
414+
415+
#[tokio::test]
416+
async fn test_min_by_string_float() -> error::Result<()> {
417+
let query = format!(
418+
"SELECT min_by({}, {}) FROM {}",
419+
STRING_COLUMN_NAME, FLOAT64_COLUMN_NAME, TEST_TABLE_NAME
420+
);
421+
let df = ctx()?.sql(&query).await?;
422+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
423+
assert_eq!(result, MIN_STRING_VALUE);
424+
Ok(())
425+
}
426+
427+
#[tokio::test]
428+
async fn test_min_by_float_string() -> error::Result<()> {
429+
let query = format!(
430+
"SELECT min_by({}, {}) FROM {}",
431+
FLOAT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME
432+
);
433+
let df = ctx()?.sql(&query).await?;
434+
let result = extract_single_value::<f64, arrow::array::Float64Array>(df).await?;
435+
assert_eq!(result, MIN_FLOAT_VALUE);
436+
Ok(())
437+
}
438+
439+
#[tokio::test]
440+
async fn test_min_by_int_string() -> error::Result<()> {
441+
let query = format!(
442+
"SELECT min_by({}, {}) FROM {}",
443+
INT64_COLUMN_NAME, STRING_COLUMN_NAME, TEST_TABLE_NAME
444+
);
445+
let df = ctx()?.sql(&query).await?;
446+
let result = extract_single_value::<i64, arrow::array::Int64Array>(df).await?;
447+
assert_eq!(result, MIN_INT_VALUE);
448+
Ok(())
449+
}
450+
451+
#[tokio::test]
452+
async fn test_min_by_dictionary_int() -> error::Result<()> {
453+
let query = format!(
454+
"SELECT min_by({}, {}) FROM {}",
455+
DICTIONARY_COLUMN_NAME, INT64_COLUMN_NAME, TEST_TABLE_NAME
456+
);
457+
let df = ctx()?.sql(&query).await?;
458+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
459+
assert_eq!(result, MIN_DICTIONARY_VALUE);
460+
Ok(())
461+
}
462+
463+
fn ctx() -> error::Result<prelude::SessionContext> {
464+
let ctx = test_ctx()?;
465+
let min_by_udaf = MinByFunction::new();
466+
ctx.register_udaf(min_by_udaf.into());
467+
Ok(ctx)
468+
}
469+
}
470+
}

0 commit comments

Comments
 (0)