diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 16765656a..5b4a8aeaf 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -405,6 +405,17 @@ def select_columns(self, *args: str) -> DataFrame: """ return self.select(*args) + def select_exprs(self, *args: str) -> DataFrame: + """Project arbitrary list of expression strings into a new DataFrame. + + This method will parse string expressions into logical plan expressions. + The output DataFrame has one column for each expression. + + Returns: + DataFrame only containing the specified columns. + """ + return self.df.select_exprs(*args) + def select(self, *exprs: Expr | str) -> DataFrame: """Project arbitrary expressions into a new :py:class:`DataFrame`. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index cd85221c5..a420e8f21 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -221,6 +221,38 @@ def test_select(df): assert result.column(1) == pa.array([1, 2, 3]) +def test_select_exprs(df): + df_1 = df.select_exprs( + "a + b", + "a - b", + ) + + # execute and collect the first (and only) batch + result = df_1.collect()[0] + + assert result.column(0) == pa.array([5, 7, 9]) + assert result.column(1) == pa.array([-3, -3, -3]) + + df_2 = df.select_exprs("b", "a") + + # execute and collect the first (and only) batch + result = df_2.collect()[0] + + assert result.column(0) == pa.array([4, 5, 6]) + assert result.column(1) == pa.array([1, 2, 3]) + + df_3 = df.select_exprs( + "abs(a + b)", + "abs(a - b)", + ) + + # execute and collect the first (and only) batch + result = df_3.collect()[0] + + assert result.column(0) == pa.array([5, 7, 9]) + assert result.column(1) == pa.array([3, 3, 3]) + + def test_drop_quoted_columns(): ctx = SessionContext() batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], names=["ID_For_Students"]) diff --git a/src/dataframe.rs b/src/dataframe.rs index c23c0c97f..f603c28b4 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -435,6 +435,13 @@ impl PyDataFrame { Ok(Self::new(df)) } + #[pyo3(signature = (*args))] + fn select_exprs(&self, args: Vec) -> PyDataFusionResult { + let args = args.iter().map(|s| s.as_ref()).collect::>(); + let df = self.df.as_ref().clone().select_exprs(&args)?; + Ok(Self::new(df)) + } + #[pyo3(signature = (*args))] fn select(&self, args: Vec) -> PyDataFusionResult { let expr: Vec = args.into_iter().map(|e| e.into()).collect();