Skip to content

Commit 8a6c312

Browse files
authored
fix: spark-like pass_through=True, eager_only=True interaction (#2606)
1 parent a0e6b9e commit 8a6c312

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

narwhals/translate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -524,14 +524,14 @@ def _from_native_impl( # noqa: C901, PLR0911, PLR0912, PLR0915
524524
# PySpark
525525
elif is_native_spark_like(native_object): # pragma: no cover
526526
ns_spark = version.namespace.from_native_object(native_object)
527-
if series_only:
528-
msg = (
529-
f"Cannot only use `series_only` with {ns_spark.implementation} DataFrame"
530-
)
531-
raise TypeError(msg)
532-
if eager_only or eager_or_interchange_only:
533-
msg = f"Cannot only use `eager_only` or `eager_or_interchange_only` with {ns_spark.implementation} DataFrame"
534-
raise TypeError(msg)
527+
if series_only or eager_only or eager_or_interchange_only:
528+
if not pass_through:
529+
msg = (
530+
"Cannot only use `series_only`, `eager_only` or `eager_or_interchange_only` "
531+
f"with {ns_spark.implementation} DataFrame"
532+
)
533+
raise TypeError(msg)
534+
return native_object
535535
return ns_spark.compliant.from_native(native_object).to_narwhals()
536536

537537
# Interchange protocol

tests/translate/from_native_test.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
import narwhals as nw
3333
import narwhals.stable.v1 as nw_v1
34-
from tests.utils import maybe_get_modin_df
34+
from tests.utils import Constructor, maybe_get_modin_df
3535

3636
if TYPE_CHECKING:
3737
from _pytest.mark import ParameterSet
@@ -286,7 +286,13 @@ def test_series_only_sqlframe() -> None: # pragma: no cover
286286
("eager_only", "context"),
287287
[
288288
(False, does_not_raise()),
289-
(True, pytest.raises(TypeError, match="Cannot only use `eager_only`")),
289+
(
290+
True,
291+
pytest.raises(
292+
TypeError,
293+
match="Cannot only use `series_only`, `eager_only` or `eager_or_interchange_only` with sqlframe DataFrame",
294+
),
295+
),
290296
],
291297
)
292298
@pytest.mark.skipif(sys.version_info < (3, 9), reason="too old for sqlframe")
@@ -548,3 +554,28 @@ def test_pyspark_connect_deps_2517() -> None: # pragma: no cover
548554
spark = SparkSession.builder.getOrCreate()
549555
# Check this doesn't raise
550556
nw.from_native(spark.createDataFrame([(1,)], ["a"]))
557+
558+
559+
@pytest.mark.parametrize(
560+
("eager_only", "pass_through", "context"),
561+
[
562+
(False, False, does_not_raise()),
563+
(False, True, does_not_raise()),
564+
(True, True, does_not_raise()),
565+
(True, False, pytest.raises(TypeError, match="Cannot only use")),
566+
],
567+
)
568+
def test_eager_only_pass_through_main(
569+
constructor: Constructor, *, eager_only: bool, pass_through: bool, context: Any
570+
) -> None:
571+
if not any(s in str(constructor) for s in ("pyspark", "dask", "ibis", "duckdb")):
572+
pytest.skip(reason="Non lazy or polars")
573+
574+
df = constructor(data)
575+
576+
with context:
577+
res = nw.from_native(df, eager_only=eager_only, pass_through=pass_through) # type: ignore[call-overload]
578+
if eager_only and pass_through:
579+
assert not isinstance(res, nw.LazyFrame)
580+
else:
581+
assert isinstance(res, nw.LazyFrame)

0 commit comments

Comments
 (0)