Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-cli/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ impl TableFunctionImpl for ParquetMetadataFunc {
compression_arr.push(format!("{:?}", column.compression()));
// need to collect into Vec to format
let encodings: Vec<_> = column.encodings().collect();
encodings_arr.push(format!("{:?}", encodings));
encodings_arr.push(format!("{encodings:?}"));
index_page_offset_arr.push(column.index_page_offset());
dictionary_page_offset_arr.push(column.dictionary_page_offset());
data_page_offset_arr.push(column.data_page_offset());
Expand Down
70 changes: 64 additions & 6 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,15 +445,31 @@ impl DataFrame {
/// # Ok(())
/// # }
/// ```
pub fn drop_columns(self, columns: &[&str]) -> Result<DataFrame> {
pub fn drop_columns<T>(self, columns: &[T]) -> Result<DataFrame>
where
T: Into<Column> + Clone,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering why the Clone bound is needed. In case anyone else is curious -- it turns out it needed because what is passed in is &[T] (vs something like [T])

error[E0277]: the trait bound `datafusion_common::Column: std::convert::From<&T>` is not satisfied
   --> datafusion/core/src/dataframe/mod.rs:455:42
    |
455 |                 let column: Column = col.into();
    |                                          ^^^^ the trait `std::convert::From<&T>` is not implemented for `datafusion_common::Column`
    |
    = note: required for `&T` to implement `std::convert::Into<datafusion_common::Column>`
help: consider extending the `where` clause, but there might be an alternative better way to express this requirement
    |
450 |         T: Into<Column>, datafusion_common::Column: std::convert::From<&T>
    |                          +++++++++++++++++++++++++++++++++++++++++++++++++


{
let fields_to_drop = columns
.iter()
.flat_map(|name| {
self.plan
.schema()
.qualified_fields_with_unqualified_name(name)
.flat_map(|col| {
let column: Column = col.clone().into();
match column.relation.as_ref() {
Some(_) => {
// qualified_field_from_column returns Result<(Option<&TableReference>, &FieldRef)>
vec![self.plan.schema().qualified_field_from_column(&column)]
}
None => {
// qualified_fields_with_unqualified_name returns Vec<(Option<&TableReference>, &FieldRef)>
self.plan
.schema()
.qualified_fields_with_unqualified_name(&column.name)
.into_iter()
.map(Ok)
.collect::<Vec<_>>()
}
}
})
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()?;
let expr: Vec<Expr> = self
.plan
.schema()
Expand Down Expand Up @@ -2463,6 +2479,48 @@ impl DataFrame {
.collect()
}

/// Find qualified columns for this dataframe from names
///
/// # Arguments
/// * `names` - Unqualified names to find.
///
/// # Example
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # use datafusion_common::ScalarValue;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// ctx.register_csv("first_table", "tests/data/example.csv", CsvReadOptions::new())
/// .await?;
/// let df = ctx.table("first_table").await?;
/// ctx.register_csv("second_table", "tests/data/example.csv", CsvReadOptions::new())
/// .await?;
/// let df2 = ctx.table("second_table").await?;
/// let join_expr = df.find_qualified_columns(&["a"])?.iter()
/// .zip(df2.find_qualified_columns(&["a"])?.iter())
/// .map(|(col1, col2)| col(*col1).eq(col(*col2)))
/// .collect::<Vec<Expr>>();
/// let df3 = df.join_on(df2, JoinType::Inner, join_expr)?;
/// # Ok(())
/// # }
/// ```
pub fn find_qualified_columns(
&self,
names: &[&str],
) -> Result<Vec<(Option<&TableReference>, &FieldRef)>> {
let schema = self.logical_plan().schema();
names
.iter()
.map(|name| {
schema
.qualified_field_from_column(&Column::from_name(*name))
.map_err(|_| plan_datafusion_err!("Column '{}' not found", name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much better than silently ignoring the error 👍

})
.collect()
}

/// Helper for creating DataFrame.
/// # Example
/// ```
Expand Down
106 changes: 104 additions & 2 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ async fn drop_columns_with_nonexistent_columns() -> Result<()> {
async fn drop_columns_with_empty_array() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let t2 = t.drop_columns(&[])?;
let drop_columns = vec![] as Vec<&str>;
let t2 = t.drop_columns(&drop_columns)?;
let plan = t2.logical_plan().clone();

// build query using SQL
Expand All @@ -549,6 +550,107 @@ async fn drop_columns_with_empty_array() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn drop_columns_qualified() -> Result<()> {
// build plan using Table API
let mut t = test_table().await?;
t = t.select_columns(&["c1", "c2", "c11"])?;
let mut t2 = test_table_with_name("another_table").await?;
t2 = t2.select_columns(&["c1", "c2", "c11"])?;
let mut t3 = t.join_on(
t2,
JoinType::Inner,
[col("aggregate_test_100.c1").eq(col("another_table.c1"))],
)?;
t3 = t3.drop_columns(&["another_table.c2", "another_table.c11"])?;

let plan = t3.logical_plan().clone();

let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1";
let ctx = SessionContext::new();
register_aggregate_csv(&ctx, "aggregate_test_100").await?;
register_aggregate_csv(&ctx, "another_table").await?;
let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan();

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_columns_qualified_find_qualified() -> Result<()> {
// build plan using Table API
let mut t = test_table().await?;
t = t.select_columns(&["c1", "c2", "c11"])?;
let mut t2 = test_table_with_name("another_table").await?;
t2 = t2.select_columns(&["c1", "c2", "c11"])?;
let mut t3 = t.join_on(
t2.clone(),
JoinType::Inner,
[col("aggregate_test_100.c1").eq(col("another_table.c1"))],
)?;
t3 = t3.drop_columns(&t2.find_qualified_columns(&["c2", "c11"])?)?;

let plan = t3.logical_plan().clone();

let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1";
let ctx = SessionContext::new();
register_aggregate_csv(&ctx, "aggregate_test_100").await?;
register_aggregate_csv(&ctx, "another_table").await?;
let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan();

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn test_find_qualified_names() -> Result<()> {
let t = test_table().await?;
let column_names = ["c1", "c2", "c3"];
let columns = t.find_qualified_columns(&column_names)?;

// Expected results for each column
let binding = TableReference::bare("aggregate_test_100");
let expected = [
(Some(&binding), "c1"),
(Some(&binding), "c2"),
(Some(&binding), "c3"),
];

// Verify we got the expected number of results
assert_eq!(
columns.len(),
expected.len(),
"Expected {} columns, got {}",
expected.len(),
columns.len()
);

// Iterate over the results and check each one individually
for (i, (actual, expected)) in columns.iter().zip(expected.iter()).enumerate() {
let (actual_table_ref, actual_field_ref) = actual;
let (expected_table_ref, expected_field_name) = expected;

// Check table reference
assert_eq!(
actual_table_ref, expected_table_ref,
"Column {i}: expected table reference {expected_table_ref:?}, got {actual_table_ref:?}"
);

// Check field name
assert_eq!(
actual_field_ref.name(),
*expected_field_name,
"Column {i}: expected field name '{expected_field_name}', got '{actual_field_ref}'"
);
}

Ok(())
}

#[tokio::test]
async fn drop_with_quotes() -> Result<()> {
// define data with a column name that has a "." in it:
Expand Down Expand Up @@ -594,7 +696,7 @@ async fn drop_with_periods() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;

let df = ctx.table("t").await?.drop_columns(&["f.c1"])?;
let df = ctx.table("t").await?.drop_columns(&["\"f.c1\""])?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while this is technically a breaking API change, I think it is reasonable to treat it as a bug fix


let df_results = df.collect().await?;

Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ pub mod test {

#[test]
fn test_decimal32_to_i32() {
let cases: [(i32, i8, Either<i32, String>); _] = [
let cases: [(i32, i8, Either<i32, String>); 10] = [
(123, 0, Either::Left(123)),
(1230, 1, Either::Left(123)),
(123000, 3, Either::Left(123)),
Expand Down Expand Up @@ -456,7 +456,7 @@ pub mod test {

#[test]
fn test_decimal64_to_i64() {
let cases: [(i64, i8, Either<i64, String>); _] = [
let cases: [(i64, i8, Either<i64, String>); 8] = [
(123, 0, Either::Left(123)),
(1234567890, 2, Either::Left(12345678)),
(-1234567890, 2, Either::Left(-12345678)),
Expand Down