Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/user-guide/common-operations/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ approaches.
Indexing an element of an array via ``[]`` starts at index 0 whereas
:py:func:`~datafusion.functions.array_element` starts at index 1.

Starting in DataFusion 49.0.0 you can also create slices of array elements using
slice syntax from Python.

.. ipython:: python

df.select(col("a")[1:3].alias("second_two_elements"))

To check if an array is empty, you can use the function :py:func:`datafusion.functions.array_empty` or `datafusion.functions.empty`.
This function returns a boolean indicating whether the array is empty.

Expand Down
31 changes: 29 additions & 2 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,17 +352,44 @@ def __invert__(self) -> Expr:
"""Binary not (~)."""
return Expr(self.expr.__invert__())

def __getitem__(self, key: str | int) -> Expr:
def __getitem__(self, key: str | int | slice) -> Expr:
"""Retrieve sub-object.

If ``key`` is a string, returns the subfield of the struct.
If ``key`` is an integer, retrieves the element in the array. Note that the
element index begins at ``0``, unlike `array_element` which begins at ``1``.
element index begins at ``0``, unlike
:py:func:`~datafusion.functions.array_element` which begins at ``1``.
If ``key`` is a slice, returns an array that contains a slice of the
original array. Similar to integer indexing, this follows Python convention
where the index begins at ``0`` unlike
:py:func:`~datafusion.functions.array_slice` which begins at ``1``.
"""
if isinstance(key, int):
return Expr(
functions_internal.array_element(self.expr, Expr.literal(key + 1).expr)
)
if isinstance(key, slice):
if isinstance(key.start, int):
start = Expr.literal(key.start + 1).expr
elif isinstance(key.start, Expr):
start = (key.start + Expr.literal(1)).expr
else:
# Default start at the first element, index 1
start = Expr.literal(1).expr

if isinstance(key.stop, int):
stop = Expr.literal(key.stop).expr
else:
stop = key.stop.expr

if isinstance(key.step, int):
step = Expr.literal(key.step).expr
elif isinstance(key.step, Expr):
step = key.step.expr
else:
step = key.step

return Expr(functions_internal.array_slice(self.expr, start, stop, step))
return Expr(self.expr.__getitem__(key))

def __eq__(self, rhs: object) -> Expr:
Expand Down
2 changes: 1 addition & 1 deletion python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2655,7 +2655,7 @@ def cume_dist(
"""Create a cumulative distribution window function.

This window function is similar to :py:func:`rank` except that the returned values
are the ratio of the row number to the total numebr of rows. Here is an example of a
are the ratio of the row number to the total number of rows. Here is an example of a
dataframe with a window ordered by descending ``points`` and the associated
cumulative distribution::

Expand Down
29 changes: 28 additions & 1 deletion python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,30 @@ def py_flatten(arr):
lambda col: f.list_slice(col, literal(-1), literal(2)),
lambda data: [arr[-1:2] for arr in data],
),
(
lambda col: col[:3],
lambda data: [arr[:3] for arr in data],
),
(
lambda col: col[1:3],
lambda data: [arr[1:3] for arr in data],
),
(
lambda col: col[1:4:2],
lambda data: [arr[1:4:2] for arr in data],
),
(
lambda col: col[literal(1) : literal(4)],
lambda data: [arr[1:4] for arr in data],
),
(
lambda col: col[column("indices") : column("indices") + literal(2)],
lambda data: [[2.0, 3.0], [], [6.0]],
),
(
lambda col: col[literal(1) : literal(4) : literal(2)],
lambda data: [arr[1:4:2] for arr in data],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also test non literal expressions? Even if they are supposed to error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They will not error. Unit test added to demonstrate.

),
(
lambda col: f.array_intersect(col, literal([3.0, 4.0])),
lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
Expand Down Expand Up @@ -534,8 +558,11 @@ def py_flatten(arr):
)
def test_array_functions(stmt, py_expr):
data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
indices = [1, 3, 0]
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
batch = pa.RecordBatch.from_arrays(
[np.array(data, dtype=object), indices], names=["arr", "indices"]
)
df = ctx.create_dataframe([[batch]])

col = column("arr")
Expand Down
Loading