Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions tests/expr_and_series/rolling_var_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import random
from typing import TYPE_CHECKING, Any
from typing import Any

import hypothesis.strategies as st
import pytest
Expand All @@ -17,9 +17,6 @@
assert_equal_data,
)

if TYPE_CHECKING:
from narwhals.typing import Frame

pytest.importorskip("pandas")
import pandas as pd

Expand Down Expand Up @@ -125,7 +122,7 @@ def test_rolling_var_hypothesis(center: bool, values: list[float]) -> None: # n
.to_frame("a")
)

result: Frame = nw.from_native(pa.Table.from_pandas(df)).select(
result = nw.from_native(pa.Table.from_pandas(df)).select(
nw.col("a").rolling_var(
window_size, center=center, min_samples=min_samples, ddof=ddof
)
Expand Down
12 changes: 4 additions & 8 deletions tests/frame/invalid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
from narwhals.exceptions import MultiOutputExpressionError
from tests.utils import NUMPY_VERSION, POLARS_VERSION, Constructor

if TYPE_CHECKING:
from narwhals.typing import Frame


T = TypeVar("T")


Expand All @@ -21,14 +17,14 @@
)
def test_all_vs_all(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6]}
df: Frame = nw.from_native(constructor(data))
df = nw.from_native(constructor(data))
with pytest.raises(MultiOutputExpressionError):
df.lazy().select(nw.all() + nw.col("b", "a")).collect()


def test_invalid() -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
df: Frame = nw.from_native(pd.DataFrame(data))
df = nw.from_native(pd.DataFrame(data))
with pytest.raises(ValueError, match="Multi-output"):
df.select(nw.all() + nw.all())

Expand All @@ -37,7 +33,7 @@ def test_invalid_pyarrow() -> None:
pytest.importorskip("pyarrow")
import pyarrow as pa

df: Frame = nw.from_native(pa.table({"a": [1, 2], "b": [3, 4]}))
df = nw.from_native(pa.table({"a": [1, 2], "b": [3, 4]}))
with pytest.raises(MultiOutputExpressionError):
df.select(nw.all() + nw.all())

Expand All @@ -47,7 +43,7 @@ def test_invalid_polars() -> None:
import polars as pl

data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
df: Frame = nw.from_native(pd.DataFrame(data))
df = nw.from_native(pd.DataFrame(data))
with pytest.raises(TypeError, match="Perhaps you"):
df.select([pl.col("a")]) # type: ignore[list-item]
with pytest.raises(TypeError, match="Expected Narwhals dtype"):
Expand Down
2 changes: 0 additions & 2 deletions tests/series_only/is_ordered_categorical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tests.utils import POLARS_VERSION

if TYPE_CHECKING:
from narwhals.typing import IntoSeries
from tests.utils import ConstructorEager


Expand All @@ -33,7 +32,6 @@ def test_is_ordered_categorical_polars() -> None:
pytest.importorskip("polars")
import polars as pl

s: IntoSeries | Any
s = pl.Series(["a", "b"], dtype=pl.Categorical)
if POLARS_VERSION < (1, 32): # pragma: no cover
assert nw.is_ordered_categorical(nw.from_native(s, series_only=True))
Expand Down
9 changes: 4 additions & 5 deletions tests/translate/get_native_namespace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import narwhals as nw

if TYPE_CHECKING:
from narwhals.typing import Frame
from tests.utils import Constructor
from tests.utils import Constructor, ConstructorEager


data = {"a": [1, 2, 3]}
Expand Down Expand Up @@ -62,18 +61,18 @@ def test_native_namespace_frame(constructor: Constructor) -> None:

expected_namespace = _get_expected_namespace(constructor_name=constructor_name)

df: Frame = nw.from_native(constructor(data))
df = nw.from_native(constructor(data))
assert nw.get_native_namespace(df) is expected_namespace
assert nw.get_native_namespace(df.to_native()) is expected_namespace
assert nw.get_native_namespace(df.lazy().to_native()) is expected_namespace


def test_native_namespace_series(constructor_eager: Constructor) -> None:
def test_native_namespace_series(constructor_eager: ConstructorEager) -> None:
constructor_name = constructor_eager.__name__

expected_namespace = _get_expected_namespace(constructor_name=constructor_name)

df: Frame = nw.from_native(constructor_eager(data), eager_only=True)
df = nw.from_native(constructor_eager(data), eager_only=True)

assert nw.get_native_namespace(df["a"].to_native()) is expected_namespace
assert nw.get_native_namespace(df, df["a"].to_native()) is expected_namespace
Expand Down
6 changes: 2 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,9 @@ def uses_pyarrow_backend(constructor: Constructor | ConstructorEager) -> bool:


def maybe_collect(df: Frame) -> Frame:
"""Collect the DataFrame if it is a LazyFrame.
"""Collect to DataFrame if it is a LazyFrame.

Use this function to test specific behaviors during collection.
For example, Polars only errors when we call `collect` in the lazy case.
"""
if isinstance(df, nw.LazyFrame):
return df.collect()
return df # pragma: no cover
return df.collect() if isinstance(df, nw.LazyFrame) else df
5 changes: 2 additions & 3 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from narwhals._utils import _SupportsVersion
from narwhals.series import Series
from narwhals.typing import IntoSeries


@dataclass
Expand Down Expand Up @@ -126,7 +125,7 @@ def test_maybe_set_index_polars_column_names(
],
)
def test_maybe_set_index_pandas_direct_index(
narwhals_index: Series[IntoSeries] | list[Series[IntoSeries]],
narwhals_index: Series[pd.Series[Any]] | list[Series[pd.Series[Any]]],
pandas_index: pd.Series[Any] | list[pd.Series[Any]],
native_df_or_series: pd.DataFrame | pd.Series[Any],
) -> None:
Expand All @@ -151,7 +150,7 @@ def test_maybe_set_index_pandas_direct_index(
],
)
def test_maybe_set_index_polars_direct_index(
index: Series[IntoSeries] | list[Series[IntoSeries]] | None,
index: Series[pd.Series[Any]] | list[Series[pd.Series[Any]]] | None,
) -> None:
pytest.importorskip("polars")
import polars as pl
Expand Down
Loading