Skip to content

Commit 6a0824d

Browse files
update tests
Signed-off-by: Praateek <[email protected]>
1 parent 09a4766 commit 6a0824d

File tree

1 file changed

+26
-37
lines changed

1 file changed

+26
-37
lines changed

pandas/tests/series/test_arithmetic.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -155,42 +155,25 @@ def _check_fill(meth, op, a, b, fill_value=0):
155155
# should accept axis=0 or axis='rows'
156156
op(a, b, axis=0)
157157

158-
def test_extarray_rhs_datetime_sub_with_fill_value(self):
159-
# Ensure ExtensionArray (DatetimeArray) RHS is handled via array-like path
160-
# and does not hit scalar isna branch.
161-
left = Series(
162-
[
163-
pd.Timestamp("2025-08-20"),
164-
pd.Timestamp("2025-08-21"),
165-
pd.Timestamp("2025-08-22"),
166-
],
167-
dtype=np.dtype("datetime64[ns]"),
168-
)
169-
right = left._values # DatetimeArray
170-
171-
result = left.sub(right, fill_value=left.iloc[0])
172-
# result dtype may vary (e.g., seconds vs ns), build expected from result dtype
173-
expected = Series(np.zeros(3, dtype=np.dtype("timedelta64[ns]")))
174-
tm.assert_series_equal(result, expected)
158+
@pytest.mark.parametrize("kind", ["datetime", "timedelta"])
159+
def test_rhs_extension_array_sub_with_fill_value(self, kind):
160+
if kind == "datetime":
161+
left = Series(
162+
[pd.Timestamp("2025-08-20"), pd.Timestamp("2025-08-21")],
163+
dtype=np.dtype("datetime64[ns]"),
164+
)
165+
else:
166+
left = Series(
167+
[Timedelta(days=1), Timedelta(days=2)],
168+
dtype=np.dtype("timedelta64[ns]"),
169+
)
175170

176-
def test_extarray_rhs_timedelta_sub_with_fill_value(self):
177-
left = Series(
178-
[Timedelta(days=1), Timedelta(days=2), Timedelta(days=3)],
179-
dtype=np.dtype("timedelta64[ns]"),
180-
)
181-
right = left._values # TimedeltaArray
171+
right = (
172+
left._values
173+
) # DatetimeArray or TimedeltaArray which is an ExtensionArray
182174

183175
result = left.sub(right, fill_value=left.iloc[0])
184-
expected = Series(np.zeros(3, dtype=np.dtype("timedelta64[ns]")))
185-
tm.assert_series_equal(result, expected)
186-
187-
def test_extarray_rhs_period_eq_with_fill_value(self):
188-
# Use equality to validate ExtensionArray RHS path for PeriodArray
189-
left = Series(pd.period_range("2020Q1", periods=3, freq="Q"))
190-
right = left._values # PeriodArray
191-
192-
result = left.eq(right, fill_value=left.iloc[0])
193-
expected = Series([True, True, True])
176+
expected = Series(np.zeros(len(left), dtype=np.dtype("timedelta64[ns]")))
194177
tm.assert_series_equal(result, expected)
195178

196179

@@ -442,10 +425,16 @@ def test_comparison_flex_alignment(self, values, op):
442425
expected = Series(values, index=list("abcd"))
443426
tm.assert_series_equal(result, expected)
444427

445-
def test_extarray_rhs_categorical_eq_with_fill_value(self):
446-
# Categorical RHS should be treated as array-like, not as scalar
447-
left = Series(Categorical(["a", "b", "a"]))
448-
right = left._values # Categorical
428+
@pytest.mark.parametrize(
429+
"left",
430+
[
431+
Series(Categorical(["a", "b", "a"])),
432+
Series(pd.period_range("2020Q1", periods=3, freq="Q")),
433+
],
434+
ids=["categorical", "period"],
435+
)
436+
def test_rhs_extension_array_eq_with_fill_value(self, left):
437+
right = left._values # this is an ExtensionArray
449438

450439
result = left.eq(right, fill_value=left.iloc[0])
451440
expected = Series([True, True, True])

0 commit comments

Comments
 (0)