Skip to content

Commit d3ea9fd

Browse files
authored
feat: pass spark.comet.datafusion.* configs through to DataFusion session (apache#3455)
1 parent 53f4cf7 commit d3ea9fd

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,16 @@ object CometConf extends ShimCometConf {
828828
.bytesConf(ByteUnit.BYTE)
829829
.createWithDefault(100L * 1024 * 1024 * 1024) // 100 GB
830830

831+
val COMET_RESPECT_DATAFUSION_CONFIGS: ConfigEntry[Boolean] =
832+
conf(s"$COMET_EXEC_CONFIG_PREFIX.respectDataFusionConfigs")
833+
.category(CATEGORY_TESTING)
834+
.doc(
835+
"Development and testing configuration option to allow DataFusion configs set in " +
836+
"Spark configuration settings starting with `spark.comet.datafusion.` to be passed " +
837+
"into native execution.")
838+
.booleanConf
839+
.createWithDefault(false)
840+
831841
val COMET_STRICT_TESTING: ConfigEntry[Boolean] = conf(s"$COMET_PREFIX.testing.strict")
832842
.category(CATEGORY_TESTING)
833843
.doc("Experimental option to enable strict testing, which will fail tests that could be " +

native/core/src/execution/jni_api.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
255255
local_dirs_vec,
256256
max_temp_directory_size,
257257
task_cpus as usize,
258+
&spark_config,
258259
)?;
259260

260261
let plan_creation_time = start.elapsed();
@@ -309,6 +310,7 @@ fn prepare_datafusion_session_context(
309310
local_dirs: Vec<String>,
310311
max_temp_directory_size: u64,
311312
task_cpus: usize,
313+
spark_config: &HashMap<String, String>,
312314
) -> CometResult<SessionContext> {
313315
let paths = local_dirs.into_iter().map(PathBuf::from).collect();
314316
let disk_manager = DiskManagerBuilder::default()
@@ -317,10 +319,7 @@ fn prepare_datafusion_session_context(
317319
let mut rt_config = RuntimeEnvBuilder::new().with_disk_manager_builder(disk_manager);
318320
rt_config = rt_config.with_memory_pool(memory_pool);
319321

320-
// Get Datafusion configuration from Spark Execution context
321-
// can be configured in Comet Spark JVM using Spark --conf parameters
322-
// e.g: spark-shell --conf spark.datafusion.sql_parser.parse_float_as_decimal=true
323-
let session_config = SessionConfig::new()
322+
let mut session_config = SessionConfig::new()
324323
.with_target_partitions(task_cpus)
325324
// This DataFusion context is within the scope of an executing Spark Task. We want to set
326325
// its internal parallelism to the number of CPUs allocated to Spark Tasks. This can be
@@ -337,6 +336,17 @@ fn prepare_datafusion_session_context(
337336
&ScalarValue::Float64(Some(1.1)),
338337
);
339338

339+
// Pass through DataFusion configs from Spark.
340+
// e.g: spark-shell --conf spark.comet.datafusion.sql_parser.parse_float_as_decimal=true
341+
// becomes datafusion.sql_parser.parse_float_as_decimal=true
342+
const SPARK_COMET_DF_PREFIX: &str = "spark.comet.datafusion.";
343+
for (key, value) in spark_config {
344+
if let Some(df_key) = key.strip_prefix(SPARK_COMET_DF_PREFIX) {
345+
let df_key = format!("datafusion.{df_key}");
346+
session_config = session_config.set_str(&df_key, value);
347+
}
348+
}
349+
340350
let runtime = rt_config.build()?;
341351

342352
let mut session_ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime));

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,13 @@ object CometExecIterator extends Logging {
270270
def serializeCometSQLConfs(): Array[Byte] = {
271271
val builder = ConfigMap.newBuilder()
272272
cometSqlConfs.foreach { case (k, v) =>
273-
builder.putEntries(k, v)
273+
if (k.startsWith(s"${CometConf.COMET_PREFIX}.datafusion.")) {
274+
if (CometConf.COMET_RESPECT_DATAFUSION_CONFIGS.get(SQLConf.get)) {
275+
builder.putEntries(k, v)
276+
}
277+
} else {
278+
builder.putEntries(k, v)
279+
}
274280
}
275281
builder.build().toByteArray
276282
}

0 commit comments

Comments
 (0)