Skip to content

Commit 9caeec1

Browse files
authored
chore: Pass Comet configs to native createPlan (#2543)
1 parent cd29597 commit 9caeec1

File tree

6 files changed

+71
-30
lines changed

6 files changed

+71
-30
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ object CometConf extends ShimCometConf {
6363

6464
def conf(key: String): ConfigBuilder = ConfigBuilder(key)
6565

66-
val COMET_EXEC_CONFIG_PREFIX = "spark.comet.exec";
66+
val COMET_PREFIX = "spark.comet";
6767

68-
val COMET_EXPR_CONFIG_PREFIX = "spark.comet.expression";
68+
val COMET_EXEC_CONFIG_PREFIX: String = s"$COMET_PREFIX.exec";
69+
70+
val COMET_EXPR_CONFIG_PREFIX: String = s"$COMET_PREFIX.expression";
6971

7072
val COMET_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.enabled")
7173
.doc(

native/core/src/execution/jni_api.rs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ use crate::execution::spark_plan::SparkPlan;
7878

7979
use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_trace};
8080

81+
use crate::execution::spark_config::{
82+
SparkConfig, COMET_DEBUG_ENABLED, COMET_EXPLAIN_NATIVE_ENABLED, COMET_MAX_TEMP_DIRECTORY_SIZE,
83+
COMET_TRACING_ENABLED,
84+
};
8185
use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID};
8286
use datafusion_comet_proto::spark_operator::operator::OpStruct;
8387
use log::info;
@@ -168,14 +172,23 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
168172
memory_limit: jlong,
169173
memory_limit_per_task: jlong,
170174
task_attempt_id: jlong,
171-
debug_native: jboolean,
172-
explain_native: jboolean,
173-
tracing_enabled: jboolean,
174-
max_temp_directory_size: jlong,
175175
key_unwrapper_obj: JObject,
176176
) -> jlong {
177177
try_unwrap_or_throw(&e, |mut env| {
178-
with_trace("createPlan", tracing_enabled != JNI_FALSE, || {
178+
// Deserialize Spark configs
179+
let array = unsafe { JPrimitiveArray::from_raw(serialized_spark_configs) };
180+
let bytes = env.convert_byte_array(array)?;
181+
let spark_configs = serde::deserialize_config(bytes.as_slice())?;
182+
let spark_config: HashMap<String, String> = spark_configs.entries.into_iter().collect();
183+
184+
// Access Comet configs
185+
let debug_native = spark_config.get_bool(COMET_DEBUG_ENABLED);
186+
let explain_native = spark_config.get_bool(COMET_EXPLAIN_NATIVE_ENABLED);
187+
let tracing_enabled = spark_config.get_bool(COMET_TRACING_ENABLED);
188+
let max_temp_directory_size =
189+
spark_config.get_u64(COMET_MAX_TEMP_DIRECTORY_SIZE, 100 * 1024 * 1024 * 1024);
190+
191+
with_trace("createPlan", tracing_enabled, || {
179192
// Init JVM classes
180193
JVMClasses::init(&mut env);
181194

@@ -186,15 +199,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
186199
let bytes = env.convert_byte_array(array)?;
187200
let spark_plan = serde::deserialize_op(bytes.as_slice())?;
188201

189-
// Deserialize Spark configs
190-
let array = unsafe { JPrimitiveArray::from_raw(serialized_spark_configs) };
191-
let bytes = env.convert_byte_array(array)?;
192-
let spark_configs = serde::deserialize_config(bytes.as_slice())?;
193-
194-
// Convert Spark configs to HashMap
195-
let _spark_config_map: HashMap<String, String> =
196-
spark_configs.entries.into_iter().collect();
197-
198202
let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?);
199203

200204
// Get the global references of input sources
@@ -238,7 +242,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
238242
batch_size as usize,
239243
memory_pool,
240244
local_dirs,
241-
max_temp_directory_size as u64,
245+
max_temp_directory_size,
242246
)?;
243247

244248
let plan_creation_time = start.elapsed();
@@ -274,10 +278,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
274278
metrics_last_update_time: Instant::now(),
275279
plan_creation_time,
276280
session_ctx: Arc::new(session),
277-
debug_native: debug_native == 1,
278-
explain_native: explain_native == 1,
281+
debug_native,
282+
explain_native,
279283
memory_pool_config,
280-
tracing_enabled: tracing_enabled != JNI_FALSE,
284+
tracing_enabled,
281285
});
282286

283287
Ok(Box::into_raw(exec_context) as i64)

native/core/src/execution/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub(crate) mod sort;
2727
pub(crate) mod spark_plan;
2828
pub use datafusion_comet_spark_expr::timezone;
2929
mod memory_pools;
30+
pub(crate) mod spark_config;
3031
pub(crate) mod tracing;
3132
pub(crate) mod utils;
3233

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::collections::HashMap;
19+
20+
pub(crate) const COMET_TRACING_ENABLED: &str = "spark.comet.tracing.enabled";
21+
pub(crate) const COMET_DEBUG_ENABLED: &str = "spark.comet.debug.enabled";
22+
pub(crate) const COMET_EXPLAIN_NATIVE_ENABLED: &str = "spark.comet.explain.native.enabled";
23+
pub(crate) const COMET_MAX_TEMP_DIRECTORY_SIZE: &str = "spark.comet.maxTempDirectorySize";
24+
25+
pub(crate) trait SparkConfig {
26+
fn get_bool(&self, name: &str) -> bool;
27+
fn get_u64(&self, name: &str, default_value: u64) -> u64;
28+
}
29+
30+
impl SparkConfig for HashMap<String, String> {
31+
fn get_bool(&self, name: &str) -> bool {
32+
self.get(name)
33+
.and_then(|str_val| str_val.parse::<bool>().ok())
34+
.unwrap_or(false)
35+
}
36+
37+
fn get_u64(&self, name: &str, default_value: u64) -> u64 {
38+
self.get(name)
39+
.and_then(|str_val| str_val.parse::<u64>().ok())
40+
.unwrap_or(default_value)
41+
}
42+
}

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ class CometExecIterator(
9696
CometSparkSessionExtensions.getCometMemoryOverhead(conf)
9797
}
9898

99-
// serialize Spark conf in protobuf format
99+
// serialize Comet related Spark configs in protobuf format
100100
val builder = ConfigMap.newBuilder()
101-
conf.getAll.foreach { case (k, v) =>
101+
conf.getAll.filter(_._1.startsWith(CometConf.COMET_PREFIX)).foreach { case (k, v) =>
102102
builder.putEntries(k, v)
103103
}
104104
val protobufSparkConfigs = builder.build().toByteArray
@@ -140,10 +140,6 @@ class CometExecIterator(
140140
memoryLimit,
141141
memoryLimitPerTask,
142142
taskAttemptId,
143-
debug = COMET_DEBUG_ENABLED.get(),
144-
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
145-
tracingEnabled,
146-
maxTempDirectorySize = CometConf.COMET_MAX_TEMP_DIRECTORY_SIZE.get(),
147143
keyUnwrapper)
148144
}
149145

spark/src/main/scala/org/apache/comet/Native.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,6 @@ class Native extends NativeBase {
6868
memoryLimit: Long,
6969
memoryLimitPerTask: Long,
7070
taskAttemptId: Long,
71-
debug: Boolean,
72-
explain: Boolean,
73-
tracingEnabled: Boolean,
74-
maxTempDirectorySize: Long,
7571
keyUnwrapper: CometFileKeyUnwrapper): Long
7672
// scalastyle:on
7773

0 commit comments

Comments
 (0)