Skip to content

Commit 40c7a6f

Browse files
committed
Refactor column selection in REPL pipeline to use ColumnSpec enum
- Updated the `PipelineStage` enum to replace `Vec<String>` with `Vec<ColumnSpec>` for column selection, enhancing type safety and clarity. - Refactored the `exec_select` method and related functions to handle `ColumnSpec`, allowing for both exact and case-insensitive column matching. - Introduced `resolve_column_specs` function to resolve column specifications against the schema, improving the selection logic. - Updated tests to validate the new column selection behavior, ensuring robust functionality across various scenarios.
1 parent bb4d7e0 commit 40c7a6f

File tree

2 files changed

+196
-30
lines changed

2 files changed

+196
-30
lines changed

src/cli/repl.rs

Lines changed: 100 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ use crate::pipeline::VecRecordBatchReaderSource;
1111
use crate::pipeline::display::write_record_batches_as_csv;
1212
use crate::pipeline::read_to_batches;
1313
use crate::pipeline::select;
14+
use crate::pipeline::select::ColumnSpec;
1415
use crate::pipeline::write_batches;
1516

1617
/// A planned pipeline stage with validated, extracted arguments.
1718
#[derive(Debug, PartialEq)]
1819
pub enum PipelineStage {
1920
Read { path: String },
20-
Select { columns: Vec<String> },
21+
Select { columns: Vec<ColumnSpec> },
2122
Head { n: usize },
2223
Tail { n: usize },
2324
Count,
@@ -30,7 +31,13 @@ impl fmt::Display for PipelineStage {
3031
match self {
3132
PipelineStage::Read { path } => write!(f, r#"read("{path}")"#),
3233
PipelineStage::Select { columns } => {
33-
let cols: Vec<String> = columns.iter().map(|c| format!(":{c}")).collect::<Vec<_>>();
34+
let cols: Vec<String> = columns
35+
.iter()
36+
.map(|c| match c {
37+
ColumnSpec::Exact(s) => format!(r#""{s}""#),
38+
ColumnSpec::CaseInsensitive(s) => format!(":{s}"),
39+
})
40+
.collect::<Vec<_>>();
3441
write!(f, "select({})", cols.join(", "))
3542
}
3643
PipelineStage::Head { n } => write!(f, "head({n})"),
@@ -72,6 +79,7 @@ impl ReplPipelineBuilder {
7279
}
7380

7481
/// Evaluates a binary expression to a pipeline.
82+
#[allow(clippy::boxed_local)]
7583
fn eval_binary_expr(
7684
&self,
7785
left: Box<Expr>,
@@ -121,7 +129,7 @@ impl ReplPipelineBuilder {
121129
}
122130

123131
/// Selects columns from the batches in context.
124-
async fn exec_select(&mut self, columns: &[String]) -> crate::Result<()> {
132+
async fn exec_select(&mut self, columns: &[ColumnSpec]) -> crate::Result<()> {
125133
let batches = self.batches.take().ok_or_else(|| {
126134
Error::GenericError("select requires a preceding read in the pipe".to_string())
127135
})?;
@@ -235,13 +243,14 @@ fn extract_path_from_args(func_name: &str, args: &[Expr]) -> crate::Result<Strin
235243
}
236244
}
237245

238-
/// Extracts column names from select args (symbols like :one, identifiers, or strings like "one").
239-
fn extract_column_names(args: &[Expr]) -> crate::Result<Vec<String>> {
246+
/// Extracts column specs from select args. Symbols (:one) and identifiers use case-insensitive
247+
/// match; strings ("one") use exact match.
248+
fn extract_column_specs(args: &[Expr]) -> crate::Result<Vec<ColumnSpec>> {
240249
args.iter()
241250
.map(|expr| match expr {
242-
Expr::Literal(Literal::Symbol(s)) => Ok(s.clone()),
243-
Expr::Literal(Literal::String(s)) => Ok(s.clone()),
244-
Expr::Ident(s) => Ok(s.clone()),
251+
Expr::Literal(Literal::Symbol(s)) => Ok(ColumnSpec::CaseInsensitive(s.clone())),
252+
Expr::Literal(Literal::String(s)) => Ok(ColumnSpec::Exact(s.clone())),
253+
Expr::Ident(s) => Ok(ColumnSpec::CaseInsensitive(s.clone())),
245254
_ => Err(Error::UnsupportedFunctionCall(format!(
246255
"select expects symbol or string column names, got {expr:?}"
247256
))),
@@ -262,7 +271,7 @@ impl ReplPipelineBuilder {
262271
}
263272

264273
async fn eval_select(&mut self, args: Vec<Expr>) -> crate::Result<()> {
265-
let columns = extract_column_names(&args)?;
274+
let columns = extract_column_specs(&args)?;
266275
if columns.is_empty() {
267276
return Err(Error::UnsupportedFunctionCall(
268277
"select expects at least one column name".to_string(),
@@ -311,7 +320,7 @@ fn plan_stage(expr: Expr) -> crate::Result<PipelineStage> {
311320
Ok(PipelineStage::Read { path })
312321
}
313322
"select" => {
314-
let columns = extract_column_names(&args)?;
323+
let columns = extract_column_specs(&args)?;
315324
if columns.is_empty() {
316325
return Err(Error::UnsupportedFunctionCall(
317326
"select expects at least one column name".to_string(),
@@ -433,7 +442,10 @@ mod tests {
433442
assert_eq!(
434443
stage,
435444
PipelineStage::Select {
436-
columns: vec!["one".to_string(), "two".to_string()]
445+
columns: vec![
446+
ColumnSpec::CaseInsensitive("one".into()),
447+
ColumnSpec::CaseInsensitive("two".into())
448+
]
437449
}
438450
);
439451
}
@@ -497,7 +509,7 @@ mod tests {
497509
assert_eq!(
498510
pipeline[1],
499511
PipelineStage::Select {
500-
columns: vec!["x".to_string()]
512+
columns: vec![ColumnSpec::CaseInsensitive("x".into())]
501513
}
502514
);
503515
assert_eq!(
@@ -600,56 +612,81 @@ mod tests {
600612
assert!(matches!(&stages[0], Expr::BinaryExpr(_, BinaryOp::Add, _)));
601613
}
602614

603-
// ── extract_column_names ────────────────────────────────────────
615+
// ── extract_column_specs ────────────────────────────────────────
604616

605617
#[test]
606-
fn test_extract_column_names_symbols() {
618+
fn test_extract_column_specs_symbols() {
607619
let args = vec![
608620
Expr::Literal(Literal::Symbol("one".into())),
609621
Expr::Literal(Literal::Symbol("two".into())),
610622
];
611-
let result = extract_column_names(&args).unwrap();
612-
assert_eq!(result, vec!["one", "two"]);
623+
let result = extract_column_specs(&args).unwrap();
624+
assert_eq!(
625+
result,
626+
vec![
627+
ColumnSpec::CaseInsensitive("one".into()),
628+
ColumnSpec::CaseInsensitive("two".into())
629+
]
630+
);
613631
}
614632

615633
#[test]
616-
fn test_extract_column_names_strings() {
634+
fn test_extract_column_specs_strings() {
617635
let args = vec![
618636
Expr::Literal(Literal::String("col_a".into())),
619637
Expr::Literal(Literal::String("col_b".into())),
620638
];
621-
let result = extract_column_names(&args).unwrap();
622-
assert_eq!(result, vec!["col_a", "col_b"]);
639+
let result = extract_column_specs(&args).unwrap();
640+
assert_eq!(
641+
result,
642+
vec![
643+
ColumnSpec::Exact("col_a".into()),
644+
ColumnSpec::Exact("col_b".into())
645+
]
646+
);
623647
}
624648

625649
#[test]
626-
fn test_extract_column_names_idents() {
650+
fn test_extract_column_specs_idents() {
627651
let args = vec![Expr::Ident("foo".into()), Expr::Ident("bar".into())];
628-
let result = extract_column_names(&args).unwrap();
629-
assert_eq!(result, vec!["foo", "bar"]);
652+
let result = extract_column_specs(&args).unwrap();
653+
assert_eq!(
654+
result,
655+
vec![
656+
ColumnSpec::CaseInsensitive("foo".into()),
657+
ColumnSpec::CaseInsensitive("bar".into())
658+
]
659+
);
630660
}
631661

632662
#[test]
633-
fn test_extract_column_names_mixed() {
663+
fn test_extract_column_specs_mixed() {
634664
let args = vec![
635665
Expr::Literal(Literal::Symbol("sym".into())),
636666
Expr::Literal(Literal::String("str".into())),
637667
Expr::Ident("ident".into()),
638668
];
639-
let result = extract_column_names(&args).unwrap();
640-
assert_eq!(result, vec!["sym", "str", "ident"]);
669+
let result = extract_column_specs(&args).unwrap();
670+
assert_eq!(
671+
result,
672+
vec![
673+
ColumnSpec::CaseInsensitive("sym".into()),
674+
ColumnSpec::Exact("str".into()),
675+
ColumnSpec::CaseInsensitive("ident".into())
676+
]
677+
);
641678
}
642679

643680
#[test]
644-
fn test_extract_column_names_empty() {
645-
let result = extract_column_names(&[]).unwrap();
681+
fn test_extract_column_specs_empty() {
682+
let result = extract_column_specs(&[]).unwrap();
646683
assert!(result.is_empty());
647684
}
648685

649686
#[test]
650-
fn test_extract_column_names_unsupported_expr() {
687+
fn test_extract_column_specs_unsupported_expr() {
651688
let args = vec![Expr::Literal(Literal::Boolean(true))];
652-
let result = extract_column_names(&args);
689+
let result = extract_column_specs(&args);
653690
assert!(result.is_err());
654691
let err = result.unwrap_err();
655692
assert!(
@@ -913,6 +950,40 @@ mod tests {
913950
assert!(batches.is_none(), "batches consumed by write");
914951
}
915952

953+
#[tokio::test(flavor = "multi_thread")]
954+
async fn test_repl_pipeline_select_symbol_case_insensitive() {
955+
let mut ctx = new_context();
956+
ctx.eval_read(vec![Expr::Literal(Literal::String(
957+
"fixtures/table.parquet".into(),
958+
))])
959+
.await
960+
.expect("read");
961+
962+
let args = vec![
963+
Expr::Literal(Literal::Symbol("ONE".into())),
964+
Expr::Literal(Literal::Symbol("TWO".into())),
965+
];
966+
ctx.eval_select(args).await.expect("select");
967+
let batches = ctx.batches.as_ref().expect("batches after select");
968+
let schema = batches[0].schema();
969+
let col_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
970+
assert_eq!(col_names, vec!["one", "two"]);
971+
}
972+
973+
#[tokio::test(flavor = "multi_thread")]
974+
async fn test_repl_pipeline_select_string_exact_match_fails_on_wrong_case() {
975+
let mut ctx = new_context();
976+
ctx.eval_read(vec![Expr::Literal(Literal::String(
977+
"fixtures/table.parquet".into(),
978+
))])
979+
.await
980+
.expect("read");
981+
982+
let args = vec![Expr::Literal(Literal::String("ONE".into()))];
983+
let result = ctx.eval_select(args).await;
984+
assert!(result.is_err());
985+
}
986+
916987
// ── eval_pipe ───────────────────────────────────────────────────
917988

918989
#[tokio::test(flavor = "multi_thread")]

src/pipeline/select.rs

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! DataFusion DataFrame API for column selection.
22
3+
use arrow::datatypes::Schema;
34
use arrow::record_batch::RecordBatch;
45
use datafusion::execution::context::SessionContext;
56
use datafusion::prelude::AvroReadOptions;
@@ -8,6 +9,42 @@ use datafusion::prelude::ParquetReadOptions;
89
use crate::pipeline::RecordBatchReaderSource;
910
use crate::pipeline::VecRecordBatchReaderSource;
1011

12+
/// How to match a column name: exact (case-sensitive) or case-insensitive.
13+
#[derive(Clone, Debug, PartialEq)]
14+
pub enum ColumnSpec {
15+
/// Exact match (from string literal like "column").
16+
Exact(String),
17+
/// Case-insensitive match (from symbol like :column or bare identifier).
18+
CaseInsensitive(String),
19+
}
20+
21+
impl ColumnSpec {
22+
/// Resolves this spec against a schema, returning the actual column name.
23+
pub fn resolve(&self, schema: &Schema) -> crate::Result<String> {
24+
match self {
25+
ColumnSpec::Exact(name) => schema
26+
.index_of(name)
27+
.map(|_| name.clone())
28+
.map_err(|e| crate::Error::GenericError(format!("Column '{name}' not found: {e}"))),
29+
ColumnSpec::CaseInsensitive(name) => schema
30+
.fields()
31+
.iter()
32+
.find(|f| f.name().eq_ignore_ascii_case(name))
33+
.map(|f| f.name().clone())
34+
.ok_or_else(|| {
35+
crate::Error::GenericError(format!(
36+
"Column '{name}' not found (case-insensitive match)"
37+
))
38+
}),
39+
}
40+
}
41+
}
42+
43+
/// Resolves column specs to actual schema column names.
44+
pub fn resolve_column_specs(schema: &Schema, specs: &[ColumnSpec]) -> crate::Result<Vec<String>> {
45+
specs.iter().map(|s| s.resolve(schema)).collect()
46+
}
47+
1148
/// Reads a Parquet file and selects columns using the DataFusion DataFrame API.
1249
pub async fn read_parquet_select(
1350
path: &str,
@@ -68,13 +105,17 @@ pub async fn read_avro_select(
68105

69106
/// Applies column selection to record batches using the DataFusion DataFrame API.
70107
/// Returns the selected batches directly (for use when RecordBatchReaderSource is not needed).
108+
/// Resolves ColumnSpec against the schema: Exact uses case-sensitive match, CaseInsensitive uses
109+
/// case-insensitive match.
71110
pub async fn select_columns_to_batches(
72111
batches: Vec<RecordBatch>,
73-
columns: &[String],
112+
specs: &[ColumnSpec],
74113
) -> crate::Result<Vec<RecordBatch>> {
75114
if batches.is_empty() {
76115
return Ok(batches);
77116
}
117+
let schema = batches[0].schema();
118+
let columns = resolve_column_specs(&schema, specs)?;
78119
let ctx = SessionContext::new();
79120
let col_refs: Vec<&str> = columns.iter().map(String::as_str).collect();
80121
let df = ctx
@@ -110,3 +151,57 @@ pub async fn select_columns_from_batches(
110151
.map_err(|e| crate::Error::GenericError(e.to_string()))?;
111152
Ok(Box::new(VecRecordBatchReaderSource::new(result_batches)))
112153
}
154+
155+
#[cfg(test)]
156+
mod tests {
157+
use arrow::datatypes::DataType;
158+
use arrow::datatypes::Field;
159+
160+
use super::*;
161+
162+
fn schema_with_columns(names: &[&str]) -> Schema {
163+
let fields: Vec<Field> = names
164+
.iter()
165+
.map(|n| Field::new(*n, DataType::Utf8, true))
166+
.collect();
167+
Schema::new(fields)
168+
}
169+
170+
#[test]
171+
fn test_resolve_exact_match() {
172+
let schema = schema_with_columns(&["one", "two", "three"]);
173+
let specs = vec![
174+
ColumnSpec::Exact("one".into()),
175+
ColumnSpec::Exact("three".into()),
176+
];
177+
let resolved = resolve_column_specs(&schema, &specs).unwrap();
178+
assert_eq!(resolved, vec!["one", "three"]);
179+
}
180+
181+
#[test]
182+
fn test_resolve_exact_no_match_wrong_case() {
183+
let schema = schema_with_columns(&["one", "two"]);
184+
let specs = vec![ColumnSpec::Exact("ONE".into())];
185+
let result = resolve_column_specs(&schema, &specs);
186+
assert!(result.is_err());
187+
}
188+
189+
#[test]
190+
fn test_resolve_case_insensitive_match() {
191+
let schema = schema_with_columns(&["one", "two", "Email"]);
192+
let specs = vec![
193+
ColumnSpec::CaseInsensitive("ONE".into()),
194+
ColumnSpec::CaseInsensitive("email".into()),
195+
];
196+
let resolved = resolve_column_specs(&schema, &specs).unwrap();
197+
assert_eq!(resolved, vec!["one", "Email"]);
198+
}
199+
200+
#[test]
201+
fn test_resolve_case_insensitive_no_match() {
202+
let schema = schema_with_columns(&["one", "two"]);
203+
let specs = vec![ColumnSpec::CaseInsensitive("missing".into())];
204+
let result = resolve_column_specs(&schema, &specs);
205+
assert!(result.is_err());
206+
}
207+
}

0 commit comments

Comments
 (0)