Skip to content

Commit 4df8fbe

Browse files
committed
replace type signature for starts_with
1 parent 7299d0e commit 4df8fbe

File tree

4 files changed

+70
-29
lines changed

4 files changed

+70
-29
lines changed

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1256,7 +1256,7 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
12561256
/// Coercion rules for binary (Binary/LargeBinary) to string (Utf8/LargeUtf8):
12571257
/// If one argument is binary and the other is a string then coerce to string
12581258
/// (e.g. for `like`)
1259-
fn binary_to_string_coercion(
1259+
pub fn binary_to_string_coercion(
12601260
lhs_type: &DataType,
12611261
rhs_type: &DataType,
12621262
) -> Option<DataType> {

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ pub use datafusion_expr_common::columnar_value::ColumnarValue;
7676
pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
7777
pub use datafusion_expr_common::operator::Operator;
7878
pub use datafusion_expr_common::signature::{
79-
ArrayFunctionArgument, ArrayFunctionSignature, Signature, TypeSignature,
79+
ArrayFunctionArgument, ArrayFunctionSignature, Coercion, Signature, TypeSignature,
8080
TypeSignatureClass, Volatility, TIMEZONE_WILDCARD,
8181
};
8282
pub use datafusion_expr_common::type_coercion::binary;

datafusion/functions/src/string/starts_with.rs

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,41 @@ use std::sync::Arc;
2020

2121
use arrow::array::ArrayRef;
2222
use arrow::datatypes::DataType;
23+
use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
2324
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
2425

2526
use crate::utils::make_scalar_function;
27+
use datafusion_common::types::logical_string;
2628
use datafusion_common::{internal_err, Result, ScalarValue};
27-
use datafusion_expr::{ColumnarValue, Documentation, Expr, Like};
28-
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
29+
use datafusion_expr::{
30+
cast, Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
31+
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
32+
};
2933
use datafusion_macros::user_doc;
3034

3135
/// Returns true if string starts with prefix.
3236
/// starts_with('alphabet', 'alph') = 't'
3337
pub fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
34-
let result = arrow::compute::kernels::comparison::starts_with(&args[0], &args[1])?;
35-
Ok(Arc::new(result) as ArrayRef)
38+
if let Some(coercion_data_type) =
39+
string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
40+
binary_to_string_coercion(args[0].data_type(), args[1].data_type())
41+
})
42+
{
43+
let arg0 = if args[0].data_type() == &coercion_data_type {
44+
args[0].clone()
45+
} else {
46+
arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
47+
};
48+
let arg1 = if args[1].data_type() == &coercion_data_type {
49+
args[1].clone()
50+
} else {
51+
arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
52+
};
53+
let result = arrow::compute::kernels::comparison::starts_with(&arg0, &arg1)?;
54+
Ok(Arc::new(result) as ArrayRef)
55+
} else {
56+
internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")
57+
}
3658
}
3759

3860
#[user_doc(
@@ -64,7 +86,13 @@ impl Default for StartsWithFunc {
6486
impl StartsWithFunc {
6587
pub fn new() -> Self {
6688
Self {
67-
signature: Signature::string(2, Volatility::Immutable),
89+
signature: Signature::coercible(
90+
vec![
91+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
92+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
93+
],
94+
Volatility::Immutable,
95+
),
6896
}
6997
}
7098
}
@@ -98,7 +126,7 @@ impl ScalarUDFImpl for StartsWithFunc {
98126
fn simplify(
99127
&self,
100128
args: Vec<Expr>,
101-
_info: &dyn SimplifyInfo,
129+
info: &dyn SimplifyInfo,
102130
) -> Result<ExprSimplifyResult> {
103131
if let Expr::Literal(scalar_value) = &args[1] {
104132
// Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping
@@ -107,31 +135,44 @@ impl ScalarUDFImpl for StartsWithFunc {
107135
// 2. 'ja\%' (escape special char '%')
108136
// 3. 'ja\%%' (add suffix for starts_with)
109137
let like_expr = match scalar_value {
110-
ScalarValue::Utf8(Some(pattern)) => {
138+
ScalarValue::Utf8(Some(pattern))
139+
| ScalarValue::LargeUtf8(Some(pattern))
140+
| ScalarValue::Utf8View(Some(pattern)) => {
111141
let escaped_pattern = pattern.replace("%", "\\%");
112142
let like_pattern = format!("{}%", escaped_pattern);
113143
Expr::Literal(ScalarValue::Utf8(Some(like_pattern)))
114144
}
115-
ScalarValue::LargeUtf8(Some(pattern)) => {
116-
let escaped_pattern = pattern.replace("%", "\\%");
117-
let like_pattern = format!("{}%", escaped_pattern);
118-
Expr::Literal(ScalarValue::LargeUtf8(Some(like_pattern)))
119-
}
120-
ScalarValue::Utf8View(Some(pattern)) => {
121-
let escaped_pattern = pattern.replace("%", "\\%");
122-
let like_pattern = format!("{}%", escaped_pattern);
123-
Expr::Literal(ScalarValue::Utf8View(Some(like_pattern)))
124-
}
125145
_ => return Ok(ExprSimplifyResult::Original(args)),
126146
};
127147

128-
return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
129-
negated: false,
130-
expr: Box::new(args[0].clone()),
131-
pattern: Box::new(like_expr),
132-
escape_char: None,
133-
case_insensitive: false,
134-
})));
148+
let expr_data_type = info.get_data_type(&args[0])?;
149+
let pattern_data_type = info.get_data_type(&like_expr)?;
150+
151+
if let Some(coercion_data_type) =
152+
string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
153+
binary_to_string_coercion(&expr_data_type, &pattern_data_type)
154+
})
155+
{
156+
let expr = if expr_data_type == coercion_data_type {
157+
args[0].clone()
158+
} else {
159+
cast(args[0].clone(), coercion_data_type.clone())
160+
};
161+
162+
let pattern = if pattern_data_type == coercion_data_type {
163+
like_expr
164+
} else {
165+
cast(like_expr, coercion_data_type)
166+
};
167+
168+
return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
169+
negated: false,
170+
expr: Box::new(expr),
171+
pattern: Box::new(pattern),
172+
escape_char: None,
173+
case_insensitive: false,
174+
})));
175+
}
135176
}
136177

137178
Ok(ExprSimplifyResult::Original(args))

datafusion/sqllogictest/test_files/string/string_view.slt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ EXPLAIN SELECT
299299
FROM test;
300300
----
301301
logical_plan
302-
01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, CAST(test.column2_utf8 AS Utf8View)) AS c2, starts_with(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3
302+
01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, test.column2_utf8) AS c2, starts_with(test.column1_utf8view, test.column2_large_utf8) AS c3
303303
02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view, column2_utf8view]
304304

305305
query BBB
@@ -326,7 +326,7 @@ EXPLAIN SELECT
326326
FROM test;
327327
----
328328
logical_plan
329-
01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(CAST(test.column1_utf8 AS LargeUtf8), test.column2_large_utf8) AS c4
329+
01)Projection: starts_with(test.column1_utf8, test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(test.column1_utf8, test.column2_large_utf8) AS c4
330330
02)--TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view]
331331

332332
query BBB
@@ -382,7 +382,7 @@ EXPLAIN SELECT
382382
FROM test;
383383
----
384384
logical_plan
385-
01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), substr(test.column1_utf8, Int64(1), Int64(2))) AS c1, starts_with(CAST(test.column1_large_utf8 AS Utf8View), substr(test.column1_large_utf8, Int64(1), Int64(2))) AS c2, starts_with(test.column1_utf8view, substr(test.column1_utf8view, Int64(1), Int64(2))) AS c3
385+
01)Projection: starts_with(test.column1_utf8, substr(test.column1_utf8, Int64(1), Int64(2))) AS c1, starts_with(test.column1_large_utf8, substr(test.column1_large_utf8, Int64(1), Int64(2))) AS c2, starts_with(test.column1_utf8view, substr(test.column1_utf8view, Int64(1), Int64(2))) AS c3
386386
02)--TableScan: test projection=[column1_utf8, column1_large_utf8, column1_utf8view]
387387

388388
query BBB

0 commit comments

Comments
 (0)