Skip to content

Commit 2671e0c

Browse files
authored
Stop passing Java config map into native createPlan (#1101)
1 parent 36a2307 commit 2671e0c

File tree

3 files changed

+28
-83
lines changed

3 files changed

+28
-83
lines changed

native/core/src/execution/jni_api.rs

Lines changed: 16 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ use futures::poll;
3131
use jni::{
3232
errors::Result as JNIResult,
3333
objects::{
34-
JByteArray, JClass, JIntArray, JLongArray, JMap, JObject, JObjectArray, JPrimitiveArray,
35-
JString, ReleaseMode,
34+
JByteArray, JClass, JIntArray, JLongArray, JObject, JObjectArray, JPrimitiveArray, JString,
35+
ReleaseMode,
3636
},
3737
sys::{jbyteArray, jint, jlong, jlongArray},
3838
JNIEnv,
@@ -77,8 +77,6 @@ struct ExecutionContext {
7777
pub input_sources: Vec<Arc<GlobalRef>>,
7878
/// The record batch stream to pull results from
7979
pub stream: Option<SendableRecordBatchStream>,
80-
/// Configurations for DF execution
81-
pub conf: HashMap<String, String>,
8280
/// The Tokio runtime used for async.
8381
pub runtime: Runtime,
8482
/// Native metrics
@@ -103,11 +101,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
103101
e: JNIEnv,
104102
_class: JClass,
105103
id: jlong,
106-
config_object: JObject,
107104
iterators: jobjectArray,
108105
serialized_query: jbyteArray,
109106
metrics_node: JObject,
110107
comet_task_memory_manager_obj: JObject,
108+
batch_size: jint,
109+
debug_native: jboolean,
110+
explain_native: jboolean,
111+
worker_threads: jint,
112+
blocking_threads: jint,
111113
) -> jlong {
112114
try_unwrap_or_throw(&e, |mut env| {
113115
// Init JVM classes
@@ -121,36 +123,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
121123
// Deserialize query plan
122124
let spark_plan = serde::deserialize_op(bytes.as_slice())?;
123125

124-
// Sets up context
125-
let mut configs = HashMap::new();
126-
127-
let config_map = JMap::from_env(&mut env, &config_object)?;
128-
let mut map_iter = config_map.iter(&mut env)?;
129-
while let Some((key, value)) = map_iter.next(&mut env)? {
130-
let key: String = env.get_string(&JString::from(key)).unwrap().into();
131-
let value: String = env.get_string(&JString::from(value)).unwrap().into();
132-
configs.insert(key, value);
133-
}
134-
135-
// Whether we've enabled additional debugging on the native side
136-
let debug_native = parse_bool(&configs, "debug_native")?;
137-
let explain_native = parse_bool(&configs, "explain_native")?;
138-
139-
let worker_threads = configs
140-
.get("worker_threads")
141-
.map(String::as_str)
142-
.unwrap_or("4")
143-
.parse::<usize>()?;
144-
let blocking_threads = configs
145-
.get("blocking_threads")
146-
.map(String::as_str)
147-
.unwrap_or("10")
148-
.parse::<usize>()?;
149-
150126
// Use multi-threaded tokio runtime to prevent blocking spawned tasks if any
151127
let runtime = tokio::runtime::Builder::new_multi_thread()
152-
.worker_threads(worker_threads)
153-
.max_blocking_threads(blocking_threads)
128+
.worker_threads(worker_threads as usize)
129+
.max_blocking_threads(blocking_threads as usize)
154130
.enable_all()
155131
.build()?;
156132

@@ -171,7 +147,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
171147
// We need to keep the session context alive. Some session state like temporary
172148
// dictionaries are stored in session context. If it is dropped, the temporary
173149
// dictionaries will be dropped as well.
174-
let session = prepare_datafusion_session_context(&configs, task_memory_manager)?;
150+
let session = prepare_datafusion_session_context(batch_size as usize, task_memory_manager)?;
175151

176152
let plan_creation_time = start.elapsed();
177153

@@ -182,33 +158,24 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
182158
scans: vec![],
183159
input_sources,
184160
stream: None,
185-
conf: configs,
186161
runtime,
187162
metrics,
188163
plan_creation_time,
189164
session_ctx: Arc::new(session),
190-
debug_native,
191-
explain_native,
165+
debug_native: debug_native == 1,
166+
explain_native: explain_native == 1,
192167
metrics_jstrings: HashMap::new(),
193168
});
194169

195170
Ok(Box::into_raw(exec_context) as i64)
196171
})
197172
}
198173

199-
/// Parse Comet configs and configure DataFusion session context.
174+
/// Configure DataFusion session context.
200175
fn prepare_datafusion_session_context(
201-
conf: &HashMap<String, String>,
176+
batch_size: usize,
202177
comet_task_memory_manager: Arc<GlobalRef>,
203178
) -> CometResult<SessionContext> {
204-
// Get the batch size from Comet JVM side
205-
let batch_size = conf
206-
.get("batch_size")
207-
.ok_or(CometError::Internal(
208-
"Config 'batch_size' is not specified from Comet JVM side".to_string(),
209-
))?
210-
.parse::<usize>()?;
211-
212179
let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs);
213180

214181
// Set Comet memory pool for native
@@ -218,7 +185,7 @@ fn prepare_datafusion_session_context(
218185
// Get Datafusion configuration from Spark Execution context
219186
// can be configured in Comet Spark JVM using Spark --conf parameters
220187
// e.g: spark-shell --conf spark.datafusion.sql_parser.parse_float_as_decimal=true
221-
let mut session_config = SessionConfig::new()
188+
let session_config = SessionConfig::new()
222189
.with_batch_size(batch_size)
223190
// DataFusion partial aggregates can emit duplicate rows so we disable the
224191
// skip partial aggregation feature because this is not compatible with Spark's
@@ -231,11 +198,7 @@ fn prepare_datafusion_session_context(
231198
&ScalarValue::Float64(Some(1.1)),
232199
);
233200

234-
for (key, value) in conf.iter().filter(|(k, _)| k.starts_with("datafusion.")) {
235-
session_config = session_config.set_str(key, value);
236-
}
237-
238-
let runtime = RuntimeEnv::try_new(rt_config).unwrap();
201+
let runtime = RuntimeEnv::try_new(rt_config)?;
239202

240203
let mut session_ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime));
241204

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,43 +60,23 @@ class CometExecIterator(
6060
new CometBatchIterator(iterator, nativeUtil)
6161
}.toArray
6262
private val plan = {
63-
val configs = createNativeConf
6463
nativeLib.createPlan(
6564
id,
66-
configs,
6765
cometBatchIterators,
6866
protobufQueryPlan,
6967
nativeMetrics,
70-
new CometTaskMemoryManager(id))
68+
new CometTaskMemoryManager(id),
69+
batchSize = COMET_BATCH_SIZE.get(),
70+
debug = COMET_DEBUG_ENABLED.get(),
71+
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
72+
workerThreads = COMET_WORKER_THREADS.get(),
73+
blockingThreads = COMET_BLOCKING_THREADS.get())
7174
}
7275

7376
private var nextBatch: Option[ColumnarBatch] = None
7477
private var currentBatch: ColumnarBatch = null
7578
private var closed: Boolean = false
7679

77-
/**
78-
* Creates a new configuration map to be passed to the native side.
79-
*/
80-
private def createNativeConf: java.util.HashMap[String, String] = {
81-
val result = new java.util.HashMap[String, String]()
82-
val conf = SparkEnv.get.conf
83-
84-
result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
85-
result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
86-
result.put("explain_native", String.valueOf(COMET_EXPLAIN_NATIVE_ENABLED.get()))
87-
result.put("worker_threads", String.valueOf(COMET_WORKER_THREADS.get()))
88-
result.put("blocking_threads", String.valueOf(COMET_BLOCKING_THREADS.get()))
89-
90-
// Strip mandatory prefix spark. which is not required for DataFusion session params
91-
conf.getAll.foreach {
92-
case (k, v) if k.startsWith("spark.datafusion") =>
93-
result.put(k.replaceFirst("spark\\.", ""), v)
94-
case _ =>
95-
}
96-
97-
result
98-
}
99-
10080
def getNextBatch(): Option[ColumnarBatch] = {
10181
assert(partitionIndex >= 0 && partitionIndex < numParts)
10282

spark/src/main/scala/org/apache/comet/Native.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
package org.apache.comet
2121

22-
import java.util.Map
23-
2422
import org.apache.spark.CometTaskMemoryManager
2523
import org.apache.spark.sql.comet.CometMetricNode
2624

@@ -47,11 +45,15 @@ class Native extends NativeBase {
4745
*/
4846
@native def createPlan(
4947
id: Long,
50-
configMap: Map[String, String],
5148
iterators: Array[CometBatchIterator],
5249
plan: Array[Byte],
5350
metrics: CometMetricNode,
54-
taskMemoryManager: CometTaskMemoryManager): Long
51+
taskMemoryManager: CometTaskMemoryManager,
52+
batchSize: Int,
53+
debug: Boolean,
54+
explain: Boolean,
55+
workerThreads: Int,
56+
blockingThreads: Int): Long
5557

5658
/**
5759
* Execute a native query plan based on given input Arrow arrays.

0 commit comments

Comments
 (0)