Skip to content

Commit bc26492

Browse files
committed
feat(expr-ir): Add DataFrame.write_{csv,parquet}
Child of #2572
1 parent d47a7fa commit bc26492

File tree

4 files changed

+126
-2
lines changed

4 files changed

+126
-2
lines changed

narwhals/_plan/arrow/dataframe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
1515
from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy, partition_by
1616
from narwhals._plan.arrow.series import ArrowSeries as Series
17-
from narwhals._plan.common import temp
17+
from narwhals._plan.common import temp, todo
1818
from narwhals._plan.compliant.dataframe import EagerDataFrame
1919
from narwhals._plan.compliant.typing import namespace
2020
from narwhals._plan.exceptions import shape_error
@@ -191,6 +191,10 @@ def with_row_index_by(
191191
column = fn.unsort_indices(indices)
192192
return self._with_native(self.native.add_column(0, name, column))
193193

194+
write_csv = todo()
195+
write_parquet = todo()
196+
sink_parquet = todo()
197+
194198
def to_struct(self, name: str = "") -> Series:
195199
native = self.native
196200
if fn.TO_STRUCT_ARRAY_ACCEPTS_EMPTY:

narwhals/_plan/compliant/dataframe.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
if TYPE_CHECKING:
1919
from collections.abc import Iterator, Mapping, Sequence
20+
from io import BytesIO
2021

2122
import polars as pl
2223
from typing_extensions import Self, TypeAlias
@@ -36,7 +37,7 @@
3637
from narwhals._typing import _EagerAllowedImpl
3738
from narwhals._utils import Implementation, Version
3839
from narwhals.dtypes import DType
39-
from narwhals.typing import IntoSchema, UniqueKeepStrategy
40+
from narwhals.typing import FileSource, IntoSchema, UniqueKeepStrategy
4041

4142
Incomplete: TypeAlias = Any
4243

@@ -208,6 +209,12 @@ def unique_by(
208209
maintain_order: bool = False,
209210
) -> Self: ...
210211
def with_row_index(self, name: str) -> Self: ...
212+
@overload
213+
def write_csv(self, file: None) -> str: ...
214+
@overload
215+
def write_csv(self, file: FileSource | BytesIO) -> None: ...
216+
def write_csv(self, file: FileSource | BytesIO | None) -> str | None: ...
217+
def write_parquet(self, file: FileSource | BytesIO) -> None: ...
211218
def slice(self, offset: int, length: int | None = None) -> Self: ...
212219
def sample_frac(
213220
self, fraction: float, *, with_replacement: bool = False, seed: int | None = None
@@ -246,3 +253,6 @@ def with_columns(self, irs: Seq[NamedIR]) -> Self:
246253

247254
def to_series(self, index: int = 0) -> SeriesT:
248255
return self.get_column(self.columns[index])
256+
257+
# TODO @dangotbanned: Move to `CompliantLazyFrame` once that's added
258+
def sink_parquet(self, file: FileSource | BytesIO) -> None: ...

narwhals/_plan/dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from narwhals.schema import Schema
3131
from narwhals.typing import (
3232
EagerAllowed,
33+
FileSource,
3334
IntoBackend,
3435
IntoDType,
3536
IntoSchema,
@@ -39,6 +40,7 @@
3940

4041
if TYPE_CHECKING:
4142
from collections.abc import Iterable, Iterator, Mapping, Sequence
43+
from io import BytesIO
4244

4345
import polars as pl
4446
import pyarrow as pa
@@ -482,6 +484,16 @@ def with_row_index(
482484
return self._with_compliant(self._compliant.with_row_index(name))
483485
return super().with_row_index(name, order_by=order_by)
484486

487+
@overload
488+
def write_csv(self, file: None = None) -> str: ...
489+
@overload
490+
def write_csv(self, file: FileSource | BytesIO) -> None: ...
491+
def write_csv(self, file: FileSource | BytesIO | None = None) -> str | None:
492+
return self._compliant.write_csv(file)
493+
494+
def write_parquet(self, file: FileSource | BytesIO) -> None:
495+
return self._compliant.write_parquet(file)
496+
485497
def slice(self, offset: int, length: int | None = None) -> Self:
486498
return type(self)(self._compliant.slice(offset=offset, length=length))
487499

tests/plan/frame_export_test.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING, Literal
5+
6+
import pytest
7+
8+
from tests.plan.utils import dataframe
9+
from tests.utils import is_windows
10+
11+
if TYPE_CHECKING:
12+
from collections.abc import Mapping
13+
14+
from typing_extensions import TypeAlias
15+
16+
from narwhals.typing import FileSource
17+
from tests.conftest import Data
18+
19+
pytest.importorskip("pyarrow")
20+
21+
IOTargetKind: TypeAlias = Literal["str", "Path", "PathLike"]
22+
"""Duplicated from `tests.read_scan_test.py`.
23+
24+
Needs extending for `BytesIO`.
25+
"""
26+
27+
28+
class MockPathLike:
29+
def __init__(self, path: Path) -> None:
30+
self._super_secret: Path = path
31+
32+
def __fspath__(self) -> str:
33+
return self._super_secret.__fspath__()
34+
35+
36+
def _into_file_source(source: Path, which: IOTargetKind, /) -> FileSource:
37+
mapping: Mapping[IOTargetKind, FileSource] = {
38+
"str": str(source),
39+
"Path": source,
40+
"PathLike": MockPathLike(source),
41+
}
42+
return mapping[which]
43+
44+
45+
@pytest.fixture(params=["str", "Path", "PathLike"])
46+
def csv_path(
47+
tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest
48+
) -> FileSource:
49+
fp = tmp_path_factory.mktemp("data") / "file.csv"
50+
return _into_file_source(fp, request.param)
51+
52+
53+
@pytest.fixture(params=["str", "Path", "PathLike"])
54+
def parquet_path(
55+
tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest
56+
) -> FileSource:
57+
fp = tmp_path_factory.mktemp("data") / "file.parquet"
58+
return _into_file_source(fp, request.param)
59+
60+
61+
@pytest.fixture(scope="module")
62+
def data() -> Data:
63+
return {"a": [1, 2, 3]}
64+
65+
66+
XFAIL_DATAFRAME_EXPORT = pytest.mark.xfail(
67+
reason="TODO: `DataFrame.write_{csv,parquet}`()", raises=NotImplementedError
68+
)
69+
70+
71+
@XFAIL_DATAFRAME_EXPORT
72+
def test_write_csv(data: Data, csv_path: FileSource) -> None: # pragma: no cover
73+
df = dataframe(data)
74+
result_none = df.write_csv(csv_path)
75+
assert Path(csv_path).exists()
76+
assert result_none is None
77+
result = dataframe(data).write_csv()
78+
if is_windows(): # pragma: no cover
79+
result = result.replace("\r\n", "\n")
80+
if df.implementation.is_pyarrow():
81+
assert result == '"a"\n1\n2\n3\n'
82+
else: # pragma: no cover
83+
assert result == "a\n1\n2\n3\n"
84+
85+
86+
@XFAIL_DATAFRAME_EXPORT
87+
def test_write_parquet(data: Data, parquet_path: FileSource) -> None: # pragma: no cover
88+
dataframe(data).write_parquet(parquet_path)
89+
assert Path(parquet_path).exists()
90+
91+
92+
@pytest.mark.xfail(
93+
reason="TODO: `DataFrame.lazy()`, `LazyFrame.sink_parquet()`", raises=AttributeError
94+
)
95+
def test_sink_parquet(data: Data, parquet_path: FileSource) -> None: # pragma: no cover
96+
df = dataframe(data)
97+
df.lazy().sink_parquet(parquet_path) # type: ignore[attr-defined]
98+
assert Path(parquet_path).exists()

0 commit comments

Comments
 (0)