Skip to content

Commit 9c37c65

Browse files
committed
use coercion in starts_with
1 parent a5791bc commit 9c37c65

File tree

5 files changed

+62
-40
lines changed

5 files changed

+62
-40
lines changed

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1256,7 +1256,7 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
12561256
/// Coercion rules for binary (Binary/LargeBinary) to string (Utf8/LargeUtf8):
12571257
/// If one argument is binary and the other is a string then coerce to string
12581258
/// (e.g. for `like`)
1259-
fn binary_to_string_coercion(
1259+
pub fn binary_to_string_coercion(
12601260
lhs_type: &DataType,
12611261
rhs_type: &DataType,
12621262
) -> Option<DataType> {

datafusion/functions/src/string/starts_with.rs

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,39 @@ use std::sync::Arc;
2020

2121
use arrow::array::ArrayRef;
2222
use arrow::datatypes::DataType;
23+
use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
2324
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
2425

2526
use crate::utils::make_scalar_function;
2627
use datafusion_common::types::logical_string;
2728
use datafusion_common::{internal_err, Result, ScalarValue};
2829
use datafusion_expr::{
29-
Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
30+
cast, Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
3031
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
3132
};
3233
use datafusion_macros::user_doc;
3334

3435
/// Returns true if string starts with prefix.
3536
/// starts_with('alphabet', 'alph') = 't'
3637
pub fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
37-
let arg0 = arrow::compute::kernels::cast::cast(&args[0], args[1].data_type())?;
38-
let result = arrow::compute::kernels::comparison::starts_with(&arg0, &args[1])?;
39-
Ok(Arc::new(result) as ArrayRef)
38+
if args[0].data_type() == args[1].data_type() {
39+
let result =
40+
arrow::compute::kernels::comparison::starts_with(&args[0], &args[1])?;
41+
return Ok(Arc::new(result) as ArrayRef);
42+
}
43+
44+
if let Some(coercion_data_type) =
45+
string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
46+
binary_to_string_coercion(args[0].data_type(), args[1].data_type())
47+
})
48+
{
49+
let arg0 = arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?;
50+
let arg1 = arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?;
51+
let result = arrow::compute::kernels::comparison::starts_with(&arg0, &arg1)?;
52+
return Ok(Arc::new(result) as ArrayRef);
53+
} else {
54+
return internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View");
55+
}
4056
}
4157

4258
#[user_doc(
@@ -108,7 +124,7 @@ impl ScalarUDFImpl for StartsWithFunc {
108124
fn simplify(
109125
&self,
110126
args: Vec<Expr>,
111-
_info: &dyn SimplifyInfo,
127+
info: &dyn SimplifyInfo,
112128
) -> Result<ExprSimplifyResult> {
113129
if let Expr::Literal(scalar_value) = &args[1] {
114130
// Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping
@@ -117,31 +133,42 @@ impl ScalarUDFImpl for StartsWithFunc {
117133
// 2. 'ja\%' (escape special char '%')
118134
// 3. 'ja\%%' (add suffix for starts_with)
119135
let like_expr = match scalar_value {
120-
ScalarValue::Utf8(Some(pattern)) => {
136+
ScalarValue::Utf8(Some(pattern))
137+
| ScalarValue::LargeUtf8(Some(pattern))
138+
| ScalarValue::Utf8View(Some(pattern)) => {
121139
let escaped_pattern = pattern.replace("%", "\\%");
122140
let like_pattern = format!("{}%", escaped_pattern);
123141
Expr::Literal(ScalarValue::Utf8(Some(like_pattern)))
124142
}
125-
ScalarValue::LargeUtf8(Some(pattern)) => {
126-
let escaped_pattern = pattern.replace("%", "\\%");
127-
let like_pattern = format!("{}%", escaped_pattern);
128-
Expr::Literal(ScalarValue::LargeUtf8(Some(like_pattern)))
129-
}
130-
ScalarValue::Utf8View(Some(pattern)) => {
131-
let escaped_pattern = pattern.replace("%", "\\%");
132-
let like_pattern = format!("{}%", escaped_pattern);
133-
Expr::Literal(ScalarValue::Utf8View(Some(like_pattern)))
134-
}
135143
_ => return Ok(ExprSimplifyResult::Original(args)),
136144
};
137145

138-
return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
139-
negated: false,
140-
expr: Box::new(args[0].clone()),
141-
pattern: Box::new(like_expr),
142-
escape_char: None,
143-
case_insensitive: false,
144-
})));
146+
let expr_data_type = info.get_data_type(&args[0])?;
147+
let pattern_data_type = info.get_data_type(&like_expr)?;
148+
149+
if expr_data_type == pattern_data_type {
150+
return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
151+
negated: false,
152+
expr: Box::new(args[0].clone()),
153+
pattern: Box::new(like_expr),
154+
escape_char: None,
155+
case_insensitive: false,
156+
})));
157+
}
158+
159+
if let Some(coercion_data_type) =
160+
string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
161+
binary_to_string_coercion(&expr_data_type, &pattern_data_type)
162+
})
163+
{
164+
return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
165+
negated: false,
166+
expr: Box::new(cast(args[0].clone(), coercion_data_type.clone())),
167+
pattern: Box::new(cast(like_expr, coercion_data_type)),
168+
escape_char: None,
169+
case_insensitive: false,
170+
})));
171+
}
145172
}
146173

147174
Ok(ExprSimplifyResult::Original(args))

datafusion/physical-expr/src/expressions/like.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ use std::hash::Hash;
1919
use std::{any::Any, sync::Arc};
2020

2121
use crate::PhysicalExpr;
22-
use arrow::compute::can_cast_types;
2322
use arrow::datatypes::{DataType, Schema};
2423
use arrow::record_batch::RecordBatch;
2524
use datafusion_common::{internal_err, Result};
@@ -122,10 +121,7 @@ impl PhysicalExpr for LikeExpr {
122121
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
123122
use arrow::compute::*;
124123
let lhs = self.expr.evaluate(batch)?;
125-
let rhs = self
126-
.pattern
127-
.evaluate(batch)?
128-
.cast_to(&lhs.data_type(), None)?;
124+
let rhs = self.pattern.evaluate(batch)?;
129125
match (self.negated, self.case_insensitive) {
130126
(false, false) => apply_cmp(&lhs, &rhs, like),
131127
(false, true) => apply_cmp(&lhs, &rhs, ilike),
@@ -169,10 +165,7 @@ pub fn like(
169165
) -> Result<Arc<dyn PhysicalExpr>> {
170166
let expr_type = &expr.data_type(input_schema)?;
171167
let pattern_type = &pattern.data_type(input_schema)?;
172-
if !expr_type.eq(pattern_type)
173-
&& !can_cast_types(expr_type, pattern_type)
174-
&& !can_like_type(expr_type)
175-
{
168+
if !expr_type.eq(pattern_type) && !can_like_type(expr_type) {
176169
return internal_err!(
177170
"The type of {expr_type} AND {pattern_type} of like physical should be same"
178171
);

datafusion/sqllogictest/test_files/parquet.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,8 @@ query TT
619619
explain select * from foo where starts_with(column1, 'f');
620620
----
621621
logical_plan
622-
01)Filter: foo.column1 LIKE Utf8("f%")
623-
02)--TableScan: foo projection=[column1], partial_filters=[foo.column1 LIKE Utf8("f%")]
622+
01)Filter: CAST(foo.column1 AS Utf8View) LIKE Utf8View("f%")
623+
02)--TableScan: foo projection=[column1], partial_filters=[CAST(foo.column1 AS Utf8View) LIKE Utf8View("f%")]
624624
physical_plan
625625
01)CoalesceBatchesExec: target_batch_size=8192
626626
02)--FilterExec: column1@0 LIKE f%

datafusion/sqllogictest/test_files/string/string_view.slt

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,9 @@ EXPLAIN SELECT
355355
FROM test;
356356
----
357357
logical_plan
358-
01)Projection: test.column1_utf8view LIKE Utf8("äöüß%") AS c1, CASE test.column1_utf8view IS NOT NULL WHEN Boolean(true) THEN Boolean(true) END AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4
359-
02)--TableScan: test projection=[column1_utf8view]
358+
01)Projection: __common_expr_1 LIKE Utf8View("äöüß%") AS c1, CASE __common_expr_1 IS NOT NULL WHEN Boolean(true) THEN Boolean(true) END AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4
359+
02)--Projection: CAST(test.column1_utf8view AS Utf8View) AS __common_expr_1, test.column1_utf8view
360+
03)----TableScan: test projection=[column1_utf8view]
360361

361362
## Test STARTS_WITH is rewitten to LIKE when the pattern is a constant
362363
query TT
@@ -370,8 +371,9 @@ EXPLAIN SELECT
370371
FROM test;
371372
----
372373
logical_plan
373-
01)Projection: test.column1_utf8 LIKE Utf8("foo\%%") AS c1, test.column1_large_utf8 LIKE Utf8("foo\%%") AS c2, test.column1_utf8view LIKE Utf8("foo\%%") AS c3, test.column1_utf8 LIKE Utf8("f_o%") AS c4, test.column1_large_utf8 LIKE Utf8("f_o%") AS c5, test.column1_utf8view LIKE Utf8("f_o%") AS c6
374-
02)--TableScan: test projection=[column1_utf8, column1_large_utf8, column1_utf8view]
374+
01)Projection: test.column1_utf8 LIKE Utf8("foo\%%") AS c1, __common_expr_1 LIKE LargeUtf8("foo\%%") AS c2, __common_expr_2 LIKE Utf8View("foo\%%") AS c3, test.column1_utf8 LIKE Utf8("f_o%") AS c4, __common_expr_1 LIKE LargeUtf8("f_o%") AS c5, __common_expr_2 LIKE Utf8View("f_o%") AS c6
375+
02)--Projection: CAST(test.column1_large_utf8 AS LargeUtf8) AS __common_expr_1, CAST(test.column1_utf8view AS Utf8View) AS __common_expr_2, test.column1_utf8
376+
03)----TableScan: test projection=[column1_utf8, column1_large_utf8, column1_utf8view]
375377

376378
## Test STARTS_WITH works with column arguments
377379
query TT
@@ -940,7 +942,7 @@ EXPLAIN SELECT
940942
FROM test;
941943
----
942944
logical_plan
943-
01)Projection: test.column1_utf8view LIKE Utf8("foo%") AS c, starts_with(test.column1_utf8view, test.column2_utf8view) AS c2
945+
01)Projection: CAST(test.column1_utf8view AS Utf8View) LIKE Utf8View("foo%") AS c, starts_with(test.column1_utf8view, test.column2_utf8view) AS c2
944946
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]
945947

946948
## Ensure no casts for TRANSLATE

0 commit comments

Comments
 (0)