Skip to content

Commit 9da11c5

Browse files
authored
fix: default values for native_datafusion scan (apache#1756)
1 parent e25be9e commit 9da11c5

File tree

9 files changed

+184
-29
lines changed

9 files changed

+184
-29
lines changed

native/core/src/execution/planner.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,42 @@ impl PhysicalPlanner {
11151115
.map(|expr| self.create_expr(expr, Arc::clone(&required_schema)))
11161116
.collect();
11171117

1118+
let default_values: Option<HashMap<usize, ScalarValue>> = if !scan
1119+
.default_values
1120+
.is_empty()
1121+
{
1122+
// We have default values. Extract the two lists (same length) of values and
1123+
// indexes in the schema, and then create a HashMap to use in the SchemaMapper.
1124+
let default_values: Result<Vec<ScalarValue>, DataFusionError> = scan
1125+
.default_values
1126+
.iter()
1127+
.map(|expr| {
1128+
let literal = self.create_expr(expr, Arc::clone(&required_schema))?;
1129+
let df_literal = literal
1130+
.as_any()
1131+
.downcast_ref::<DataFusionLiteral>()
1132+
.ok_or_else(|| {
1133+
GeneralError("Expected literal of default value.".to_string())
1134+
})?;
1135+
Ok(df_literal.value().clone())
1136+
})
1137+
.collect();
1138+
let default_values = default_values?;
1139+
let default_values_indexes: Vec<usize> = scan
1140+
.default_values_indexes
1141+
.iter()
1142+
.map(|offset| *offset as usize)
1143+
.collect();
1144+
Some(
1145+
default_values_indexes
1146+
.into_iter()
1147+
.zip(default_values)
1148+
.collect(),
1149+
)
1150+
} else {
1151+
None
1152+
};
1153+
11181154
// Get one file from the list of files
11191155
let one_file = scan
11201156
.file_partitions
@@ -1152,6 +1188,7 @@ impl PhysicalPlanner {
11521188
file_groups,
11531189
Some(projection_vector),
11541190
Some(data_filters?),
1191+
default_values,
11551192
scan.session_timezone.as_str(),
11561193
)?;
11571194
Ok((
@@ -3164,7 +3201,10 @@ mod tests {
31643201

31653202
let source = Arc::new(
31663203
ParquetSource::default().with_schema_adapter_factory(Arc::new(
3167-
SparkSchemaAdapterFactory::new(SparkParquetOptions::new(EvalMode::Ansi, "", false)),
3204+
SparkSchemaAdapterFactory::new(
3205+
SparkParquetOptions::new(EvalMode::Ansi, "", false),
3206+
None,
3207+
),
31683208
)),
31693209
);
31703210

native/core/src/parquet/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
715715
file_groups,
716716
None,
717717
data_filters,
718+
None,
718719
session_timezone.as_str(),
719720
)?;
720721

native/core/src/parquet/parquet_exec.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ use datafusion::datasource::source::DataSourceExec;
2828
use datafusion::execution::object_store::ObjectStoreUrl;
2929
use datafusion::physical_expr::expressions::BinaryExpr;
3030
use datafusion::physical_expr::PhysicalExpr;
31+
use datafusion::scalar::ScalarValue;
3132
use datafusion_comet_spark_expr::EvalMode;
3233
use itertools::Itertools;
34+
use std::collections::HashMap;
3335
use std::sync::Arc;
3436

3537
/// Initializes a DataSourceExec plan with a ParquetSource. This may be used by either the
@@ -61,12 +63,14 @@ pub(crate) fn init_datasource_exec(
6163
file_groups: Vec<Vec<PartitionedFile>>,
6264
projection_vector: Option<Vec<usize>>,
6365
data_filters: Option<Vec<Arc<dyn PhysicalExpr>>>,
66+
default_values: Option<HashMap<usize, ScalarValue>>,
6467
session_timezone: &str,
6568
) -> Result<Arc<DataSourceExec>, ExecutionError> {
6669
let (table_parquet_options, spark_parquet_options) = get_options(session_timezone);
67-
let mut parquet_source = ParquetSource::new(table_parquet_options).with_schema_adapter_factory(
68-
Arc::new(SparkSchemaAdapterFactory::new(spark_parquet_options)),
69-
);
70+
let mut parquet_source =
71+
ParquetSource::new(table_parquet_options).with_schema_adapter_factory(Arc::new(
72+
SparkSchemaAdapterFactory::new(spark_parquet_options, default_values),
73+
));
7074
// Create a conjunctive form of the vector because ParquetExecBuilder takes
7175
// a single expression
7276
if let Some(data_filters) = data_filters {

native/core/src/parquet/parquet_support.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ pub struct SparkParquetOptions {
6262
pub allow_incompat: bool,
6363
/// Support casting unsigned ints to signed ints (used by Parquet SchemaAdapter)
6464
pub allow_cast_unsigned_ints: bool,
65-
/// We also use the cast logic for adapting Parquet schemas, so this flag is used
66-
/// for that use case
67-
pub is_adapting_schema: bool,
6865
/// Whether to always represent decimals using 128 bits. If false, the native reader may represent decimals using 32 or 64 bits, depending on the precision.
6966
pub use_decimal_128: bool,
7067
/// Whether to read dates/timestamps that were written in the legacy hybrid Julian + Gregorian calendar as it is. If false, throw exceptions instead. If the spark type is TimestampNTZ, this should be true.
@@ -80,7 +77,6 @@ impl SparkParquetOptions {
8077
timezone: timezone.to_string(),
8178
allow_incompat,
8279
allow_cast_unsigned_ints: false,
83-
is_adapting_schema: false,
8480
use_decimal_128: false,
8581
use_legacy_date_timestamp_or_ntz: false,
8682
case_sensitive: false,
@@ -93,7 +89,6 @@ impl SparkParquetOptions {
9389
timezone: "".to_string(),
9490
allow_incompat,
9591
allow_cast_unsigned_ints: false,
96-
is_adapting_schema: false,
9792
use_decimal_128: false,
9893
use_legacy_date_timestamp_or_ntz: false,
9994
case_sensitive: false,

native/core/src/parquet/schema_adapter.rs

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
//! Custom schema adapter that uses Spark-compatible conversions
1919
2020
use crate::parquet::parquet_support::{spark_parquet_convert, SparkParquetOptions};
21-
use arrow::array::{new_null_array, RecordBatch, RecordBatchOptions};
21+
use arrow::array::{RecordBatch, RecordBatchOptions};
2222
use arrow::datatypes::{Schema, SchemaRef};
2323
use datafusion::common::ColumnStatistics;
2424
use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper};
2525
use datafusion::physical_plan::ColumnarValue;
26+
use datafusion::scalar::ScalarValue;
27+
use std::collections::HashMap;
2628
use std::sync::Arc;
2729

2830
/// An implementation of DataFusion's `SchemaAdapterFactory` that uses a Spark-compatible
@@ -31,12 +33,17 @@ use std::sync::Arc;
3133
pub struct SparkSchemaAdapterFactory {
3234
/// Spark cast options
3335
parquet_options: SparkParquetOptions,
36+
default_values: Option<HashMap<usize, ScalarValue>>,
3437
}
3538

3639
impl SparkSchemaAdapterFactory {
37-
pub fn new(options: SparkParquetOptions) -> Self {
40+
pub fn new(
41+
options: SparkParquetOptions,
42+
default_values: Option<HashMap<usize, ScalarValue>>,
43+
) -> Self {
3844
Self {
3945
parquet_options: options,
46+
default_values,
4047
}
4148
}
4249
}
@@ -56,6 +63,7 @@ impl SchemaAdapterFactory for SparkSchemaAdapterFactory {
5663
Box::new(SparkSchemaAdapter {
5764
required_schema,
5865
parquet_options: self.parquet_options.clone(),
66+
default_values: self.default_values.clone(),
5967
})
6068
}
6169
}
@@ -69,6 +77,7 @@ pub struct SparkSchemaAdapter {
6977
required_schema: SchemaRef,
7078
/// Spark cast options
7179
parquet_options: SparkParquetOptions,
80+
default_values: Option<HashMap<usize, ScalarValue>>,
7281
}
7382

7483
impl SchemaAdapter for SparkSchemaAdapter {
@@ -134,6 +143,7 @@ impl SchemaAdapter for SparkSchemaAdapter {
134143
required_schema: Arc::<Schema>::clone(&self.required_schema),
135144
field_mappings,
136145
parquet_options: self.parquet_options.clone(),
146+
default_values: self.default_values.clone(),
137147
}),
138148
projection,
139149
))
@@ -158,16 +168,7 @@ impl SchemaAdapter for SparkSchemaAdapter {
158168
/// out of the execution of this query. Thus `map_batch` uses
159169
/// `projected_table_schema` as it can only operate on the projected fields.
160170
///
161-
/// [`map_partial_batch`] is used to create a RecordBatch with a schema that
162-
/// can be used for Parquet predicate pushdown, meaning that it may contain
163-
/// fields which are not in the projected schema (as the fields that parquet
164-
/// pushdown filters operate can be completely distinct from the fields that are
165-
/// projected (output) out of the ParquetExec). `map_partial_batch` thus uses
166-
/// `table_schema` to create the resulting RecordBatch (as it could be operating
167-
/// on any fields in the schema).
168-
///
169171
/// [`map_batch`]: Self::map_batch
170-
/// [`map_partial_batch`]: Self::map_partial_batch
171172
#[derive(Debug)]
172173
pub struct SchemaMapping {
173174
/// The schema of the table. This is the expected schema after conversion
@@ -181,6 +182,7 @@ pub struct SchemaMapping {
181182
field_mappings: Vec<Option<usize>>,
182183
/// Spark cast options
183184
parquet_options: SparkParquetOptions,
185+
default_values: Option<HashMap<usize, ScalarValue>>,
184186
}
185187

186188
impl SchemaMapper for SchemaMapping {
@@ -197,15 +199,43 @@ impl SchemaMapper for SchemaMapping {
197199
// go through each field in the projected schema
198200
.fields()
199201
.iter()
202+
.enumerate()
200203
// and zip it with the index that maps fields from the projected table schema to the
201204
// projected file schema in `batch`
202205
.zip(&self.field_mappings)
203206
// and for each one...
204-
.map(|(field, file_idx)| {
207+
.map(|((field_idx, field), file_idx)| {
205208
file_idx.map_or_else(
206-
// If this field only exists in the table, and not in the file, then we know
207-
// that it's null, so just return that.
208-
|| Ok(new_null_array(field.data_type(), batch_rows)),
209+
// If this field only exists in the table, and not in the file, then we need to
210+
// populate a default value for it.
211+
|| {
212+
if self.default_values.is_some() {
213+
// We have a map of default values, see if this field is in there.
214+
if let Some(value) =
215+
self.default_values.as_ref().unwrap().get(&field_idx)
216+
// Default value exists, construct a column from it.
217+
{
218+
let cv = if field.data_type() == &value.data_type() {
219+
ColumnarValue::Scalar(value.clone())
220+
} else {
221+
// Data types don't match. This can happen when default values
222+
// are stored by Spark in a format different than the column's
223+
// type (e.g., INT32 when the column is DATE32)
224+
spark_parquet_convert(
225+
ColumnarValue::Scalar(value.clone()),
226+
field.data_type(),
227+
&self.parquet_options,
228+
)?
229+
};
230+
return cv.into_array(batch_rows);
231+
}
232+
}
233+
// Construct an entire column of nulls. We use the Scalar representation
234+
// for better performance.
235+
let cv =
236+
ColumnarValue::Scalar(ScalarValue::try_new_null(field.data_type())?);
237+
cv.into_array(batch_rows)
238+
},
209239
// However, if it does exist in both, then try to cast it to the correct output
210240
// type
211241
|batch_idx| {
@@ -316,7 +346,7 @@ mod test {
316346

317347
let parquet_source = Arc::new(
318348
ParquetSource::new(TableParquetOptions::new()).with_schema_adapter_factory(Arc::new(
319-
SparkSchemaAdapterFactory::new(spark_parquet_options),
349+
SparkSchemaAdapterFactory::new(spark_parquet_options, None),
320350
)),
321351
);
322352

native/proto/src/proto/operator.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ message NativeScan {
9191
repeated SparkFilePartition file_partitions = 7;
9292
repeated int64 projection_vector = 8;
9393
string session_timezone = 9;
94+
repeated spark.spark_expression.Expr default_values = 10;
95+
repeated int64 default_values_indexes = 11;
9496
}
9597

9698
message Projection {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, Normalize
3232
import org.apache.spark.sql.catalyst.plans._
3333
import org.apache.spark.sql.catalyst.plans.physical._
3434
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
35+
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
3536
import org.apache.spark.sql.comet._
3637
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
3738
import org.apache.spark.sql.execution
@@ -2307,6 +2308,24 @@ object QueryPlanSerde extends Logging with CometExprShim {
23072308
nativeScanBuilder.addAllDataFilters(dataFilters.asJava)
23082309
}
23092310

2311+
val possibleDefaultValues = getExistenceDefaultValues(scan.requiredSchema)
2312+
if (possibleDefaultValues.exists(_ != null)) {
2313+
// Our schema has default values. Serialize two lists, one with the default values
2314+
// and another with the indexes in the schema so the native side can map missing
2315+
// columns to these default values.
2316+
val (defaultValues, indexes) = possibleDefaultValues.zipWithIndex
2317+
.filter { case (expr, _) => expr != null }
2318+
.map { case (expr, index) =>
2319+
// ResolveDefaultColumnsUtil.getExistenceDefaultValues has evaluated these
2320+
// expressions and they should now just be literals.
2321+
(Literal(expr), index.toLong.asInstanceOf[java.lang.Long])
2322+
}
2323+
.unzip
2324+
nativeScanBuilder.addAllDefaultValues(
2325+
defaultValues.flatMap(exprToProto(_, scan.output)).toIterable.asJava)
2326+
nativeScanBuilder.addAllDefaultValuesIndexes(indexes.toIterable.asJava)
2327+
}
2328+
23102329
// TODO: modify CometNativeScan to generate the file partitions without instantiating RDD.
23112330
scan.inputRDD match {
23122331
case rdd: DataSourceRDD =>
@@ -2331,18 +2350,18 @@ object QueryPlanSerde extends Logging with CometExprShim {
23312350
val requiredSchema = schema2Proto(scan.requiredSchema.fields)
23322351
val dataSchema = schema2Proto(scan.relation.dataSchema.fields)
23332352

2334-
val data_schema_idxs = scan.requiredSchema.fields.map(field => {
2353+
val dataSchemaIndexes = scan.requiredSchema.fields.map(field => {
23352354
scan.relation.dataSchema.fieldIndex(field.name)
23362355
})
2337-
val partition_schema_idxs = Array
2356+
val partitionSchemaIndexes = Array
23382357
.range(
23392358
scan.relation.dataSchema.fields.length,
23402359
scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length)
23412360

2342-
val projection_vector = (data_schema_idxs ++ partition_schema_idxs).map(idx =>
2361+
val projectionVector = (dataSchemaIndexes ++ partitionSchemaIndexes).map(idx =>
23432362
idx.toLong.asInstanceOf[java.lang.Long])
23442363

2345-
nativeScanBuilder.addAllProjectionVector(projection_vector.toIterable.asJava)
2364+
nativeScanBuilder.addAllProjectionVector(projectionVector.toIterable.asJava)
23462365

23472366
// In `CometScanRule`, we ensure partitionSchema is supported.
23482367
assert(partitionSchema.length == scan.relation.partitionSchema.fields.length)

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
5757
}
5858
}
5959

60+
test("parquet default values") {
61+
withTable("t1") {
62+
sql("create table t1(col1 boolean) using parquet")
63+
sql("insert into t1 values(true)")
64+
sql("alter table t1 add column col2 string default 'hello'")
65+
checkSparkAnswerAndOperator("select * from t1")
66+
}
67+
}
68+
6069
test("coalesce should return correct datatype") {
6170
Seq(true, false).foreach { dictionaryEnabled =>
6271
withTempDir { dir =>

0 commit comments

Comments
 (0)