Skip to content

Commit 8469666

Browse files
zjregeeavantgardnerio
authored andcommitted
fix starts_with (apache#14812, apache#19077)
1 parent 31382b7 commit 8469666

File tree

4 files changed

+106
-29
lines changed

4 files changed

+106
-29
lines changed

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1236,7 +1236,7 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
12361236
/// Coercion rules for binary (Binary/LargeBinary) to string (Utf8/LargeUtf8):
12371237
/// If one argument is binary and the other is a string then coerce to string
12381238
/// (e.g. for `like`)
1239-
fn binary_to_string_coercion(
1239+
pub fn binary_to_string_coercion(
12401240
lhs_type: &DataType,
12411241
rhs_type: &DataType,
12421242
) -> Option<DataType> {

datafusion/functions/src/string/starts_with.rs

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,40 @@ use std::sync::Arc;
2121
use arrow::array::ArrayRef;
2222
use arrow::datatypes::DataType;
2323
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
24+
use datafusion_expr::type_coercion::binary::{
25+
binary_to_string_coercion, string_coercion,
26+
};
2427

2528
use crate::utils::make_scalar_function;
2629
use datafusion_common::{internal_err, Result, ScalarValue};
30+
use datafusion_expr::cast;
2731
use datafusion_expr::{ColumnarValue, Documentation, Expr, Like};
2832
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2933
use datafusion_macros::user_doc;
3034

3135
/// Returns true if string starts with prefix.
3236
/// starts_with('alphabet', 'alph') = 't'
33-
pub fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
34-
let result = arrow::compute::kernels::comparison::starts_with(&args[0], &args[1])?;
35-
Ok(Arc::new(result) as ArrayRef)
37+
fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
38+
if let Some(coercion_data_type) =
39+
string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
40+
binary_to_string_coercion(args[0].data_type(), args[1].data_type())
41+
})
42+
{
43+
let arg0 = if args[0].data_type() == &coercion_data_type {
44+
Arc::clone(&args[0])
45+
} else {
46+
arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
47+
};
48+
let arg1 = if args[1].data_type() == &coercion_data_type {
49+
Arc::clone(&args[1])
50+
} else {
51+
arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
52+
};
53+
let result = arrow::compute::kernels::comparison::starts_with(&arg0, &arg1)?;
54+
Ok(Arc::new(result) as ArrayRef)
55+
} else {
56+
internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")
57+
}
3658
}
3759

3860
#[user_doc(
@@ -102,40 +124,56 @@ impl ScalarUDFImpl for StartsWithFunc {
102124
fn simplify(
103125
&self,
104126
args: Vec<Expr>,
105-
_info: &dyn SimplifyInfo,
127+
info: &dyn SimplifyInfo,
106128
) -> Result<ExprSimplifyResult> {
107129
if let Expr::Literal(scalar_value) = &args[1] {
108130
// Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping
109-
// Example: starts_with(col, 'ja%') -> col LIKE 'ja\%%'
110-
// 1. 'ja%' (input pattern)
111-
// 2. 'ja\%' (escape special char '%')
112-
// 3. 'ja\%%' (add suffix for starts_with)
131+
// Escapes pattern characters: starts_with(col, 'j\_a%') -> col LIKE 'j\\\_a\%%'
132+
// 1. 'j\_a%' (input pattern)
133+
// 2. 'j\\\_a\%' (escape special chars '%', '_' and '\')
134+
// 3. 'j\\\_a\%%' (add unescaped % suffix for starts_with)
113135
let like_expr = match scalar_value {
114-
ScalarValue::Utf8(Some(pattern)) => {
115-
let escaped_pattern = pattern.replace("%", "\\%");
136+
ScalarValue::Utf8(Some(pattern))
137+
| ScalarValue::LargeUtf8(Some(pattern))
138+
| ScalarValue::Utf8View(Some(pattern)) => {
139+
let escaped_pattern = pattern
140+
.replace("\\", "\\\\")
141+
.replace("%", "\\%")
142+
.replace("_", "\\_");
116143
let like_pattern = format!("{}%", escaped_pattern);
117144
Expr::Literal(ScalarValue::Utf8(Some(like_pattern)))
118145
}
119-
ScalarValue::LargeUtf8(Some(pattern)) => {
120-
let escaped_pattern = pattern.replace("%", "\\%");
121-
let like_pattern = format!("{}%", escaped_pattern);
122-
Expr::Literal(ScalarValue::LargeUtf8(Some(like_pattern)))
123-
}
124-
ScalarValue::Utf8View(Some(pattern)) => {
125-
let escaped_pattern = pattern.replace("%", "\\%");
126-
let like_pattern = format!("{}%", escaped_pattern);
127-
Expr::Literal(ScalarValue::Utf8View(Some(like_pattern)))
128-
}
129146
_ => return Ok(ExprSimplifyResult::Original(args)),
130147
};
131148

132-
return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
133-
negated: false,
134-
expr: Box::new(args[0].clone()),
135-
pattern: Box::new(like_expr),
136-
escape_char: None,
137-
case_insensitive: false,
138-
})));
149+
let expr_data_type = info.get_data_type(&args[0])?;
150+
let pattern_data_type = info.get_data_type(&like_expr)?;
151+
152+
if let Some(coercion_data_type) =
153+
string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
154+
binary_to_string_coercion(&expr_data_type, &pattern_data_type)
155+
})
156+
{
157+
let expr = if expr_data_type == coercion_data_type {
158+
args[0].clone()
159+
} else {
160+
cast(args[0].clone(), coercion_data_type.clone())
161+
};
162+
163+
let pattern = if pattern_data_type == coercion_data_type {
164+
like_expr
165+
} else {
166+
cast(like_expr, coercion_data_type)
167+
};
168+
169+
return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
170+
negated: false,
171+
expr: Box::new(expr),
172+
pattern: Box::new(pattern),
173+
escape_char: None,
174+
case_insensitive: false,
175+
})));
176+
}
139177
}
140178

141179
Ok(ExprSimplifyResult::Original(args))

datafusion/sqllogictest/test_files/string/string_literal.slt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,25 @@ SELECT ends_with('foobar', 'foo')
207207
----
208208
false
209209

210+
query B
211+
SELECT ends_with(a, '%bar') from (values ('foobar'), ('foo%bar')) as t(a);
212+
----
213+
false
214+
true
215+
216+
query B
217+
SELECT ends_with(a, '_bar') from (values ('foobar'), ('foo_bar')) as t(a);
218+
----
219+
false
220+
true
221+
222+
query B
223+
SELECT ends_with(a, '\_bar') from (values ('foobar'), ('foo\\bar'), ('foo\_bar')) as t(a);
224+
----
225+
false
226+
false
227+
true
228+
210229
query I
211230
SELECT levenshtein('kitten', 'sitting')
212231
----
@@ -826,6 +845,26 @@ SELECT starts_with('foobar', 'bar')
826845
----
827846
false
828847

848+
849+
query B
850+
SELECT starts_with(a, 'foo%') from (values ('foobar'), ('foo%bar')) as t(a);
851+
----
852+
false
853+
true
854+
855+
query B
856+
SELECT starts_with(a, 'foo\_') from (values ('foobar'), ('foo\\_bar'), ('foo\_bar')) as t(a);
857+
----
858+
false
859+
false
860+
true
861+
862+
query B
863+
SELECT starts_with(a, 'foo_') from (values ('foobar'), ('foo_bar')) as t(a);
864+
----
865+
false
866+
true
867+
829868
query TT
830869
select ' ', '|'
831870
----

datafusion/sqllogictest/test_files/string/string_view.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ EXPLAIN SELECT
370370
FROM test;
371371
----
372372
logical_plan
373-
01)Projection: test.column1_utf8 LIKE Utf8("foo\%%") AS c1, test.column1_large_utf8 LIKE LargeUtf8("foo\%%") AS c2, test.column1_utf8view LIKE Utf8View("foo\%%") AS c3, test.column1_utf8 LIKE Utf8("f_o%") AS c4, test.column1_large_utf8 LIKE LargeUtf8("f_o%") AS c5, test.column1_utf8view LIKE Utf8View("f_o%") AS c6
373+
01)Projection: test.column1_utf8 LIKE Utf8("foo\%%") AS c1, test.column1_large_utf8 LIKE LargeUtf8("foo\%%") AS c2, test.column1_utf8view LIKE Utf8View("foo\%%") AS c3, test.column1_utf8 LIKE Utf8("f\_o%") AS c4, test.column1_large_utf8 LIKE LargeUtf8("f\_o%") AS c5, test.column1_utf8view LIKE Utf8View("f\_o%") AS c6
374374
02)--TableScan: test projection=[column1_utf8, column1_large_utf8, column1_utf8view]
375375

376376
## Test STARTS_WITH works with column arguments

0 commit comments

Comments
 (0)