diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index ce991d014c..1d41552e7e 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -146,6 +146,8 @@ struct ExecutionContext { pub memory_pool_config: MemoryPoolConfig, /// Whether to log memory usage on each call to execute_plan pub tracing_enabled: bool, + /// Spark session timeZoneId + pub session_time_zone_id: String, } /// Accept serialized query plan and return the address of the native query plan. @@ -171,12 +173,14 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( memory_limit_per_task: jlong, task_attempt_id: jlong, key_unwrapper_obj: JObject, + session_timezone: JString, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { // Deserialize Spark configs let bytes = env.convert_byte_array(serialized_spark_configs)?; let spark_configs = serde::deserialize_config(bytes.as_slice())?; let spark_config: HashMap = spark_configs.entries.into_iter().collect(); + let session_time_zone_id: String = env.get_string(&session_timezone).unwrap().into(); // Access Comet configs let debug_native = spark_config.get_bool(COMET_DEBUG_ENABLED); @@ -276,6 +280,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( explain_native, memory_pool_config, tracing_enabled, + session_time_zone_id, }); Ok(Box::into_raw(exec_context) as i64) @@ -464,9 +469,12 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // query plan, we need to defer stream initialization to first time execution. if exec_context.root_op.is_none() { let start = Instant::now(); - let planner = - PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) - .with_exec_id(exec_context_id); + let planner = PhysicalPlanner::new( + Arc::clone(&exec_context.session_ctx), + &exec_context.session_time_zone_id, + partition, + ) + .with_exec_id(exec_context_id); let (scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), @@ -594,6 +602,7 @@ fn update_metrics(env: &mut JNIEnv, exec_context: &mut ExecutionContext) -> Come fn convert_datatype_arrays( env: &'_ mut JNIEnv<'_>, serialized_datatypes: JObjectArray, + session_time_zone_id: &str, ) -> JNIResult> { let array_len = env.get_array_length(&serialized_datatypes)?; let mut res: Vec = Vec::new(); @@ -603,7 +612,7 @@ fn convert_datatype_arrays( let inner_array: JByteArray = inner_array.into(); let bytes = env.convert_byte_array(inner_array)?; let data_type = serde::deserialize_data_type(bytes.as_slice()).unwrap(); - let arrow_dt = to_arrow_datatype(&data_type); + let arrow_dt = to_arrow_datatype(&data_type, session_time_zone_id); res.push(arrow_dt); } @@ -637,13 +646,17 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative compression_codec: JString, compression_level: jint, tracing_enabled: jboolean, + session_timezone: JString, ) -> jlongArray { try_unwrap_or_throw(&e, |mut env| unsafe { with_trace( "writeSortedFileNative", tracing_enabled != JNI_FALSE, || { - let data_types = convert_datatype_arrays(&mut env, serialized_datatypes)?; + let session_time_zone_id: String = + env.get_string(&session_timezone).unwrap().into(); + let data_types = + convert_datatype_arrays(&mut env, serialized_datatypes, &session_time_zone_id)?; let row_num = env.get_array_length(&row_addresses)? as usize; let row_addresses = diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index a33df705b3..f3ba972f00 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -152,19 +152,25 @@ pub struct PhysicalPlanner { exec_context_id: i64, partition: i32, session_ctx: Arc, + session_time_zone_id: String, } impl Default for PhysicalPlanner { fn default() -> Self { - Self::new(Arc::new(SessionContext::new()), 0) + Self::new(Arc::new(SessionContext::new()), "UTC", 0) } } impl PhysicalPlanner { - pub fn new(session_ctx: Arc, partition: i32) -> Self { + pub fn new( + session_ctx: Arc, + session_time_zone_id: &str, + partition: i32, + ) -> Self { Self { exec_context_id: TEST_EXEC_CONTEXT_ID, session_ctx, + session_time_zone_id: session_time_zone_id.to_string(), partition, } } @@ -173,6 +179,7 @@ impl PhysicalPlanner { Self { exec_context_id, partition: self.partition, + session_time_zone_id: self.session_time_zone_id.clone(), session_ctx: Arc::clone(&self.session_ctx), } } @@ -311,7 +318,10 @@ impl PhysicalPlanner { let result = create_modulo_expr( left, right, - expr.return_type.as_ref().map(to_arrow_datatype).unwrap(), + expr.return_type + .as_ref() + .map(|dt| to_arrow_datatype(dt, &self.session_time_zone_id)) + .unwrap(), input_schema, eval_mode == EvalMode::Ansi, &self.session_ctx.state(), @@ -371,7 +381,10 @@ impl PhysicalPlanner { Ok(Arc::new(Column::new(field.name().as_str(), idx))) } ExprStruct::Unbound(unbound) => { - let data_type = to_arrow_datatype(unbound.datatype.as_ref().unwrap()); + let data_type = to_arrow_datatype( + unbound.datatype.as_ref().unwrap(), + &self.session_time_zone_id, + ); Ok(Arc::new(UnboundColumn::new( unbound.name.as_str(), data_type, @@ -400,7 +413,10 @@ impl PhysicalPlanner { Ok(Arc::new(BinaryExpr::new(left, op, right))) } ExprStruct::Literal(literal) => { - let data_type = to_arrow_datatype(literal.datatype.as_ref().unwrap()); + let data_type = to_arrow_datatype( + literal.datatype.as_ref().unwrap(), + &self.session_time_zone_id, + ); let scalar_value = if literal.is_null { match data_type { DataType::Boolean => ScalarValue::Boolean(None), @@ -491,7 +507,8 @@ impl PhysicalPlanner { } ExprStruct::Cast(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let datatype = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; Ok(Arc::new(Cast::new( child, @@ -593,7 +610,8 @@ impl PhysicalPlanner { } ExprStruct::CheckOverflow(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; - let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let data_type = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); let fail_on_error = expr.fail_on_error; Ok(Arc::new(CheckOverflow::new( @@ -734,12 +752,14 @@ impl PhysicalPlanner { } ExprStruct::NormalizeNanAndZero(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; - let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let data_type = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); Ok(Arc::new(NormalizeNaNAndZero::new(data_type, child))) } ExprStruct::Subquery(expr) => { let id = expr.id; - let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let data_type = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); Ok(Arc::new(Subquery::new(self.exec_context_id, id, data_type))) } ExprStruct::BloomFilterMightContain(expr) => { @@ -937,7 +957,9 @@ impl PhysicalPlanner { >= DECIMAL128_MAX_PRECISION) || (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) => { - let data_type = return_type.map(to_arrow_datatype).unwrap(); + let data_type = return_type + .map(|dt| to_arrow_datatype(dt, &self.session_time_zone_id)) + .unwrap(); // For some Decimal128 operations, we need wider internal digits. // Cast left and right to Decimal256 and cast the result back to Decimal128 let left = Arc::new(Cast::new( @@ -962,7 +984,9 @@ impl PhysicalPlanner { Ok(DataType::Decimal128(_p1, _s1)), Ok(DataType::Decimal128(_p2, _s2)), ) => { - let data_type = return_type.map(to_arrow_datatype).unwrap(); + let data_type = return_type + .map(|dt| to_arrow_datatype(dt, &self.session_time_zone_id)) + .unwrap(); let func_name = if options.is_integral_div { // Decimal256 division in Arrow may overflow, so we still need this variant of decimal_div. // Otherwise, we may be able to reuse the previous case-match instead of here, @@ -987,7 +1011,9 @@ impl PhysicalPlanner { ))) } _ => { - let data_type = return_type.map(to_arrow_datatype).unwrap(); + let data_type = return_type + .map(|dt| to_arrow_datatype(dt, &self.session_time_zone_id)) + .unwrap(); if [EvalMode::Try, EvalMode::Ansi].contains(&eval_mode) && (data_type.is_integer() || (data_type.is_floating() && op == DataFusionOperator::Divide)) @@ -1231,11 +1257,18 @@ impl PhysicalPlanner { )) } OpStruct::NativeScan(scan) => { - let data_schema = convert_spark_types_to_arrow_schema(scan.data_schema.as_slice()); - let required_schema: SchemaRef = - convert_spark_types_to_arrow_schema(scan.required_schema.as_slice()); - let partition_schema: SchemaRef = - convert_spark_types_to_arrow_schema(scan.partition_schema.as_slice()); + let data_schema = convert_spark_types_to_arrow_schema( + scan.data_schema.as_slice(), + &self.session_time_zone_id, + ); + let required_schema: SchemaRef = convert_spark_types_to_arrow_schema( + scan.required_schema.as_slice(), + &self.session_time_zone_id, + ); + let partition_schema: SchemaRef = convert_spark_types_to_arrow_schema( + scan.partition_schema.as_slice(), + &self.session_time_zone_id, + ); let projection_vector: Vec = scan .projection_vector .iter() @@ -1337,7 +1370,11 @@ impl PhysicalPlanner { )) } OpStruct::Scan(scan) => { - let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); + let data_types = scan + .fields + .iter() + .map(|dt| to_arrow_datatype(dt, &self.session_time_zone_id)) + .collect_vec(); // If it is not test execution context for unit test, we should have at least one // input source @@ -1800,7 +1837,8 @@ impl PhysicalPlanner { } AggExprStruct::Min(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let datatype = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); let child = Arc::new(CastExpr::new(child, datatype.clone(), None)); AggregateExprBuilder::new(min_udaf(), vec![child]) @@ -1813,7 +1851,8 @@ impl PhysicalPlanner { } AggExprStruct::Max(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let datatype = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); let child = Arc::new(CastExpr::new(child, datatype.clone(), None)); AggregateExprBuilder::new(max_udaf(), vec![child]) @@ -1826,7 +1865,8 @@ impl PhysicalPlanner { } AggExprStruct::Sum(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let datatype = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); let builder = match datatype { DataType::Decimal128(_, _) => { @@ -1851,8 +1891,12 @@ impl PhysicalPlanner { } AggExprStruct::Avg(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); + let datatype = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); + let input_datatype = to_arrow_datatype( + expr.sum_datatype.as_ref().unwrap(), + &self.session_time_zone_id, + ); let builder = match datatype { DataType::Decimal128(_, _) => { let func = @@ -1939,7 +1983,8 @@ impl PhysicalPlanner { self.create_expr(expr.child1.as_ref().unwrap(), Arc::clone(&schema))?; let child2 = self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let datatype = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); match expr.stats_type { 0 => { let func = AggregateUDF::new_from_impl(Covariance::new( @@ -1978,7 +2023,8 @@ impl PhysicalPlanner { } AggExprStruct::Variance(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let datatype = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); match expr.stats_type { 0 => { let func = AggregateUDF::new_from_impl(Variance::new( @@ -2007,7 +2053,8 @@ impl PhysicalPlanner { } AggExprStruct::Stddev(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let datatype = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); match expr.stats_type { 0 => { let func = AggregateUDF::new_from_impl(Stddev::new( @@ -2039,7 +2086,8 @@ impl PhysicalPlanner { self.create_expr(expr.child1.as_ref().unwrap(), Arc::clone(&schema))?; let child2 = self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let datatype = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); let func = AggregateUDF::new_from_impl(Correlation::new( "correlation", datatype, @@ -2053,7 +2101,8 @@ impl PhysicalPlanner { self.create_expr(expr.num_items.as_ref().unwrap(), Arc::clone(&schema))?; let num_bits = self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let datatype = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); let func = AggregateUDF::new_from_impl(BloomFilterAgg::new( Arc::clone(&num_items), Arc::clone(&num_bits), @@ -2263,7 +2312,8 @@ impl PhysicalPlanner { } Some(AggExprStruct::Sum(expr)) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let arrow_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let arrow_type = + to_arrow_datatype(expr.datatype.as_ref().unwrap(), &self.session_time_zone_id); let datatype = child.data_type(&schema)?; let child = if datatype != arrow_type { @@ -2389,48 +2439,51 @@ impl PhysicalPlanner { .map(|x| x.data_type(input_schema.as_ref())) .collect::, _>>()?; - let (data_type, coerced_input_types) = - match expr.return_type.as_ref().map(to_arrow_datatype) { - Some(t) => (t, input_expr_types.clone()), - None => { - let fun_name = match fun_name.as_ref() { - "read_side_padding" => "rpad", // use the same return type as rpad - other => other, - }; - let func = self.session_ctx.udf(fun_name)?; - let coerced_types = func - .coerce_types(&input_expr_types) - .unwrap_or_else(|_| input_expr_types.clone()); + let (data_type, coerced_input_types) = match expr + .return_type + .as_ref() + .map(|dt| to_arrow_datatype(dt, &self.session_time_zone_id)) + { + Some(t) => (t, input_expr_types.clone()), + None => { + let fun_name = match fun_name.as_ref() { + "read_side_padding" => "rpad", // use the same return type as rpad + other => other, + }; + let func = self.session_ctx.udf(fun_name)?; + let coerced_types = func + .coerce_types(&input_expr_types) + .unwrap_or_else(|_| input_expr_types.clone()); - let arg_fields = coerced_types - .iter() - .enumerate() - .map(|(i, dt)| Arc::new(Field::new(format!("arg{i}"), dt.clone(), true))) - .collect::>(); + let arg_fields = coerced_types + .iter() + .enumerate() + .map(|(i, dt)| Arc::new(Field::new(format!("arg{i}"), dt.clone(), true))) + .collect::>(); - // TODO this should try and find scalar - let arguments = args - .iter() - .map(|e| { - e.as_ref() - .as_any() - .downcast_ref::() - .map(|lit| lit.value()) - }) - .collect::>(); + // TODO this should try and find scalar + let arguments = args + .iter() + .map(|e| { + e.as_ref() + .as_any() + .downcast_ref::() + .map(|lit| lit.value()) + }) + .collect::>(); - let args = ReturnFieldArgs { - arg_fields: &arg_fields, - scalar_arguments: &arguments, - }; + let args = ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &arguments, + }; - let data_type = Arc::clone(&func.inner().return_field_from_args(args)?) - .data_type() - .clone(); + let data_type = Arc::clone(&func.inner().return_field_from_args(args)?) + .data_type() + .clone(); - (data_type, coerced_types) - } - }; + (data_type, coerced_types) + } + }; let fun_expr = create_comet_physical_fun( fun_name, @@ -2641,13 +2694,14 @@ fn from_protobuf_eval_mode(value: i32) -> Result SchemaRef { let arrow_fields = spark_types .iter() .map(|spark_type| { Field::new( String::clone(&spark_type.name), - to_arrow_datatype(spark_type.data_type.as_ref().unwrap()), + to_arrow_datatype(spark_type.data_type.as_ref().unwrap(), session_time_zone_id), spark_type.nullable, ) }) @@ -3243,7 +3297,7 @@ mod tests { datafusion_functions_nested::make_array::MakeArray::new(), )); let task_ctx = session_ctx.task_ctx(); - let planner = PhysicalPlanner::new(Arc::from(session_ctx), 0); + let planner = PhysicalPlanner::new(Arc::from(session_ctx), "UTC", 0); // Create a plan for // ProjectionExec: expr=[make_array(col_0@0) as col_0] @@ -3361,7 +3415,7 @@ mod tests { fn test_array_repeat() { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let planner = PhysicalPlanner::new(Arc::from(session_ctx), 0); + let planner = PhysicalPlanner::new(Arc::from(session_ctx), "UTC", 0); // Mock scan operator with 3 INT32 columns let op_scan = Operator { diff --git a/native/core/src/execution/serde.rs b/native/core/src/execution/serde.rs index e95fd7eca2..f3082a47c7 100644 --- a/native/core/src/execution/serde.rs +++ b/native/core/src/execution/serde.rs @@ -91,7 +91,7 @@ pub fn deserialize_data_type(buf: &[u8]) -> Result ArrowDataType { +pub fn to_arrow_datatype(dt_value: &DataType, session_time_zone_id: &str) -> ArrowDataType { match DataTypeId::try_from(dt_value.type_id).unwrap() { Bool => ArrowDataType::Boolean, Int8 => ArrowDataType::Int8, @@ -115,9 +115,10 @@ pub fn to_arrow_datatype(dt_value: &DataType) -> ArrowDataType { } _ => unreachable!(), }, - DataTypeId::Timestamp => { - ArrowDataType::Timestamp(TimeUnit::Microsecond, Some("UTC".to_string().into())) - } + DataTypeId::Timestamp => ArrowDataType::Timestamp( + TimeUnit::Microsecond, + Some(session_time_zone_id.to_string().into()), + ), DataTypeId::TimestampNtz => ArrowDataType::Timestamp(TimeUnit::Microsecond, None), DataTypeId::Date => ArrowDataType::Date32, DataTypeId::Null => ArrowDataType::Null, @@ -132,7 +133,7 @@ pub fn to_arrow_datatype(dt_value: &DataType) -> ArrowDataType { DatatypeStruct::List(info) => { let field = Field::new( "item", - to_arrow_datatype(info.element_type.as_ref().unwrap()), + to_arrow_datatype(info.element_type.as_ref().unwrap(), session_time_zone_id), info.contains_null, ); ArrowDataType::List(Arc::new(field)) @@ -150,12 +151,12 @@ pub fn to_arrow_datatype(dt_value: &DataType) -> ArrowDataType { DatatypeStruct::Map(info) => { let key_field = Field::new( "key", - to_arrow_datatype(info.key_type.as_ref().unwrap()), + to_arrow_datatype(info.key_type.as_ref().unwrap(), session_time_zone_id), false, ); let value_field = Field::new( "value", - to_arrow_datatype(info.value_type.as_ref().unwrap()), + to_arrow_datatype(info.value_type.as_ref().unwrap(), session_time_zone_id), info.value_contains_null, ); let struct_field = Field::new( @@ -183,7 +184,7 @@ pub fn to_arrow_datatype(dt_value: &DataType) -> ArrowDataType { .map(|(idx, name)| { Field::new( name, - to_arrow_datatype(&info.field_datatypes[idx]), + to_arrow_datatype(&info.field_datatypes[idx], session_time_zone_id), info.field_nullable[idx], ) }) diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index c8a480e97a..9c5aa6efa9 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -669,8 +669,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_validateObjec ) { try_unwrap_or_throw(&e, |mut env| { let session_config = SessionConfig::new(); - let planner = - PhysicalPlanner::new(Arc::new(SessionContext::new_with_config(session_config)), 0); + let planner = PhysicalPlanner::new( + Arc::new(SessionContext::new_with_config(session_config)), + "UTC", + 0, + ); let session_ctx = planner.session_ctx(); let path: String = env.get_string(&file_path).unwrap().into(); let object_store_config = get_object_store_options(&mut env, object_store_options)?; @@ -706,8 +709,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat try_unwrap_or_throw(&e, |mut env| unsafe { JVMClasses::init(&mut env); let session_config = SessionConfig::new().with_batch_size(batch_size as usize); - let planner = - PhysicalPlanner::new(Arc::new(SessionContext::new_with_config(session_config)), 0); + let planner = PhysicalPlanner::new( + Arc::new(SessionContext::new_with_config(session_config)), + "UTC", + 0, + ); let session_ctx = planner.session_ctx(); let path: String = env.get_string(&file_path).unwrap().into(); diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java index 044c7842f0..0f3e5d85f2 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java @@ -196,7 +196,10 @@ protected long doSpilling( currentChecksum, compressionCodec, compressionLevel, - tracingEnabled); + tracingEnabled, + // TODO using session time zone causes regressions in Parquet scan + // SQLConf.get().sessionLocalTimeZone() + "UTC"); long written = results[0]; checksum = results[1]; diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index a680cbf592..2e90aba0fc 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -30,6 +30,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.comet.CometMetricNode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized._ import org.apache.spark.util.SerializableConfiguration @@ -124,7 +125,8 @@ class CometExecIterator( memoryConfig.memoryLimit, memoryConfig.memoryLimitPerTask, taskAttemptId, - keyUnwrapper) + keyUnwrapper, + SQLConf.get.sessionLocalTimeZone) } private var nextBatch: Option[ColumnarBatch] = None diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 6ef92d0a67..136d004004 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -68,7 +68,8 @@ class Native extends NativeBase { memoryLimit: Long, memoryLimitPerTask: Long, taskAttemptId: Long, - keyUnwrapper: CometFileKeyUnwrapper): Long + keyUnwrapper: CometFileKeyUnwrapper, + timeZoneId: String): Long // scalastyle:on /** @@ -147,7 +148,8 @@ class Native extends NativeBase { currentChecksum: Long, compressionCodec: String, compressionLevel: Int, - tracingEnabled: Boolean): Array[Long] + tracingEnabled: Boolean, + timeZoneId: String): Array[Long] // scalastyle:on /**