Skip to content

Commit 3f66495

Browse files
authored
chore: Pass Spark configs to native createPlan (#2180)
1 parent fa27427 commit 3f66495

File tree

7 files changed

+65
-3
lines changed

7 files changed

+65
-3
lines changed

native/core/src/execution/jni_api.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ use jni::{
6060
objects::GlobalRef,
6161
sys::{jboolean, jdouble, jintArray, jobjectArray, jstring},
6262
};
63+
use std::collections::HashMap;
6364
use std::path::PathBuf;
6465
use std::time::{Duration, Instant};
6566
use std::{sync::Arc, task::Poll};
@@ -151,6 +152,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
151152
id: jlong,
152153
iterators: jobjectArray,
153154
serialized_query: jbyteArray,
155+
serialized_spark_configs: jbyteArray,
154156
partition_count: jint,
155157
metrics_node: JObject,
156158
metrics_update_interval: jlong,
@@ -173,12 +175,20 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
173175

174176
let start = Instant::now();
175177

178+
// Deserialize query plan
176179
let array = unsafe { JPrimitiveArray::from_raw(serialized_query) };
177180
let bytes = env.convert_byte_array(array)?;
178-
179-
// Deserialize query plan
180181
let spark_plan = serde::deserialize_op(bytes.as_slice())?;
181182

183+
// Deserialize Spark configs
184+
let array = unsafe { JPrimitiveArray::from_raw(serialized_spark_configs) };
185+
let bytes = env.convert_byte_array(array)?;
186+
let spark_configs = serde::deserialize_config(bytes.as_slice())?;
187+
188+
// Convert Spark configs to HashMap
189+
let _spark_config_map: HashMap<String, String> =
190+
spark_configs.entries.into_iter().collect();
191+
182192
let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?);
183193

184194
// Get the global references of input sources

native/core/src/execution/serde.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::errors::ExpressionError;
2222
use arrow::datatypes::{DataType as ArrowDataType, TimeUnit};
2323
use arrow::datatypes::{Field, Fields};
2424
use datafusion_comet_proto::{
25-
spark_expression,
25+
spark_config, spark_expression,
2626
spark_expression::data_type::{
2727
data_type_info::DatatypeStruct,
2828
DataTypeId,
@@ -74,6 +74,14 @@ pub fn deserialize_op(buf: &[u8]) -> Result<spark_operator::Operator, ExecutionE
7474
}
7575
}
7676

77+
/// Deserialize bytes to protobuf type of data type
78+
pub fn deserialize_config(buf: &[u8]) -> Result<spark_config::ConfigMap, ExecutionError> {
79+
match spark_config::ConfigMap::decode(&mut Cursor::new(buf)) {
80+
Ok(e) => Ok(e),
81+
Err(err) => Err(ExecutionError::from(err)),
82+
}
83+
}
84+
7785
/// Deserialize bytes to protobuf type of data type
7886
pub fn deserialize_data_type(buf: &[u8]) -> Result<spark_expression::DataType, ExecutionError> {
7987
match spark_expression::DataType::decode(&mut Cursor::new(buf)) {

native/proto/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ fn main() -> Result<()> {
3333
"src/proto/metric.proto",
3434
"src/proto/partitioning.proto",
3535
"src/proto/operator.proto",
36+
"src/proto/config.proto",
3637
],
3738
&["src/proto"],
3839
)?;

native/proto/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,9 @@ pub mod spark_operator {
4343
pub mod spark_metric {
4444
include!(concat!("generated", "/spark.spark_metric.rs"));
4545
}
46+
47+
// Include generated modules from .proto files.
48+
#[allow(missing_docs)]
49+
pub mod spark_config {
50+
include!(concat!("generated", "/spark.spark_config.rs"));
51+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
syntax = "proto3";
19+
20+
package spark.spark_config;
21+
22+
option java_package = "org.apache.comet.serde";
23+
24+
message ConfigMap {
25+
map<string, string> entries = 1;
26+
}

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.vectorized._
3232

3333
import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_POOL_TYPE, COMET_EXPLAIN_NATIVE_ENABLED, COMET_METRICS_UPDATE_INTERVAL}
3434
import org.apache.comet.Tracing.withTrace
35+
import org.apache.comet.serde.Config.ConfigMap
3536
import org.apache.comet.vector.NativeUtil
3637

3738
/**
@@ -84,10 +85,19 @@ class CometExecIterator(
8485
// and `memory_fraction` below.
8586
CometSparkSessionExtensions.getCometMemoryOverhead(conf)
8687
}
88+
89+
// serialize Spark conf in protobuf format
90+
val builder = ConfigMap.newBuilder()
91+
conf.getAll.foreach { case (k, v) =>
92+
builder.putEntries(k, v)
93+
}
94+
val protobufSparkConfigs = builder.build().toByteArray
95+
8796
nativeLib.createPlan(
8897
id,
8998
cometBatchIterators,
9099
protobufQueryPlan,
100+
protobufSparkConfigs,
91101
numParts,
92102
nativeMetrics,
93103
metricsUpdateInterval = COMET_METRICS_UPDATE_INTERVAL.get(),

spark/src/main/scala/org/apache/comet/Native.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class Native extends NativeBase {
5454
id: Long,
5555
iterators: Array[CometBatchIterator],
5656
plan: Array[Byte],
57+
configMapProto: Array[Byte],
5758
partitionCount: Int,
5859
metrics: CometMetricNode,
5960
metricsUpdateInterval: Long,

0 commit comments

Comments
 (0)