diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 16765656a..86131c45f 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -482,6 +482,28 @@ def filter(self, *predicates: Expr) -> DataFrame: df = df.filter(ensure_expr(p)) return DataFrame(df) + def parse_sql_expr(self, expr: str) -> Expr: + """Creates logical expression from a SQL query text. + + The expression is created and processed against the current schema. + + Example:: + + from datafusion import col, lit + df.parse_sql_expr("a > 1") + + should produce: + + col("a") > lit(1) + + Args: + expr: Expression string to be converted to datafusion expression + + Returns: + Logical expression . + """ + return Expr(self.df.parse_sql_expr(expr)) + def with_column(self, name: str, expr: Expr) -> DataFrame: """Add an additional column to the DataFrame. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index cd85221c5..76b808038 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -274,6 +274,36 @@ def test_filter(df): assert result.column(2) == pa.array([5]) +def test_parse_sql_expr(df): + plan1 = df.filter(df.parse_sql_expr("a > 2")).logical_plan() + plan2 = df.filter(column("a") > literal(2)).logical_plan() + # object equality not implemented but string representation should match + assert str(plan1) == str(plan2) + + df1 = df.filter(df.parse_sql_expr("a > 2")).select( + column("a") + column("b"), + column("a") - column("b"), + ) + + # execute and collect the first (and only) batch + result = df1.collect()[0] + + assert result.column(0) == pa.array([9]) + assert result.column(1) == pa.array([-3]) + + df.show() + # verify that if there is no filter applied, internal dataframe is unchanged + df2 = df.filter() + assert df.df == df2.df + + df3 = df.filter(df.parse_sql_expr("a > 1"), df.parse_sql_expr("b != 6")) + result = df3.collect()[0] + + assert result.column(0) == pa.array([2]) + assert result.column(1) == pa.array([5]) + assert result.column(2) == pa.array([5]) + + def test_show_empty(df, capsys): df_empty = df.filter(column("a") > literal(3)) df_empty.show() diff --git a/src/dataframe.rs b/src/dataframe.rs index c23c0c97f..34da87443 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -454,6 +454,14 @@ impl PyDataFrame { Ok(Self::new(df)) } + fn parse_sql_expr(&self, expr: PyBackedStr) -> PyDataFusionResult { + self.df + .as_ref() + .parse_sql_expr(&expr) + .map(|e| PyExpr::from(e)) + .map_err(PyDataFusionError::from) + } + fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().with_column(name, expr.into())?; Ok(Self::new(df))