diff --git a/narwhals/functions.py b/narwhals/functions.py index 6baef07688..bcc2a9b7c5 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -606,8 +606,36 @@ def show_versions() -> None: print(f"{k:>13}: {stat}") # noqa: T201 +def _validate_separator(separator: str, native_separator: str, **kwargs: Any) -> None: + if native_separator in kwargs and kwargs[native_separator] != separator: + msg = ( + f"`separator` and `{native_separator}` do not match: " + f"`separator`={separator} and `{native_separator}`={kwargs[native_separator]}." + ) + raise TypeError(msg) + + +def _validate_separator_pyarrow(separator: str, **kwargs: Any) -> Any: + if "parse_options" in kwargs: + parse_options = kwargs.pop("parse_options") + if parse_options.delimiter != separator: + msg = ( + "`separator` and `parse_options.delimiter` do not match: " + f"`separator`={separator} and `delimiter`={parse_options.delimiter}." + ) + raise TypeError(msg) + return kwargs + from pyarrow import csv # ignore-banned-import + + return {"parse_options": csv.ParseOptions(delimiter=separator)} + + def read_csv( - source: FileSource, *, backend: IntoBackend[EagerAllowed], **kwargs: Any + source: FileSource, + *, + backend: IntoBackend[EagerAllowed], + separator: str = ",", + **kwargs: Any, ) -> DataFrame[Any]: """Read a CSV file into a DataFrame. @@ -620,6 +648,7 @@ def read_csv( `POLARS`, `MODIN` or `CUDF`. - As a string: `"pandas"`, `"pyarrow"`, `"polars"`, `"modin"` or `"cudf"`. - Directly as a module `pandas`, `pyarrow`, `polars`, `modin` or `cudf`. + separator: Single byte character to use as separator in the file. kwargs: Extra keyword arguments which are passed to the native CSV reader. For example, you could use `nw.read_csv('file.csv', backend='pandas', engine='pyarrow')`. @@ -638,14 +667,17 @@ def read_csv( impl = Implementation.from_backend(backend) native_namespace = impl.to_native_namespace() native_frame: NativeDataFrame - if impl in { - Implementation.POLARS, - Implementation.PANDAS, - Implementation.MODIN, - Implementation.CUDF, - }: - native_frame = native_namespace.read_csv(normalize_path(source), **kwargs) + if impl in {Implementation.PANDAS, Implementation.MODIN, Implementation.CUDF}: + _validate_separator(separator, "sep", **kwargs) + native_frame = native_namespace.read_csv( + normalize_path(source), sep=separator, **kwargs + ) + elif impl is Implementation.POLARS: + native_frame = native_namespace.read_csv( + normalize_path(source), separator=separator, **kwargs + ) elif impl is Implementation.PYARROW: + kwargs = _validate_separator_pyarrow(separator, **kwargs) from pyarrow import csv # ignore-banned-import native_frame = csv.read_csv(source, **kwargs) @@ -674,7 +706,11 @@ def read_csv( def scan_csv( - source: FileSource, *, backend: IntoBackend[Backend], **kwargs: Any + source: FileSource, + *, + backend: IntoBackend[Backend], + separator: str = ",", + **kwargs: Any, ) -> LazyFrame[Any]: """Lazily read from a CSV file. @@ -690,6 +726,7 @@ def scan_csv( `POLARS`, `MODIN` or `CUDF`. - As a string: `"pandas"`, `"pyarrow"`, `"polars"`, `"modin"` or `"cudf"`. - Directly as a module `pandas`, `pyarrow`, `polars`, `modin` or `cudf`. + separator: Single byte character to use as separator in the file. kwargs: Extra keyword arguments which are passed to the native CSV reader. For example, you could use `nw.scan_csv('file.csv', backend=pd, engine='pyarrow')`. @@ -713,32 +750,39 @@ def scan_csv( native_frame: NativeDataFrame | NativeLazyFrame source = normalize_path(source) if implementation is Implementation.POLARS: - native_frame = native_namespace.scan_csv(source, **kwargs) + native_frame = native_namespace.scan_csv(source, separator=separator, **kwargs) elif implementation in { Implementation.PANDAS, Implementation.MODIN, Implementation.CUDF, Implementation.DASK, - Implementation.DUCKDB, Implementation.IBIS, }: - native_frame = native_namespace.read_csv(source, **kwargs) + _validate_separator(separator, "sep", **kwargs) + native_frame = native_namespace.read_csv(source, sep=separator, **kwargs) + elif implementation is Implementation.DUCKDB: + _validate_separator(separator, "delimiter", **kwargs) + _validate_separator(separator, "delim", **kwargs) + native_frame = native_namespace.read_csv(source, delimiter=separator, **kwargs) elif implementation is Implementation.PYARROW: + kwargs = _validate_separator_pyarrow(separator, **kwargs) from pyarrow import csv # ignore-banned-import native_frame = csv.read_csv(source, **kwargs) elif implementation.is_spark_like(): + _validate_separator(separator, "sep", **kwargs) + _validate_separator(separator, "delimiter", **kwargs) if (session := kwargs.pop("session", None)) is None: msg = "Spark like backends require a session object to be passed in `kwargs`." raise ValueError(msg) csv_reader = session.read.format("csv") native_frame = ( - csv_reader.load(source) + csv_reader.load(source, sep=separator) if ( implementation is Implementation.SQLFRAME and implementation._backend_version() < (3, 27, 0) ) - else csv_reader.options(**kwargs).load(source) + else csv_reader.options(sep=separator, **kwargs).load(source) ) else: # pragma: no cover try: diff --git a/narwhals/stable/v2/__init__.py b/narwhals/stable/v2/__init__.py index d85b129b8a..78190b4ce0 100644 --- a/narwhals/stable/v2/__init__.py +++ b/narwhals/stable/v2/__init__.py @@ -1093,7 +1093,11 @@ def from_numpy( def read_csv( - source: str, *, backend: IntoBackend[EagerAllowed], **kwargs: Any + source: str, + *, + backend: IntoBackend[EagerAllowed], + separator: str = ",", + **kwargs: Any, ) -> DataFrame[Any]: """Read a CSV file into a DataFrame. @@ -1106,15 +1110,18 @@ def read_csv( `POLARS`, `MODIN` or `CUDF`. - As a string: `"pandas"`, `"pyarrow"`, `"polars"`, `"modin"` or `"cudf"`. - Directly as a module `pandas`, `pyarrow`, `polars`, `modin` or `cudf`. + separator: Single byte character to use as separator in the file. kwargs: Extra keyword arguments which are passed to the native CSV reader. For example, you could use `nw.read_csv('file.csv', backend='pandas', engine='pyarrow')`. """ - return _stableify(nw_f.read_csv(source, backend=backend, **kwargs)) + return _stableify( + nw_f.read_csv(source, backend=backend, separator=separator, **kwargs) + ) def scan_csv( - source: str, *, backend: IntoBackend[Backend], **kwargs: Any + source: str, *, backend: IntoBackend[Backend], separator: str = ",", **kwargs: Any ) -> LazyFrame[Any]: """Lazily read from a CSV file. @@ -1130,11 +1137,14 @@ def scan_csv( `POLARS`, `MODIN` or `CUDF`. - As a string: `"pandas"`, `"pyarrow"`, `"polars"`, `"modin"` or `"cudf"`. - Directly as a module `pandas`, `pyarrow`, `polars`, `modin` or `cudf`. + separator: Single byte character to use as separator in the file. kwargs: Extra keyword arguments which are passed to the native CSV reader. For example, you could use `nw.scan_csv('file.csv', backend=pd, engine='pyarrow')`. """ - return _stableify(nw_f.scan_csv(source, backend=backend, **kwargs)) + return _stableify( + nw_f.scan_csv(source, backend=backend, separator=separator, **kwargs) + ) def read_parquet( diff --git a/tests/read_scan_test.py b/tests/read_scan_test.py index 7b8cfeaa23..a7c97fa0ad 100644 --- a/tests/read_scan_test.py +++ b/tests/read_scan_test.py @@ -65,6 +65,15 @@ def csv_path( return _into_file_source(fp, request.param) +@pytest.fixture(scope="module", params=["str", "Path", "PathLike"]) +def csv_path_sep( + tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest +) -> FileSource: + fp = tmp_path_factory.mktemp("data") / "file.csv" + pl.DataFrame(data).write_csv(fp, separator="|") + return _into_file_source(fp, request.param) + + @pytest.fixture(scope="module", params=["str", "Path", "PathLike"]) def parquet_path( tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest @@ -88,13 +97,24 @@ def native_namespace(cb: Constructor, /) -> ModuleType: return nw.get_native_namespace(nw.from_native(cb(data))) # type: ignore[no-any-return] -def test_read_csv(csv_path: FileSource, eager_backend: EagerAllowed) -> None: +def test_read_csv( + csv_path: FileSource, csv_path_sep: FileSource, eager_backend: EagerAllowed +) -> None: assert_equal_eager(nw.read_csv(csv_path, backend=eager_backend)) + assert_equal_eager(nw.read_csv(csv_path_sep, backend=eager_backend, separator="|")) @skipif_pandas_lt_1_5 def test_read_csv_kwargs(csv_path: FileSource) -> None: + pytest.importorskip("pyarrow") + from pyarrow import csv + assert_equal_eager(nw.read_csv(csv_path, backend=pd, engine="pyarrow")) + assert_equal_eager( + nw.read_csv( + csv_path, backend="pyarrow", parse_options=csv.ParseOptions(delimiter=",") + ) + ) @lazy_core_backend @@ -104,7 +124,9 @@ def test_read_csv_raise_with_lazy(backend: _LazyOnly) -> None: nw.read_csv("unused.csv", backend=backend) # type: ignore[arg-type] -def test_scan_csv(csv_path: FileSource, constructor: Constructor) -> None: +def test_scan_csv( + csv_path: FileSource, csv_path_sep: FileSource, constructor: Constructor +) -> None: kwargs: dict[str, Any] if "sqlframe" in str(constructor): kwargs = {"session": sqlframe_session(), "inferSchema": True, "header": True} @@ -114,6 +136,7 @@ def test_scan_csv(csv_path: FileSource, constructor: Constructor) -> None: kwargs = {} backend = native_namespace(constructor) assert_equal_lazy(nw.scan_csv(csv_path, backend=backend, **kwargs)) + assert_equal_lazy(nw.scan_csv(csv_path_sep, backend=backend, separator="|", **kwargs)) @skipif_pandas_lt_1_5 @@ -165,3 +188,58 @@ def test_scan_fail_spark_like_without_session( pattern = re.compile(r"spark.+backend.+require.+session", re.IGNORECASE) with pytest.raises(ValueError, match=pattern): getattr(nw, scan_method)("unused.csv", backend=backend) + + +def test_read_csv_raise_sep_multiple_lazy(csv_path: FileSource) -> None: + pytest.importorskip("duckdb") + pytest.importorskip("pandas") + pytest.importorskip("pyarrow") + pytest.importorskip("sqlframe") + import duckdb + import pandas as pd + import pyarrow as pa + import sqlframe + from pyarrow import csv + from sqlframe.duckdb import DuckDBSession + + msg = "do not match:" + with pytest.raises(TypeError, match=msg): + nw.read_csv( + csv_path, + backend=pa, + separator="|", + parse_options=csv.ParseOptions(delimiter=";"), + ) + with pytest.raises(TypeError, match=msg): + nw.scan_csv( + csv_path, + backend=pa, + separator="|", + parse_options=csv.ParseOptions(delimiter=";"), + ) + with pytest.raises(TypeError, match=msg): + nw.read_csv(csv_path, backend=pd, separator="|", sep=";") + with pytest.raises(TypeError, match=msg): + nw.scan_csv(csv_path, backend=pd, separator="|", sep=";") + with pytest.raises(TypeError, match=msg): + nw.scan_csv(csv_path, backend=duckdb, separator="|", delimiter=";") + with pytest.raises(TypeError, match=msg): + nw.scan_csv(csv_path, backend=duckdb, separator="|", delim=";") + with pytest.raises(TypeError, match=msg): + nw.scan_csv( + csv_path, + backend=sqlframe, + separator="|", + sep=";", + session=DuckDBSession(), + inferSchema=True, + ) + with pytest.raises(TypeError, match=msg): + nw.scan_csv( + csv_path, + backend=sqlframe, + separator="|", + delimiter=";", + session=DuckDBSession(), + inferSchema=True, + )