Skip to content

Commit 79807d5

Browse files
committed
fix
1 parent a654964 commit 79807d5

File tree

3 files changed

+44
-40
lines changed

3 files changed

+44
-40
lines changed

native/spark-expr/src/datetime_funcs/unix_timestamp.rs

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use crate::utils::array_with_timezone;
1919
use arrow::array::AsArray;
20+
use arrow::compute::cast;
2021
use arrow::datatypes::{DataType, Field, TimeUnit::Microsecond};
2122
use datafusion::common::{internal_datafusion_err, DataFusionError};
2223
use datafusion::config::ConfigOptions;
@@ -27,7 +28,6 @@ use num::integer::div_floor;
2728
use std::{any::Any, fmt::Debug, sync::Arc};
2829

2930
const MICROS_PER_SECOND: i64 = 1_000_000;
30-
const SECONDS_PER_DAY: i64 = 86400;
3131

3232
#[derive(Debug, PartialEq, Eq, Hash)]
3333
pub struct SparkUnixTimestamp {
@@ -72,45 +72,50 @@ impl ScalarUDFImpl for SparkUnixTimestamp {
7272
&self,
7373
args: ScalarFunctionArgs,
7474
) -> datafusion::common::Result<ColumnarValue> {
75-
let args: [ColumnarValue; 1] = args.args.try_into().map_err(|_| {
76-
internal_datafusion_err!("unix_timestamp expects exactly one argument")
77-
})?;
75+
let args: [ColumnarValue; 1] = args
76+
.args
77+
.try_into()
78+
.map_err(|_| internal_datafusion_err!("unix_timestamp expects exactly one argument"))?;
7879

7980
match args {
80-
[ColumnarValue::Array(array)] => {
81-
match array.data_type() {
82-
DataType::Timestamp(_, _) => {
83-
let array = array_with_timezone(
84-
array,
85-
self.timezone.clone(),
86-
Some(&DataType::Timestamp(
87-
Microsecond,
88-
Some("UTC".into()),
89-
)),
90-
)?;
91-
92-
let timestamp_array =
93-
array.as_primitive::<arrow::datatypes::TimestampMicrosecondType>();
94-
let result: arrow::array::Int64Array = timestamp_array
95-
.iter()
96-
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
97-
.collect();
98-
Ok(ColumnarValue::Array(Arc::new(result)))
99-
}
100-
DataType::Date32 => {
101-
let date_array = array.as_primitive::<arrow::datatypes::Date32Type>();
102-
let result: arrow::array::Int64Array = date_array
103-
.iter()
104-
.map(|v| v.map(|days| (days as i64) * SECONDS_PER_DAY))
105-
.collect();
106-
Ok(ColumnarValue::Array(Arc::new(result)))
107-
}
108-
_ => Err(DataFusionError::Execution(format!(
109-
"unix_timestamp does not support input type: {:?}",
110-
array.data_type()
111-
))),
81+
[ColumnarValue::Array(array)] => match array.data_type() {
82+
DataType::Timestamp(_, _) => {
83+
let array = array_with_timezone(
84+
array,
85+
self.timezone.clone(),
86+
Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))),
87+
)?;
88+
89+
let timestamp_array =
90+
array.as_primitive::<arrow::datatypes::TimestampMicrosecondType>();
91+
let result: arrow::array::Int64Array = timestamp_array
92+
.iter()
93+
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
94+
.collect();
95+
Ok(ColumnarValue::Array(Arc::new(result)))
11296
}
113-
}
97+
DataType::Date32 => {
98+
let timestamp_array = cast(&array, &DataType::Timestamp(Microsecond, None))?;
99+
100+
let array = array_with_timezone(
101+
timestamp_array,
102+
self.timezone.clone(),
103+
Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))),
104+
)?;
105+
106+
let timestamp_array =
107+
array.as_primitive::<arrow::datatypes::TimestampMicrosecondType>();
108+
let result: arrow::array::Int64Array = timestamp_array
109+
.iter()
110+
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
111+
.collect();
112+
Ok(ColumnarValue::Array(Arc::new(result)))
113+
}
114+
_ => Err(DataFusionError::Execution(format!(
115+
"unix_timestamp does not support input type: {:?}",
116+
array.data_type()
117+
))),
118+
},
114119
_ => Err(DataFusionError::Execution(
115120
"unix_timestamp(scalar) should be fold in Spark JVM side.".to_string(),
116121
)),

spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
135135
}
136136
}
137137

138-
139138
private def createTimestampTestData = {
140139
val r = new Random(42)
141140
val schema = StructType(

spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase {
8484
spark.sql(s"select timestamp_micros(cast(value/100000 as integer)) as ts FROM $tbl"))
8585
val isDictionary = if (useDictionary) "(Dictionary)" else ""
8686
runWithComet(s"Unix Timestamp from Timestamp $isDictionary", values) {
87-
spark.sql(s"select unix_timestamp(ts) from parquetV1Table").noop()
87+
spark.sql("select unix_timestamp(ts) from parquetV1Table").noop()
8888
}
8989
}
9090
}
@@ -99,7 +99,7 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase {
9999
s"select cast(timestamp_micros(cast(value/100000 as integer)) as date) as dt FROM $tbl"))
100100
val isDictionary = if (useDictionary) "(Dictionary)" else ""
101101
runWithComet(s"Unix Timestamp from Date $isDictionary", values) {
102-
spark.sql(s"select unix_timestamp(dt) from parquetV1Table").noop()
102+
spark.sql("select unix_timestamp(dt) from parquetV1Table").noop()
103103
}
104104
}
105105
}

0 commit comments

Comments
 (0)