From f78eb05935d9a63e9309a653763aca67526d0050 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 28 Aug 2025 10:04:04 -0400 Subject: [PATCH 1/3] Allow passing a slice to and expression with the [] indexing --- python/datafusion/expr.py | 24 +++++++++++++++++++++++- python/tests/test_functions.py | 20 ++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index c0b495717..1ea2001b4 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -352,7 +352,7 @@ def __invert__(self) -> Expr: """Binary not (~).""" return Expr(self.expr.__invert__()) - def __getitem__(self, key: str | int) -> Expr: + def __getitem__(self, key: str | int | slice) -> Expr: """Retrieve sub-object. If ``key`` is a string, returns the subfield of the struct. @@ -363,6 +363,28 @@ def __getitem__(self, key: str | int) -> Expr: return Expr( functions_internal.array_element(self.expr, Expr.literal(key + 1).expr) ) + if isinstance(key, slice): + if isinstance(key.start, int): + start = Expr.literal(key.start + 1).expr + elif isinstance(key.start, Expr): + start = (key.start + Expr.literal(1)).expr + else: + # Default start at the first element, index 1 + start = Expr.literal(1).expr + + if isinstance(key.stop, int): + stop = Expr.literal(key.stop).expr + else: + stop = key.stop.expr + + if isinstance(key.step, int): + step = Expr.literal(key.step).expr + elif isinstance(key.step, Expr): + step = key.step.expr + else: + step = key.step + + return Expr(functions_internal.array_slice(self.expr, start, stop, step)) return Expr(self.expr.__getitem__(key)) def __eq__(self, rhs: object) -> Expr: diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 525915316..27bcbb050 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -494,6 +494,26 @@ def py_flatten(arr): lambda col: f.list_slice(col, literal(-1), literal(2)), lambda data: [arr[-1:2] for arr in data], ), + ( + lambda col: col[:3], + lambda data: [arr[:3] for arr in data], + ), + ( + lambda col: col[1:3], + lambda data: [arr[1:3] for arr in data], + ), + ( + lambda col: col[1:4:2], + lambda data: [arr[1:4:2] for arr in data], + ), + ( + lambda col: col[literal(1) : literal(4)], + lambda data: [arr[1:4] for arr in data], + ), + ( + lambda col: col[literal(1) : literal(4) : literal(2)], + lambda data: [arr[1:4:2] for arr in data], + ), ( lambda col: f.array_intersect(col, literal([3.0, 4.0])), lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], From c4b1e1b87fcc78d01473590070db211c2cb2b93d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 Aug 2025 08:58:27 -0400 Subject: [PATCH 2/3] Update documentation --- docs/source/user-guide/common-operations/expressions.rst | 7 +++++++ python/datafusion/expr.py | 7 ++++++- python/datafusion/functions.py | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/source/user-guide/common-operations/expressions.rst b/docs/source/user-guide/common-operations/expressions.rst index e94e1a6b5..72ad1c5e9 100644 --- a/docs/source/user-guide/common-operations/expressions.rst +++ b/docs/source/user-guide/common-operations/expressions.rst @@ -82,6 +82,13 @@ approaches. Indexing an element of an array via ``[]`` starts at index 0 whereas :py:func:`~datafusion.functions.array_element` starts at index 1. +Starting in DataFusion 49.0.0 you can also create slices of array elements using +slice syntax from Python. + +.. ipython:: python + + df.select(col("a")[1:3].alias("second_two_elements")) + To check if an array is empty, you can use the function :py:func:`datafusion.functions.array_empty` or `datafusion.functions.empty`. This function returns a boolean indicating whether the array is empty. diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 1ea2001b4..4b851cc7e 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -357,7 +357,12 @@ def __getitem__(self, key: str | int | slice) -> Expr: If ``key`` is a string, returns the subfield of the struct. If ``key`` is an integer, retrieves the element in the array. Note that the - element index begins at ``0``, unlike `array_element` which begins at ``1``. + element index begins at ``0``, unlike + :py:func:`~datafusion.functions.array_element` which begins at ``1``. + If ``key`` is a slice, returns an array that contains a slice of the + original array. Similar to integer indexing, this follows Python convention + where the index begins at ``0`` unlike + :py:func:`~datafusion.functions.array_slice` which begins at ``1``. """ if isinstance(key, int): return Expr( diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 34068805c..297034b1e 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -2655,7 +2655,7 @@ def cume_dist( """Create a cumulative distribution window function. This window function is similar to :py:func:`rank` except that the returned values - are the ratio of the row number to the total numebr of rows. Here is an example of a + are the ratio of the row number to the total number of rows. Here is an example of a dataframe with a window ordered by descending ``points`` and the associated cumulative distribution:: From c83fa8162f43ba9772b687533124ef701042a1ac Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 4 Sep 2025 15:06:08 -0400 Subject: [PATCH 3/3] Add unit test covering expressions in slice --- python/tests/test_functions.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 27bcbb050..ee19d021a 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -510,6 +510,10 @@ def py_flatten(arr): lambda col: col[literal(1) : literal(4)], lambda data: [arr[1:4] for arr in data], ), + ( + lambda col: col[column("indices") : column("indices") + literal(2)], + lambda data: [[2.0, 3.0], [], [6.0]], + ), ( lambda col: col[literal(1) : literal(4) : literal(2)], lambda data: [arr[1:4:2] for arr in data], @@ -554,8 +558,11 @@ def py_flatten(arr): ) def test_array_functions(stmt, py_expr): data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]] + indices = [1, 3, 0] ctx = SessionContext() - batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"]) + batch = pa.RecordBatch.from_arrays( + [np.array(data, dtype=object), indices], names=["arr", "indices"] + ) df = ctx.create_dataframe([[batch]]) col = column("arr")