Skip to content

Commit c698c97

Browse files
authored
Merge branch 'main' into feat--cast-dataframe
2 parents b39a5f0 + 7cca028 commit c698c97

File tree

11 files changed

+270
-54
lines changed

11 files changed

+270
-54
lines changed

Cargo.lock

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ build-backend = "maturin"
2323
name = "datafusion"
2424
description = "Build and run queries against data"
2525
readme = "README.md"
26-
license = {file = "LICENSE.txt"}
27-
requires-python = ">=3.6"
26+
license = { file = "LICENSE.txt" }
27+
requires-python = ">=3.7"
2828
keywords = ["datafusion", "dataframe", "rust", "query-engine"]
2929
classifier = [
3030
"Development Status :: 2 - Pre-Alpha",
@@ -42,10 +42,7 @@ classifier = [
4242
"Programming Language :: Python",
4343
"Programming Language :: Rust",
4444
]
45-
dependencies = [
46-
"pyarrow>=11.0.0",
47-
"typing-extensions;python_version<'3.13'",
48-
]
45+
dependencies = ["pyarrow>=11.0.0", "typing-extensions;python_version<'3.13'"]
4946

5047
[project.urls]
5148
homepage = "https://datafusion.apache.org/python"
@@ -58,9 +55,7 @@ profile = "black"
5855
[tool.maturin]
5956
python-source = "python"
6057
module-name = "datafusion._internal"
61-
include = [
62-
{ path = "Cargo.lock", format = "sdist" }
63-
]
58+
include = [{ path = "Cargo.lock", format = "sdist" }]
6459
exclude = [".github/**", "ci/**", ".asf.yaml"]
6560
# Require Cargo.lock is up to date
6661
locked = true

python/datafusion/context.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from datafusion.record_batch import RecordBatchStream
3131
from datafusion.udf import ScalarUDF, AggregateUDF, WindowUDF
3232

33-
from typing import Any, TYPE_CHECKING
33+
from typing import Any, TYPE_CHECKING, Protocol
3434
from typing_extensions import deprecated
3535

3636
if TYPE_CHECKING:
@@ -41,6 +41,28 @@
4141
from datafusion.plan import LogicalPlan, ExecutionPlan
4242

4343

44+
class ArrowStreamExportable(Protocol):
45+
"""Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface.
46+
47+
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
48+
"""
49+
50+
def __arrow_c_stream__( # noqa: D105
51+
self, requested_schema: object | None = None
52+
) -> object: ...
53+
54+
55+
class ArrowArrayExportable(Protocol):
56+
"""Type hint for object exporting Arrow C Array via Arrow PyCapsule Interface.
57+
58+
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
59+
"""
60+
61+
def __arrow_c_array__( # noqa: D105
62+
self, requested_schema: object | None = None
63+
) -> tuple[object, object]: ...
64+
65+
4466
class SessionConfig:
4567
"""Session configuration options."""
4668

@@ -592,12 +614,18 @@ def from_pydict(
592614
"""
593615
return DataFrame(self.ctx.from_pydict(data, name))
594616

595-
def from_arrow(self, data: Any, name: str | None = None) -> DataFrame:
617+
def from_arrow(
618+
self,
619+
data: ArrowStreamExportable | ArrowArrayExportable,
620+
name: str | None = None,
621+
) -> DataFrame:
596622
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow source.
597623
598624
The Arrow data source can be any object that implements either
599625
``__arrow_c_stream__`` or ``__arrow_c_array__``. For the latter, it must return
600-
a struct array. Common examples of sources from pyarrow include
626+
a struct array.
627+
628+
Arrow data can be Polars, Pandas, Pyarrow etc.
601629
602630
Args:
603631
data: Arrow data source.

python/datafusion/dataframe.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
from __future__ import annotations
2323

24-
from typing import Any, Iterable, List, TYPE_CHECKING
24+
25+
from typing import Any, Iterable, List, Literal, TYPE_CHECKING
2526
from datafusion.record_batch import RecordBatchStream
2627
from typing_extensions import deprecated
2728
from datafusion.plan import LogicalPlan, ExecutionPlan
@@ -129,6 +130,17 @@ def select(self, *exprs: Expr | str) -> DataFrame:
129130
]
130131
return DataFrame(self.df.select(*exprs_internal))
131132

133+
def drop(self, *columns: str) -> DataFrame:
134+
"""Drop arbitrary amount of columns.
135+
136+
Args:
137+
columns: Column names to drop from the dataframe.
138+
139+
Returns:
140+
DataFrame with those columns removed in the projection.
141+
"""
142+
return DataFrame(self.df.drop(*columns))
143+
132144
def filter(self, *predicates: Expr) -> DataFrame:
133145
"""Return a DataFrame for which ``predicate`` evaluates to ``True``.
134146
@@ -163,14 +175,25 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
163175
def with_columns(
164176
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
165177
) -> DataFrame:
166-
"""Add an additional column to the DataFrame.
178+
"""Add columns to the DataFrame.
179+
180+
By passing expressions, iteratables of expressions, or named expressions. To
181+
pass named expressions use the form name=Expr.
182+
183+
Example usage: The following will add 4 columns labeled a, b, c, and d::
184+
185+
df = df.with_columns(
186+
lit(0).alias('a'),
187+
[lit(1).alias('b'), lit(2).alias('c')],
188+
d=lit(3)
189+
)
167190
168191
Args:
169-
*exprs: Name of the column to add.
170-
**named_exprs: Expression to compute the column.
192+
exprs: Either a single expression or an iterable of expressions to add.
193+
named_exprs: Named expressions in the form of ``name=expr``
171194
172195
Returns:
173-
DataFrame with the new column.
196+
DataFrame with the new columns added.
174197
"""
175198

176199
def _simplify_expression(
@@ -339,6 +362,29 @@ def join(
339362
"""
340363
return DataFrame(self.df.join(right.df, join_keys, how))
341364

365+
def join_on(
366+
self,
367+
right: DataFrame,
368+
*on_exprs: Expr,
369+
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
370+
) -> DataFrame:
371+
"""Join two :py:class:`DataFrame`using the specified expressions.
372+
373+
On expressions are used to support in-equality predicates. Equality
374+
predicates are correctly optimized
375+
376+
Args:
377+
right: Other DataFrame to join with.
378+
on_exprs: single or multiple (in)-equality predicates.
379+
how: Type of join to perform. Supported types are "inner", "left",
380+
"right", "full", "semi", "anti".
381+
382+
Returns:
383+
DataFrame after join.
384+
"""
385+
exprs = [expr.expr for expr in on_exprs]
386+
return DataFrame(self.df.join_on(right.df, exprs, how))
387+
342388
def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame:
343389
"""Return a DataFrame with the explanation of its plan so far.
344390

python/datafusion/expr.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,18 @@ def is_not_null(self) -> Expr:
406406
"""Returns ``True`` if this expression is not null."""
407407
return Expr(self.expr.is_not_null())
408408

409+
def fill_nan(self, value: Any | Expr | None = None) -> Expr:
410+
"""Fill NaN values with a provided value."""
411+
if not isinstance(value, Expr):
412+
value = Expr.literal(value)
413+
return Expr(functions_internal.nanvl(self.expr, value.expr))
414+
415+
def fill_null(self, value: Any | Expr | None = None) -> Expr:
416+
"""Fill NULL values with a provided value."""
417+
if not isinstance(value, Expr):
418+
value = Expr.literal(value)
419+
return Expr(functions_internal.nvl(self.expr, value.expr))
420+
409421
_to_pyarrow_types = {
410422
float: pa.float64(),
411423
int: pa.int64(),

python/datafusion/functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186
"min",
187187
"named_struct",
188188
"nanvl",
189+
"nvl",
189190
"now",
190191
"nth_value",
191192
"nullif",
@@ -673,6 +674,11 @@ def nanvl(x: Expr, y: Expr) -> Expr:
673674
return Expr(f.nanvl(x.expr, y.expr))
674675

675676

677+
def nvl(x: Expr, y: Expr) -> Expr:
678+
"""Returns ``x`` if ``x`` is not ``NULL``. Otherwise returns ``y``."""
679+
return Expr(f.nvl(x.expr, y.expr))
680+
681+
676682
def octet_length(arg: Expr) -> Expr:
677683
"""Returns the number of bytes of a string."""
678684
return Expr(f.octet_length(arg.expr))

python/tests/test_dataframe.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,17 @@ def test_sort(df):
169169
assert table.to_pydict() == expected
170170

171171

172+
def test_drop(df):
173+
df = df.drop("c")
174+
175+
# execute and collect the first (and only) batch
176+
result = df.collect()[0]
177+
178+
assert df.schema().names == ["a", "b"]
179+
assert result.column(0) == pa.array([1, 2, 3])
180+
assert result.column(1) == pa.array([4, 5, 6])
181+
182+
172183
def test_limit(df):
173184
df = df.limit(1)
174185

@@ -299,6 +310,42 @@ def test_join():
299310
assert table.to_pydict() == expected
300311

301312

313+
def test_join_on():
314+
ctx = SessionContext()
315+
316+
batch = pa.RecordBatch.from_arrays(
317+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
318+
names=["a", "b"],
319+
)
320+
df = ctx.create_dataframe([[batch]], "l")
321+
322+
batch = pa.RecordBatch.from_arrays(
323+
[pa.array([1, 2]), pa.array([-8, 10])],
324+
names=["a", "c"],
325+
)
326+
df1 = ctx.create_dataframe([[batch]], "r")
327+
328+
df2 = df.join_on(df1, column("l.a").__eq__(column("r.a")), how="inner")
329+
df2.show()
330+
df2 = df2.sort(column("l.a"))
331+
table = pa.Table.from_batches(df2.collect())
332+
333+
expected = {"a": [1, 2], "c": [-8, 10], "b": [4, 5]}
334+
assert table.to_pydict() == expected
335+
336+
df3 = df.join_on(
337+
df1,
338+
column("l.a").__eq__(column("r.a")),
339+
column("l.a").__lt__(column("r.c")),
340+
how="inner",
341+
)
342+
df3.show()
343+
df3 = df3.sort(column("l.a"))
344+
table = pa.Table.from_batches(df3.collect())
345+
expected = {"a": [2], "c": [10], "b": [5]}
346+
assert table.to_pydict() == expected
347+
348+
302349
def test_distinct():
303350
ctx = SessionContext()
304351

python/tests/test_expr.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import pyarrow
18+
import pyarrow as pa
1919
import pytest
2020
from datafusion import SessionContext, col
2121
from datafusion.expr import (
@@ -125,8 +125,8 @@ def test_sort(test_ctx):
125125
def test_relational_expr(test_ctx):
126126
ctx = SessionContext()
127127

128-
batch = pyarrow.RecordBatch.from_arrays(
129-
[pyarrow.array([1, 2, 3]), pyarrow.array(["alpha", "beta", "gamma"])],
128+
batch = pa.RecordBatch.from_arrays(
129+
[pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
130130
names=["a", "b"],
131131
)
132132
df = ctx.create_dataframe([[batch]], name="batch_array")
@@ -216,3 +216,30 @@ def test_display_name_deprecation():
216216
# returns appropriate result
217217
assert name == expr.schema_name()
218218
assert name == "foo"
219+
220+
221+
@pytest.fixture
222+
def df():
223+
ctx = SessionContext()
224+
225+
# create a RecordBatch and a new DataFrame from it
226+
batch = pa.RecordBatch.from_arrays(
227+
[pa.array([1, 2, None]), pa.array([4, None, 6]), pa.array([None, None, 8])],
228+
names=["a", "b", "c"],
229+
)
230+
231+
return ctx.from_arrow(batch)
232+
233+
234+
def test_fill_null(df):
235+
df = df.select(
236+
col("a").fill_null(100).alias("a"),
237+
col("b").fill_null(25).alias("b"),
238+
col("c").fill_null(1234).alias("c"),
239+
)
240+
df.show()
241+
result = df.collect()[0]
242+
243+
assert result.column(0) == pa.array([1, 2, 100])
244+
assert result.column(1) == pa.array([4, 25, 6])
245+
assert result.column(2) == pa.array([1234, 1234, 8])

0 commit comments

Comments
 (0)