Skip to content

Commit 2a94b15

Browse files
committed
Df52 migration
1 parent 2a26ac5 commit 2a94b15

File tree

2 files changed

+57
-19
lines changed

2 files changed

+57
-19
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,9 @@ use crate::utils::array_with_timezone;
1919
use crate::{timezone, BinaryOutputStyle};
2020
use crate::{EvalMode, SparkError, SparkResult};
2121
use arrow::array::builder::StringBuilder;
22-
use arrow::array::{
23-
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
24-
PrimitiveBuilder, StringArray, StructArray,
25-
};
22+
use arrow::array::{BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, PrimitiveBuilder, StringArray, StructArray, TimestampMicrosecondArray, TimestampMillisecondArray};
2623
use arrow::compute::can_cast_types;
27-
use arrow::datatypes::{
28-
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType,
29-
Schema,
30-
};
24+
use arrow::datatypes::{i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType, Schema, TimeUnit};
3125
use arrow::{
3226
array::{
3327
cast::AsArray,
@@ -964,9 +958,11 @@ fn cast_array(
964958
cast_options: &SparkCastOptions,
965959
) -> DataFusionResult<ArrayRef> {
966960
use DataType::*;
967-
let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
968961
let from_type = array.data_type().clone();
969962

963+
let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
964+
let eval_mode = cast_options.eval_mode;
965+
970966
let native_cast_options: CastOptions = CastOptions {
971967
safe: !matches!(cast_options.eval_mode, EvalMode::Ansi), // take safe mode from cast_options passed
972968
format_options: FormatOptions::new()
@@ -1015,10 +1011,8 @@ fn cast_array(
10151011
}
10161012
}
10171013
};
1018-
let from_type = array.data_type();
1019-
let eval_mode = cast_options.eval_mode;
10201014

1021-
let cast_result = match (from_type, to_type) {
1015+
let cast_result = match (&from_type, to_type) {
10221016
(Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
10231017
(LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
10241018
(Utf8, Timestamp(_, _)) => {
@@ -1044,10 +1038,10 @@ fn cast_array(
10441038
| (Int16, Int8)
10451039
if eval_mode != EvalMode::Try =>
10461040
{
1047-
spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
1041+
spark_cast_int_to_int(&array, eval_mode, &from_type, to_type)
10481042
}
10491043
(Int8 | Int16 | Int32 | Int64, Decimal128(precision, scale)) => {
1050-
cast_int_to_decimal128(&array, eval_mode, from_type, to_type, *precision, *scale)
1044+
cast_int_to_decimal128(&array, eval_mode, &from_type, to_type, *precision, *scale)
10511045
}
10521046
(Utf8, Int8 | Int16 | Int32 | Int64) => {
10531047
cast_string_to_int::<i32>(to_type, &array, eval_mode)
@@ -1079,19 +1073,19 @@ fn cast_array(
10791073
| (Decimal128(_, _), Int64)
10801074
if eval_mode != EvalMode::Try =>
10811075
{
1082-
spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
1076+
spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, &from_type, to_type)
10831077
}
10841078
(Decimal128(_p, _s), Boolean) => spark_cast_decimal_to_boolean(&array),
10851079
(Utf8View, Utf8) => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?),
10861080
(Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
10871081
(Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
10881082
array.as_struct(),
1089-
from_type,
1083+
&from_type,
10901084
to_type,
10911085
cast_options,
10921086
)?),
10931087
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
1094-
(List(_), List(_)) if can_cast_types(from_type, to_type) => {
1088+
(List(_), List(_)) if can_cast_types(&from_type, to_type) => {
10951089
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
10961090
}
10971091
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
@@ -1101,7 +1095,7 @@ fn cast_array(
11011095
}
11021096
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
11031097
_ if cast_options.is_adapting_schema
1104-
|| is_datafusion_spark_compatible(from_type, to_type) =>
1098+
|| is_datafusion_spark_compatible(&from_type, to_type) =>
11051099
{
11061100
// use DataFusion cast only when we know that it is compatible with Spark
11071101
Ok(cast_with_options(&array, to_type, &native_cast_options)?)
@@ -1115,7 +1109,7 @@ fn cast_array(
11151109
)))
11161110
}
11171111
};
1118-
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
1112+
Ok(spark_cast_postprocess(cast_result?, &from_type, to_type))
11191113
}
11201114

11211115
fn cast_string_to_float(

native/spark-expr/src/utils.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ use arrow::{
3535
array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray},
3636
temporal_conversions::as_datetime,
3737
};
38+
use arrow::array::TimestampMicrosecondArray;
3839
use chrono::{DateTime, Offset, TimeZone};
3940

4041
/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or
@@ -71,6 +72,49 @@ pub fn array_with_timezone(
7172
to_type: Option<&DataType>,
7273
) -> Result<ArrayRef, ArrowError> {
7374
match array.data_type() {
75+
DataType::Timestamp(TimeUnit::Millisecond, None) => {
76+
assert!(!timezone.is_empty());
77+
match to_type {
78+
Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array),
79+
Some(DataType::Timestamp(_, Some(_))) => {
80+
timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str()))
81+
}
82+
Some(DataType::Timestamp(TimeUnit::Microsecond, None)) => {
83+
// Convert from Timestamp(Millisecond, None) to Timestamp(Microsecond, None)
84+
let millis_array = as_primitive_array::<TimestampMillisecondType>(&array);
85+
let micros_array: TimestampMicrosecondArray = millis_array
86+
.iter()
87+
.map(|opt| opt.map(|v| v * 1000))
88+
.collect();
89+
Ok(Arc::new(micros_array))
90+
}
91+
_ => {
92+
// Not supported
93+
panic!(
94+
"Cannot convert from {:?} to {:?}",
95+
array.data_type(),
96+
to_type.unwrap()
97+
)
98+
}
99+
}
100+
}
101+
DataType::Timestamp(TimeUnit::Microsecond, None) => {
102+
assert!(!timezone.is_empty());
103+
match to_type {
104+
Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array),
105+
Some(DataType::Timestamp(_, Some(_))) => {
106+
timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str()))
107+
}
108+
_ => {
109+
// Not supported
110+
panic!(
111+
"Cannot convert from {:?} to {:?}",
112+
array.data_type(),
113+
to_type.unwrap()
114+
)
115+
}
116+
}
117+
}
74118
DataType::Timestamp(_, None) => {
75119
assert!(!timezone.is_empty());
76120
match to_type {

0 commit comments

Comments
 (0)