|
28 | 28 | from narwhals._plan.common import ensure_list_str, temp |
29 | 29 | from narwhals._plan.typing import NonCrossJoinStrategy, OneOrSeq |
30 | 30 | from narwhals._utils import check_column_names_are_unique |
31 | | -from narwhals.typing import JoinStrategy, SingleColSelector |
| 31 | +from narwhals.typing import AsofJoinStrategy, JoinStrategy, SingleColSelector |
32 | 32 |
|
33 | 33 | if TYPE_CHECKING: |
34 | 34 | from collections.abc import ( |
|
47 | 47 | Aggregation as _Aggregation, |
48 | 48 | ) |
49 | 49 | from narwhals._plan.arrow.group_by import AggSpec |
50 | | - from narwhals._plan.arrow.typing import ArrowAny, JoinTypeSubset, ScalarAny |
| 50 | + from narwhals._plan.arrow.typing import ( |
| 51 | + ArrowAny, |
| 52 | + ChunkedArrayAny, |
| 53 | + JoinTypeSubset, |
| 54 | + ScalarAny, |
| 55 | + ) |
51 | 56 | from narwhals._plan.typing import OneOrIterable, Seq |
52 | 57 | from narwhals.typing import NonNestedLiteral |
53 | 58 |
|
@@ -278,6 +283,53 @@ def _hashjoin( |
278 | 283 | return Decl("hashjoin", options, [_into_decl(left), _into_decl(right)]) |
279 | 284 |
|
280 | 285 |
|
| 286 | +def _join_asof_suffix_collisions( |
| 287 | + left: pa.Table, right: pa.Table, right_on: str, right_by: Sequence[str], suffix: str |
| 288 | +) -> pa.Table: |
| 289 | + """Adapted from [upstream] to avoid raising early. |
| 290 | +
|
| 291 | + [upstream]: https://github.com/apache/arrow/blob/9b03118e834dfdaa0cf9e03595477b499252a9cb/python/pyarrow/acero.py#L306-L316 |
| 292 | + """ |
| 293 | + right_names = right.schema.names |
| 294 | + allowed = {right_on, *right_by} |
| 295 | + if collisions := set(right_names).difference(allowed).intersection(left.schema.names): |
| 296 | + renamed = [f"{nm}{suffix}" if nm in collisions else nm for nm in right_names] |
| 297 | + return right.rename_columns(renamed) |
| 298 | + return right |
| 299 | + |
| 300 | + |
| 301 | +def _join_asof_strategy_to_tolerance( |
| 302 | + left_on: ChunkedArrayAny, right_on: ChunkedArrayAny, /, strategy: AsofJoinStrategy |
| 303 | +) -> int: |
| 304 | + """Calculate the **required** `tolerance` argument, from `*on` values and strategy. |
| 305 | +
|
| 306 | + For both `polars` and `pandas` this is optional and in `narwhals` it isn't supported (yet). |
| 307 | +
|
| 308 | + So we need to get the lowest/highest value for a match and use that for a similar default. |
| 309 | +
|
| 310 | + `"backward"`: |
| 311 | +
|
| 312 | + tolerance <= right_on - left_on <= 0 |
| 313 | +
|
| 314 | + `"forward"`: |
| 315 | +
|
| 316 | + 0 <= right_on - left_on <= tolerance |
| 317 | +
|
| 318 | + Note: |
| 319 | + `tolerance` is interpreted in the same units as the `on` keys. |
| 320 | + """ |
| 321 | + import narwhals._plan.arrow.functions as fn |
| 322 | + |
| 323 | + if strategy == "nearest": |
| 324 | + msg = "Only 'backward' and 'forward' strategies are currently supported for `pyarrow`" |
| 325 | + raise NotImplementedError(msg) |
| 326 | + lower = fn.min_horizontal(fn.min_(left_on), fn.min_(right_on)) |
| 327 | + upper = fn.max_horizontal(fn.max_(left_on), fn.max_(right_on)) |
| 328 | + scalar = fn.sub(lower, upper) if strategy == "backward" else fn.sub(upper, lower) |
| 329 | + tolerance: int = fn.cast(scalar, fn.I64).as_py() |
| 330 | + return tolerance |
| 331 | + |
| 332 | + |
281 | 333 | def declare(*declarations: Decl) -> Decl: |
282 | 334 | """Compose one or more `Declaration` nodes for execution as a pipeline.""" |
283 | 335 | if len(declarations) == 1: |
@@ -439,6 +491,31 @@ def join_cross_tables( |
439 | 491 | return collect(decl, ensure_unique_column_names=True).remove_column(0) |
440 | 492 |
|
441 | 493 |
|
| 494 | +def join_asof_tables( |
| 495 | + left: pa.Table, |
| 496 | + right: pa.Table, |
| 497 | + left_on: str, |
| 498 | + right_on: str, |
| 499 | + *, |
| 500 | + left_by: Sequence[str] = (), |
| 501 | + right_by: Sequence[str] = (), |
| 502 | + strategy: AsofJoinStrategy = "backward", |
| 503 | + suffix: str = "_right", |
| 504 | +) -> pa.Table: |
| 505 | + """Perform an inexact join between two tables, using the nearest key.""" |
| 506 | + right = _join_asof_suffix_collisions(left, right, right_on, right_by, suffix=suffix) |
| 507 | + tolerance = _join_asof_strategy_to_tolerance( |
| 508 | + left.column(left_on), right.column(right_on), strategy |
| 509 | + ) |
| 510 | + lb: list[Any] = [] if not left_by else list(left_by) |
| 511 | + rb: list[Any] = [] if not right_by else list(right_by) |
| 512 | + join_opts = pac.AsofJoinNodeOptions( |
| 513 | + left_on=left_on, right_on=right_on, left_by=lb, right_by=rb, tolerance=tolerance |
| 514 | + ) |
| 515 | + inputs = [table_source(left), table_source(right)] |
| 516 | + return Decl("asofjoin", join_opts, inputs).to_table() |
| 517 | + |
| 518 | + |
442 | 519 | def _add_column_table( |
443 | 520 | native: pa.Table, index: int, name: str, values: IntoExpr | ArrowAny |
444 | 521 | ) -> pa.Table: |
|
0 commit comments