@@ -31,8 +31,8 @@ use futures::poll;
3131use jni:: {
3232 errors:: Result as JNIResult ,
3333 objects:: {
34- JByteArray , JClass , JIntArray , JLongArray , JMap , JObject , JObjectArray , JPrimitiveArray ,
35- JString , ReleaseMode ,
34+ JByteArray , JClass , JIntArray , JLongArray , JObject , JObjectArray , JPrimitiveArray , JString ,
35+ ReleaseMode ,
3636 } ,
3737 sys:: { jbyteArray, jint, jlong, jlongArray} ,
3838 JNIEnv ,
@@ -77,8 +77,6 @@ struct ExecutionContext {
7777 pub input_sources : Vec < Arc < GlobalRef > > ,
7878 /// The record batch stream to pull results from
7979 pub stream : Option < SendableRecordBatchStream > ,
80- /// Configurations for DF execution
81- pub conf : HashMap < String , String > ,
8280 /// The Tokio runtime used for async.
8381 pub runtime : Runtime ,
8482 /// Native metrics
@@ -103,11 +101,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
103101 e : JNIEnv ,
104102 _class : JClass ,
105103 id : jlong ,
106- config_object : JObject ,
107104 iterators : jobjectArray ,
108105 serialized_query : jbyteArray ,
109106 metrics_node : JObject ,
110107 comet_task_memory_manager_obj : JObject ,
108+ batch_size : jint ,
109+ debug_native : jboolean ,
110+ explain_native : jboolean ,
111+ worker_threads : jint ,
112+ blocking_threads : jint ,
111113) -> jlong {
112114 try_unwrap_or_throw ( & e, |mut env| {
113115 // Init JVM classes
@@ -121,36 +123,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
121123 // Deserialize query plan
122124 let spark_plan = serde:: deserialize_op ( bytes. as_slice ( ) ) ?;
123125
124- // Sets up context
125- let mut configs = HashMap :: new ( ) ;
126-
127- let config_map = JMap :: from_env ( & mut env, & config_object) ?;
128- let mut map_iter = config_map. iter ( & mut env) ?;
129- while let Some ( ( key, value) ) = map_iter. next ( & mut env) ? {
130- let key: String = env. get_string ( & JString :: from ( key) ) . unwrap ( ) . into ( ) ;
131- let value: String = env. get_string ( & JString :: from ( value) ) . unwrap ( ) . into ( ) ;
132- configs. insert ( key, value) ;
133- }
134-
135- // Whether we've enabled additional debugging on the native side
136- let debug_native = parse_bool ( & configs, "debug_native" ) ?;
137- let explain_native = parse_bool ( & configs, "explain_native" ) ?;
138-
139- let worker_threads = configs
140- . get ( "worker_threads" )
141- . map ( String :: as_str)
142- . unwrap_or ( "4" )
143- . parse :: < usize > ( ) ?;
144- let blocking_threads = configs
145- . get ( "blocking_threads" )
146- . map ( String :: as_str)
147- . unwrap_or ( "10" )
148- . parse :: < usize > ( ) ?;
149-
150126 // Use multi-threaded tokio runtime to prevent blocking spawned tasks if any
151127 let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
152- . worker_threads ( worker_threads)
153- . max_blocking_threads ( blocking_threads)
128+ . worker_threads ( worker_threads as usize )
129+ . max_blocking_threads ( blocking_threads as usize )
154130 . enable_all ( )
155131 . build ( ) ?;
156132
@@ -171,7 +147,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
171147 // We need to keep the session context alive. Some session state like temporary
172148 // dictionaries are stored in session context. If it is dropped, the temporary
173149 // dictionaries will be dropped as well.
174- let session = prepare_datafusion_session_context ( & configs , task_memory_manager) ?;
150+ let session = prepare_datafusion_session_context ( batch_size as usize , task_memory_manager) ?;
175151
176152 let plan_creation_time = start. elapsed ( ) ;
177153
@@ -182,33 +158,24 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
182158 scans : vec ! [ ] ,
183159 input_sources,
184160 stream : None ,
185- conf : configs,
186161 runtime,
187162 metrics,
188163 plan_creation_time,
189164 session_ctx : Arc :: new ( session) ,
190- debug_native,
191- explain_native,
165+ debug_native : debug_native == 1 ,
166+ explain_native : explain_native == 1 ,
192167 metrics_jstrings : HashMap :: new ( ) ,
193168 } ) ;
194169
195170 Ok ( Box :: into_raw ( exec_context) as i64 )
196171 } )
197172}
198173
199- /// Parse Comet configs and configure DataFusion session context.
174+ /// Configure DataFusion session context.
200175fn prepare_datafusion_session_context (
201- conf : & HashMap < String , String > ,
176+ batch_size : usize ,
202177 comet_task_memory_manager : Arc < GlobalRef > ,
203178) -> CometResult < SessionContext > {
204- // Get the batch size from Comet JVM side
205- let batch_size = conf
206- . get ( "batch_size" )
207- . ok_or ( CometError :: Internal (
208- "Config 'batch_size' is not specified from Comet JVM side" . to_string ( ) ,
209- ) ) ?
210- . parse :: < usize > ( ) ?;
211-
212179 let mut rt_config = RuntimeConfig :: new ( ) . with_disk_manager ( DiskManagerConfig :: NewOs ) ;
213180
214181 // Set Comet memory pool for native
@@ -218,7 +185,7 @@ fn prepare_datafusion_session_context(
218185 // Get Datafusion configuration from Spark Execution context
219186 // can be configured in Comet Spark JVM using Spark --conf parameters
220187 // e.g: spark-shell --conf spark.datafusion.sql_parser.parse_float_as_decimal=true
221- let mut session_config = SessionConfig :: new ( )
188+ let session_config = SessionConfig :: new ( )
222189 . with_batch_size ( batch_size)
223190 // DataFusion partial aggregates can emit duplicate rows so we disable the
224191 // skip partial aggregation feature because this is not compatible with Spark's
@@ -231,11 +198,7 @@ fn prepare_datafusion_session_context(
231198 & ScalarValue :: Float64 ( Some ( 1.1 ) ) ,
232199 ) ;
233200
234- for ( key, value) in conf. iter ( ) . filter ( |( k, _) | k. starts_with ( "datafusion." ) ) {
235- session_config = session_config. set_str ( key, value) ;
236- }
237-
238- let runtime = RuntimeEnv :: try_new ( rt_config) . unwrap ( ) ;
201+ let runtime = RuntimeEnv :: try_new ( rt_config) ?;
239202
240203 let mut session_ctx = SessionContext :: new_with_config_rt ( session_config, Arc :: new ( runtime) ) ;
241204
0 commit comments