Skip to content

Commit d682d22

Browse files
authored
feat: Support reading CSV files with inconsistent column counts (#17553)
* feat: Support CSV files with inconsistent column counts Enable DataFusion to read directories containing CSV files with different numbers of columns by implementing schema union during inference. Changes: - Modified CSV schema inference to create union schema from all files - Extended infer_schema_from_stream to handle varying column counts - Added tests for schema building logic and integration scenarios Requires CsvReadOptions::new().truncated_rows(true) to handle files with fewer columns than the inferred schema. Fixes #17516 * refactor: Address review comments for CSV union schema feature Addresses all review feedback from PR #17553 to improve the CSV schema union implementation that allows reading CSV files with different column counts. Changes based on review: - Moved unit tests from separate tests.rs to bottom of file_format.rs - Updated documentation wording from "now supports" to "can handle" - Removed all println statements from integration test - Added comprehensive assertions for actual row content verification - Simplified HashSet initialization using HashSet::from([...]) syntax - Updated truncated_rows config documentation to reflect expanded purpose - Removed unnecessary min() calculation in column processing loop - Fixed clippy warnings by using enumerate() instead of range loop Technical improvements: - Tests now verify null patterns correctly across union schema - Cleaner iteration logic without redundant bounds checking - Better documentation explaining union schema behavior The feature continues to work as designed: - Creates union schema from all CSV files in a directory - Files with fewer columns have nulls for missing fields - Requires explicit opt-in via truncated_rows(true) - Maintains full backward compatibility * Apply cargo fmt formatting fixes * refactor: Address PR review comments for CSV union schema feature - Remove pub(crate) visibility from build_schema_helper function - Refactor column type processing to use zip iterator before extension logic - Add missing error handling for truncated_rows=false case - Improve truncated_rows documentation to clarify dual purpose - Replace manual testing with assert_snapshot for better test coverage - Fix clippy warnings and ensure all tests pass Addresses all reviewer feedback from PR #17553 while maintaining backward compatibility and the original CSV union schema functionality.
1 parent 1cc4daf commit d682d22

File tree

3 files changed

+247
-11
lines changed

3 files changed

+247
-11
lines changed

datafusion/common/src/config.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2535,9 +2535,15 @@ config_namespace! {
25352535
// The input regex for Nulls when loading CSVs.
25362536
pub null_regex: Option<String>, default = None
25372537
pub comment: Option<u8>, default = None
2538-
// Whether to allow truncated rows when parsing.
2539-
// By default this is set to false and will error if the CSV rows have different lengths.
2540-
// When set to true then it will allow records with less than the expected number of columns
2538+
/// Whether to allow truncated rows when parsing, both within a single file and across files.
2539+
///
2540+
/// When set to false (default), reading a single CSV file which has rows of different lengths will
2541+
/// error; if reading multiple CSV files with different number of columns, it will also fail.
2542+
///
2543+
/// When set to true, reading a single CSV file with rows of different lengths will pad the truncated
2544+
/// rows with null values for the missing columns; if reading multiple CSV files with different number
2545+
/// of columns, it creates a union schema containing all columns found across the files, and will
2546+
/// pad any files missing columns with null values for their rows.
25412547
pub truncated_rows: Option<bool>, default = None
25422548
}
25432549
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Test for CSV schema inference with different column counts (GitHub issue #17516)
19+
20+
use datafusion::error::Result;
21+
use datafusion::prelude::*;
22+
use datafusion_common::test_util::batches_to_sort_string;
23+
use insta::assert_snapshot;
24+
use std::fs;
25+
use tempfile::TempDir;
26+
27+
#[tokio::test]
28+
async fn test_csv_schema_inference_different_column_counts() -> Result<()> {
29+
// Create temporary directory for test files
30+
let temp_dir = TempDir::new().expect("Failed to create temp dir");
31+
let temp_path = temp_dir.path();
32+
33+
// Create CSV file 1 with 3 columns (simulating older railway services format)
34+
let csv1_content = r#"service_id,route_type,agency_id
35+
1,bus,agency1
36+
2,rail,agency2
37+
3,bus,agency3
38+
"#;
39+
fs::write(temp_path.join("services_2024.csv"), csv1_content)?;
40+
41+
// Create CSV file 2 with 6 columns (simulating newer railway services format)
42+
let csv2_content = r#"service_id,route_type,agency_id,stop_platform_change,stop_planned_platform,stop_actual_platform
43+
4,rail,agency2,true,Platform A,Platform B
44+
5,bus,agency1,false,Stop 1,Stop 1
45+
6,rail,agency3,true,Platform C,Platform D
46+
"#;
47+
fs::write(temp_path.join("services_2025.csv"), csv2_content)?;
48+
49+
// Create DataFusion context
50+
let ctx = SessionContext::new();
51+
52+
// This should now work (previously would have failed with column count mismatch)
53+
// Enable truncated_rows to handle files with different column counts
54+
let df = ctx
55+
.read_csv(
56+
temp_path.to_str().unwrap(),
57+
CsvReadOptions::new().truncated_rows(true),
58+
)
59+
.await
60+
.expect("Should successfully read CSV directory with different column counts");
61+
62+
// Verify the schema contains all 6 columns (union of both files)
63+
let df_clone = df.clone();
64+
let schema = df_clone.schema();
65+
assert_eq!(
66+
schema.fields().len(),
67+
6,
68+
"Schema should contain all 6 columns"
69+
);
70+
71+
// Check that we have all expected columns
72+
let field_names: Vec<&str> =
73+
schema.fields().iter().map(|f| f.name().as_str()).collect();
74+
assert!(field_names.contains(&"service_id"));
75+
assert!(field_names.contains(&"route_type"));
76+
assert!(field_names.contains(&"agency_id"));
77+
assert!(field_names.contains(&"stop_platform_change"));
78+
assert!(field_names.contains(&"stop_planned_platform"));
79+
assert!(field_names.contains(&"stop_actual_platform"));
80+
81+
// All fields should be nullable since they don't appear in all files
82+
for field in schema.fields() {
83+
assert!(
84+
field.is_nullable(),
85+
"Field {} should be nullable",
86+
field.name()
87+
);
88+
}
89+
90+
// Verify we can actually read the data
91+
let results = df.collect().await?;
92+
93+
// Calculate total rows across all batches
94+
let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum();
95+
assert_eq!(total_rows, 6, "Should have 6 total rows across all batches");
96+
97+
// All batches should have 6 columns (the union schema)
98+
for batch in &results {
99+
assert_eq!(batch.num_columns(), 6, "All batches should have 6 columns");
100+
assert_eq!(
101+
batch.schema().fields().len(),
102+
6,
103+
"Each batch should use the union schema with 6 fields"
104+
);
105+
}
106+
107+
// Verify the actual content of the data using snapshot testing
108+
assert_snapshot!(batches_to_sort_string(&results), @r"
109+
+------------+------------+-----------+----------------------+-----------------------+----------------------+
110+
| service_id | route_type | agency_id | stop_platform_change | stop_planned_platform | stop_actual_platform |
111+
+------------+------------+-----------+----------------------+-----------------------+----------------------+
112+
| 1 | bus | agency1 | | | |
113+
| 2 | rail | agency2 | | | |
114+
| 3 | bus | agency3 | | | |
115+
| 4 | rail | agency2 | true | Platform A | Platform B |
116+
| 5 | bus | agency1 | false | Stop 1 | Stop 1 |
117+
| 6 | rail | agency3 | true | Platform C | Platform D |
118+
+------------+------------+-----------+----------------------+-----------------------+----------------------+
119+
");
120+
121+
Ok(())
122+
}

datafusion/datasource-csv/src/file_format.rs

Lines changed: 116 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,20 @@ impl FileFormat for CsvFormat {
497497
impl CsvFormat {
498498
/// Return the inferred schema reading up to records_to_read from a
499499
/// stream of delimited chunks returning the inferred schema and the
500-
/// number of lines that were read
500+
/// number of lines that were read.
501+
///
502+
/// This method can handle CSV files with different numbers of columns.
503+
/// The inferred schema will be the union of all columns found across all files.
504+
/// Files with fewer columns will have missing columns filled with null values.
505+
///
506+
/// # Example
507+
///
508+
/// If you have two CSV files:
509+
/// - `file1.csv`: `col1,col2,col3`
510+
/// - `file2.csv`: `col1,col2,col3,col4,col5`
511+
///
512+
/// The inferred schema will contain all 5 columns, with files that don't
513+
/// have columns 4 and 5 having null values for those columns.
501514
pub async fn infer_schema_from_stream(
502515
&self,
503516
state: &dyn Session,
@@ -560,21 +573,37 @@ impl CsvFormat {
560573
})
561574
.unzip();
562575
} else {
563-
if fields.len() != column_type_possibilities.len() {
576+
if fields.len() != column_type_possibilities.len()
577+
&& !self.options.truncated_rows.unwrap_or(false)
578+
{
564579
return exec_err!(
565-
"Encountered unequal lengths between records on CSV file whilst inferring schema. \
566-
Expected {} fields, found {} fields at record {}",
567-
column_type_possibilities.len(),
568-
fields.len(),
569-
record_number + 1
570-
);
580+
"Encountered unequal lengths between records on CSV file whilst inferring schema. \
581+
Expected {} fields, found {} fields at record {}",
582+
column_type_possibilities.len(),
583+
fields.len(),
584+
record_number + 1
585+
);
571586
}
572587

588+
// First update type possibilities for existing columns using zip
573589
column_type_possibilities.iter_mut().zip(&fields).for_each(
574590
|(possibilities, field)| {
575591
possibilities.insert(field.data_type().clone());
576592
},
577593
);
594+
595+
// Handle files with different numbers of columns by extending the schema
596+
if fields.len() > column_type_possibilities.len() {
597+
// New columns found - extend our tracking structures
598+
for field in fields.iter().skip(column_type_possibilities.len()) {
599+
column_names.push(field.name().clone());
600+
let mut possibilities = HashSet::new();
601+
if records_read > 0 {
602+
possibilities.insert(field.data_type().clone());
603+
}
604+
column_type_possibilities.push(possibilities);
605+
}
606+
}
578607
}
579608

580609
if records_to_read == 0 {
@@ -769,3 +798,82 @@ impl DataSink for CsvSink {
769798
FileSink::write_all(self, data, context).await
770799
}
771800
}
801+
802+
#[cfg(test)]
803+
mod tests {
804+
use super::build_schema_helper;
805+
use arrow::datatypes::DataType;
806+
use std::collections::HashSet;
807+
808+
#[test]
809+
fn test_build_schema_helper_different_column_counts() {
810+
// Test the core schema building logic with different column counts
811+
let mut column_names =
812+
vec!["col1".to_string(), "col2".to_string(), "col3".to_string()];
813+
814+
// Simulate adding two more columns from another file
815+
column_names.push("col4".to_string());
816+
column_names.push("col5".to_string());
817+
818+
let column_type_possibilities = vec![
819+
HashSet::from([DataType::Int64]),
820+
HashSet::from([DataType::Utf8]),
821+
HashSet::from([DataType::Float64]),
822+
HashSet::from([DataType::Utf8]), // col4
823+
HashSet::from([DataType::Utf8]), // col5
824+
];
825+
826+
let schema = build_schema_helper(column_names, &column_type_possibilities);
827+
828+
// Verify schema has 5 columns
829+
assert_eq!(schema.fields().len(), 5);
830+
assert_eq!(schema.field(0).name(), "col1");
831+
assert_eq!(schema.field(1).name(), "col2");
832+
assert_eq!(schema.field(2).name(), "col3");
833+
assert_eq!(schema.field(3).name(), "col4");
834+
assert_eq!(schema.field(4).name(), "col5");
835+
836+
// All fields should be nullable
837+
for field in schema.fields() {
838+
assert!(
839+
field.is_nullable(),
840+
"Field {} should be nullable",
841+
field.name()
842+
);
843+
}
844+
}
845+
846+
#[test]
847+
fn test_build_schema_helper_type_merging() {
848+
// Test type merging logic
849+
let column_names = vec!["col1".to_string(), "col2".to_string()];
850+
851+
let column_type_possibilities = vec![
852+
HashSet::from([DataType::Int64, DataType::Float64]), // Should resolve to Float64
853+
HashSet::from([DataType::Utf8]), // Should remain Utf8
854+
];
855+
856+
let schema = build_schema_helper(column_names, &column_type_possibilities);
857+
858+
// col1 should be Float64 due to Int64 + Float64 = Float64
859+
assert_eq!(*schema.field(0).data_type(), DataType::Float64);
860+
861+
// col2 should remain Utf8
862+
assert_eq!(*schema.field(1).data_type(), DataType::Utf8);
863+
}
864+
865+
#[test]
866+
fn test_build_schema_helper_conflicting_types() {
867+
// Test when we have incompatible types - should default to Utf8
868+
let column_names = vec!["col1".to_string()];
869+
870+
let column_type_possibilities = vec![
871+
HashSet::from([DataType::Boolean, DataType::Int64, DataType::Utf8]), // Should resolve to Utf8 due to conflicts
872+
];
873+
874+
let schema = build_schema_helper(column_names, &column_type_possibilities);
875+
876+
// Should default to Utf8 for conflicting types
877+
assert_eq!(*schema.field(0).data_type(), DataType::Utf8);
878+
}
879+
}

0 commit comments

Comments
 (0)