diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index ef5435cbc9..1694d81ea7 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -75,6 +75,9 @@ use crate::execution::spark_plan::SparkPlan; use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_trace}; +use crate::execution::spark_config::{ + SparkConfig, COMET_DEBUG_ENABLED, COMET_EXPLAIN_NATIVE_ENABLED, COMET_TRACING_ENABLED, +}; use datafusion_comet_proto::spark_operator::operator::OpStruct; use log::info; use once_cell::sync::Lazy; @@ -164,12 +167,20 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( memory_limit: jlong, memory_limit_per_task: jlong, task_attempt_id: jlong, - debug_native: jboolean, - explain_native: jboolean, - tracing_enabled: jboolean, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { - with_trace("createPlan", tracing_enabled != JNI_FALSE, || { + // Deserialize Spark configs + let array = unsafe { JPrimitiveArray::from_raw(serialized_spark_configs) }; + let bytes = env.convert_byte_array(array)?; + let spark_configs = serde::deserialize_config(bytes.as_slice())?; + let spark_config: HashMap = spark_configs.entries.into_iter().collect(); + + // Access Spark configs + let debug_native = spark_config.get_bool(COMET_DEBUG_ENABLED); + let explain_native = spark_config.get_bool(COMET_EXPLAIN_NATIVE_ENABLED); + let tracing_enabled = spark_config.get_bool(COMET_TRACING_ENABLED); + + with_trace("createPlan", tracing_enabled, || { // Init JVM classes JVMClasses::init(&mut env); @@ -180,15 +191,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let bytes = env.convert_byte_array(array)?; let spark_plan = serde::deserialize_op(bytes.as_slice())?; - // Deserialize Spark configs - let array = unsafe { JPrimitiveArray::from_raw(serialized_spark_configs) }; - let bytes = env.convert_byte_array(array)?; - let spark_configs = serde::deserialize_config(bytes.as_slice())?; - - // Convert Spark configs to HashMap - let _spark_config_map: HashMap = - spark_configs.entries.into_iter().collect(); - let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?); // Get the global references of input sources @@ -253,10 +255,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( metrics_last_update_time: Instant::now(), plan_creation_time, session_ctx: Arc::new(session), - debug_native: debug_native == 1, - explain_native: explain_native == 1, + debug_native, + explain_native, memory_pool_config, - tracing_enabled: tracing_enabled != JNI_FALSE, + tracing_enabled, }); Ok(Box::into_raw(exec_context) as i64) diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index c55b96f2a9..b8a3d546b3 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -27,6 +27,7 @@ pub(crate) mod sort; pub(crate) mod spark_plan; pub use datafusion_comet_spark_expr::timezone; mod memory_pools; +pub(crate) mod spark_config; pub(crate) mod tracing; pub(crate) mod utils; diff --git a/native/core/src/execution/spark_config.rs b/native/core/src/execution/spark_config.rs new file mode 100644 index 0000000000..7465a1ea9f --- /dev/null +++ b/native/core/src/execution/spark_config.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +pub(crate) const COMET_TRACING_ENABLED: &str = "spark.comet.tracing.enabled"; +pub(crate) const COMET_DEBUG_ENABLED: &str = "spark.comet.debug.enabled"; +pub(crate) const COMET_EXPLAIN_NATIVE_ENABLED: &str = "spark.comet.explain.native.enabled"; + +pub(crate) trait SparkConfig { + fn get_bool(&self, name: &str) -> bool; +} + +impl SparkConfig for HashMap { + fn get_bool(&self, name: &str) -> bool { + self.get(name) + .and_then(|str_val| str_val.parse::().ok()) + .unwrap_or(false) + } +} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 67d044f8c5..700e786e35 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -30,7 +30,7 @@ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.comet.CometMetricNode import org.apache.spark.sql.vectorized._ -import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_POOL_TYPE, COMET_EXPLAIN_NATIVE_ENABLED, COMET_METRICS_UPDATE_INTERVAL} +import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_EXEC_MEMORY_POOL_TYPE, COMET_METRICS_UPDATE_INTERVAL} import org.apache.comet.Tracing.withTrace import org.apache.comet.serde.Config.ConfigMap import org.apache.comet.vector.NativeUtil @@ -108,10 +108,7 @@ class CometExecIterator( memoryPoolType = COMET_EXEC_MEMORY_POOL_TYPE.get(), memoryLimit, memoryLimitPerTask = getMemoryLimitPerTask(conf), - taskAttemptId = TaskContext.get().taskAttemptId, - debug = COMET_DEBUG_ENABLED.get(), - explain = COMET_EXPLAIN_NATIVE_ENABLED.get(), - tracingEnabled) + taskAttemptId = TaskContext.get().taskAttemptId) } private var nextBatch: Option[ColumnarBatch] = None diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 7430a4322c..a269993bb1 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -65,10 +65,7 @@ class Native extends NativeBase { memoryPoolType: String, memoryLimit: Long, memoryLimitPerTask: Long, - taskAttemptId: Long, - debug: Boolean, - explain: Boolean, - tracingEnabled: Boolean): Long + taskAttemptId: Long): Long // scalastyle:on /**