Skip to content

Commit 69537c5

Browse files
committed
feat: transfer Apache Spark runtime conf to native engine
1 parent 0ba03a1 commit 69537c5

File tree

3 files changed

+49
-33
lines changed

3 files changed

+49
-33
lines changed

native/core/src/execution/jni_api.rs

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,31 @@
1818
//! Define JNI APIs which can be called from Java/Scala.
1919
2020
use super::{serde, utils::SparkArrowConvert, CometMemoryPool};
21+
use crate::{
22+
errors::{try_unwrap_or_throw, CometError, CometResult},
23+
execution::{
24+
metrics::utils::update_comet_metric, planner::PhysicalPlanner, serde::to_arrow_datatype,
25+
shuffle::row::process_sorted_row_partition, sort::RdxSort,
26+
},
27+
jvm_bridge::{jni_new_global_ref, JVMClasses},
28+
};
2129
use arrow::array::RecordBatch;
2230
use arrow::datatypes::DataType as ArrowDataType;
31+
use datafusion::common::ScalarValue;
2332
use datafusion::execution::memory_pool::{
2433
FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, UnboundedMemoryPool,
2534
};
35+
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
2636
use datafusion::{
2737
execution::{disk_manager::DiskManagerConfig, runtime_env::RuntimeEnv},
2838
physical_plan::{display::DisplayableExecutionPlan, SendableRecordBatchStream},
2939
prelude::{SessionConfig, SessionContext},
3040
};
41+
use datafusion_comet_proto::spark_operator::Operator;
3142
use futures::poll;
43+
use futures::stream::StreamExt;
44+
use jni::objects::{JByteBuffer, JMap};
45+
use jni::sys::JNI_FALSE;
3246
use jni::{
3347
errors::Result as JNIResult,
3448
objects::{
@@ -38,30 +52,15 @@ use jni::{
3852
sys::{jbyteArray, jint, jlong, jlongArray},
3953
JNIEnv,
4054
};
41-
use std::path::PathBuf;
42-
use std::time::{Duration, Instant};
43-
use std::{collections::HashMap, sync::Arc, task::Poll};
44-
45-
use crate::{
46-
errors::{try_unwrap_or_throw, CometError, CometResult},
47-
execution::{
48-
metrics::utils::update_comet_metric, planner::PhysicalPlanner, serde::to_arrow_datatype,
49-
shuffle::row::process_sorted_row_partition, sort::RdxSort,
50-
},
51-
jvm_bridge::{jni_new_global_ref, JVMClasses},
52-
};
53-
use datafusion::common::ScalarValue;
54-
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
55-
use datafusion_comet_proto::spark_operator::Operator;
56-
use futures::stream::StreamExt;
57-
use jni::objects::JByteBuffer;
58-
use jni::sys::JNI_FALSE;
5955
use jni::{
6056
objects::GlobalRef,
6157
sys::{jboolean, jdouble, jintArray, jobjectArray, jstring},
6258
};
6359
use std::num::NonZeroUsize;
60+
use std::path::PathBuf;
6461
use std::sync::Mutex;
62+
use std::time::{Duration, Instant};
63+
use std::{collections::HashMap, sync::Arc, task::Poll};
6564
use tokio::runtime::Runtime;
6665

6766
use crate::execution::fair_memory_pool::CometFairMemoryPool;
@@ -128,6 +127,8 @@ struct ExecutionContext {
128127
pub explain_native: bool,
129128
/// Memory pool config
130129
pub memory_pool_config: MemoryPoolConfig,
130+
/// Apache Spark config
131+
pub spark_config: HashMap<String, String>,
131132
}
132133

133134
#[derive(PartialEq, Eq)]
@@ -198,6 +199,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
198199
task_attempt_id: jlong,
199200
debug_native: jboolean,
200201
explain_native: jboolean,
202+
spark_conf: JObject,
201203
) -> jlong {
202204
try_unwrap_or_throw(&e, |mut env| {
203205
// Init JVM classes
@@ -261,6 +263,17 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
261263
None
262264
};
263265

266+
// Read Apache Spark runtime config
267+
let spark_conf_map = JMap::from_env(&mut env, &spark_conf)?;
268+
let mut spark_conf_iter = spark_conf_map.iter(&mut env)?;
269+
let mut spark_conf = HashMap::new();
270+
271+
while let Some((key, value)) = spark_conf_iter.next(&mut env)? {
272+
let key: String = env.get_string(&JString::from(key)).unwrap().into();
273+
let value: String = env.get_string(&JString::from(value)).unwrap().into();
274+
spark_conf.insert(key, value);
275+
}
276+
264277
let exec_context = Box::new(ExecutionContext {
265278
id,
266279
task_attempt_id,
@@ -278,6 +291,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
278291
debug_native: debug_native == 1,
279292
explain_native: explain_native == 1,
280293
memory_pool_config,
294+
spark_config: spark_conf,
281295
});
282296

283297
Ok(Box::into_raw(exec_context) as i64)
@@ -632,17 +646,17 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
632646
exec_context: jlong,
633647
) {
634648
try_unwrap_or_throw(&e, |mut env| unsafe {
635-
let execution_context = get_execution_context(exec_context);
649+
let exec_context = get_execution_context(exec_context);
636650

637651
// Update metrics
638-
update_metrics(&mut env, execution_context)?;
652+
update_metrics(&mut env, exec_context)?;
639653

640-
if execution_context.memory_pool_config.pool_type == MemoryPoolType::FairSpillTaskShared
641-
|| execution_context.memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared
654+
if exec_context.memory_pool_config.pool_type == MemoryPoolType::FairSpillTaskShared
655+
|| exec_context.memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared
642656
{
643657
// Decrement the number of native plans using the per-task shared memory pool, and
644658
// remove the memory pool if the released native plan is the last native plan using it.
645-
let task_attempt_id = execution_context.task_attempt_id;
659+
let task_attempt_id = exec_context.task_attempt_id;
646660
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
647661
if let Some(per_task_memory_pool) = memory_pool_map.get_mut(&task_attempt_id) {
648662
per_task_memory_pool.num_plans -= 1;
@@ -653,7 +667,7 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
653667
}
654668
}
655669
}
656-
let _: Box<ExecutionContext> = Box::from_raw(execution_context);
670+
let _: Box<ExecutionContext> = Box::from_raw(exec_context);
657671
Ok(())
658672
})
659673
}

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ class CometExecIterator(
9292
memoryLimitPerTask = getMemoryLimitPerTask(conf),
9393
taskAttemptId = TaskContext.get().taskAttemptId,
9494
debug = COMET_DEBUG_ENABLED.get(),
95-
explain = COMET_EXPLAIN_NATIVE_ENABLED.get())
95+
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
96+
sparkConfig = SparkEnv.get.conf.getAll.toMap)
9697
}
9798

9899
private var nextBatch: Option[ColumnarBatch] = None

spark/src/main/scala/org/apache/comet/Native.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ class Native extends NativeBase {
3131
* Create a native query plan from execution SparkPlan serialized in bytes.
3232
* @param id
3333
* The id of the query plan.
34-
* @param configMap
35-
* The Java Map object for the configs of native engine.
3634
* @param iterators
3735
* the input iterators to the native query plan. It should be the same number as the number of
3836
* scan nodes in the SparkPlan.
@@ -46,6 +44,8 @@ class Native extends NativeBase {
4644
* @param taskMemoryManager
4745
* the task-level memory manager that is responsible for tracking memory usage across JVM and
4846
* native side.
47+
* @param sparkConfig
48+
* Apache Spark runtime configuration
4949
* @return
5050
* the address to native query plan.
5151
*/
@@ -59,14 +59,15 @@ class Native extends NativeBase {
5959
metricsUpdateInterval: Long,
6060
taskMemoryManager: CometTaskMemoryManager,
6161
localDirs: Array[String],
62-
batchSize: Int,
63-
offHeapMode: Boolean,
64-
memoryPoolType: String,
65-
memoryLimit: Long,
66-
memoryLimitPerTask: Long,
62+
batchSize: Int, // move to spark conf ?
63+
offHeapMode: Boolean, // move to spark conf ?
64+
memoryPoolType: String, // move to spark conf ?
65+
memoryLimit: Long, // move to spark conf ?
66+
memoryLimitPerTask: Long, // move to spark conf ?
6767
taskAttemptId: Long,
6868
debug: Boolean,
69-
explain: Boolean): Long
69+
explain: Boolean,
70+
sparkConfig: Map[String, String]): Long
7071
// scalastyle:on
7172

7273
/**

0 commit comments

Comments
 (0)