Skip to content

Commit c4c998c

Browse files
authored
feat: pushdown filter for native_iceberg_compat (#1566)
* feat: pushdown filter for native_iceberg_compat * fix style * add data schema * fix filter bound * fix in expr * add primitive type tests * enable native_datafusion test
1 parent e982aad commit c4c998c

File tree

10 files changed

+490
-33
lines changed

10 files changed

+490
-33
lines changed

common/src/main/java/org/apache/comet/parquet/Native.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ public static native long initRecordBatchReader(
253253
long fileSize,
254254
long start,
255255
long length,
256+
byte[] filter,
256257
byte[] requiredSchema,
258+
byte[] dataSchema,
257259
String sessionTimezone);
258260

259261
// arrow native version of read batch

common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme
108108
private final Map<String, SQLMetric> metrics;
109109

110110
private StructType sparkSchema;
111+
private StructType dataSchema;
111112
private MessageType requestedSchema;
112113
private CometVector[] vectors;
113114
private AbstractColumnReader[] columnReaders;
@@ -117,6 +118,7 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme
117118
private boolean[] missingColumns;
118119
private boolean isInitialized;
119120
private ParquetMetadata footer;
121+
private byte[] nativeFilter;
120122

121123
/**
122124
* Whether the native scan should always return decimal represented by 128 bits, regardless of its
@@ -190,8 +192,10 @@ public NativeBatchReader(AbstractColumnReader[] columnReaders) {
190192
Configuration conf,
191193
PartitionedFile inputSplit,
192194
ParquetMetadata footer,
195+
byte[] nativeFilter,
193196
int capacity,
194197
StructType sparkSchema,
198+
StructType dataSchema,
195199
boolean isCaseSensitive,
196200
boolean useFieldId,
197201
boolean ignoreMissingIds,
@@ -202,6 +206,7 @@ public NativeBatchReader(AbstractColumnReader[] columnReaders) {
202206
this.conf = conf;
203207
this.capacity = capacity;
204208
this.sparkSchema = sparkSchema;
209+
this.dataSchema = dataSchema;
205210
this.isCaseSensitive = isCaseSensitive;
206211
this.useFieldId = useFieldId;
207212
this.ignoreMissingIds = ignoreMissingIds;
@@ -210,6 +215,7 @@ public NativeBatchReader(AbstractColumnReader[] columnReaders) {
210215
this.partitionValues = partitionValues;
211216
this.file = inputSplit;
212217
this.footer = footer;
218+
this.nativeFilter = nativeFilter;
213219
this.metrics = metrics;
214220
this.taskContext = TaskContext$.MODULE$.get();
215221
}
@@ -262,10 +268,9 @@ public void init() throws URISyntaxException, IOException {
262268
String timeZoneId = conf.get("spark.sql.session.timeZone");
263269
// Native code uses "UTC" always as the timeZoneId when converting from spark to arrow schema.
264270
Schema arrowSchema = Utils$.MODULE$.toArrowSchema(sparkSchema, "UTC");
265-
ByteArrayOutputStream out = new ByteArrayOutputStream();
266-
WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out));
267-
MessageSerializer.serialize(writeChannel, arrowSchema);
268-
byte[] serializedRequestedArrowSchema = out.toByteArray();
271+
byte[] serializedRequestedArrowSchema = serializeArrowSchema(arrowSchema);
272+
Schema dataArrowSchema = Utils$.MODULE$.toArrowSchema(dataSchema, "UTC");
273+
byte[] serializedDataArrowSchema = serializeArrowSchema(dataArrowSchema);
269274

270275
//// Create Column readers
271276
List<ColumnDescriptor> columns = requestedSchema.getColumns();
@@ -350,7 +355,14 @@ public void init() throws URISyntaxException, IOException {
350355

351356
this.handle =
352357
Native.initRecordBatchReader(
353-
filePath, fileSize, start, length, serializedRequestedArrowSchema, timeZoneId);
358+
filePath,
359+
fileSize,
360+
start,
361+
length,
362+
nativeFilter,
363+
serializedRequestedArrowSchema,
364+
serializedDataArrowSchema,
365+
timeZoneId);
354366
isInitialized = true;
355367
}
356368

@@ -524,4 +536,11 @@ private int loadNextBatch() throws Throwable {
524536
return Option.apply(null); // None
525537
}
526538
}
539+
540+
private byte[] serializeArrowSchema(Schema schema) throws IOException {
541+
ByteArrayOutputStream out = new ByteArrayOutputStream();
542+
WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out));
543+
MessageSerializer.serialize(writeChannel, schema);
544+
return out.toByteArray();
545+
}
527546
}

native/core/src/execution/planner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ impl PhysicalPlanner {
230230
}
231231

232232
/// Create a DataFusion physical expression from Spark physical expression
233-
fn create_expr(
233+
pub(crate) fn create_expr(
234234
&self,
235235
spark_expr: &Expr,
236236
input_schema: SchemaRef,

native/core/src/parquet/mod.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ use jni::{
4545

4646
use self::util::jni::TypePromotionInfo;
4747
use crate::execution::operators::ExecutionError;
48+
use crate::execution::planner::PhysicalPlanner;
49+
use crate::execution::serde;
4850
use crate::execution::utils::SparkArrowConvert;
4951
use crate::parquet::data_type::AsBytes;
5052
use crate::parquet::parquet_exec::init_datasource_exec;
@@ -644,7 +646,9 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
644646
file_size: jlong,
645647
start: jlong,
646648
length: jlong,
649+
filter: jbyteArray,
647650
required_schema: jbyteArray,
651+
data_schema: jbyteArray,
648652
session_timezone: jstring,
649653
) -> jlong {
650654
try_unwrap_or_throw(&e, |mut env| unsafe {
@@ -666,6 +670,23 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
666670
let required_schema_buffer = env.convert_byte_array(&required_schema_array)?;
667671
let required_schema = Arc::new(deserialize_schema(required_schema_buffer.as_bytes())?);
668672

673+
let data_schema_array = JByteArray::from_raw(data_schema);
674+
let data_schema_buffer = env.convert_byte_array(&data_schema_array)?;
675+
let data_schema = Arc::new(deserialize_schema(data_schema_buffer.as_bytes())?);
676+
677+
let planer = PhysicalPlanner::default();
678+
679+
let data_filters = if !filter.is_null() {
680+
let filter_array = JByteArray::from_raw(filter);
681+
let filter_buffer = env.convert_byte_array(&filter_array)?;
682+
let filter_expr = serde::deserialize_expr(filter_buffer.as_slice())?;
683+
Some(vec![
684+
planer.create_expr(&filter_expr, Arc::clone(&data_schema))?
685+
])
686+
} else {
687+
None
688+
};
689+
669690
let file_groups =
670691
get_file_groups_single_file(&object_store_path, file_size as u64, start, length);
671692

@@ -676,13 +697,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
676697

677698
let scan = init_datasource_exec(
678699
required_schema,
679-
None,
700+
Some(data_schema),
680701
None,
681702
None,
682703
object_store_url,
683704
file_groups,
684705
None,
685-
None,
706+
data_filters,
686707
session_timezone.as_str(),
687708
)?;
688709

spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -114,36 +114,33 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with
114114
footerFileMetaData,
115115
datetimeRebaseModeInRead)
116116

117-
val pushed = if (parquetFilterPushDown) {
118-
val parquetSchema = footerFileMetaData.getSchema
119-
val parquetFilters = new ParquetFilters(
120-
parquetSchema,
121-
pushDownDate,
122-
pushDownTimestamp,
123-
pushDownDecimal,
124-
pushDownStringPredicate,
125-
pushDownInFilterThreshold,
126-
isCaseSensitive,
127-
datetimeRebaseSpec)
128-
filters
129-
// Collects all converted Parquet filter predicates. Notice that not all predicates can
130-
// be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a
131-
// `flatMap` is used here.
132-
.flatMap(parquetFilters.createFilter)
133-
.reduceOption(FilterApi.and)
134-
} else {
135-
None
136-
}
137-
pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p))
117+
val parquetSchema = footerFileMetaData.getSchema
118+
val parquetFilters = new ParquetFilters(
119+
parquetSchema,
120+
dataSchema,
121+
pushDownDate,
122+
pushDownTimestamp,
123+
pushDownDecimal,
124+
pushDownStringPredicate,
125+
pushDownInFilterThreshold,
126+
isCaseSensitive,
127+
datetimeRebaseSpec)
138128

139129
val recordBatchReader =
140130
if (nativeIcebergCompat) {
131+
val pushed = if (parquetFilterPushDown) {
132+
parquetFilters.createNativeFilters(filters)
133+
} else {
134+
None
135+
}
141136
val batchReader = new NativeBatchReader(
142137
sharedConf,
143138
file,
144139
footer,
140+
pushed.orNull,
145141
capacity,
146142
requiredSchema,
143+
dataSchema,
147144
isCaseSensitive,
148145
useFieldId,
149146
ignoreMissingIds,
@@ -160,6 +157,18 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with
160157
}
161158
batchReader
162159
} else {
160+
val pushed = if (parquetFilterPushDown) {
161+
filters
162+
// Collects all converted Parquet filter predicates. Notice that not all predicates
163+
// can be converted (`ParquetFilters.createFilter` returns an `Option`). That's why
164+
// a `flatMap` is used here.
165+
.flatMap(parquetFilters.createFilter)
166+
.reduceOption(FilterApi.and)
167+
} else {
168+
None
169+
}
170+
pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p))
171+
163172
val batchReader = new BatchReader(
164173
sharedConf,
165174
file,

spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ case class CometParquetPartitionReaderFactory(
199199
val parquetSchema = footerFileMetaData.getSchema
200200
val parquetFilters = new ParquetFilters(
201201
parquetSchema,
202+
readDataSchema,
202203
pushDownDate,
203204
pushDownTimestamp,
204205
pushDownDecimal,

0 commit comments

Comments
 (0)