Skip to content

Commit 48af872

Browse files
authored
fix: get_struct field is incorrect when struct in array (#1687)
* fix: get_struct field is incorrect when struct in array * comments * fix: cast list of structs and other cast fixes * clippy * fix: cast list of structs and other cast fixes * fix: cast list of structs and other cast fixes * clippy * clippy
1 parent c89e456 commit 48af872

File tree

3 files changed

+205
-5
lines changed

3 files changed

+205
-5
lines changed

native/core/src/execution/planner.rs

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr}
6969
use crate::execution::operators::ExecutionError::GeneralError;
7070
use crate::execution::shuffle::CompressionCodec;
7171
use crate::execution::spark_plan::SparkPlan;
72-
use crate::parquet::parquet_exec::init_datasource_exec;
7372
use crate::parquet::parquet_support::prepare_object_store;
7473
use datafusion::common::scalar::ScalarStructBuilder;
7574
use datafusion::common::{
@@ -86,6 +85,7 @@ use datafusion::physical_expr::expressions::{Literal, StatsType};
8685
use datafusion::physical_expr::window::WindowExpr;
8786
use datafusion::physical_expr::LexOrdering;
8887

88+
use crate::parquet::parquet_exec::init_datasource_exec;
8989
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
9090
use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec;
9191
use datafusion_comet_proto::spark_operator::SparkFilePartition;
@@ -2503,19 +2503,27 @@ fn create_case_expr(
25032503

25042504
#[cfg(test)]
25052505
mod tests {
2506-
use std::{sync::Arc, task::Poll};
2507-
25082506
use futures::{poll, StreamExt};
2507+
use std::{sync::Arc, task::Poll};
25092508

25102509
use arrow::array::{Array, DictionaryArray, Int32Array, StringArray};
2511-
use arrow::datatypes::DataType;
2510+
use arrow::datatypes::{DataType, Field, Fields, Schema};
2511+
use datafusion::catalog::memory::DataSourceExec;
2512+
use datafusion::datasource::listing::PartitionedFile;
2513+
use datafusion::datasource::object_store::ObjectStoreUrl;
2514+
use datafusion::datasource::physical_plan::{FileGroup, FileScanConfigBuilder, ParquetSource};
2515+
use datafusion::error::DataFusionError;
25122516
use datafusion::logical_expr::ScalarUDF;
2517+
use datafusion::physical_plan::ExecutionPlan;
25132518
use datafusion::{assert_batches_eq, physical_plan::common::collect, prelude::SessionContext};
2519+
use tempfile::TempDir;
25142520
use tokio::sync::mpsc;
25152521

25162522
use crate::execution::{operators::InputBatch, planner::PhysicalPlanner};
25172523

25182524
use crate::execution::operators::ExecutionError;
2525+
use crate::parquet::parquet_support::SparkParquetOptions;
2526+
use crate::parquet::schema_adapter::SparkSchemaAdapterFactory;
25192527
use datafusion_comet_proto::spark_expression::expr::ExprStruct;
25202528
use datafusion_comet_proto::{
25212529
spark_expression::expr::ExprStruct::*,
@@ -2524,6 +2532,7 @@ mod tests {
25242532
spark_operator,
25252533
spark_operator::{operator::OpStruct, Operator},
25262534
};
2535+
use datafusion_comet_spark_expr::EvalMode;
25272536

25282537
#[test]
25292538
fn test_unpack_dictionary_primitive() {
@@ -3083,4 +3092,93 @@ mod tests {
30833092
}
30843093
});
30853094
}
3095+
3096+
/*
3097+
Testing a nested types scenario
3098+
3099+
select arr[0].a, arr[0].c from (
3100+
select array(named_struct('a', 1, 'b', 'n', 'c', 'x')) arr)
3101+
*/
3102+
#[tokio::test]
3103+
async fn test_nested_types() -> Result<(), DataFusionError> {
3104+
let session_ctx = SessionContext::new();
3105+
3106+
// generate test data in the temp folder
3107+
let test_data = "select make_array(named_struct('a', 1, 'b', 'n', 'c', 'x')) c0";
3108+
let tmp_dir = TempDir::new()?;
3109+
let test_path = tmp_dir.path().to_str().unwrap().to_string();
3110+
3111+
let plan = session_ctx
3112+
.sql(test_data)
3113+
.await?
3114+
.create_physical_plan()
3115+
.await?;
3116+
3117+
// Write parquet file into temp folder
3118+
session_ctx
3119+
.write_parquet(plan, test_path.clone(), None)
3120+
.await?;
3121+
3122+
// Define schema Comet reads with
3123+
let required_schema = Schema::new(Fields::from(vec![Field::new(
3124+
"c0",
3125+
DataType::List(
3126+
Field::new(
3127+
"element",
3128+
DataType::Struct(Fields::from(vec![
3129+
Field::new("a", DataType::Int32, true),
3130+
Field::new("c", DataType::Utf8, true),
3131+
] as Vec<Field>)),
3132+
true,
3133+
)
3134+
.into(),
3135+
),
3136+
true,
3137+
)]));
3138+
3139+
// Register all parquet with temp data as file groups
3140+
let mut file_groups: Vec<FileGroup> = vec![];
3141+
for entry in std::fs::read_dir(&test_path)? {
3142+
let entry = entry?;
3143+
let path = entry.path();
3144+
3145+
if path.extension().and_then(|ext| ext.to_str()) == Some("parquet") {
3146+
if let Some(path_str) = path.to_str() {
3147+
file_groups.push(FileGroup::new(vec![PartitionedFile::from_path(
3148+
path_str.into(),
3149+
)?]));
3150+
}
3151+
}
3152+
}
3153+
3154+
let source = Arc::new(
3155+
ParquetSource::default().with_schema_adapter_factory(Arc::new(
3156+
SparkSchemaAdapterFactory::new(SparkParquetOptions::new(EvalMode::Ansi, "", false)),
3157+
)),
3158+
);
3159+
3160+
let object_store_url = ObjectStoreUrl::local_filesystem();
3161+
let file_scan_config =
3162+
FileScanConfigBuilder::new(object_store_url, required_schema.into(), source)
3163+
.with_file_groups(file_groups)
3164+
.build();
3165+
3166+
// Run native read
3167+
let scan = Arc::new(DataSourceExec::new(Arc::new(file_scan_config.clone())));
3168+
let stream = scan.execute(0, session_ctx.task_ctx())?;
3169+
let result: Vec<_> = stream.collect().await;
3170+
3171+
let actual = result.first().unwrap().as_ref().unwrap();
3172+
3173+
let expected = [
3174+
"+----------------+",
3175+
"| c0 |",
3176+
"+----------------+",
3177+
"| [{a: 1, c: x}] |",
3178+
"+----------------+",
3179+
];
3180+
assert_batches_eq!(expected, &[actual.clone()]);
3181+
3182+
Ok(())
3183+
}
30863184
}

native/core/src/parquet/parquet_support.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
// under the License.
1717

1818
use crate::execution::operators::ExecutionError;
19+
use arrow::array::ListArray;
20+
use arrow::compute::can_cast_types;
1921
use arrow::{
2022
array::{
2123
cast::AsArray, new_null_array, types::Int32Type, types::TimestampMicrosecondType, Array,
@@ -156,13 +158,30 @@ fn cast_array(
156158
};
157159
let from_type = array.data_type();
158160

161+
// Try Comet specific handlers first, then arrow-rs cast if supported,
162+
// return uncasted data otherwise
159163
match (from_type, to_type) {
160164
(Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
161165
array.as_struct(),
162166
from_type,
163167
to_type,
164168
parquet_options,
165169
)?),
170+
(List(_), List(to_inner_type)) => {
171+
let list_arr: &ListArray = array.as_list();
172+
let cast_field = cast_array(
173+
Arc::clone(list_arr.values()),
174+
to_inner_type.data_type(),
175+
parquet_options,
176+
)?;
177+
178+
Ok(Arc::new(ListArray::new(
179+
Arc::clone(to_inner_type),
180+
list_arr.offsets().clone(),
181+
cast_field,
182+
list_arr.nulls().cloned(),
183+
)))
184+
}
166185
(Timestamp(TimeUnit::Microsecond, None), Timestamp(TimeUnit::Microsecond, Some(tz))) => {
167186
Ok(Arc::new(
168187
array
@@ -171,7 +190,11 @@ fn cast_array(
171190
.with_timezone(Arc::clone(tz)),
172191
))
173192
}
174-
_ => Ok(cast_with_options(&array, to_type, &PARQUET_OPTIONS)?),
193+
// If Arrow cast supports the cast, delegate the cast to Arrow
194+
_ if can_cast_types(from_type, to_type) => {
195+
Ok(cast_with_options(&array, to_type, &PARQUET_OPTIONS)?)
196+
}
197+
_ => Ok(array),
175198
}
176199
}
177200

spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,4 +224,83 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper
224224
|""".stripMargin,
225225
"select c0 from tbl")
226226
}
227+
test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - second field") {
228+
testSingleLineQuery(
229+
"""
230+
| select array(str0, str1) c0 from
231+
| (
232+
| select
233+
| named_struct('a', 1, 'b', 'n', 'c', 'x') str0,
234+
| named_struct('a', 2, 'b', 'w', 'c', 'y') str1
235+
| )
236+
|""".stripMargin,
237+
"select c0[0].b col0 from tbl")
238+
}
239+
240+
test("native reader - read a STRUCT subfield - field from second") {
241+
withSQLConf(
242+
CometConf.COMET_EXEC_ENABLED.key -> "true",
243+
SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
244+
CometConf.COMET_ENABLED.key -> "true",
245+
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "false",
246+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") {
247+
testSingleLineQuery(
248+
"""
249+
|select 1 a, named_struct('a', 1, 'b', 'n') c0
250+
|""".stripMargin,
251+
"select c0.b from tbl")
252+
}
253+
}
254+
255+
test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - field from first") {
256+
testSingleLineQuery(
257+
"""
258+
| select array(str0, str1) c0 from
259+
| (
260+
| select
261+
| named_struct('a', 1, 'b', 'n', 'c', 'x') str0,
262+
| named_struct('a', 2, 'b', 'w', 'c', 'y') str1
263+
| )
264+
|""".stripMargin,
265+
"select c0[0].a, c0[0].b, c0[0].c from tbl")
266+
}
267+
268+
test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - reverse fields") {
269+
testSingleLineQuery(
270+
"""
271+
| select array(str0, str1) c0 from
272+
| (
273+
| select
274+
| named_struct('a', 1, 'b', 'n', 'c', 'x') str0,
275+
| named_struct('a', 2, 'b', 'w', 'c', 'y') str1
276+
| )
277+
|""".stripMargin,
278+
"select c0[0].c, c0[0].b, c0[0].a from tbl")
279+
}
280+
281+
test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - skip field") {
282+
testSingleLineQuery(
283+
"""
284+
| select array(str0, str1) c0 from
285+
| (
286+
| select
287+
| named_struct('a', 1, 'b', 'n', 'c', 'x') str0,
288+
| named_struct('a', 2, 'b', 'w', 'c', 'y') str1
289+
| )
290+
|""".stripMargin,
291+
"select c0[0].a, c0[0].c from tbl")
292+
}
293+
294+
test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - duplicate first field") {
295+
testSingleLineQuery(
296+
"""
297+
| select array(str0, str1) c0 from
298+
| (
299+
| select
300+
| named_struct('a', 1, 'b', 'n', 'c', 'x') str0,
301+
| named_struct('a', 2, 'b', 'w', 'c', 'y') str1
302+
| )
303+
|""".stripMargin,
304+
"select c0[0].a, c0[0].a from tbl")
305+
}
227306
}

0 commit comments

Comments
 (0)