|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import re |
| 4 | +from typing import TYPE_CHECKING, Any, cast |
| 5 | + |
| 6 | +import pytest |
| 7 | + |
| 8 | +pytest.importorskip("numpy") |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +import narwhals as nw |
| 12 | +from tests.utils import assert_equal_data |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from collections.abc import Sequence |
| 16 | + |
| 17 | + from narwhals._namespace import EagerAllowed |
| 18 | + from narwhals.dtypes import NestedType |
| 19 | + from narwhals.typing import IntoDType, _1DArray |
| 20 | + |
| 21 | + |
| 22 | +arr: _1DArray = cast("_1DArray", np.array([5, 2, 0, 1])) |
| 23 | +NAME = "a" |
| 24 | + |
| 25 | + |
| 26 | +def assert_equal_series( |
| 27 | + result: nw.Series[Any], expected: Sequence[Any], name: str |
| 28 | +) -> None: |
| 29 | + assert_equal_data(result.to_frame(), {name: expected}) |
| 30 | + |
| 31 | + |
| 32 | +def test_series_from_numpy(eager_backend: EagerAllowed) -> None: |
| 33 | + expected = [5, 2, 0, 1] |
| 34 | + result = nw.Series.from_numpy(NAME, arr, backend=eager_backend) |
| 35 | + assert isinstance(result, nw.Series) |
| 36 | + assert_equal_series(result, expected, NAME) |
| 37 | + |
| 38 | + |
| 39 | +@pytest.mark.parametrize( |
| 40 | + ("dtype", "expected"), |
| 41 | + [ |
| 42 | + (nw.Int16, [5, 2, 0, 1]), |
| 43 | + (nw.Int32(), [5, 2, 0, 1]), |
| 44 | + (nw.Float64, [5.0, 2.0, 0.0, 1.0]), |
| 45 | + (nw.Float32(), [5.0, 2.0, 0.0, 1.0]), |
| 46 | + ], |
| 47 | + ids=str, |
| 48 | +) |
| 49 | +def test_series_from_numpy_dtype( |
| 50 | + eager_backend: EagerAllowed, dtype: IntoDType, expected: Sequence[Any] |
| 51 | +) -> None: |
| 52 | + result = nw.Series.from_numpy(NAME, arr, backend=eager_backend, dtype=dtype) |
| 53 | + assert result.dtype == dtype |
| 54 | + assert_equal_series(result, expected, NAME) |
| 55 | + |
| 56 | + |
| 57 | +@pytest.mark.parametrize( |
| 58 | + ("bad_dtype", "message"), |
| 59 | + [ |
| 60 | + (nw.List, r"nw.List.+not.+valid.+hint"), |
| 61 | + (nw.Struct, r"nw.Struct.+not.+valid.+hint"), |
| 62 | + (nw.Array, r"nw.Array.+not.+valid.+hint"), |
| 63 | + (np.floating, r"expected.+narwhals.+dtype.+floating"), |
| 64 | + (list[int], r"expected.+narwhals.+dtype.+(types.GenericAlias|list)"), |
| 65 | + ], |
| 66 | + ids=str, |
| 67 | +) |
| 68 | +def test_series_from_numpy_not_init_dtype( |
| 69 | + eager_backend: EagerAllowed, bad_dtype: type[NestedType] | object, message: str |
| 70 | +) -> None: |
| 71 | + with pytest.raises(TypeError, match=re.compile(message, re.IGNORECASE | re.DOTALL)): |
| 72 | + nw.Series.from_numpy(NAME, arr, bad_dtype, backend=eager_backend) # type: ignore[arg-type] |
| 73 | + |
| 74 | + |
| 75 | +def test_series_from_numpy_not_eager() -> None: |
| 76 | + pytest.importorskip("ibis") |
| 77 | + with pytest.raises(ValueError, match="lazy-only"): |
| 78 | + nw.Series.from_numpy(NAME, arr, backend="ibis") |
| 79 | + |
| 80 | + |
| 81 | +def test_series_from_numpy_not_1d(eager_backend: EagerAllowed) -> None: |
| 82 | + with pytest.raises(ValueError, match="`from_numpy` only accepts 1D numpy arrays"): |
| 83 | + nw.Series.from_numpy(NAME, np.array([[0], [2]]), backend=eager_backend) # pyright: ignore[reportArgumentType] |
0 commit comments