Skip to content

Commit 5d18648

Browse files
authored
Add ScalarValue::try_as_str to get str value from logical strings (apache#14167)
1 parent 57eb17f commit 5d18648

File tree

11 files changed

+132
-129
lines changed

11 files changed

+132
-129
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2849,6 +2849,50 @@ impl ScalarValue {
28492849
ScalarValue::from(value).cast_to(target_type)
28502850
}
28512851

2852+
/// Returns the Some(`&str`) representation of `ScalarValue` of logical string type
2853+
///
2854+
/// Returns `None` if this `ScalarValue` is not a logical string type or the
2855+
/// `ScalarValue` represents the `NULL` value.
2856+
///
2857+
/// Note you can use [`Option::flatten`] to check for non null logical
2858+
/// strings.
2859+
///
2860+
/// For example, [`ScalarValue::Utf8`], [`ScalarValue::LargeUtf8`], and
2861+
/// [`ScalarValue::Dictionary`] with a logical string value and store
2862+
/// strings and can be accessed as `&str` using this method.
2863+
///
2864+
/// # Example: logical strings
2865+
/// ```
2866+
/// # use datafusion_common::ScalarValue;
2867+
/// /// non strings return None
2868+
/// let scalar = ScalarValue::from(42);
2869+
/// assert_eq!(scalar.try_as_str(), None);
2870+
/// // Non null logical string returns Some(Some(&str))
2871+
/// let scalar = ScalarValue::from("hello");
2872+
/// assert_eq!(scalar.try_as_str(), Some(Some("hello")));
2873+
/// // Null logical string returns Some(None)
2874+
/// let scalar = ScalarValue::Utf8(None);
2875+
/// assert_eq!(scalar.try_as_str(), Some(None));
2876+
/// ```
2877+
///
2878+
/// # Example: use [`Option::flatten`] to check for non-null logical strings
2879+
/// ```
2880+
/// # use datafusion_common::ScalarValue;
2881+
/// // Non null logical string returns Some(Some(&str))
2882+
/// let scalar = ScalarValue::from("hello");
2883+
/// assert_eq!(scalar.try_as_str().flatten(), Some("hello"));
2884+
/// ```
2885+
pub fn try_as_str(&self) -> Option<Option<&str>> {
2886+
let v = match self {
2887+
ScalarValue::Utf8(v) => v,
2888+
ScalarValue::LargeUtf8(v) => v,
2889+
ScalarValue::Utf8View(v) => v,
2890+
ScalarValue::Dictionary(_, v) => return v.try_as_str(),
2891+
_ => return None,
2892+
};
2893+
Some(v.as_ref().map(|v| v.as_str()))
2894+
}
2895+
28522896
/// Try to cast this value to a ScalarValue of type `data_type`
28532897
pub fn cast_to(&self, target_type: &DataType) -> Result<Self> {
28542898
self.cast_to_with_options(target_type, &DEFAULT_CAST_OPTIONS)

datafusion/core/tests/sql/path_partition.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,11 @@ async fn parquet_distinct_partition_col() -> Result<()> {
218218
assert_eq!(min_limit, resulting_limit);
219219

220220
let s = ScalarValue::try_from_array(results[0].column(1), 0)?;
221-
let month = match extract_as_utf(&s) {
222-
Some(month) => month,
223-
s => panic!("Expected month as Dict(_, Utf8) found {s:?}"),
224-
};
221+
assert!(
222+
matches!(s.data_type(), DataType::Dictionary(_, v) if v.as_ref() == &DataType::Utf8),
223+
"Expected month as Dict(_, Utf8) found {s:?}"
224+
);
225+
let month = s.try_as_str().flatten().unwrap();
225226

226227
let sql_on_partition_boundary = format!(
227228
"SELECT month from t where month = '{}' LIMIT {}",
@@ -241,15 +242,6 @@ async fn parquet_distinct_partition_col() -> Result<()> {
241242
Ok(())
242243
}
243244

244-
fn extract_as_utf(v: &ScalarValue) -> Option<String> {
245-
if let ScalarValue::Dictionary(_, v) = v {
246-
if let ScalarValue::Utf8(v) = v.as_ref() {
247-
return v.clone();
248-
}
249-
}
250-
None
251-
}
252-
253245
#[tokio::test]
254246
async fn csv_filter_with_file_col() -> Result<()> {
255247
let ctx = SessionContext::new_with_config(

datafusion/functions-aggregate/src/string_agg.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,14 @@ impl AggregateUDFImpl for StringAgg {
108108

109109
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
110110
if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() {
111-
return match lit.value() {
112-
ScalarValue::Utf8(Some(delimiter))
113-
| ScalarValue::LargeUtf8(Some(delimiter)) => {
114-
Ok(Box::new(StringAggAccumulator::new(delimiter.as_str())))
111+
return match lit.value().try_as_str() {
112+
Some(Some(delimiter)) => {
113+
Ok(Box::new(StringAggAccumulator::new(delimiter)))
114+
}
115+
Some(None) => Ok(Box::new(StringAggAccumulator::new(""))),
116+
None => {
117+
not_impl_err!("StringAgg not supported for delimiter {}", lit.value())
115118
}
116-
ScalarValue::Utf8(None)
117-
| ScalarValue::LargeUtf8(None)
118-
| ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))),
119-
e => not_impl_err!("StringAgg not supported for delimiter {}", e),
120119
};
121120
}
122121

datafusion/functions/src/crypto/basic.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,9 @@ pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
121121
);
122122
}
123123
let digest_algorithm = match &args[1] {
124-
ColumnarValue::Scalar(scalar) => match scalar {
125-
ScalarValue::Utf8View(Some(method))
126-
| ScalarValue::Utf8(Some(method))
127-
| ScalarValue::LargeUtf8(Some(method)) => method.parse::<DigestAlgorithm>(),
128-
other => exec_err!("Unsupported data type {other:?} for function digest"),
124+
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
125+
Some(Some(method)) => method.parse::<DigestAlgorithm>(),
126+
_ => exec_err!("Unsupported data type {scalar:?} for function digest"),
129127
},
130128
ColumnarValue::Array(_) => {
131129
internal_err!("Digest using dynamically decided method is not yet supported")

datafusion/functions/src/datetime/common.rs

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,12 @@ where
211211
))),
212212
other => exec_err!("Unsupported data type {other:?} for function {name}"),
213213
},
214-
ColumnarValue::Scalar(scalar) => match scalar {
215-
ScalarValue::Utf8View(a)
216-
| ScalarValue::LargeUtf8(a)
217-
| ScalarValue::Utf8(a) => {
214+
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
215+
Some(a) => {
218216
let result = a.as_ref().map(|x| op(x)).transpose()?;
219217
Ok(ColumnarValue::Scalar(S::scalar(result)))
220218
}
221-
other => exec_err!("Unsupported data type {other:?} for function {name}"),
219+
_ => exec_err!("Unsupported data type {scalar:?} for function {name}"),
222220
},
223221
}
224222
}
@@ -270,10 +268,8 @@ where
270268
}
271269
},
272270
// if the first argument is a scalar utf8 all arguments are expected to be scalar utf8
273-
ColumnarValue::Scalar(scalar) => match scalar {
274-
ScalarValue::Utf8View(a)
275-
| ScalarValue::LargeUtf8(a)
276-
| ScalarValue::Utf8(a) => {
271+
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
272+
Some(a) => {
277273
let a = a.as_ref();
278274
// ASK: Why do we trust `a` to be non-null at this point?
279275
let a = unwrap_or_internal_err!(a);
@@ -291,7 +287,7 @@ where
291287
};
292288

293289
if let Some(s) = x {
294-
match op(a.as_str(), s.as_str()) {
290+
match op(a, s.as_str()) {
295291
Ok(r) => {
296292
ret = Some(Ok(ColumnarValue::Scalar(S::scalar(Some(
297293
op2(r),
@@ -408,19 +404,10 @@ where
408404
DataType::Utf8 => Ok(a.as_string::<i32>().value(pos)),
409405
other => exec_err!("Unexpected type encountered '{other}'"),
410406
},
411-
ColumnarValue::Scalar(s) => match s {
412-
ScalarValue::Utf8View(a)
413-
| ScalarValue::LargeUtf8(a)
414-
| ScalarValue::Utf8(a) => {
415-
if let Some(v) = a {
416-
Ok(v.as_str())
417-
} else {
418-
continue;
419-
}
420-
}
421-
other => {
422-
exec_err!("Unexpected scalar type encountered '{other}'")
423-
}
407+
ColumnarValue::Scalar(s) => match s.try_as_str() {
408+
Some(Some(v)) => Ok(v),
409+
Some(None) => continue, // null string
410+
None => exec_err!("Unexpected scalar type encountered '{s}'"),
424411
},
425412
}?;
426413

datafusion/functions/src/encoding/inner.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -546,12 +546,10 @@ fn encode(args: &[ColumnarValue]) -> Result<ColumnarValue> {
546546
);
547547
}
548548
let encoding = match &args[1] {
549-
ColumnarValue::Scalar(scalar) => match scalar {
550-
ScalarValue::Utf8(Some(method)) | ScalarValue::Utf8View(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => {
551-
method.parse::<Encoding>()
552-
}
549+
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
550+
Some(Some(method)) => method.parse::<Encoding>(),
553551
_ => not_impl_err!(
554-
"Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported"
552+
"Second argument to encode must be non null constant string: Encode using dynamically decided method is not yet supported. Got {scalar:?}"
555553
),
556554
},
557555
ColumnarValue::Array(_) => not_impl_err!(
@@ -572,12 +570,10 @@ fn decode(args: &[ColumnarValue]) -> Result<ColumnarValue> {
572570
);
573571
}
574572
let encoding = match &args[1] {
575-
ColumnarValue::Scalar(scalar) => match scalar {
576-
ScalarValue::Utf8(Some(method)) | ScalarValue::Utf8View(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => {
577-
method.parse::<Encoding>()
578-
}
573+
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
574+
Some(Some(method))=> method.parse::<Encoding>(),
579575
_ => not_impl_err!(
580-
"Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported"
576+
"Second argument to decode must be a non null constant string: Decode using dynamically decided method is not yet supported. Got {scalar:?}"
581577
),
582578
},
583579
ColumnarValue::Array(_) => not_impl_err!(

datafusion/functions/src/string/concat.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,16 @@ impl ScalarUDFImpl for ConcatFunc {
134134
if array_len.is_none() {
135135
let mut result = String::new();
136136
for arg in args {
137-
match arg {
138-
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v)))
139-
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v)))
140-
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => {
141-
result.push_str(v);
142-
}
143-
ColumnarValue::Scalar(ScalarValue::Utf8(None))
144-
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
145-
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
146-
other => plan_err!(
137+
let ColumnarValue::Scalar(scalar) = arg else {
138+
return internal_err!("concat expected scalar value, got {arg:?}");
139+
};
140+
141+
match scalar.try_as_str() {
142+
Some(Some(v)) => result.push_str(v),
143+
Some(None) => {} // null literal
144+
None => plan_err!(
147145
"Concat function does not support scalar type {:?}",
148-
other
146+
scalar
149147
)?,
150148
}
151149
}

datafusion/functions/src/string/concat_ws.rs

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -124,48 +124,54 @@ impl ScalarUDFImpl for ConcatWsFunc {
124124

125125
// Scalar
126126
if array_len.is_none() {
127-
let sep = match &args[0] {
128-
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
129-
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
130-
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => s,
131-
ColumnarValue::Scalar(ScalarValue::Utf8(None))
132-
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
133-
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {
127+
let ColumnarValue::Scalar(scalar) = &args[0] else {
128+
// loop above checks for all args being scalar
129+
unreachable!()
130+
};
131+
let sep = match scalar.try_as_str() {
132+
Some(Some(s)) => s,
133+
Some(None) => {
134+
// null literal string
134135
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
135136
}
136-
_ => unreachable!(),
137+
None => return internal_err!("Expected string literal, got {scalar:?}"),
137138
};
138139

139140
let mut result = String::new();
140-
let iter = &mut args[1..].iter();
141-
142-
for arg in iter.by_ref() {
143-
match arg {
144-
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
145-
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
146-
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
141+
// iterator over Option<str>
142+
let iter = &mut args[1..].iter().map(|arg| {
143+
let ColumnarValue::Scalar(scalar) = arg else {
144+
// loop above checks for all args being scalar
145+
unreachable!()
146+
};
147+
scalar.try_as_str()
148+
});
149+
150+
// append first non null arg
151+
for scalar in iter.by_ref() {
152+
match scalar {
153+
Some(Some(s)) => {
147154
result.push_str(s);
148155
break;
149156
}
150-
ColumnarValue::Scalar(ScalarValue::Utf8(None))
151-
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
152-
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
153-
_ => unreachable!(),
157+
Some(None) => {} // null literal string
158+
None => {
159+
return internal_err!("Expected string literal, got {scalar:?}")
160+
}
154161
}
155162
}
156163

157-
for arg in iter.by_ref() {
158-
match arg {
159-
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
160-
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
161-
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
164+
// handle subsequent non null args
165+
for scalar in iter.by_ref() {
166+
match scalar {
167+
Some(Some(s)) => {
162168
result.push_str(sep);
163169
result.push_str(s);
164170
}
165-
ColumnarValue::Scalar(ScalarValue::Utf8(None))
166-
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
167-
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
168-
_ => unreachable!(),
171+
Some(None) => {} // null literal string
172+
None => {
173+
return internal_err!("Expected string literal, got {scalar:?}")
174+
}
169175
}
170176
}
171177

datafusion/optimizer/src/unwrap_cast_in_comparison.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -475,12 +475,7 @@ fn try_cast_string_literal(
475475
lit_value: &ScalarValue,
476476
target_type: &DataType,
477477
) -> Option<ScalarValue> {
478-
let string_value = match lit_value {
479-
ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => {
480-
s.clone()
481-
}
482-
_ => return None,
483-
};
478+
let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
484479
let scalar_value = match target_type {
485480
DataType::Utf8 => ScalarValue::Utf8(string_value),
486481
DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -251,22 +251,13 @@ macro_rules! compute_utf8_flag_op_scalar {
251251
.downcast_ref::<$ARRAYTYPE>()
252252
.expect("compute_utf8_flag_op_scalar failed to downcast array");
253253

254-
let string_value = match $RIGHT {
255-
ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value,
256-
ScalarValue::Dictionary(_, value) => {
257-
match *value {
258-
ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value,
259-
other => return internal_err!(
260-
"compute_utf8_flag_op_scalar failed to cast dictionary value {} for operation '{}'",
261-
other, stringify!($OP)
262-
)
263-
}
264-
},
254+
let string_value = match $RIGHT.try_as_str() {
255+
Some(Some(string_value)) => string_value,
256+
// null literal or non string
265257
_ => return internal_err!(
266-
"compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'",
267-
$RIGHT, stringify!($OP)
268-
)
269-
258+
"compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'",
259+
$RIGHT, stringify!($OP)
260+
)
270261
};
271262

272263
let flag = $FLAG.then_some("i");

0 commit comments

Comments
 (0)