Skip to content

Commit c183ead

Browse files
committed
improve benchmarks
1 parent d0ef8bf commit c183ead

File tree

2 files changed

+75
-32
lines changed

2 files changed

+75
-32
lines changed

native/spark-expr/src/datetime_funcs/unix_timestamp.rs

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
// under the License.
1717

1818
use crate::utils::array_with_timezone;
19-
use arrow::array::AsArray;
19+
use arrow::array::{Array, AsArray, PrimitiveArray};
2020
use arrow::compute::cast;
21-
use arrow::datatypes::{DataType, TimeUnit::Microsecond};
21+
use arrow::datatypes::{DataType, Int64Type, TimeUnit::Microsecond};
2222
use datafusion::common::{internal_datafusion_err, DataFusionError};
2323
use datafusion::logical_expr::{
2424
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
@@ -79,35 +79,67 @@ impl ScalarUDFImpl for SparkUnixTimestamp {
7979
match args {
8080
[ColumnarValue::Array(array)] => match array.data_type() {
8181
DataType::Timestamp(_, _) => {
82-
let array = array_with_timezone(
83-
array,
84-
self.timezone.clone(),
85-
Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))),
86-
)?;
82+
let is_utc = self.timezone == "UTC";
83+
let array = if is_utc
84+
&& matches!(array.data_type(), DataType::Timestamp(Microsecond, Some(tz)) if tz.as_ref() == "UTC")
85+
{
86+
array
87+
} else {
88+
array_with_timezone(
89+
array,
90+
self.timezone.clone(),
91+
Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))),
92+
)?
93+
};
8794

8895
let timestamp_array =
8996
array.as_primitive::<arrow::datatypes::TimestampMicrosecondType>();
90-
let result: arrow::array::Int64Array = timestamp_array
91-
.iter()
92-
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
93-
.collect();
97+
98+
let result: PrimitiveArray<Int64Type> = if timestamp_array.null_count() == 0 {
99+
timestamp_array
100+
.values()
101+
.iter()
102+
.map(|&micros| micros / MICROS_PER_SECOND)
103+
.collect()
104+
} else {
105+
timestamp_array
106+
.iter()
107+
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
108+
.collect()
109+
};
110+
94111
Ok(ColumnarValue::Array(Arc::new(result)))
95112
}
96113
DataType::Date32 => {
97114
let timestamp_array = cast(&array, &DataType::Timestamp(Microsecond, None))?;
98115

99-
let array = array_with_timezone(
100-
timestamp_array,
101-
self.timezone.clone(),
102-
Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))),
103-
)?;
116+
let is_utc = self.timezone == "UTC";
117+
let array = if is_utc {
118+
timestamp_array
119+
} else {
120+
array_with_timezone(
121+
timestamp_array,
122+
self.timezone.clone(),
123+
Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))),
124+
)?
125+
};
104126

105127
let timestamp_array =
106128
array.as_primitive::<arrow::datatypes::TimestampMicrosecondType>();
107-
let result: arrow::array::Int64Array = timestamp_array
108-
.iter()
109-
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
110-
.collect();
129+
130+
let result: PrimitiveArray<Int64Type> = if timestamp_array.null_count() == 0 {
131+
timestamp_array
132+
.values()
133+
.iter()
134+
.map(|&micros| micros / MICROS_PER_SECOND)
135+
.collect()
136+
} else {
137+
timestamp_array
138+
.iter()
139+
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
140+
.collect()
141+
};
142+
111143
Ok(ColumnarValue::Array(Arc::new(result)))
112144
}
113145
_ => Err(DataFusionError::Execution(format!(
@@ -130,6 +162,8 @@ impl ScalarUDFImpl for SparkUnixTimestamp {
130162
mod tests {
131163
use super::*;
132164
use arrow::array::{Array, Date32Array, TimestampMicrosecondArray};
165+
use arrow::datatypes::Field;
166+
use datafusion::config::ConfigOptions;
133167
use std::sync::Arc;
134168

135169
#[test]

spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,28 +77,32 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase {
7777
}
7878
}
7979

80-
def unixTimestampBenchmark(values: Int): Unit = {
80+
def unixTimestampBenchmark(values: Int, timeZone: String): Unit = {
8181
withTempPath { dir =>
8282
withTempTable("parquetV1Table") {
8383
prepareTable(
8484
dir,
8585
spark.sql(s"select timestamp_micros(cast(value/100000 as integer)) as ts FROM $tbl"))
86-
runWithComet(s"Unix Timestamp from Timestamp", values) {
87-
spark.sql("select unix_timestamp(ts) from parquetV1Table").noop()
86+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
87+
runWithComet(s"Unix Timestamp from Timestamp ($timeZone)", values) {
88+
spark.sql("select unix_timestamp(ts) from parquetV1Table").noop()
89+
}
8890
}
8991
}
9092
}
9193
}
9294

93-
def unixTimestampFromDateBenchmark(values: Int): Unit = {
95+
def unixTimestampFromDateBenchmark(values: Int, timeZone: String): Unit = {
9496
withTempPath { dir =>
9597
withTempTable("parquetV1Table") {
9698
prepareTable(
9799
dir,
98100
spark.sql(
99101
s"select cast(timestamp_micros(cast(value/100000 as integer)) as date) as dt FROM $tbl"))
100-
runWithComet(s"Unix Timestamp from Date", values) {
101-
spark.sql("select unix_timestamp(dt) from parquetV1Table").noop()
102+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
103+
runWithComet(s"Unix Timestamp from Date ($timeZone)", values) {
104+
spark.sql("select unix_timestamp(dt) from parquetV1Table").noop()
105+
}
102106
}
103107
}
104108
}
@@ -107,6 +111,17 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase {
107111
override def runCometBenchmark(mainArgs: Array[String]): Unit = {
108112
val values = 1024 * 1024;
109113

114+
for (timeZone <- Seq("UTC", "America/Los_Angeles")) {
115+
withSQLConf("spark.sql.parquet.datetimeRebaseModeInWrite" -> "CORRECTED") {
116+
runBenchmarkWithTable(s"UnixTimestamp(timestamp) - $timeZone", values) { v =>
117+
unixTimestampBenchmark(v, timeZone)
118+
}
119+
runBenchmarkWithTable(s"UnixTimestamp(date) - $timeZone", values) { v =>
120+
unixTimestampFromDateBenchmark(v, timeZone)
121+
}
122+
}
123+
}
124+
110125
withDefaultTimeZone(LA) {
111126
withSQLConf(
112127
SQLConf.SESSION_LOCAL_TIMEZONE.key -> LA.getId,
@@ -124,12 +139,6 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase {
124139
runBenchmarkWithTable("TimestampTrunc (Dictionary)", values, useDictionary = true) { v =>
125140
timestampTruncExprBenchmark(v, useDictionary = true)
126141
}
127-
runBenchmarkWithTable("UnixTimestamp(timestamp)", values) { v =>
128-
unixTimestampBenchmark(v)
129-
}
130-
runBenchmarkWithTable("UnixTimestamp(date))", values) { v =>
131-
unixTimestampFromDateBenchmark(v)
132-
}
133142
}
134143
}
135144
}

0 commit comments

Comments
 (0)