Skip to content

Commit 6cbc3e8

Browse files
parametrize pyarrow tests
1 parent d8149e6 commit 6cbc3e8

File tree

1 file changed

+17
-100
lines changed

1 file changed

+17
-100
lines changed

pandas/tests/strings/test_get_dummies.py

Lines changed: 17 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pandas.util._test_decorators as td
55

66
from pandas import (
7-
ArrowDtype,
87
DataFrame,
98
Index,
109
MultiIndex,
@@ -69,108 +68,26 @@ def test_get_dummies_with_dtype(any_string_dtype, dtype):
6968

7069

7170
@td.skip_if_no("pyarrow")
72-
def test_get_dummies_with_pyarrow_dtype_int8(any_string_dtype):
73-
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
74-
result = s.str.get_dummies("|", dtype=ArrowDtype(pa.int8()))
75-
expected = DataFrame(
76-
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
77-
columns=list("abc"),
78-
dtype=ArrowDtype(pa.int8()),
79-
)
80-
tm.assert_frame_equal(result, expected)
81-
82-
83-
@td.skip_if_no("pyarrow")
84-
def test_get_dummies_with_pyarrow_dtype_uint8(any_string_dtype):
85-
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
86-
result = s.str.get_dummies("|", dtype=ArrowDtype(pa.uint8()))
87-
expected = DataFrame(
88-
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
89-
columns=list("abc"),
90-
dtype=ArrowDtype(pa.uint8()),
91-
)
92-
tm.assert_frame_equal(result, expected)
93-
94-
95-
@td.skip_if_no("pyarrow")
96-
def test_get_dummies_with_pyarrow_dtype_int16(any_string_dtype):
97-
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
98-
result = s.str.get_dummies("|", dtype=ArrowDtype(pa.int16()))
99-
expected = DataFrame(
100-
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
101-
columns=list("abc"),
102-
dtype=ArrowDtype(pa.int16()),
103-
)
104-
tm.assert_frame_equal(result, expected)
105-
106-
107-
@td.skip_if_no("pyarrow")
108-
def test_get_dummies_with_pyarrow_dtype_uint16(any_string_dtype):
109-
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
110-
result = s.str.get_dummies("|", dtype=ArrowDtype(pa.uint16()))
111-
expected = DataFrame(
112-
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
113-
columns=list("abc"),
114-
dtype=ArrowDtype(pa.uint16()),
115-
)
116-
tm.assert_frame_equal(result, expected)
117-
118-
119-
@td.skip_if_no("pyarrow")
120-
def test_get_dummies_with_pyarrow_dtype_int32(any_string_dtype):
121-
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
122-
result = s.str.get_dummies("|", dtype=ArrowDtype(pa.int32()))
123-
expected = DataFrame(
124-
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
125-
columns=list("abc"),
126-
dtype=ArrowDtype(pa.int32()),
127-
)
128-
tm.assert_frame_equal(result, expected)
129-
130-
131-
@td.skip_if_no("pyarrow")
132-
def test_get_dummies_with_pyarrow_dtype_uint32(any_string_dtype):
133-
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
134-
result = s.str.get_dummies("|", dtype=ArrowDtype(pa.uint32()))
135-
expected = DataFrame(
136-
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
137-
columns=list("abc"),
138-
dtype=ArrowDtype(pa.uint32()),
139-
)
140-
tm.assert_frame_equal(result, expected)
141-
142-
143-
@td.skip_if_no("pyarrow")
144-
def test_get_dummies_with_pyarrow_dtype_int64(any_string_dtype):
145-
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
146-
result = s.str.get_dummies("|", dtype=ArrowDtype(pa.int64()))
147-
expected = DataFrame(
148-
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
149-
columns=list("abc"),
150-
dtype=ArrowDtype(pa.int64()),
151-
)
152-
tm.assert_frame_equal(result, expected)
153-
154-
155-
@td.skip_if_no("pyarrow")
156-
def test_get_dummies_with_pyarrow_dtype_uint64(any_string_dtype):
157-
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
158-
result = s.str.get_dummies("|", dtype=ArrowDtype(pa.uint64()))
159-
expected = DataFrame(
160-
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
161-
columns=list("abc"),
162-
dtype=ArrowDtype(pa.uint64()),
163-
)
164-
tm.assert_frame_equal(result, expected)
165-
166-
167-
@td.skip_if_no("pyarrow")
168-
def test_get_dummies_with_pyarrow_dtype_bool(any_string_dtype):
71+
@pytest.mark.parametrize(
72+
"dtype",
73+
[
74+
"int8[pyarrow]",
75+
"uint8[pyarrow]",
76+
"int16[pyarrow]",
77+
"uint16[pyarrow]",
78+
"int32[pyarrow]",
79+
"uint32[pyarrow]",
80+
"int64[pyarrow]",
81+
"uint64[pyarrow]",
82+
"bool[pyarrow]",
83+
],
84+
)
85+
def test_get_dummies_with_pyarrow_dtype(any_string_dtype, dtype):
16986
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
170-
result = s.str.get_dummies("|", dtype=ArrowDtype(pa.bool_()))
87+
result = s.str.get_dummies("|", dtype=dtype)
17188
expected = DataFrame(
17289
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
17390
columns=list("abc"),
174-
dtype=ArrowDtype(pa.bool_()),
91+
dtype=dtype,
17592
)
17693
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)