Skip to content

Commit 8d6101c

Browse files
authored
feat(expr-ir): Support DataFrame.join_asof (#3378)
1 parent 058782d commit 8d6101c

File tree

6 files changed

+494
-57
lines changed

6 files changed

+494
-57
lines changed

narwhals/_plan/arrow/acero.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from narwhals._plan.common import ensure_list_str, temp
2929
from narwhals._plan.typing import NonCrossJoinStrategy, OneOrSeq
3030
from narwhals._utils import check_column_names_are_unique
31-
from narwhals.typing import JoinStrategy, SingleColSelector
31+
from narwhals.typing import AsofJoinStrategy, JoinStrategy, SingleColSelector
3232

3333
if TYPE_CHECKING:
3434
from collections.abc import (
@@ -47,7 +47,12 @@
4747
Aggregation as _Aggregation,
4848
)
4949
from narwhals._plan.arrow.group_by import AggSpec
50-
from narwhals._plan.arrow.typing import ArrowAny, JoinTypeSubset, ScalarAny
50+
from narwhals._plan.arrow.typing import (
51+
ArrowAny,
52+
ChunkedArrayAny,
53+
JoinTypeSubset,
54+
ScalarAny,
55+
)
5156
from narwhals._plan.typing import OneOrIterable, Seq
5257
from narwhals.typing import NonNestedLiteral
5358

@@ -278,6 +283,53 @@ def _hashjoin(
278283
return Decl("hashjoin", options, [_into_decl(left), _into_decl(right)])
279284

280285

286+
def _join_asof_suffix_collisions(
287+
left: pa.Table, right: pa.Table, right_on: str, right_by: Sequence[str], suffix: str
288+
) -> pa.Table:
289+
"""Adapted from [upstream] to avoid raising early.
290+
291+
[upstream]: https://github.com/apache/arrow/blob/9b03118e834dfdaa0cf9e03595477b499252a9cb/python/pyarrow/acero.py#L306-L316
292+
"""
293+
right_names = right.schema.names
294+
allowed = {right_on, *right_by}
295+
if collisions := set(right_names).difference(allowed).intersection(left.schema.names):
296+
renamed = [f"{nm}{suffix}" if nm in collisions else nm for nm in right_names]
297+
return right.rename_columns(renamed)
298+
return right
299+
300+
301+
def _join_asof_strategy_to_tolerance(
302+
left_on: ChunkedArrayAny, right_on: ChunkedArrayAny, /, strategy: AsofJoinStrategy
303+
) -> int:
304+
"""Calculate the **required** `tolerance` argument, from `*on` values and strategy.
305+
306+
For both `polars` and `pandas` this is optional and in `narwhals` it isn't supported (yet).
307+
308+
So we need to get the lowest/highest value for a match and use that for a similar default.
309+
310+
`"backward"`:
311+
312+
tolerance <= right_on - left_on <= 0
313+
314+
`"forward"`:
315+
316+
0 <= right_on - left_on <= tolerance
317+
318+
Note:
319+
`tolerance` is interpreted in the same units as the `on` keys.
320+
"""
321+
import narwhals._plan.arrow.functions as fn
322+
323+
if strategy == "nearest":
324+
msg = "Only 'backward' and 'forward' strategies are currently supported for `pyarrow`"
325+
raise NotImplementedError(msg)
326+
lower = fn.min_horizontal(fn.min_(left_on), fn.min_(right_on))
327+
upper = fn.max_horizontal(fn.max_(left_on), fn.max_(right_on))
328+
scalar = fn.sub(lower, upper) if strategy == "backward" else fn.sub(upper, lower)
329+
tolerance: int = fn.cast(scalar, fn.I64).as_py()
330+
return tolerance
331+
332+
281333
def declare(*declarations: Decl) -> Decl:
282334
"""Compose one or more `Declaration` nodes for execution as a pipeline."""
283335
if len(declarations) == 1:
@@ -439,6 +491,31 @@ def join_cross_tables(
439491
return collect(decl, ensure_unique_column_names=True).remove_column(0)
440492

441493

494+
def join_asof_tables(
495+
left: pa.Table,
496+
right: pa.Table,
497+
left_on: str,
498+
right_on: str,
499+
*,
500+
left_by: Sequence[str] = (),
501+
right_by: Sequence[str] = (),
502+
strategy: AsofJoinStrategy = "backward",
503+
suffix: str = "_right",
504+
) -> pa.Table:
505+
"""Perform an inexact join between two tables, using the nearest key."""
506+
right = _join_asof_suffix_collisions(left, right, right_on, right_by, suffix=suffix)
507+
tolerance = _join_asof_strategy_to_tolerance(
508+
left.column(left_on), right.column(right_on), strategy
509+
)
510+
lb: list[Any] = [] if not left_by else list(left_by)
511+
rb: list[Any] = [] if not right_by else list(right_by)
512+
join_opts = pac.AsofJoinNodeOptions(
513+
left_on=left_on, right_on=right_on, left_by=lb, right_by=rb, tolerance=tolerance
514+
)
515+
inputs = [table_source(left), table_source(right)]
516+
return Decl("asofjoin", join_opts, inputs).to_table()
517+
518+
442519
def _add_column_table(
443520
native: pa.Table, index: int, name: str, values: IntoExpr | ArrowAny
444521
) -> pa.Table:

narwhals/_plan/arrow/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class ArrowFrameSeries(Generic[NativeT]):
3131
_native: NativeT
3232
_version: Version
3333

34+
# NOTE: Aliases to integrate with `@requires.backend_version`
35+
_backend_version = compat.BACKEND_VERSION
36+
_implementation = implementation
37+
3438
@property
3539
def native(self) -> NativeT:
3640
return self._native

narwhals/_plan/arrow/dataframe.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from narwhals._plan.compliant.typing import LazyFrameAny, namespace
2121
from narwhals._plan.exceptions import shape_error
2222
from narwhals._plan.expressions import NamedIR, named_ir
23-
from narwhals._utils import Version, generate_repr
23+
from narwhals._utils import Version, generate_repr, requires
2424
from narwhals.schema import Schema
2525

2626
if TYPE_CHECKING:
@@ -37,7 +37,7 @@
3737
from narwhals._plan.typing import NonCrossJoinStrategy
3838
from narwhals._typing import _LazyAllowedImpl
3939
from narwhals.dtypes import DType
40-
from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy
40+
from narwhals.typing import AsofJoinStrategy, IntoSchema, PivotAgg, UniqueKeepStrategy
4141

4242
Incomplete: TypeAlias = Any
4343

@@ -308,6 +308,31 @@ def join_inner(self, other: Self, on: list[str], /) -> Self:
308308
"""Less flexible, but more direct equivalent to join(how="inner", left_on=...)`."""
309309
return self._with_native(acero.join_inner_tables(self.native, other.native, on))
310310

311+
@requires.backend_version((16,))
312+
def join_asof(
313+
self,
314+
other: Self,
315+
*,
316+
left_on: str,
317+
right_on: str,
318+
left_by: Sequence[str] = (),
319+
right_by: Sequence[str] = (),
320+
strategy: AsofJoinStrategy = "backward",
321+
suffix: str = "_right",
322+
) -> Self:
323+
return self._with_native(
324+
acero.join_asof_tables(
325+
self.native,
326+
other.native,
327+
left_on,
328+
right_on,
329+
left_by=left_by,
330+
right_by=right_by,
331+
strategy=strategy,
332+
suffix=suffix,
333+
)
334+
)
335+
311336
def _filter(self, predicate: Predicate | acero.Expr) -> Self:
312337
mask: Incomplete = predicate
313338
return self._with_native(self.native.filter(mask))

narwhals/_plan/compliant/dataframe.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
4646
from narwhals._utils import Implementation, Version
4747
from narwhals.dtypes import DType
48-
from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy
48+
from narwhals.typing import AsofJoinStrategy, IntoSchema, PivotAgg, UniqueKeepStrategy
4949

5050
Incomplete: TypeAlias = Any
5151

@@ -72,6 +72,27 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
7272
def explode(self, subset: Sequence[str], options: ExplodeOptions) -> Self: ...
7373
# Shouldn't *need* to be `NamedIR`, but current impl depends on a name being passed around
7474
def filter(self, predicate: NamedIR, /) -> Self: ...
75+
def join(
76+
self,
77+
other: Self,
78+
*,
79+
how: NonCrossJoinStrategy,
80+
left_on: Sequence[str],
81+
right_on: Sequence[str],
82+
suffix: str = "_right",
83+
) -> Self: ...
84+
def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ...
85+
def join_asof(
86+
self,
87+
other: Self,
88+
*,
89+
left_on: str,
90+
right_on: str,
91+
left_by: Sequence[str] = (), # https://github.com/pola-rs/polars/issues/18496
92+
right_by: Sequence[str] = (),
93+
strategy: AsofJoinStrategy = "backward",
94+
suffix: str = "_right",
95+
) -> Self: ...
7596
def rename(self, mapping: Mapping[str, str]) -> Self: ...
7697
@property
7798
def schema(self) -> Mapping[str, DType]: ...
@@ -213,16 +234,6 @@ def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Se
213234

214235
def filter(self, predicate: NamedIR, /) -> Self: ...
215236
def iter_columns(self) -> Iterator[SeriesT]: ...
216-
def join(
217-
self,
218-
other: Self,
219-
*,
220-
how: NonCrossJoinStrategy,
221-
left_on: Sequence[str],
222-
right_on: Sequence[str],
223-
suffix: str = "_right",
224-
) -> Self: ...
225-
def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ...
226237
def partition_by(
227238
self, by: Sequence[str], *, include_key: bool = True
228239
) -> list[Self]: ...

0 commit comments

Comments
 (0)