Skip to content

Commit 292da03

Browse files
authored
TST: share transform tests (#38304)
1 parent 007128c commit 292da03

File tree

3 files changed

+77
-186
lines changed

3 files changed

+77
-186
lines changed

pandas/tests/frame/apply/test_frame_transform.py

Lines changed: 71 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,41 @@
1010
from pandas.core.groupby.base import transformation_kernels
1111
from pandas.tests.frame.common import zip_frames
1212

13-
14-
def test_transform_ufunc(axis, float_frame):
13+
# tshift only works on time index and is deprecated
14+
# There is no DataFrame.cumcount
15+
frame_kernels = [
16+
x for x in sorted(transformation_kernels) if x not in ["tshift", "cumcount"]
17+
]
18+
19+
20+
def unpack_obj(obj, klass, axis):
21+
"""
22+
Helper to ensure we have the right type of object for a test parametrized
23+
over frame_or_series.
24+
"""
25+
if klass is not DataFrame:
26+
obj = obj["A"]
27+
if axis != 0:
28+
pytest.skip(f"Test is only for DataFrame with axis={axis}")
29+
return obj
30+
31+
32+
def test_transform_ufunc(axis, float_frame, frame_or_series):
1533
# GH 35964
34+
obj = unpack_obj(float_frame, frame_or_series, axis)
35+
1636
with np.errstate(all="ignore"):
17-
f_sqrt = np.sqrt(float_frame)
18-
result = float_frame.transform(np.sqrt, axis=axis)
37+
f_sqrt = np.sqrt(obj)
38+
39+
# ufunc
40+
result = obj.transform(np.sqrt, axis=axis)
1941
expected = f_sqrt
20-
tm.assert_frame_equal(result, expected)
42+
tm.assert_equal(result, expected)
2143

2244

23-
@pytest.mark.parametrize("op", sorted(transformation_kernels))
45+
@pytest.mark.parametrize("op", frame_kernels)
2446
def test_transform_groupby_kernel(axis, float_frame, op):
2547
# GH 35964
26-
if op == "cumcount":
27-
pytest.xfail("DataFrame.cumcount does not exist")
28-
if op == "tshift":
29-
pytest.xfail("Only works on time index and is deprecated")
3048

3149
args = [0.0] if op == "fillna" else []
3250
if axis == 0 or axis == "index":
@@ -61,9 +79,11 @@ def test_transform_listlike(axis, float_frame, ops, names):
6179

6280

6381
@pytest.mark.parametrize("ops", [[], np.array([])])
64-
def test_transform_empty_listlike(float_frame, ops):
82+
def test_transform_empty_listlike(float_frame, ops, frame_or_series):
83+
obj = unpack_obj(float_frame, frame_or_series, 0)
84+
6585
with pytest.raises(ValueError, match="No transform functions were provided"):
66-
float_frame.transform(ops)
86+
obj.transform(ops)
6787

6888

6989
@pytest.mark.parametrize("box", [dict, Series])
@@ -90,25 +110,29 @@ def test_transform_dictlike(axis, float_frame, box):
90110
{"A": ["cumsum"], "B": []},
91111
],
92112
)
93-
def test_transform_empty_dictlike(float_frame, ops):
113+
def test_transform_empty_dictlike(float_frame, ops, frame_or_series):
114+
obj = unpack_obj(float_frame, frame_or_series, 0)
115+
94116
with pytest.raises(ValueError, match="No transform functions were provided"):
95-
float_frame.transform(ops)
117+
obj.transform(ops)
96118

97119

98120
@pytest.mark.parametrize("use_apply", [True, False])
99-
def test_transform_udf(axis, float_frame, use_apply):
121+
def test_transform_udf(axis, float_frame, use_apply, frame_or_series):
100122
# GH 35964
123+
obj = unpack_obj(float_frame, frame_or_series, axis)
124+
101125
# transform uses UDF either via apply or passing the entire DataFrame
102126
def func(x):
103127
# transform is using apply iff x is not a DataFrame
104-
if use_apply == isinstance(x, DataFrame):
128+
if use_apply == isinstance(x, frame_or_series):
105129
# Force transform to fallback
106130
raise ValueError
107131
return x + 1
108132

109-
result = float_frame.transform(func, axis=axis)
110-
expected = float_frame + 1
111-
tm.assert_frame_equal(result, expected)
133+
result = obj.transform(func, axis=axis)
134+
expected = obj + 1
135+
tm.assert_equal(result, expected)
112136

113137

114138
@pytest.mark.parametrize("method", ["abs", "shift", "pct_change", "cumsum", "rank"])
@@ -142,54 +166,56 @@ def test_agg_dict_nested_renaming_depr():
142166
df.transform({"A": {"foo": "min"}, "B": {"bar": "max"}})
143167

144168

145-
def test_transform_reducer_raises(all_reductions):
169+
def test_transform_reducer_raises(all_reductions, frame_or_series):
146170
# GH 35964
147171
op = all_reductions
148-
df = DataFrame({"A": [1, 2, 3]})
172+
173+
obj = DataFrame({"A": [1, 2, 3]})
174+
if frame_or_series is not DataFrame:
175+
obj = obj["A"]
176+
149177
msg = "Function did not transform"
150178
with pytest.raises(ValueError, match=msg):
151-
df.transform(op)
179+
obj.transform(op)
152180
with pytest.raises(ValueError, match=msg):
153-
df.transform([op])
181+
obj.transform([op])
154182
with pytest.raises(ValueError, match=msg):
155-
df.transform({"A": op})
183+
obj.transform({"A": op})
156184
with pytest.raises(ValueError, match=msg):
157-
df.transform({"A": [op]})
185+
obj.transform({"A": [op]})
186+
187+
188+
wont_fail = ["ffill", "bfill", "fillna", "pad", "backfill", "shift"]
189+
frame_kernels_raise = [x for x in frame_kernels if x not in wont_fail]
158190

159191

160192
# mypy doesn't allow adding lists of different types
161193
# https://github.com/python/mypy/issues/5492
162-
@pytest.mark.parametrize("op", [*sorted(transformation_kernels), lambda x: x + 1])
163-
def test_transform_bad_dtype(op):
194+
@pytest.mark.parametrize("op", [*frame_kernels_raise, lambda x: x + 1])
195+
def test_transform_bad_dtype(op, frame_or_series):
164196
# GH 35964
165-
df = DataFrame({"A": 3 * [object]}) # DataFrame that will fail on most transforms
166-
if op in ("backfill", "shift", "pad", "bfill", "ffill"):
167-
pytest.xfail("Transform function works on any datatype")
197+
obj = DataFrame({"A": 3 * [object]}) # DataFrame that will fail on most transforms
198+
if frame_or_series is not DataFrame:
199+
obj = obj["A"]
200+
168201
msg = "Transform function failed"
169202

170203
# tshift is deprecated
171204
warn = None if op != "tshift" else FutureWarning
172205
with tm.assert_produces_warning(warn, check_stacklevel=False):
173206
with pytest.raises(ValueError, match=msg):
174-
df.transform(op)
207+
obj.transform(op)
175208
with pytest.raises(ValueError, match=msg):
176-
df.transform([op])
209+
obj.transform([op])
177210
with pytest.raises(ValueError, match=msg):
178-
df.transform({"A": op})
211+
obj.transform({"A": op})
179212
with pytest.raises(ValueError, match=msg):
180-
df.transform({"A": [op]})
213+
obj.transform({"A": [op]})
181214

182215

183-
@pytest.mark.parametrize("op", sorted(transformation_kernels))
216+
@pytest.mark.parametrize("op", frame_kernels_raise)
184217
def test_transform_partial_failure(op):
185218
# GH 35964
186-
wont_fail = ["ffill", "bfill", "fillna", "pad", "backfill", "shift"]
187-
if op in wont_fail:
188-
pytest.xfail("Transform kernel is successful on all dtypes")
189-
if op == "cumcount":
190-
pytest.xfail("transform('cumcount') not implemented")
191-
if op == "tshift":
192-
pytest.xfail("Only works on time index; deprecated")
193219

194220
# Using object makes most transform kernels fail
195221
df = DataFrame({"A": 3 * [object], "B": [1, 2, 3]})
@@ -208,22 +234,22 @@ def test_transform_partial_failure(op):
208234

209235

210236
@pytest.mark.parametrize("use_apply", [True, False])
211-
def test_transform_passes_args(use_apply):
237+
def test_transform_passes_args(use_apply, frame_or_series):
212238
# GH 35964
213239
# transform uses UDF either via apply or passing the entire DataFrame
214240
expected_args = [1, 2]
215241
expected_kwargs = {"c": 3}
216242

217243
def f(x, a, b, c):
218244
# transform is using apply iff x is not a DataFrame
219-
if use_apply == isinstance(x, DataFrame):
245+
if use_apply == isinstance(x, frame_or_series):
220246
# Force transform to fallback
221247
raise ValueError
222248
assert [a, b] == expected_args
223249
assert c == expected_kwargs["c"]
224250
return x
225251

226-
DataFrame([1]).transform(f, 0, *expected_args, **expected_kwargs)
252+
frame_or_series([1]).transform(f, 0, *expected_args, **expected_kwargs)
227253

228254

229255
def test_transform_missing_columns(axis):

pandas/tests/series/apply/test_series_transform.py

Lines changed: 6 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,16 @@
66
from pandas.core.base import SpecificationError
77
from pandas.core.groupby.base import transformation_kernels
88

9-
10-
def test_transform_ufunc(string_series):
11-
# GH 35964
12-
with np.errstate(all="ignore"):
13-
f_sqrt = np.sqrt(string_series)
14-
15-
# ufunc
16-
result = string_series.transform(np.sqrt)
17-
expected = f_sqrt.copy()
18-
tm.assert_series_equal(result, expected)
9+
# tshift only works on time index and is deprecated
10+
# There is no Series.cumcount
11+
series_kernels = [
12+
x for x in sorted(transformation_kernels) if x not in ["tshift", "cumcount"]
13+
]
1914

2015

21-
@pytest.mark.parametrize("op", sorted(transformation_kernels))
16+
@pytest.mark.parametrize("op", series_kernels)
2217
def test_transform_groupby_kernel(string_series, op):
2318
# GH 35964
24-
if op == "cumcount":
25-
pytest.xfail("Series.cumcount does not exist")
26-
if op == "tshift":
27-
pytest.xfail("Only works on time index and is deprecated")
2819

2920
args = [0.0] if op == "fillna" else []
3021
ones = np.ones(string_series.shape[0])
@@ -51,12 +42,6 @@ def test_transform_listlike(string_series, ops, names):
5142
tm.assert_frame_equal(result, expected)
5243

5344

54-
@pytest.mark.parametrize("ops", [[], np.array([])])
55-
def test_transform_empty_listlike(string_series, ops):
56-
with pytest.raises(ValueError, match="No transform functions were provided"):
57-
string_series.transform(ops)
58-
59-
6045
@pytest.mark.parametrize("box", [dict, Series])
6146
def test_transform_dictlike(string_series, box):
6247
# GH 35964
@@ -67,45 +52,6 @@ def test_transform_dictlike(string_series, box):
6752
tm.assert_frame_equal(result, expected)
6853

6954

70-
@pytest.mark.parametrize(
71-
"ops",
72-
[
73-
{},
74-
{"A": []},
75-
{"A": [], "B": ["cumsum"]},
76-
{"A": ["cumsum"], "B": []},
77-
{"A": [], "B": "cumsum"},
78-
{"A": "cumsum", "B": []},
79-
],
80-
)
81-
def test_transform_empty_dictlike(string_series, ops):
82-
with pytest.raises(ValueError, match="No transform functions were provided"):
83-
string_series.transform(ops)
84-
85-
86-
def test_transform_udf(axis, string_series):
87-
# GH 35964
88-
# via apply
89-
def func(x):
90-
if isinstance(x, Series):
91-
raise ValueError
92-
return x + 1
93-
94-
result = string_series.transform(func)
95-
expected = string_series + 1
96-
tm.assert_series_equal(result, expected)
97-
98-
# via map Series -> Series
99-
def func(x):
100-
if not isinstance(x, Series):
101-
raise ValueError
102-
return x + 1
103-
104-
result = string_series.transform(func)
105-
expected = string_series + 1
106-
tm.assert_series_equal(result, expected)
107-
108-
10955
def test_transform_wont_agg(string_series):
11056
# GH 35964
11157
# we are trying to transform with an aggregator
@@ -127,64 +73,6 @@ def test_transform_none_to_type():
12773
df.transform({"a": int})
12874

12975

130-
def test_transform_reducer_raises(all_reductions):
131-
# GH 35964
132-
op = all_reductions
133-
s = Series([1, 2, 3])
134-
msg = "Function did not transform"
135-
with pytest.raises(ValueError, match=msg):
136-
s.transform(op)
137-
with pytest.raises(ValueError, match=msg):
138-
s.transform([op])
139-
with pytest.raises(ValueError, match=msg):
140-
s.transform({"A": op})
141-
with pytest.raises(ValueError, match=msg):
142-
s.transform({"A": [op]})
143-
144-
145-
# mypy doesn't allow adding lists of different types
146-
# https://github.com/python/mypy/issues/5492
147-
@pytest.mark.parametrize("op", [*sorted(transformation_kernels), lambda x: x + 1])
148-
def test_transform_bad_dtype(op):
149-
# GH 35964
150-
s = Series(3 * [object]) # Series that will fail on most transforms
151-
if op in ("backfill", "shift", "pad", "bfill", "ffill"):
152-
pytest.xfail("Transform function works on any datatype")
153-
154-
msg = "Transform function failed"
155-
156-
# tshift is deprecated
157-
warn = None if op != "tshift" else FutureWarning
158-
with tm.assert_produces_warning(warn, check_stacklevel=False):
159-
with pytest.raises(ValueError, match=msg):
160-
s.transform(op)
161-
with pytest.raises(ValueError, match=msg):
162-
s.transform([op])
163-
with pytest.raises(ValueError, match=msg):
164-
s.transform({"A": op})
165-
with pytest.raises(ValueError, match=msg):
166-
s.transform({"A": [op]})
167-
168-
169-
@pytest.mark.parametrize("use_apply", [True, False])
170-
def test_transform_passes_args(use_apply):
171-
# GH 35964
172-
# transform uses UDF either via apply or passing the entire Series
173-
expected_args = [1, 2]
174-
expected_kwargs = {"c": 3}
175-
176-
def f(x, a, b, c):
177-
# transform is using apply iff x is not a Series
178-
if use_apply == isinstance(x, Series):
179-
# Force transform to fallback
180-
raise ValueError
181-
assert [a, b] == expected_args
182-
assert c == expected_kwargs["c"]
183-
return x
184-
185-
Series([1]).transform(f, 0, *expected_args, **expected_kwargs)
186-
187-
18876
def test_transform_axis_1_raises():
18977
# GH 35964
19078
msg = "No axis named 1 for object type Series"

pandas/tests/series/conftest.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

0 commit comments

Comments
 (0)