Skip to content

Commit aedc330

Browse files
committed
shift -> kernel, add fancy test
1 parent 1f830bc commit aedc330

File tree

3 files changed

+38
-15
lines changed

3 files changed

+38
-15
lines changed

narwhals/_plan/arrow/expr.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -398,22 +398,8 @@ def rolling_expr(self, node: ir.RollingExpr, frame: Frame, name: str) -> Self:
398398
raise NotImplementedError
399399

400400
def shift(self, node: ir.FunctionExpr[Shift], frame: Frame, name: str) -> Self:
401-
n = node.function.n
402401
series = self._dispatch_expr(node.input[0], frame, name)
403-
native = series.native
404-
if n == 0:
405-
return self._with_native(native, name)
406-
# NOTE: This might be zero-copy and the original wasn't?
407-
# here, the non-nulls stay in their own chunk
408-
# on main, the two (or more) chunks are merged
409-
if n > 0:
410-
arrays = [
411-
pa.nulls(n, native.type),
412-
*native.slice(length=native.length() - n).chunks,
413-
]
414-
else:
415-
arrays = [*native.slice(offset=-n).chunks, pa.nulls(-n, native.type)]
416-
return self._with_native(fn.chunked_array(arrays), name)
402+
return self._with_native(fn.shift(series.native, node.function.n), name)
417403

418404
def diff(self, node: ir.FunctionExpr[Diff], frame: Frame, name: str) -> Self:
419405
series = self._dispatch_expr(node.input[0], frame, name)

narwhals/_plan/arrow/functions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,17 @@ def diff(native: ChunkedOrArrayT) -> ChunkedOrArrayT:
262262
)
263263

264264

265+
def shift(native: ChunkedArrayAny, n: int) -> ChunkedArrayAny:
266+
if n == 0:
267+
return native
268+
arr = native
269+
if n > 0:
270+
arrays = [nulls_like(n, arr), *arr.slice(length=arr.length() - n).chunks]
271+
else:
272+
arrays = [*arr.slice(offset=-n).chunks, nulls_like(-n, arr)]
273+
return pa.chunked_array(arrays)
274+
275+
265276
def is_between(
266277
native: ChunkedOrScalar[ScalarT],
267278
lower: ChunkedOrScalar[ScalarT],
@@ -325,6 +336,14 @@ def int_range(
325336
return pa.chunked_array([pa.array(np.arange(start, end, step), dtype)])
326337

327338

339+
def nulls_like(n: int, native: ArrowAny) -> ArrayAny:
340+
"""Create a strongly-typed Array instance with all elements null.
341+
342+
Uses the type of `native`.
343+
"""
344+
return pa.nulls(n, native.type) # type: ignore[no-any-return]
345+
346+
328347
def lit(value: Any, dtype: DataType | None = None) -> NativeScalar:
329348
return pa.scalar(value) if dtype is None else pa.scalar(value, dtype)
330349

tests/plan/over_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ def data() -> Data:
2727
}
2828

2929

30+
@pytest.fixture
31+
def data_alt() -> Data:
32+
return {"a": [3, 5, 1, 2, None], "b": [0, 1, 3, 2, 1], "c": [9, 1, 2, 1, 1]}
33+
34+
3035
@pytest.mark.parametrize(
3136
"partition_by",
3237
[
@@ -177,3 +182,16 @@ def test_len_over_2369() -> None:
177182
result = df.with_columns(a_len_per_group=nwp.len().over("b")).sort("a")
178183
expected = {"a": [1, 2, 4], "b": ["x", "x", "y"], "a_len_per_group": [2, 2, 1]}
179184
assert_equal_data(result, expected)
185+
186+
187+
def test_shift_kitchen_sink(data_alt: Data) -> None:
188+
result = dataframe(data_alt).select(
189+
nwp.nth(1, 2)
190+
.shift(-1)
191+
.over(order_by=nwp.nth(0))
192+
.sort(nulls_last=True)
193+
.fill_null(100)
194+
* 5
195+
)
196+
expected = {"b": [0, 5, 10, 15, 500], "c": [5, 5, 10, 45, 500]}
197+
assert_equal_data(result, expected)

0 commit comments

Comments
 (0)