Skip to content

Commit 428b209

Browse files
authored
Merge branch 'main' into feat--add-with_columns
2 parents 7228ca8 + fc7e3e5 commit 428b209

File tree

11 files changed

+253
-49
lines changed

11 files changed

+253
-49
lines changed

Cargo.lock

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ build-backend = "maturin"
2323
name = "datafusion"
2424
description = "Build and run queries against data"
2525
readme = "README.md"
26-
license = {file = "LICENSE.txt"}
27-
requires-python = ">=3.6"
26+
license = { file = "LICENSE.txt" }
27+
requires-python = ">=3.7"
2828
keywords = ["datafusion", "dataframe", "rust", "query-engine"]
2929
classifier = [
3030
"Development Status :: 2 - Pre-Alpha",
@@ -42,10 +42,7 @@ classifier = [
4242
"Programming Language :: Python",
4343
"Programming Language :: Rust",
4444
]
45-
dependencies = [
46-
"pyarrow>=11.0.0",
47-
"typing-extensions;python_version<'3.13'",
48-
]
45+
dependencies = ["pyarrow>=11.0.0", "typing-extensions;python_version<'3.13'"]
4946

5047
[project.urls]
5148
homepage = "https://datafusion.apache.org/python"
@@ -58,9 +55,7 @@ profile = "black"
5855
[tool.maturin]
5956
python-source = "python"
6057
module-name = "datafusion._internal"
61-
include = [
62-
{ path = "Cargo.lock", format = "sdist" }
63-
]
58+
include = [{ path = "Cargo.lock", format = "sdist" }]
6459
exclude = [".github/**", "ci/**", ".asf.yaml"]
6560
# Require Cargo.lock is up to date
6661
locked = true

python/datafusion/context.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from datafusion.record_batch import RecordBatchStream
3131
from datafusion.udf import ScalarUDF, AggregateUDF, WindowUDF
3232

33-
from typing import Any, TYPE_CHECKING
33+
from typing import Any, TYPE_CHECKING, Protocol
3434
from typing_extensions import deprecated
3535

3636
if TYPE_CHECKING:
@@ -41,6 +41,28 @@
4141
from datafusion.plan import LogicalPlan, ExecutionPlan
4242

4343

44+
class ArrowStreamExportable(Protocol):
45+
"""Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface.
46+
47+
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
48+
"""
49+
50+
def __arrow_c_stream__( # noqa: D105
51+
self, requested_schema: object | None = None
52+
) -> object: ...
53+
54+
55+
class ArrowArrayExportable(Protocol):
56+
"""Type hint for object exporting Arrow C Array via Arrow PyCapsule Interface.
57+
58+
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
59+
"""
60+
61+
def __arrow_c_array__( # noqa: D105
62+
self, requested_schema: object | None = None
63+
) -> tuple[object, object]: ...
64+
65+
4466
class SessionConfig:
4567
"""Session configuration options."""
4668

@@ -592,12 +614,18 @@ def from_pydict(
592614
"""
593615
return DataFrame(self.ctx.from_pydict(data, name))
594616

595-
def from_arrow(self, data: Any, name: str | None = None) -> DataFrame:
617+
def from_arrow(
618+
self,
619+
data: ArrowStreamExportable | ArrowArrayExportable,
620+
name: str | None = None,
621+
) -> DataFrame:
596622
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow source.
597623
598624
The Arrow data source can be any object that implements either
599625
``__arrow_c_stream__`` or ``__arrow_c_array__``. For the latter, it must return
600-
a struct array. Common examples of sources from pyarrow include
626+
a struct array.
627+
628+
Arrow data can be Polars, Pandas, Pyarrow etc.
601629
602630
Args:
603631
data: Arrow data source.

python/datafusion/dataframe.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,17 @@ def select(self, *exprs: Expr | str) -> DataFrame:
129129
]
130130
return DataFrame(self.df.select(*exprs_internal))
131131

132+
def drop(self, *columns: str) -> DataFrame:
133+
"""Drop arbitrary amount of columns.
134+
135+
Args:
136+
columns: Column names to drop from the dataframe.
137+
138+
Returns:
139+
DataFrame with those columns removed in the projection.
140+
"""
141+
return DataFrame(self.df.drop(*columns))
142+
132143
def filter(self, *predicates: Expr) -> DataFrame:
133144
"""Return a DataFrame for which ``predicate`` evaluates to ``True``.
134145
@@ -338,6 +349,29 @@ def join(
338349
"""
339350
return DataFrame(self.df.join(right.df, join_keys, how))
340351

352+
def join_on(
353+
self,
354+
right: DataFrame,
355+
*on_exprs: Expr,
356+
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
357+
) -> DataFrame:
358+
"""Join two :py:class:`DataFrame`using the specified expressions.
359+
360+
On expressions are used to support in-equality predicates. Equality
361+
predicates are correctly optimized
362+
363+
Args:
364+
right: Other DataFrame to join with.
365+
on_exprs: single or multiple (in)-equality predicates.
366+
how: Type of join to perform. Supported types are "inner", "left",
367+
"right", "full", "semi", "anti".
368+
369+
Returns:
370+
DataFrame after join.
371+
"""
372+
exprs = [expr.expr for expr in on_exprs]
373+
return DataFrame(self.df.join_on(right.df, exprs, how))
374+
341375
def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame:
342376
"""Return a DataFrame with the explanation of its plan so far.
343377

python/datafusion/expr.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,18 @@ def is_not_null(self) -> Expr:
406406
"""Returns ``True`` if this expression is not null."""
407407
return Expr(self.expr.is_not_null())
408408

409+
def fill_nan(self, value: Any | Expr | None = None) -> Expr:
410+
"""Fill NaN values with a provided value."""
411+
if not isinstance(value, Expr):
412+
value = Expr.literal(value)
413+
return Expr(functions_internal.nanvl(self.expr, value.expr))
414+
415+
def fill_null(self, value: Any | Expr | None = None) -> Expr:
416+
"""Fill NULL values with a provided value."""
417+
if not isinstance(value, Expr):
418+
value = Expr.literal(value)
419+
return Expr(functions_internal.nvl(self.expr, value.expr))
420+
409421
_to_pyarrow_types = {
410422
float: pa.float64(),
411423
int: pa.int64(),

python/datafusion/functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186
"min",
187187
"named_struct",
188188
"nanvl",
189+
"nvl",
189190
"now",
190191
"nth_value",
191192
"nullif",
@@ -673,6 +674,11 @@ def nanvl(x: Expr, y: Expr) -> Expr:
673674
return Expr(f.nanvl(x.expr, y.expr))
674675

675676

677+
def nvl(x: Expr, y: Expr) -> Expr:
678+
"""Returns ``x`` if ``x`` is not ``NULL``. Otherwise returns ``y``."""
679+
return Expr(f.nvl(x.expr, y.expr))
680+
681+
676682
def octet_length(arg: Expr) -> Expr:
677683
"""Returns the number of bytes of a string."""
678684
return Expr(f.octet_length(arg.expr))

python/tests/test_dataframe.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,17 @@ def test_sort(df):
169169
assert table.to_pydict() == expected
170170

171171

172+
def test_drop(df):
173+
df = df.drop("c")
174+
175+
# execute and collect the first (and only) batch
176+
result = df.collect()[0]
177+
178+
assert df.schema().names == ["a", "b"]
179+
assert result.column(0) == pa.array([1, 2, 3])
180+
assert result.column(1) == pa.array([4, 5, 6])
181+
182+
172183
def test_limit(df):
173184
df = df.limit(1)
174185

@@ -290,6 +301,42 @@ def test_join():
290301
assert table.to_pydict() == expected
291302

292303

304+
def test_join_on():
305+
ctx = SessionContext()
306+
307+
batch = pa.RecordBatch.from_arrays(
308+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
309+
names=["a", "b"],
310+
)
311+
df = ctx.create_dataframe([[batch]], "l")
312+
313+
batch = pa.RecordBatch.from_arrays(
314+
[pa.array([1, 2]), pa.array([-8, 10])],
315+
names=["a", "c"],
316+
)
317+
df1 = ctx.create_dataframe([[batch]], "r")
318+
319+
df2 = df.join_on(df1, column("l.a").__eq__(column("r.a")), how="inner")
320+
df2.show()
321+
df2 = df2.sort(column("l.a"))
322+
table = pa.Table.from_batches(df2.collect())
323+
324+
expected = {"a": [1, 2], "c": [-8, 10], "b": [4, 5]}
325+
assert table.to_pydict() == expected
326+
327+
df3 = df.join_on(
328+
df1,
329+
column("l.a").__eq__(column("r.a")),
330+
column("l.a").__lt__(column("r.c")),
331+
how="inner",
332+
)
333+
df3.show()
334+
df3 = df3.sort(column("l.a"))
335+
table = pa.Table.from_batches(df3.collect())
336+
expected = {"a": [2], "c": [10], "b": [5]}
337+
assert table.to_pydict() == expected
338+
339+
293340
def test_distinct():
294341
ctx = SessionContext()
295342

python/tests/test_expr.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import pyarrow
18+
import pyarrow as pa
1919
import pytest
2020
from datafusion import SessionContext, col
2121
from datafusion.expr import (
@@ -125,8 +125,8 @@ def test_sort(test_ctx):
125125
def test_relational_expr(test_ctx):
126126
ctx = SessionContext()
127127

128-
batch = pyarrow.RecordBatch.from_arrays(
129-
[pyarrow.array([1, 2, 3]), pyarrow.array(["alpha", "beta", "gamma"])],
128+
batch = pa.RecordBatch.from_arrays(
129+
[pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
130130
names=["a", "b"],
131131
)
132132
df = ctx.create_dataframe([[batch]], name="batch_array")
@@ -216,3 +216,30 @@ def test_display_name_deprecation():
216216
# returns appropriate result
217217
assert name == expr.schema_name()
218218
assert name == "foo"
219+
220+
221+
@pytest.fixture
222+
def df():
223+
ctx = SessionContext()
224+
225+
# create a RecordBatch and a new DataFrame from it
226+
batch = pa.RecordBatch.from_arrays(
227+
[pa.array([1, 2, None]), pa.array([4, None, 6]), pa.array([None, None, 8])],
228+
names=["a", "b", "c"],
229+
)
230+
231+
return ctx.from_arrow(batch)
232+
233+
234+
def test_fill_null(df):
235+
df = df.select(
236+
col("a").fill_null(100).alias("a"),
237+
col("b").fill_null(25).alias("b"),
238+
col("c").fill_null(1234).alias("c"),
239+
)
240+
df.show()
241+
result = df.collect()[0]
242+
243+
assert result.column(0) == pa.array([1, 2, 100])
244+
assert result.column(1) == pa.array([4, 25, 6])
245+
assert result.column(2) == pa.array([1234, 1234, 8])

src/dataframe.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ impl PyDataFrame {
170170
Ok(Self::new(df))
171171
}
172172

173+
#[pyo3(signature = (*args))]
174+
fn drop(&self, args: Vec<PyBackedStr>) -> PyResult<Self> {
175+
let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
176+
let df = self.df.as_ref().clone().drop_columns(&cols)?;
177+
Ok(Self::new(df))
178+
}
179+
173180
fn filter(&self, predicate: PyExpr) -> PyResult<Self> {
174181
let df = self.df.as_ref().clone().filter(predicate.into())?;
175182
Ok(Self::new(df))
@@ -303,6 +310,31 @@ impl PyDataFrame {
303310
Ok(Self::new(df))
304311
}
305312

313+
fn join_on(&self, right: PyDataFrame, on_exprs: Vec<PyExpr>, how: &str) -> PyResult<Self> {
314+
let join_type = match how {
315+
"inner" => JoinType::Inner,
316+
"left" => JoinType::Left,
317+
"right" => JoinType::Right,
318+
"full" => JoinType::Full,
319+
"semi" => JoinType::LeftSemi,
320+
"anti" => JoinType::LeftAnti,
321+
how => {
322+
return Err(DataFusionError::Common(format!(
323+
"The join type {how} does not exist or is not implemented"
324+
))
325+
.into());
326+
}
327+
};
328+
let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
329+
330+
let df = self
331+
.df
332+
.as_ref()
333+
.clone()
334+
.join_on(right.df.as_ref().clone(), join_type, exprs)?;
335+
Ok(Self::new(df))
336+
}
337+
306338
/// Print the query plan
307339
#[pyo3(signature = (verbose=false, analyze=false))]
308340
fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyResult<()> {

0 commit comments

Comments
 (0)