1818//! Define JNI APIs which can be called from Java/Scala.
1919
2020use super :: { serde, utils:: SparkArrowConvert , CometMemoryPool } ;
21+ use crate :: {
22+ errors:: { try_unwrap_or_throw, CometError , CometResult } ,
23+ execution:: {
24+ metrics:: utils:: update_comet_metric, planner:: PhysicalPlanner , serde:: to_arrow_datatype,
25+ shuffle:: row:: process_sorted_row_partition, sort:: RdxSort ,
26+ } ,
27+ jvm_bridge:: { jni_new_global_ref, JVMClasses } ,
28+ } ;
2129use arrow:: array:: RecordBatch ;
2230use arrow:: datatypes:: DataType as ArrowDataType ;
31+ use datafusion:: common:: ScalarValue ;
2332use datafusion:: execution:: memory_pool:: {
2433 FairSpillPool , GreedyMemoryPool , MemoryPool , TrackConsumersPool , UnboundedMemoryPool ,
2534} ;
35+ use datafusion:: execution:: runtime_env:: RuntimeEnvBuilder ;
2636use datafusion:: {
2737 execution:: { disk_manager:: DiskManagerConfig , runtime_env:: RuntimeEnv } ,
2838 physical_plan:: { display:: DisplayableExecutionPlan , SendableRecordBatchStream } ,
2939 prelude:: { SessionConfig , SessionContext } ,
3040} ;
41+ use datafusion_comet_proto:: spark_operator:: Operator ;
3142use futures:: poll;
43+ use futures:: stream:: StreamExt ;
44+ use jni:: objects:: { JByteBuffer , JMap } ;
45+ use jni:: sys:: JNI_FALSE ;
3246use jni:: {
3347 errors:: Result as JNIResult ,
3448 objects:: {
@@ -38,30 +52,15 @@ use jni::{
3852 sys:: { jbyteArray, jint, jlong, jlongArray} ,
3953 JNIEnv ,
4054} ;
41- use std:: path:: PathBuf ;
42- use std:: time:: { Duration , Instant } ;
43- use std:: { collections:: HashMap , sync:: Arc , task:: Poll } ;
44-
45- use crate :: {
46- errors:: { try_unwrap_or_throw, CometError , CometResult } ,
47- execution:: {
48- metrics:: utils:: update_comet_metric, planner:: PhysicalPlanner , serde:: to_arrow_datatype,
49- shuffle:: row:: process_sorted_row_partition, sort:: RdxSort ,
50- } ,
51- jvm_bridge:: { jni_new_global_ref, JVMClasses } ,
52- } ;
53- use datafusion:: common:: ScalarValue ;
54- use datafusion:: execution:: runtime_env:: RuntimeEnvBuilder ;
55- use datafusion_comet_proto:: spark_operator:: Operator ;
56- use futures:: stream:: StreamExt ;
57- use jni:: objects:: JByteBuffer ;
58- use jni:: sys:: JNI_FALSE ;
5955use jni:: {
6056 objects:: GlobalRef ,
6157 sys:: { jboolean, jdouble, jintArray, jobjectArray, jstring} ,
6258} ;
6359use std:: num:: NonZeroUsize ;
60+ use std:: path:: PathBuf ;
6461use std:: sync:: Mutex ;
62+ use std:: time:: { Duration , Instant } ;
63+ use std:: { collections:: HashMap , sync:: Arc , task:: Poll } ;
6564use tokio:: runtime:: Runtime ;
6665
6766use crate :: execution:: fair_memory_pool:: CometFairMemoryPool ;
@@ -128,6 +127,8 @@ struct ExecutionContext {
128127 pub explain_native : bool ,
129128 /// Memory pool config
130129 pub memory_pool_config : MemoryPoolConfig ,
130+ /// Apache Spark config
131+ pub spark_config : HashMap < String , String > ,
131132}
132133
133134#[ derive( PartialEq , Eq ) ]
@@ -198,6 +199,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
198199 task_attempt_id : jlong ,
199200 debug_native : jboolean ,
200201 explain_native : jboolean ,
202+ spark_conf : JObject ,
201203) -> jlong {
202204 try_unwrap_or_throw ( & e, |mut env| {
203205 // Init JVM classes
@@ -261,6 +263,17 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
261263 None
262264 } ;
263265
266+ // Read Apache Spark runtime config
267+ let spark_conf_map = JMap :: from_env ( & mut env, & spark_conf) ?;
268+ let mut spark_conf_iter = spark_conf_map. iter ( & mut env) ?;
269+ let mut spark_conf = HashMap :: new ( ) ;
270+
271+ while let Some ( ( key, value) ) = spark_conf_iter. next ( & mut env) ? {
272+ let key: String = env. get_string ( & JString :: from ( key) ) . unwrap ( ) . into ( ) ;
273+ let value: String = env. get_string ( & JString :: from ( value) ) . unwrap ( ) . into ( ) ;
274+ spark_conf. insert ( key, value) ;
275+ }
276+
264277 let exec_context = Box :: new ( ExecutionContext {
265278 id,
266279 task_attempt_id,
@@ -278,6 +291,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
278291 debug_native : debug_native == 1 ,
279292 explain_native : explain_native == 1 ,
280293 memory_pool_config,
294+ spark_config : spark_conf,
281295 } ) ;
282296
283297 Ok ( Box :: into_raw ( exec_context) as i64 )
@@ -632,17 +646,17 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
632646 exec_context : jlong ,
633647) {
634648 try_unwrap_or_throw ( & e, |mut env| unsafe {
635- let execution_context = get_execution_context ( exec_context) ;
649+ let exec_context = get_execution_context ( exec_context) ;
636650
637651 // Update metrics
638- update_metrics ( & mut env, execution_context ) ?;
652+ update_metrics ( & mut env, exec_context ) ?;
639653
640- if execution_context . memory_pool_config . pool_type == MemoryPoolType :: FairSpillTaskShared
641- || execution_context . memory_pool_config . pool_type == MemoryPoolType :: GreedyTaskShared
654+ if exec_context . memory_pool_config . pool_type == MemoryPoolType :: FairSpillTaskShared
655+ || exec_context . memory_pool_config . pool_type == MemoryPoolType :: GreedyTaskShared
642656 {
643657 // Decrement the number of native plans using the per-task shared memory pool, and
644658 // remove the memory pool if the released native plan is the last native plan using it.
645- let task_attempt_id = execution_context . task_attempt_id ;
659+ let task_attempt_id = exec_context . task_attempt_id ;
646660 let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS . lock ( ) . unwrap ( ) ;
647661 if let Some ( per_task_memory_pool) = memory_pool_map. get_mut ( & task_attempt_id) {
648662 per_task_memory_pool. num_plans -= 1 ;
@@ -653,7 +667,7 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
653667 }
654668 }
655669 }
656- let _: Box < ExecutionContext > = Box :: from_raw ( execution_context ) ;
670+ let _: Box < ExecutionContext > = Box :: from_raw ( exec_context ) ;
657671 Ok ( ( ) )
658672 } )
659673}
0 commit comments