Skip to content

Commit 89a2c38

Browse files
committed
Add support and test for pd.Series
1 parent 99f1d4e commit 89a2c38

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Any, List, Tuple, TypeVar, Union
2+
3+
import pandas as pd
4+
from pydantic import GetCoreSchemaHandler
5+
from pydantic_core import core_schema
6+
7+
T = TypeVar('T', str, bytes, bool, int, float, complex, pd.Timestamp, pd.Timedelta, pd.Period)
8+
9+
10+
class Series:
11+
def __init__(self, value: Any) -> None:
12+
self.value = pd.Series(value)
13+
14+
@classmethod
15+
def __get_pydantic_core_schema__(
16+
cls, source: type[Any], handler: GetCoreSchemaHandler
17+
) -> core_schema.AfterValidatorFunctionSchema:
18+
return core_schema.general_after_validator_function(
19+
cls._validate,
20+
core_schema.any_schema(),
21+
)
22+
23+
@classmethod
24+
def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> 'Series':
25+
if isinstance(__input_value, pd.Series):
26+
return cls(__input_value)
27+
return cls(pd.Series(__input_value))
28+
29+
def __repr__(self) -> str:
30+
return repr(self.value)
31+
32+
def __getattr__(self, name: str) -> Any:
33+
return getattr(self.value, name)
34+
35+
def __eq__(self, __value: object) -> bool:
36+
return isinstance(__value, pd.Series) or isinstance(__value, Series)
37+
38+
def __add__(self, other: Union['Series', List[Any], Tuple[Any], T]) -> 'Series':
39+
if isinstance(other, Series):
40+
result_val = self.value + other.value
41+
else:
42+
result_val = self.value + other
43+
return Series(result_val)

requirements/linting.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,5 @@ virtualenv==20.23.0
5050
# via pre-commit
5151

5252
# The following packages are considered to be unsafe in a requirements file:
53+
pandas-stubs==2.0.2.230605
5354
# setuptools

tests/test_pandas_types.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import pandas as pd
2+
import pytest
3+
4+
from pydantic_extra_types.pandas_types import Series
5+
6+
7+
@pytest.mark.parametrize(
8+
'data, expected',
9+
[
10+
([1, 2, 3], [1, 2, 3]),
11+
([], []),
12+
([10, 20, 30, 40], [10, 20, 30, 40]),
13+
],
14+
)
15+
def test_series_creation(data, expected):
16+
s = Series(data)
17+
assert isinstance(s, Series)
18+
assert isinstance(s.value, pd.Series)
19+
assert s.value.tolist() == expected
20+
21+
22+
def test_series_repr():
23+
data = [1, 2, 3]
24+
s = Series(data)
25+
assert repr(s) == repr(pd.Series(data))
26+
27+
28+
def test_series_attribute_access():
29+
data = [1, 2, 3]
30+
s = Series(data)
31+
assert s.sum() == pd.Series(data).sum()
32+
33+
34+
def test_series_equality():
35+
data = [1, 2, 3]
36+
s1 = Series(data)
37+
s2 = Series(data)
38+
assert s1 == s2
39+
assert s2 == pd.Series(data)
40+
41+
42+
def test_series_addition():
43+
data1 = [1, 2, 3]
44+
data2 = [4, 5, 6]
45+
s1 = Series(data1)
46+
s2 = Series(data2)
47+
s3 = s1 + s2
48+
assert isinstance(s3, Series)
49+
assert s3.value.tolist() == [5, 7, 9]
50+
51+
52+
@pytest.mark.parametrize(
53+
'data, other, expected',
54+
[
55+
([1, 2, 3], [4, 5, 6], [5, 7, 9]),
56+
([10, 20, 30], (1, 2, 3), [11, 22, 33]),
57+
([5, 10, 15], pd.Series([1, 2, 3]), [6, 12, 18]),
58+
],
59+
)
60+
def test_series_addition_with_types(data, other, expected):
61+
s = Series(data)
62+
result = s + other
63+
assert isinstance(result, Series)
64+
assert result.value.tolist() == expected
65+
66+
67+
@pytest.mark.parametrize(
68+
'data, other',
69+
[
70+
([1, 2, 3], 'invalid'), # Invalid type for addition
71+
([1, 2, 3], {'a': 1, 'b': 2}), # Invalid type for addition
72+
],
73+
)
74+
def test_series_addition_invalid_type_error(data, other) -> None:
75+
s = Series(data)
76+
with pytest.raises(TypeError):
77+
s + other
78+
79+
80+
@pytest.mark.parametrize(
81+
'data, other',
82+
[
83+
([1, 2, 3], []),
84+
],
85+
)
86+
def test_series_addition_invalid_value_error(data, other) -> None:
87+
s = Series(data)
88+
with pytest.raises(ValueError):
89+
s + other

0 commit comments

Comments
 (0)