Skip to content

Commit 7873e5c

Browse files
gstvgOmega359
andauthored
Add union_extract scalar function (#12116)
* feat: add union_extract scalar function * fix: docs fmt, add clippy atr, sql error msg * use arrow-rs implementation * docs: add union functions section * docs: simplify union_extract docs * test: simplify union_extract sqllogictests * refactor(union_extract): new udf api, docs macro, use any signature * fix: remove user_doc include attribute * fix: generate docs * fix: manually trim sqllogictest generated errors * fix: fmt * docs: add union functions section description * docs: update functions docs * docs: clarify union_extract description Co-authored-by: Bruce Ritchie <[email protected]> * fix: use return_type_from_args, tests, docs --------- Co-authored-by: Bruce Ritchie <[email protected]>
1 parent c1338b7 commit 7873e5c

File tree

6 files changed

+381
-3
lines changed

6 files changed

+381
-3
lines changed

datafusion/expr/src/udf.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,7 @@ pub mod scalar_doc_sections {
980980
DOC_SECTION_STRUCT,
981981
DOC_SECTION_MAP,
982982
DOC_SECTION_HASHING,
983+
DOC_SECTION_UNION,
983984
DOC_SECTION_OTHER,
984985
]
985986
}
@@ -996,6 +997,7 @@ pub mod scalar_doc_sections {
996997
DOC_SECTION_STRUCT,
997998
DOC_SECTION_MAP,
998999
DOC_SECTION_HASHING,
1000+
DOC_SECTION_UNION,
9991001
DOC_SECTION_OTHER,
10001002
]
10011003
}
@@ -1070,4 +1072,10 @@ The following regular expression functions are supported:"#,
10701072
label: "Other Functions",
10711073
description: None,
10721074
};
1075+
1076+
pub const DOC_SECTION_UNION: DocSection = DocSection {
1077+
include: true,
1078+
label: "Union Functions",
1079+
description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"),
1080+
};
10731081
}

datafusion/functions/src/core/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub mod nvl;
3434
pub mod nvl2;
3535
pub mod planner;
3636
pub mod r#struct;
37+
pub mod union_extract;
3738
pub mod version;
3839

3940
// create UDFs
@@ -48,6 +49,7 @@ make_udf_function!(getfield::GetFieldFunc, get_field);
4849
make_udf_function!(coalesce::CoalesceFunc, coalesce);
4950
make_udf_function!(greatest::GreatestFunc, greatest);
5051
make_udf_function!(least::LeastFunc, least);
52+
make_udf_function!(union_extract::UnionExtractFun, union_extract);
5153
make_udf_function!(version::VersionFunc, version);
5254

5355
pub mod expr_fn {
@@ -99,6 +101,11 @@ pub mod expr_fn {
99101
pub fn get_field(arg1: Expr, arg2: impl Literal) -> Expr {
100102
super::get_field().call(vec![arg1, arg2.lit()])
101103
}
104+
105+
#[doc = "Returns the value of the field with the given name from the union when it's selected, or NULL otherwise"]
106+
pub fn union_extract(arg1: Expr, arg2: impl Literal) -> Expr {
107+
super::union_extract().call(vec![arg1, arg2.lit()])
108+
}
102109
}
103110

104111
/// Returns all DataFusion functions defined in this package
@@ -121,6 +128,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
121128
coalesce(),
122129
greatest(),
123130
least(),
131+
union_extract(),
124132
version(),
125133
r#struct(),
126134
]
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::Array;
19+
use arrow::datatypes::{DataType, FieldRef, UnionFields};
20+
use datafusion_common::cast::as_union_array;
21+
use datafusion_common::{
22+
exec_datafusion_err, exec_err, internal_err, Result, ScalarValue,
23+
};
24+
use datafusion_doc::Documentation;
25+
use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs};
26+
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
27+
use datafusion_macros::user_doc;
28+
29+
#[user_doc(
30+
doc_section(label = "Union Functions"),
31+
description = "Returns the value of the given field in the union when selected, or NULL otherwise.",
32+
syntax_example = "union_extract(union, field_name)",
33+
sql_example = r#"```sql
34+
❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union;
35+
+--------------+----------------------------------+----------------------------------+
36+
| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') |
37+
+--------------+----------------------------------+----------------------------------+
38+
| {a=1} | 1 | |
39+
| {b=3.0} | | 3.0 |
40+
| {a=4} | 4 | |
41+
| {b=} | | |
42+
| {a=} | | |
43+
+--------------+----------------------------------+----------------------------------+
44+
```"#,
45+
standard_argument(name = "union", prefix = "Union"),
46+
argument(
47+
name = "field_name",
48+
description = "String expression to operate on. Must be a constant."
49+
)
50+
)]
51+
#[derive(Debug)]
52+
pub struct UnionExtractFun {
53+
signature: Signature,
54+
}
55+
56+
impl Default for UnionExtractFun {
57+
fn default() -> Self {
58+
Self::new()
59+
}
60+
}
61+
62+
impl UnionExtractFun {
63+
pub fn new() -> Self {
64+
Self {
65+
signature: Signature::any(2, Volatility::Immutable),
66+
}
67+
}
68+
}
69+
70+
impl ScalarUDFImpl for UnionExtractFun {
71+
fn as_any(&self) -> &dyn std::any::Any {
72+
self
73+
}
74+
75+
fn name(&self) -> &str {
76+
"union_extract"
77+
}
78+
79+
fn signature(&self) -> &Signature {
80+
&self.signature
81+
}
82+
83+
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
84+
// should be using return_type_from_exprs and not calling the default implementation
85+
internal_err!("union_extract should return type from exprs")
86+
}
87+
88+
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
89+
if args.arg_types.len() != 2 {
90+
return exec_err!(
91+
"union_extract expects 2 arguments, got {} instead",
92+
args.arg_types.len()
93+
);
94+
}
95+
96+
let DataType::Union(fields, _) = &args.arg_types[0] else {
97+
return exec_err!(
98+
"union_extract first argument must be a union, got {} instead",
99+
args.arg_types[0]
100+
);
101+
};
102+
103+
let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else {
104+
return exec_err!(
105+
"union_extract second argument must be a non-null string literal, got {} instead",
106+
args.arg_types[1]
107+
);
108+
};
109+
110+
let field = find_field(fields, field_name)?.1;
111+
112+
Ok(ReturnInfo::new_nullable(field.data_type().clone()))
113+
}
114+
115+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
116+
let args = args.args;
117+
118+
if args.len() != 2 {
119+
return exec_err!(
120+
"union_extract expects 2 arguments, got {} instead",
121+
args.len()
122+
);
123+
}
124+
125+
let target_name = match &args[1] {
126+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name),
127+
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"),
128+
_ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", &args[1].data_type()),
129+
};
130+
131+
match &args[0] {
132+
ColumnarValue::Array(array) => {
133+
let union_array = as_union_array(&array).map_err(|_| {
134+
exec_datafusion_err!(
135+
"union_extract first argument must be a union, got {} instead",
136+
array.data_type()
137+
)
138+
})?;
139+
140+
Ok(ColumnarValue::Array(
141+
arrow::compute::kernels::union_extract::union_extract(
142+
union_array,
143+
target_name?,
144+
)?,
145+
))
146+
}
147+
ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => {
148+
let target_name = target_name?;
149+
let (target_type_id, target) = find_field(fields, target_name)?;
150+
151+
let result = match value {
152+
Some((type_id, value)) if target_type_id == *type_id => {
153+
*value.clone()
154+
}
155+
_ => ScalarValue::try_from(target.data_type())?,
156+
};
157+
158+
Ok(ColumnarValue::Scalar(result))
159+
}
160+
other => exec_err!(
161+
"union_extract first argument must be a union, got {} instead",
162+
other.data_type()
163+
),
164+
}
165+
}
166+
167+
fn documentation(&self) -> Option<&Documentation> {
168+
self.doc()
169+
}
170+
}
171+
172+
fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> {
173+
fields
174+
.iter()
175+
.find(|field| field.1.name() == name)
176+
.ok_or_else(|| exec_datafusion_err!("field {name} not found on union"))
177+
}
178+
179+
#[cfg(test)]
180+
mod tests {
181+
182+
use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
183+
use datafusion_common::{Result, ScalarValue};
184+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
185+
186+
use super::UnionExtractFun;
187+
188+
// when it becomes possible to construct union scalars in SQL, this should go to sqllogictests
189+
#[test]
190+
fn test_scalar_value() -> Result<()> {
191+
let fun = UnionExtractFun::new();
192+
193+
let fields = UnionFields::new(
194+
vec![1, 3],
195+
vec![
196+
Field::new("str", DataType::Utf8, false),
197+
Field::new("int", DataType::Int32, false),
198+
],
199+
);
200+
201+
let result = fun.invoke_with_args(ScalarFunctionArgs {
202+
args: vec![
203+
ColumnarValue::Scalar(ScalarValue::Union(
204+
None,
205+
fields.clone(),
206+
UnionMode::Dense,
207+
)),
208+
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
209+
],
210+
number_rows: 1,
211+
return_type: &DataType::Utf8,
212+
})?;
213+
214+
assert_scalar(result, ScalarValue::Utf8(None));
215+
216+
let result = fun.invoke_with_args(ScalarFunctionArgs {
217+
args: vec![
218+
ColumnarValue::Scalar(ScalarValue::Union(
219+
Some((3, Box::new(ScalarValue::Int32(Some(42))))),
220+
fields.clone(),
221+
UnionMode::Dense,
222+
)),
223+
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
224+
],
225+
number_rows: 1,
226+
return_type: &DataType::Utf8,
227+
})?;
228+
229+
assert_scalar(result, ScalarValue::Utf8(None));
230+
231+
let result = fun.invoke_with_args(ScalarFunctionArgs {
232+
args: vec![
233+
ColumnarValue::Scalar(ScalarValue::Union(
234+
Some((1, Box::new(ScalarValue::new_utf8("42")))),
235+
fields.clone(),
236+
UnionMode::Dense,
237+
)),
238+
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
239+
],
240+
number_rows: 1,
241+
return_type: &DataType::Utf8,
242+
})?;
243+
244+
assert_scalar(result, ScalarValue::new_utf8("42"));
245+
246+
Ok(())
247+
}
248+
249+
fn assert_scalar(value: ColumnarValue, expected: ScalarValue) {
250+
match value {
251+
ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"),
252+
ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected),
253+
}
254+
}
255+
}

datafusion/sqllogictest/src/test_context.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ use std::path::Path;
2222
use std::sync::Arc;
2323

2424
use arrow::array::{
25-
ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray,
26-
StringArray, TimestampNanosecondArray,
25+
Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray,
26+
LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray,
2727
};
28-
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
28+
use arrow::buffer::ScalarBuffer;
29+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields};
2930
use arrow::record_batch::RecordBatch;
3031
use datafusion::catalog::{
3132
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session,
@@ -113,6 +114,10 @@ impl TestContext {
113114
info!("Registering metadata table tables");
114115
register_metadata_tables(test_ctx.session_ctx()).await;
115116
}
117+
"union_function.slt" => {
118+
info!("Registering table with union column");
119+
register_union_table(test_ctx.session_ctx())
120+
}
116121
_ => {
117122
info!("Using default SessionContext");
118123
}
@@ -402,3 +407,24 @@ fn create_example_udf() -> ScalarUDF {
402407
adder,
403408
)
404409
}
410+
411+
fn register_union_table(ctx: &SessionContext) {
412+
let union = UnionArray::try_new(
413+
UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]),
414+
ScalarBuffer::from(vec![3, 3]),
415+
None,
416+
vec![Arc::new(Int32Array::from(vec![1, 2]))],
417+
)
418+
.unwrap();
419+
420+
let schema = Schema::new(vec![Field::new(
421+
"union_column",
422+
union.data_type().clone(),
423+
false,
424+
)]);
425+
426+
let batch =
427+
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union)]).unwrap();
428+
429+
ctx.register_batch("union_table", batch).unwrap();
430+
}

0 commit comments

Comments
 (0)