Skip to content

Commit adb6b7a

Browse files
authored
Merge branch 'main' into series-from-numpy
2 parents 4176a67 + df38225 commit adb6b7a

File tree

9 files changed

+89
-21
lines changed

9 files changed

+89
-21
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ Join the party!
114114
- [darts](https://github.com/unit8co/darts)
115115
- [hierarchicalforecast](https://github.com/Nixtla/hierarchicalforecast)
116116
- [marimo](https://github.com/marimo-team/marimo)
117+
- [metalearners](https://github.com/Quantco/metalearners)
117118
- [panel-graphic-walker](https://github.com/panel-extensions/panel-graphic-walker)
118119
- [plotly](https://plotly.com)
119120
- [pointblank](https://github.com/posit-dev/pointblank)

docs/ecosystem.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ The following is a non-exhaustive list of libraries and tools that choose to use
66
for their dataframe interoperability needs:
77

88
* [altair](https://github.com/vega/altair/)
9+
* [bokeh](https://github.com/bokeh/bokeh)
910
* [darts](https://github.com/unit8co/darts)
1011
* [hierarchicalforecast](https://github.com/Nixtla/hierarchicalforecast)
1112
* [marimo](https://github.com/marimo-team/marimo)
13+
* [metalearners](https://github.com/Quantco/metalearners)
1214
* [panel-graphic-walker](https://github.com/panel-extensions/panel-graphic-walker)
1315
* [plotly](https://github.com/plotly/plotly.py)
1416
* [pointblank](https://github.com/posit-dev/pointblank)

narwhals/_arrow/group_by.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from typing import TYPE_CHECKING
66
from typing import Any
77
from typing import Iterator
8-
from typing import cast
98

109
import pyarrow as pa
1110
import pyarrow.compute as pc
1211

12+
from narwhals._arrow.utils import cast_to_comparable_string_types
1313
from narwhals._arrow.utils import extract_py_scalar
1414
from narwhals._expression_parsing import evaluate_output_names_and_aliases
1515
from narwhals._expression_parsing import is_elementary_expression
@@ -146,15 +146,17 @@ def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
146146

147147
table = self._df._native_frame
148148
# NOTE: stubs fail in multiple places for `ChunkedArray`
149-
it = cast(
150-
"Iterator[pa.StringArray]",
151-
(table[key].cast(pa.string()) for key in self._keys),
149+
it, separator_scalar = cast_to_comparable_string_types(
150+
*(table[key] for key in self._keys), separator=""
152151
)
153152
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
154153
# Reality: `str` is fine
155154
concat_str: Incomplete = pc.binary_join_element_wise
156155
key_values = concat_str(
157-
*it, "", null_handling="replace", null_replacement=null_token
156+
*it,
157+
separator_scalar,
158+
null_handling="replace",
159+
null_replacement=null_token,
158160
)
159161
table = table.add_column(i=0, field_=col_token, column=key_values)
160162
for v in pc.unique(key_values):

narwhals/_arrow/namespace.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from narwhals._arrow.selectors import ArrowSelectorNamespace
2020
from narwhals._arrow.series import ArrowSeries
2121
from narwhals._arrow.utils import align_series_full_broadcast
22+
from narwhals._arrow.utils import cast_to_comparable_string_types
2223
from narwhals._arrow.utils import diagonal_concat
2324
from narwhals._arrow.utils import extract_dataframe_comparand
2425
from narwhals._arrow.utils import horizontal_concat
@@ -285,27 +286,27 @@ def concat_str(
285286
separator: str,
286287
ignore_nulls: bool,
287288
) -> ArrowExpr:
288-
dtypes = import_dtypes_module(self._version)
289-
290289
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
291290
compliant_series_list = align_series_full_broadcast(
292-
*(chain.from_iterable(expr.cast(dtypes.String())(df) for expr in exprs))
291+
*(chain.from_iterable(expr(df) for expr in exprs))
293292
)
293+
name = compliant_series_list[0].name
294294
null_handling: Literal["skip", "emit_null"] = (
295295
"skip" if ignore_nulls else "emit_null"
296296
)
297-
it = (s._native_series for s in compliant_series_list)
297+
it, separator_scalar = cast_to_comparable_string_types(
298+
*(s.native for s in compliant_series_list), separator=separator
299+
)
298300
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
299301
# Reality: `str` is fine
300302
concat_str: Incomplete = pc.binary_join_element_wise
301-
return [
302-
ArrowSeries(
303-
native_series=concat_str(*it, separator, null_handling=null_handling),
304-
name=compliant_series_list[0].name,
305-
backend_version=self._backend_version,
306-
version=self._version,
307-
)
308-
]
303+
compliant = self._series(
304+
concat_str(*it, separator_scalar, null_handling=null_handling),
305+
name=name,
306+
backend_version=self._backend_version,
307+
version=self._version,
308+
)
309+
return [compliant]
309310

310311
return self._expr._from_callable(
311312
func=func,

narwhals/_arrow/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING
66
from typing import Any
77
from typing import Iterable
8+
from typing import Iterator
89
from typing import Sequence
910
from typing import cast
1011
from typing import overload
@@ -543,6 +544,19 @@ def pad_series(
543544
return series._from_native_series(concat), offset_left + offset_right
544545

545546

547+
def cast_to_comparable_string_types(
548+
*chunked_arrays: ArrowChunkedArray,
549+
separator: str,
550+
) -> tuple[Iterator[ArrowChunkedArray], pa.Scalar[Any]]:
551+
# Ensure `chunked_arrays` are either all `string` or all `large_string`.
552+
dtype = (
553+
pa.string() # (PyArrow default)
554+
if not any(pa.types.is_large_string(ca.type) for ca in chunked_arrays)
555+
else pa.large_string()
556+
)
557+
return (ca.cast(dtype) for ca in chunked_arrays), lit(separator, dtype)
558+
559+
546560
class ArrowSeriesNamespace(_SeriesNamespace["ArrowSeries", "ArrowChunkedArray"]):
547561
def __init__(self: Self, series: ArrowSeries, /) -> None:
548562
self._compliant_series = series

narwhals/functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,8 @@ def from_dict(
309309
310310
Arguments:
311311
data: Dictionary to create DataFrame from.
312-
schema: The DataFrame schema as Schema or dict of {name: type}.
312+
schema: The DataFrame schema as Schema or dict of {name: type}. If not
313+
specified, the schema will be inferred by the native library.
313314
backend: specifies which eager backend instantiate to. Only
314315
necessary if inputs are not Narwhals Series.
315316
@@ -1593,7 +1594,7 @@ def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr:
15931594
Arguments:
15941595
value: The value to use as literal.
15951596
dtype: The data type of the literal value. If not provided, the data type will
1596-
be inferred.
1597+
be inferred by the native library.
15971598
15981599
Returns:
15991600
A new expression.

narwhals/stable/v1/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1869,7 +1869,7 @@ def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr:
18691869
Arguments:
18701870
value: The value to use as literal.
18711871
dtype: The data type of the literal value. If not provided, the data type will
1872-
be inferred.
1872+
be inferred by the native library.
18731873
18741874
Returns:
18751875
A new expression.
@@ -2228,7 +2228,8 @@ def from_dict(
22282228
22292229
Arguments:
22302230
data: Dictionary to create DataFrame from.
2231-
schema: The DataFrame schema as Schema or dict of {name: type}.
2231+
schema: The DataFrame schema as Schema or dict of {name: type}. If not
2232+
specified, the schema will be inferred by the native library.
22322233
backend: specifies which eager backend instantiate to. Only
22332234
necessary if inputs are not Narwhals Series.
22342235

tests/expr_and_series/concat_str_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
from typing import Callable
4+
5+
import pyarrow as pa
36
import pytest
47

58
import narwhals.stable.v1 as nw
@@ -71,3 +74,37 @@ def test_concat_str_with_lit(constructor: Constructor) -> None:
7174
result = df.with_columns(b=nw.concat_str("a", nw.lit("ab")))
7275
expected = {"a": ["cat", "dog", "pig"], "b": ["catab", "dogab", "pigab"]}
7376
assert_equal_data(result, expected)
77+
78+
79+
@pytest.mark.parametrize(
80+
("input_schema", "input_values", "expected_function"),
81+
[
82+
(
83+
[("store", pa.large_string()), ("item", pa.large_string())],
84+
["a", "b"],
85+
pa.types.is_large_string,
86+
),
87+
(
88+
[("store", pa.large_string()), ("item", pa.int32())],
89+
[0, 1],
90+
pa.types.is_large_string,
91+
),
92+
([("store", pa.string()), ("item", pa.int32())], [0, 1], pa.types.is_string),
93+
([("store", pa.string()), ("item", pa.string())], ["a", "b"], pa.types.is_string),
94+
],
95+
)
96+
def test_pyarrow_string_type(
97+
input_schema: list[tuple[str, pa.DataType]],
98+
input_values: list[object],
99+
expected_function: Callable[[pa.DataType], bool],
100+
) -> None:
101+
df = pa.table(
102+
{"store": ["foo", "bar"], "item": input_values}, schema=pa.schema(input_schema)
103+
)
104+
result = (
105+
nw.from_native(df)
106+
.with_columns(store_item=nw.concat_str("store", "item", separator="-"))
107+
.to_native()
108+
.schema
109+
)
110+
assert expected_function(result.field("store_item").type)

tests/expr_and_series/lit_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any
66

77
import numpy as np
8+
import pyarrow as pa
89
import pytest
910

1011
import narwhals.stable.v1 as nw
@@ -129,3 +130,11 @@ def test_date_lit(constructor: Constructor, request: pytest.FixtureRequest) -> N
129130
assert result == {"a": nw.Int64, "literal": nw.Datetime}
130131
else:
131132
assert result == {"a": nw.Int64, "literal": nw.Date}
133+
134+
135+
def test_pyarrow_lit_string() -> None:
136+
df = nw.from_native(pa.table({"a": [1, 2, 3]}))
137+
result = df.select(nw.lit("foo")).to_native().schema.field("literal")
138+
assert pa.types.is_string(result.type)
139+
result = df.select(nw.lit("foo", dtype=nw.String)).to_native().schema.field("literal")
140+
assert pa.types.is_string(result.type)

0 commit comments

Comments
 (0)