Skip to content

Commit aaaa6a6

Browse files
authored
chore: Refactor cast module temporal types (apache#3624)
1 parent ba9b842 commit aaaa6a6

File tree

3 files changed

+158
-120
lines changed

3 files changed

+158
-120
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 12 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -29,38 +29,37 @@ use crate::conversion_funcs::string::{
2929
cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int,
3030
cast_string_to_timestamp, is_df_cast_from_string_spark_compatible, spark_cast_utf8_to_boolean,
3131
};
32+
use crate::conversion_funcs::temporal::{
33+
cast_date_to_timestamp, is_df_cast_from_date_spark_compatible,
34+
is_df_cast_from_timestamp_spark_compatible,
35+
};
3236
use crate::conversion_funcs::utils::spark_cast_postprocess;
3337
use crate::utils::array_with_timezone;
3438
use crate::EvalMode::Legacy;
35-
use crate::{cast_whole_num_to_binary, timezone, BinaryOutputStyle};
36-
use crate::{EvalMode, SparkError, SparkResult};
39+
use crate::{cast_whole_num_to_binary, BinaryOutputStyle};
40+
use crate::{EvalMode, SparkError};
3741
use arrow::array::builder::StringBuilder;
3842
use arrow::array::{
39-
BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray,
40-
StructArray, TimestampMicrosecondBuilder,
43+
BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray,
4144
};
4245
use arrow::compute::can_cast_types;
4346
use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema};
4447
use arrow::datatypes::{Field, Fields, GenericBinaryType};
4548
use arrow::error::ArrowError;
4649
use arrow::{
4750
array::{
48-
cast::AsArray,
49-
types::{Date32Type, Int32Type},
50-
Array, ArrayRef, GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array,
51-
OffsetSizeTrait, PrimitiveArray,
51+
cast::AsArray, types::Int32Type, Array, ArrayRef, GenericStringArray, Int16Array,
52+
Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray,
5253
},
5354
compute::{cast_with_options, take, CastOptions},
5455
record_batch::RecordBatch,
5556
util::display::FormatOptions,
5657
};
5758
use base64::prelude::BASE64_STANDARD_NO_PAD;
5859
use base64::Engine;
59-
use chrono::{NaiveDate, TimeZone};
6060
use datafusion::common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue};
6161
use datafusion::physical_expr::PhysicalExpr;
6262
use datafusion::physical_plan::ColumnarValue;
63-
use std::str::FromStr;
6463
use std::{
6564
any::Any,
6665
fmt::{Debug, Display, Formatter},
@@ -404,50 +403,6 @@ pub(crate) fn cast_array(
404403
Ok(spark_cast_postprocess(cast_result?, &from_type, to_type))
405404
}
406405

407-
fn cast_date_to_timestamp(
408-
array_ref: &ArrayRef,
409-
cast_options: &SparkCastOptions,
410-
target_tz: &Option<Arc<str>>,
411-
) -> SparkResult<ArrayRef> {
412-
let tz_str = if cast_options.timezone.is_empty() {
413-
"UTC"
414-
} else {
415-
cast_options.timezone.as_str()
416-
};
417-
// safe to unwrap since we are falling back to UTC above
418-
let tz = timezone::Tz::from_str(tz_str)?;
419-
let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
420-
let date_array = array_ref.as_primitive::<Date32Type>();
421-
422-
let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len());
423-
424-
for date in date_array.iter() {
425-
match date {
426-
Some(date) => {
427-
// safe to unwrap since chrono's range ( 262,143 yrs) is higher than
428-
// number of years possible with days as i32 (~ 6 mil yrs)
429-
// convert date in session timezone to timestamp in UTC
430-
let naive_date = epoch + chrono::Duration::days(date as i64);
431-
let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap();
432-
let local_midnight_in_microsec = tz
433-
.from_local_datetime(&local_midnight)
434-
// return earliest possible time (edge case with spring / fall DST changes)
435-
.earliest()
436-
.map(|dt| dt.timestamp_micros())
437-
// in case there is an issue with DST and returns None , we fall back to UTC
438-
.unwrap_or((date as i64) * 86_400 * 1_000_000);
439-
builder.append_value(local_midnight_in_microsec);
440-
}
441-
None => {
442-
builder.append_null();
443-
}
444-
}
445-
}
446-
Ok(Arc::new(
447-
builder.finish().with_timezone_opt(target_tz.clone()),
448-
))
449-
}
450-
451406
/// Determines if DataFusion supports the given cast in a way that is
452407
/// compatible with Spark
453408
fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool {
@@ -467,13 +422,8 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b
467422
is_df_cast_from_decimal_spark_compatible(to_type)
468423
}
469424
DataType::Utf8 => is_df_cast_from_string_spark_compatible(to_type),
470-
DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8),
471-
DataType::Timestamp(_, _) => {
472-
matches!(
473-
to_type,
474-
DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
475-
)
476-
}
425+
DataType::Date32 => is_df_cast_from_date_spark_compatible(to_type),
426+
DataType::Timestamp(_, _) => is_df_cast_from_timestamp_spark_compatible(to_type),
477427
DataType::Binary => {
478428
// note that this is not completely Spark compatible because
479429
// DataFusion only supports binary data containing valid UTF-8 strings
@@ -827,7 +777,7 @@ mod tests {
827777
use super::*;
828778
use arrow::array::StringArray;
829779
use arrow::datatypes::TimestampMicrosecondType;
830-
use arrow::datatypes::{Field, Fields, TimeUnit};
780+
use arrow::datatypes::{Field, Fields};
831781
#[test]
832782
fn test_cast_unsupported_timestamp_to_date() {
833783
// Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported
@@ -853,64 +803,6 @@ mod tests {
853803
assert!(result.is_err())
854804
}
855805

856-
#[test]
857-
fn test_cast_date_to_timestamp() {
858-
use arrow::array::Date32Array;
859-
860-
// verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side)
861-
let dates: ArrayRef = Arc::new(Date32Array::from(vec![
862-
Some(0),
863-
Some(19723),
864-
Some(19793),
865-
None,
866-
]));
867-
868-
let non_dst_date = 1704067200000000i64;
869-
let dst_date = 1710115200000000i64;
870-
let seven_hours_ts = 25200000000i64;
871-
let eight_hours_ts = 28800000000i64;
872-
873-
// validate UTC
874-
let result = cast_array(
875-
Arc::clone(&dates),
876-
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
877-
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
878-
)
879-
.unwrap();
880-
let ts = result.as_primitive::<TimestampMicrosecondType>();
881-
assert_eq!(ts.value(0), 0);
882-
assert_eq!(ts.value(1), non_dst_date);
883-
assert_eq!(ts.value(2), dst_date);
884-
assert!(ts.is_null(3));
885-
886-
// validate LA timezone (follows Daylight savings)
887-
let result = cast_array(
888-
Arc::clone(&dates),
889-
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
890-
&SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles", false),
891-
)
892-
.unwrap();
893-
let ts = result.as_primitive::<TimestampMicrosecondType>();
894-
assert_eq!(ts.value(0), eight_hours_ts);
895-
assert_eq!(ts.value(1), non_dst_date + eight_hours_ts);
896-
// should adjust for DST
897-
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
898-
assert!(ts.is_null(3));
899-
900-
// Phoenix timezone (does not follow Daylight savings)
901-
let result = cast_array(
902-
Arc::clone(&dates),
903-
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
904-
&SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix", false),
905-
)
906-
.unwrap();
907-
let ts = result.as_primitive::<TimestampMicrosecondType>();
908-
assert_eq!(ts.value(0), seven_hours_ts);
909-
assert_eq!(ts.value(1), non_dst_date + seven_hours_ts);
910-
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
911-
assert!(ts.is_null(3));
912-
}
913-
914806
#[test]
915807
fn test_cast_struct_to_utf8() {
916808
let a: ArrayRef = Arc::new(Int32Array::from(vec![

native/spark-expr/src/conversion_funcs/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ mod boolean;
1919
pub mod cast;
2020
mod numeric;
2121
mod string;
22+
mod temporal;
2223
mod utils;
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::{timezone, SparkCastOptions, SparkResult};
19+
use arrow::array::{ArrayRef, AsArray, TimestampMicrosecondBuilder};
20+
use arrow::datatypes::{DataType, Date32Type};
21+
use chrono::{NaiveDate, TimeZone};
22+
use std::str::FromStr;
23+
use std::sync::Arc;
24+
25+
pub(crate) fn is_df_cast_from_date_spark_compatible(to_type: &DataType) -> bool {
26+
matches!(to_type, DataType::Int32 | DataType::Utf8)
27+
}
28+
29+
pub(crate) fn is_df_cast_from_timestamp_spark_compatible(to_type: &DataType) -> bool {
30+
matches!(
31+
to_type,
32+
DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
33+
)
34+
}
35+
36+
pub(crate) fn cast_date_to_timestamp(
37+
array_ref: &ArrayRef,
38+
cast_options: &SparkCastOptions,
39+
target_tz: &Option<Arc<str>>,
40+
) -> SparkResult<ArrayRef> {
41+
let tz_str = if cast_options.timezone.is_empty() {
42+
"UTC"
43+
} else {
44+
cast_options.timezone.as_str()
45+
};
46+
// safe to unwrap since we are falling back to UTC above
47+
let tz = timezone::Tz::from_str(tz_str)?;
48+
let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
49+
let date_array = array_ref.as_primitive::<Date32Type>();
50+
51+
let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len());
52+
53+
for date in date_array.iter() {
54+
match date {
55+
Some(date) => {
56+
// safe to unwrap since chrono's range ( 262,143 yrs) is higher than
57+
// number of years possible with days as i32 (~ 6 mil yrs)
58+
// convert date in session timezone to timestamp in UTC
59+
let naive_date = epoch + chrono::Duration::days(date as i64);
60+
let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap();
61+
let local_midnight_in_microsec = tz
62+
.from_local_datetime(&local_midnight)
63+
// return earliest possible time (edge case with spring / fall DST changes)
64+
.earliest()
65+
.map(|dt| dt.timestamp_micros())
66+
// in case there is an issue with DST and returns None , we fall back to UTC
67+
.unwrap_or((date as i64) * 86_400 * 1_000_000);
68+
builder.append_value(local_midnight_in_microsec);
69+
}
70+
None => {
71+
builder.append_null();
72+
}
73+
}
74+
}
75+
Ok(Arc::new(
76+
builder.finish().with_timezone_opt(target_tz.clone()),
77+
))
78+
}
79+
80+
#[cfg(test)]
81+
mod tests {
82+
use super::*;
83+
use std::sync::Arc;
84+
#[test]
85+
fn test_cast_date_to_timestamp() {
86+
use crate::EvalMode;
87+
use arrow::array::Date32Array;
88+
use arrow::array::{Array, ArrayRef};
89+
use arrow::datatypes::TimestampMicrosecondType;
90+
91+
// verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side)
92+
let dates: ArrayRef = Arc::new(Date32Array::from(vec![
93+
Some(0),
94+
Some(19723),
95+
Some(19793),
96+
None,
97+
]));
98+
99+
let non_dst_date = 1704067200000000i64;
100+
let dst_date = 1710115200000000i64;
101+
let seven_hours_ts = 25200000000i64;
102+
let eight_hours_ts = 28800000000i64;
103+
104+
// validate UTC
105+
let target_tz: Option<Arc<str>> = Some("UTC".into());
106+
let result = cast_date_to_timestamp(
107+
&dates,
108+
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
109+
&target_tz,
110+
)
111+
.unwrap();
112+
let ts = result.as_primitive::<TimestampMicrosecondType>();
113+
assert_eq!(ts.value(0), 0);
114+
assert_eq!(ts.value(1), non_dst_date);
115+
assert_eq!(ts.value(2), dst_date);
116+
assert!(ts.is_null(3));
117+
118+
// validate LA timezone (follows Daylight savings)
119+
let result = cast_date_to_timestamp(
120+
&dates,
121+
&SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles", false),
122+
&target_tz,
123+
)
124+
.unwrap();
125+
let ts = result.as_primitive::<TimestampMicrosecondType>();
126+
assert_eq!(ts.value(0), eight_hours_ts);
127+
assert_eq!(ts.value(1), non_dst_date + eight_hours_ts);
128+
// should adjust for DST
129+
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
130+
assert!(ts.is_null(3));
131+
132+
// Phoenix timezone (does not follow Daylight savings)
133+
let result = cast_date_to_timestamp(
134+
&dates,
135+
&SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix", false),
136+
&target_tz,
137+
)
138+
.unwrap();
139+
let ts = result.as_primitive::<TimestampMicrosecondType>();
140+
assert_eq!(ts.value(0), seven_hours_ts);
141+
assert_eq!(ts.value(1), non_dst_date + seven_hours_ts);
142+
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
143+
assert!(ts.is_null(3));
144+
}
145+
}

0 commit comments

Comments
 (0)