Skip to content

Commit aef26ce

Browse files
authored
feat: Allow os.PathLike[str] in {read,scan}_* functions (#3112)
* test: Update fixtures for `str | Path` - First step of #3100 - Some of the error cases don't need to test both * fix(typing): Resolve most new warnings - `pyarrow.parquet.read_table` is the only one left (2x warns) - Seems to be a stub issue - Runtime *does* check `__fspath__` - https://github.com/apache/arrow/blob/982d31f35fd2cfe87494698dae9ef67d3333658c/python/pyarrow/parquet/core.py#L1381-L1391 * chore(typing): Ignore `pyarrow-stubs` issue for now (See last commit description) * fix: Normalize path for `duckdb` * ci: Ensure `pyspark` gets triggered None of the existing rules apply to IO, but this is pretty important to keep working * fix: Use `normalize_path` for spark-like Fixes https://github.com/narwhals-dev/narwhals/actions/runs/17552195156/job/49847143668?pr=3112 * test: Add failing `PathLike` tests Towards #3112 (comment) * fix: Support `__fspath__` everywhere Resolves #3112 (comment) * refactor(typing): Use `FileSource` in `normalize_path` * test(perf): Don't parametrize paths on fail cases All of these are validating something unrelated to `source` * docs: Document `FileSource` alias
1 parent 4f4f713 commit aef26ce

File tree

6 files changed

+87
-37
lines changed

6 files changed

+87
-37
lines changed

.github/workflows/pytest-pyspark.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ on:
66
- narwhals/_expression_parsing.py
77
- narwhals/_spark_like/**
88
- narwhals/_sql/**
9+
- tests/*scan*.py
10+
- tests/frame/*sink*.py
911
schedule:
1012
- cron: 0 12 * * 0 # Sunday at mid-day
1113

narwhals/_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
CompliantLazyFrame,
117117
CompliantSeries,
118118
DTypes,
119+
FileSource,
119120
IntoSeriesT,
120121
MultiIndexSelector,
121122
SingleIndexSelector,
@@ -2156,3 +2157,11 @@ def to_pyarrow_table(tbl: pa.Table | pa.RecordBatchReader) -> pa.Table:
21562157
if isinstance(tbl, pa.RecordBatchReader): # pragma: no cover
21572158
return pa.Table.from_batches(tbl)
21582159
return tbl
2160+
2161+
2162+
def normalize_path(source: FileSource, /) -> str:
2163+
if isinstance(source, str):
2164+
return source
2165+
from pathlib import Path
2166+
2167+
return str(Path(source))

narwhals/functions.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
is_compliant_expr,
2222
is_eager_allowed,
2323
is_sequence_but_not_str,
24+
normalize_path,
2425
supports_arrow_c_stream,
2526
validate_laziness,
2627
)
@@ -46,6 +47,7 @@
4647
from narwhals.dataframe import DataFrame, LazyFrame
4748
from narwhals.typing import (
4849
ConcatMethod,
50+
FileSource,
4951
FrameT,
5052
IntoDType,
5153
IntoExpr,
@@ -564,7 +566,7 @@ def show_versions() -> None:
564566

565567

566568
def read_csv(
567-
source: str, *, backend: IntoBackend[EagerAllowed], **kwargs: Any
569+
source: FileSource, *, backend: IntoBackend[EagerAllowed], **kwargs: Any
568570
) -> DataFrame[Any]:
569571
"""Read a CSV file into a DataFrame.
570572
@@ -604,7 +606,7 @@ def read_csv(
604606
Implementation.MODIN,
605607
Implementation.CUDF,
606608
}:
607-
native_frame = native_namespace.read_csv(source, **kwargs)
609+
native_frame = native_namespace.read_csv(normalize_path(source), **kwargs)
608610
elif impl is Implementation.PYARROW:
609611
from pyarrow import csv # ignore-banned-import
610612

@@ -634,7 +636,7 @@ def read_csv(
634636

635637

636638
def scan_csv(
637-
source: str, *, backend: IntoBackend[Backend], **kwargs: Any
639+
source: FileSource, *, backend: IntoBackend[Backend], **kwargs: Any
638640
) -> LazyFrame[Any]:
639641
"""Lazily read from a CSV file.
640642
@@ -674,6 +676,7 @@ def scan_csv(
674676
implementation = Implementation.from_backend(backend)
675677
native_namespace = implementation.to_native_namespace()
676678
native_frame: NativeDataFrame | NativeLazyFrame
679+
source = normalize_path(source)
677680
if implementation is Implementation.POLARS:
678681
native_frame = native_namespace.scan_csv(source, **kwargs)
679682
elif implementation in {
@@ -693,7 +696,6 @@ def scan_csv(
693696
if (session := kwargs.pop("session", None)) is None:
694697
msg = "Spark like backends require a session object to be passed in `kwargs`."
695698
raise ValueError(msg)
696-
697699
csv_reader = session.read.format("csv")
698700
native_frame = (
699701
csv_reader.load(source)
@@ -715,7 +717,7 @@ def scan_csv(
715717

716718

717719
def read_parquet(
718-
source: str, *, backend: IntoBackend[EagerAllowed], **kwargs: Any
720+
source: FileSource, *, backend: IntoBackend[EagerAllowed], **kwargs: Any
719721
) -> DataFrame[Any]:
720722
"""Read into a DataFrame from a parquet file.
721723
@@ -760,11 +762,12 @@ def read_parquet(
760762
Implementation.MODIN,
761763
Implementation.CUDF,
762764
}:
765+
source = normalize_path(source)
763766
native_frame = native_namespace.read_parquet(source, **kwargs)
764767
elif impl is Implementation.PYARROW:
765768
import pyarrow.parquet as pq # ignore-banned-import
766769

767-
native_frame = pq.read_table(source, **kwargs)
770+
native_frame = pq.read_table(source, **kwargs) # type: ignore[arg-type]
768771
elif impl in {
769772
Implementation.PYSPARK,
770773
Implementation.DASK,
@@ -790,7 +793,7 @@ def read_parquet(
790793

791794

792795
def scan_parquet(
793-
source: str, *, backend: IntoBackend[Backend], **kwargs: Any
796+
source: FileSource, *, backend: IntoBackend[Backend], **kwargs: Any
794797
) -> LazyFrame[Any]:
795798
"""Lazily read from a parquet file.
796799
@@ -857,6 +860,7 @@ def scan_parquet(
857860
implementation = Implementation.from_backend(backend)
858861
native_namespace = implementation.to_native_namespace()
859862
native_frame: NativeDataFrame | NativeLazyFrame
863+
source = normalize_path(source)
860864
if implementation is Implementation.POLARS:
861865
native_frame = native_namespace.scan_parquet(source, **kwargs)
862866
elif implementation in {
@@ -876,7 +880,6 @@ def scan_parquet(
876880
if (session := kwargs.pop("session", None)) is None:
877881
msg = "Spark like backends require a session object to be passed in `kwargs`."
878882
raise ValueError(msg)
879-
880883
pq_reader = session.read.format("parquet")
881884
native_frame = (
882885
pq_reader.load(source)

narwhals/stable/v1/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from narwhals.dataframe import MultiColSelector, MultiIndexSelector
8181
from narwhals.dtypes import DType
8282
from narwhals.typing import (
83+
FileSource,
8384
IntoDType,
8485
IntoExpr,
8586
IntoFrame,
@@ -1280,7 +1281,7 @@ def from_numpy(
12801281

12811282
@deprecate_native_namespace(required=True)
12821283
def read_csv(
1283-
source: str,
1284+
source: FileSource,
12841285
*,
12851286
backend: IntoBackend[EagerAllowed] | None = None,
12861287
native_namespace: ModuleType | None = None, # noqa: ARG001
@@ -1298,7 +1299,7 @@ def read_csv(
12981299

12991300
@deprecate_native_namespace(required=True)
13001301
def scan_csv(
1301-
source: str,
1302+
source: FileSource,
13021303
*,
13031304
backend: IntoBackend[Backend] | None = None,
13041305
native_namespace: ModuleType | None = None, # noqa: ARG001
@@ -1316,7 +1317,7 @@ def scan_csv(
13161317

13171318
@deprecate_native_namespace(required=True)
13181319
def read_parquet(
1319-
source: str,
1320+
source: FileSource,
13201321
*,
13211322
backend: IntoBackend[EagerAllowed] | None = None,
13221323
native_namespace: ModuleType | None = None, # noqa: ARG001
@@ -1334,7 +1335,7 @@ def read_parquet(
13341335

13351336
@deprecate_native_namespace(required=True)
13361337
def scan_parquet(
1337-
source: str,
1338+
source: FileSource,
13381339
*,
13391340
backend: IntoBackend[Backend] | None = None,
13401341
native_namespace: ModuleType | None = None, # noqa: ARG001

narwhals/typing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
if TYPE_CHECKING:
1010
import datetime as dt
11+
import os
1112
from collections.abc import Iterable, Sequence, Sized
1213
from decimal import Decimal
1314
from types import ModuleType
@@ -432,6 +433,15 @@ def Binary(self) -> type[dtypes.Binary]: ...
432433
IntoPolarsSchema: TypeAlias = "pl.Schema | Mapping[str, pl.DataType]"
433434
IntoPandasSchema: TypeAlias = Mapping[str, PandasLikeDType]
434435

436+
FileSource: TypeAlias = "str | os.PathLike[str]"
437+
"""Path to a file.
438+
439+
Either a string or an object that implements [`__fspath__`], such as [`pathlib.Path`].
440+
441+
[`__fspath__`]: https://docs.python.org/3/library/os.html#os.PathLike
442+
[`pathlib.Path`]: https://docs.python.org/3/library/pathlib.html#pathlib.Path
443+
"""
444+
435445

436446
# Annotations for `__getitem__` methods
437447
_T = TypeVar("_T")

tests/read_scan_test.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import re
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, Literal
55

66
import pandas as pd
77
import pytest
@@ -20,9 +20,15 @@
2020

2121
if TYPE_CHECKING:
2222
from collections.abc import Mapping
23+
from pathlib import Path
2324
from types import ModuleType
2425

26+
from typing_extensions import TypeAlias
27+
2528
from narwhals._typing import EagerAllowed, _LazyOnly, _SparkLike
29+
from narwhals.typing import FileSource
30+
31+
IOSourceKind: TypeAlias = Literal["str", "Path", "PathLike"]
2632

2733
data: Mapping[str, Any] = {"a": [1, 2, 3], "b": [4.5, 6.7, 8.9], "z": ["x", "y", "w"]}
2834
skipif_pandas_lt_1_5 = pytest.mark.skipif(
@@ -32,20 +38,39 @@
3238
spark_like_backend = pytest.mark.parametrize("backend", ["pyspark", "sqlframe"])
3339

3440

35-
@pytest.fixture(scope="module")
36-
def csv_path(tmp_path_factory: pytest.TempPathFactory) -> str:
41+
class MockPathLike:
42+
def __init__(self, path: Path) -> None:
43+
self._super_secret: Path = path
44+
45+
def __fspath__(self) -> str:
46+
return self._super_secret.__fspath__()
47+
48+
49+
def _into_file_source(source: Path, which: IOSourceKind, /) -> FileSource:
50+
mapping: Mapping[IOSourceKind, FileSource] = {
51+
"str": str(source),
52+
"Path": source,
53+
"PathLike": MockPathLike(source),
54+
}
55+
return mapping[which]
56+
57+
58+
@pytest.fixture(scope="module", params=["str", "Path", "PathLike"])
59+
def csv_path(
60+
tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest
61+
) -> FileSource:
3762
fp = tmp_path_factory.mktemp("data") / "file.csv"
38-
filepath = str(fp)
39-
pl.DataFrame(data).write_csv(filepath)
40-
return filepath
63+
pl.DataFrame(data).write_csv(fp)
64+
return _into_file_source(fp, request.param)
4165

4266

43-
@pytest.fixture(scope="module")
44-
def parquet_path(tmp_path_factory: pytest.TempPathFactory) -> str:
67+
@pytest.fixture(scope="module", params=["str", "Path", "PathLike"])
68+
def parquet_path(
69+
tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest
70+
) -> FileSource:
4571
fp = tmp_path_factory.mktemp("data") / "file.parquet"
46-
filepath = str(fp)
47-
pl.DataFrame(data).write_parquet(filepath)
48-
return filepath
72+
pl.DataFrame(data).write_parquet(fp)
73+
return _into_file_source(fp, request.param)
4974

5075

5176
def assert_equal_eager(result: nw.DataFrame[Any]) -> None:
@@ -62,23 +87,23 @@ def native_namespace(cb: Constructor, /) -> ModuleType:
6287
return nw.get_native_namespace(nw.from_native(cb(data))) # type: ignore[no-any-return]
6388

6489

65-
def test_read_csv(csv_path: str, eager_backend: EagerAllowed) -> None:
90+
def test_read_csv(csv_path: FileSource, eager_backend: EagerAllowed) -> None:
6691
assert_equal_eager(nw.read_csv(csv_path, backend=eager_backend))
6792

6893

6994
@skipif_pandas_lt_1_5
70-
def test_read_csv_kwargs(csv_path: str) -> None:
95+
def test_read_csv_kwargs(csv_path: FileSource) -> None:
7196
assert_equal_eager(nw.read_csv(csv_path, backend=pd, engine="pyarrow"))
7297

7398

7499
@lazy_core_backend
75-
def test_read_csv_raise_with_lazy(csv_path: str, backend: _LazyOnly) -> None:
100+
def test_read_csv_raise_with_lazy(backend: _LazyOnly) -> None:
76101
pytest.importorskip(backend)
77102
with pytest.raises(ValueError, match="Expected eager backend, found"):
78-
nw.read_csv(csv_path, backend=backend) # type: ignore[arg-type]
103+
nw.read_csv("unused.csv", backend=backend) # type: ignore[arg-type]
79104

80105

81-
def test_scan_csv(csv_path: str, constructor: Constructor) -> None:
106+
def test_scan_csv(csv_path: FileSource, constructor: Constructor) -> None:
82107
kwargs: dict[str, Any]
83108
if "sqlframe" in str(constructor):
84109
kwargs = {"session": sqlframe_session(), "inferSchema": True, "header": True}
@@ -91,29 +116,29 @@ def test_scan_csv(csv_path: str, constructor: Constructor) -> None:
91116

92117

93118
@skipif_pandas_lt_1_5
94-
def test_scan_csv_kwargs(csv_path: str) -> None:
119+
def test_scan_csv_kwargs(csv_path: FileSource) -> None:
95120
assert_equal_data(nw.scan_csv(csv_path, backend=pd, engine="pyarrow"), data)
96121

97122

98123
@skipif_pandas_lt_1_5
99-
def test_read_parquet(parquet_path: str, eager_backend: EagerAllowed) -> None:
124+
def test_read_parquet(parquet_path: FileSource, eager_backend: EagerAllowed) -> None:
100125
assert_equal_eager(nw.read_parquet(parquet_path, backend=eager_backend))
101126

102127

103128
@skipif_pandas_lt_1_5
104-
def test_read_parquet_kwargs(parquet_path: str) -> None:
129+
def test_read_parquet_kwargs(parquet_path: FileSource) -> None:
105130
assert_equal_eager(nw.read_parquet(parquet_path, backend=pd, engine="pyarrow"))
106131

107132

108133
@lazy_core_backend
109-
def test_read_parquet_raise_with_lazy(parquet_path: str, backend: _LazyOnly) -> None:
134+
def test_read_parquet_raise_with_lazy(backend: _LazyOnly) -> None:
110135
pytest.importorskip(backend)
111136
with pytest.raises(ValueError, match="Expected eager backend, found"):
112-
nw.read_parquet(parquet_path, backend=backend) # type: ignore[arg-type]
137+
nw.read_parquet("unused.parquet", backend=backend) # type: ignore[arg-type]
113138

114139

115140
@skipif_pandas_lt_1_5
116-
def test_scan_parquet(parquet_path: str, constructor: Constructor) -> None:
141+
def test_scan_parquet(parquet_path: FileSource, constructor: Constructor) -> None:
117142
kwargs: dict[str, Any]
118143
if "sqlframe" in str(constructor):
119144
kwargs = {"session": sqlframe_session(), "inferSchema": True}
@@ -126,16 +151,16 @@ def test_scan_parquet(parquet_path: str, constructor: Constructor) -> None:
126151

127152

128153
@skipif_pandas_lt_1_5
129-
def test_scan_parquet_kwargs(parquet_path: str) -> None:
154+
def test_scan_parquet_kwargs(parquet_path: FileSource) -> None:
130155
assert_equal_lazy(nw.scan_parquet(parquet_path, backend=pd, engine="pyarrow"))
131156

132157

133158
@spark_like_backend
134159
@pytest.mark.parametrize("scan_method", ["scan_csv", "scan_parquet"])
135160
def test_scan_fail_spark_like_without_session(
136-
parquet_path: str, backend: _SparkLike, scan_method: str
161+
backend: _SparkLike, scan_method: str
137162
) -> None:
138163
pytest.importorskip(backend)
139164
pattern = re.compile(r"spark.+backend.+require.+session", re.IGNORECASE)
140165
with pytest.raises(ValueError, match=pattern):
141-
getattr(nw, scan_method)(parquet_path, backend=backend)
166+
getattr(nw, scan_method)("unused.csv", backend=backend)

0 commit comments

Comments
 (0)