Skip to content

Commit 2a292f2

Browse files
committed
Merge branch 'str_hdf5_put_xfails' of https://github.com/rhshadrach/pandas into str_hdf5_round_trip
2 parents ad6ed76 + 2be559a commit 2a292f2

File tree

2 files changed

+42
-17
lines changed

2 files changed

+42
-17
lines changed

pandas/io/pytables.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
PeriodArray,
8787
)
8888
from pandas.core.arrays.datetimes import tz_to_dtype
89+
from pandas.core.arrays.string_ import BaseStringArray
8990
import pandas.core.common as com
9091
from pandas.core.computation.pytables import (
9192
PyTablesExpr,
@@ -3185,6 +3186,8 @@ def write_array(
31853186
# both self._filters and EA
31863187

31873188
value = extract_array(obj, extract_numpy=True)
3189+
if isinstance(value, BaseStringArray):
3190+
value = value.to_numpy()
31883191

31893192
if key in self.group:
31903193
self._handle.remove_node(self.group, key)
@@ -3363,7 +3366,11 @@ def read(
33633366

33643367
columns = items[items.get_indexer(blk_items)]
33653368
df = DataFrame(values.T, columns=columns, index=axes[1], copy=False)
3366-
if using_string_dtype() and is_string_array(values, skipna=True):
3369+
if (
3370+
using_string_dtype()
3371+
and isinstance(values, np.ndarray)
3372+
and is_string_array(values, skipna=True)
3373+
):
33673374
df = df.astype(StringDtype(na_value=np.nan))
33683375
dfs.append(df)
33693376

@@ -4737,9 +4744,10 @@ def read(
47374744
df = DataFrame._from_arrays([values], columns=cols_, index=index_)
47384745
if not (using_string_dtype() and values.dtype.kind == "O"):
47394746
assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype)
4740-
if using_string_dtype() and is_string_array(
4741-
values, # type: ignore[arg-type]
4742-
skipna=True,
4747+
if (
4748+
using_string_dtype()
4749+
and isinstance(values, np.ndarray)
4750+
and is_string_array(values, skipna=True)
47434751
):
47444752
df = df.astype(StringDtype(na_value=np.nan))
47454753
frames.append(df)

pandas/tests/io/pytables/test_put.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
from pandas._libs.tslibs import Timestamp
97

108
import pandas as pd
@@ -26,7 +24,6 @@
2624

2725
pytestmark = [
2826
pytest.mark.single_cpu,
29-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
3027
]
3128

3229

@@ -99,7 +96,7 @@ def test_api_default_format(tmp_path, setup_path):
9996
assert store.get_storer("df4").is_table
10097

10198

102-
def test_put(setup_path):
99+
def test_put(setup_path, using_infer_string):
103100
with ensure_clean_store(setup_path) as store:
104101
ts = Series(
105102
np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10)
@@ -133,7 +130,11 @@ def test_put(setup_path):
133130

134131
# overwrite table
135132
store.put("c", df[:10], format="table", append=False)
136-
tm.assert_frame_equal(df[:10], store["c"])
133+
expected = df[:10]
134+
if using_infer_string:
135+
expected.columns = expected.columns.astype("str")
136+
result = store["c"]
137+
tm.assert_frame_equal(result, expected)
137138

138139

139140
def test_put_string_index(setup_path):
@@ -162,7 +163,7 @@ def test_put_string_index(setup_path):
162163
tm.assert_frame_equal(store["b"], df)
163164

164165

165-
def test_put_compression(setup_path):
166+
def test_put_compression(setup_path, using_infer_string):
166167
with ensure_clean_store(setup_path) as store:
167168
df = DataFrame(
168169
np.random.default_rng(2).standard_normal((10, 4)),
@@ -171,7 +172,11 @@ def test_put_compression(setup_path):
171172
)
172173

173174
store.put("c", df, format="table", complib="zlib")
174-
tm.assert_frame_equal(store["c"], df)
175+
expected = df
176+
if using_infer_string:
177+
expected.columns = expected.columns.astype("str")
178+
result = store["c"]
179+
tm.assert_frame_equal(result, expected)
175180

176181
# can't compress if format='fixed'
177182
msg = "Compression not supported on Fixed format stores"
@@ -180,7 +185,7 @@ def test_put_compression(setup_path):
180185

181186

182187
@td.skip_if_windows
183-
def test_put_compression_blosc(setup_path):
188+
def test_put_compression_blosc(setup_path, using_infer_string):
184189
df = DataFrame(
185190
np.random.default_rng(2).standard_normal((10, 4)),
186191
columns=Index(list("ABCD"), dtype=object),
@@ -194,10 +199,14 @@ def test_put_compression_blosc(setup_path):
194199
store.put("b", df, format="fixed", complib="blosc")
195200

196201
store.put("c", df, format="table", complib="blosc")
197-
tm.assert_frame_equal(store["c"], df)
202+
expected = df
203+
if using_infer_string:
204+
expected.columns = expected.columns.astype("str")
205+
result = store["c"]
206+
tm.assert_frame_equal(result, expected)
198207

199208

200-
def test_put_mixed_type(setup_path, performance_warning):
209+
def test_put_mixed_type(setup_path, performance_warning, using_infer_string):
201210
df = DataFrame(
202211
np.random.default_rng(2).standard_normal((10, 4)),
203212
columns=Index(list("ABCD"), dtype=object),
@@ -223,8 +232,11 @@ def test_put_mixed_type(setup_path, performance_warning):
223232
with tm.assert_produces_warning(performance_warning):
224233
store.put("df", df)
225234

226-
expected = store.get("df")
227-
tm.assert_frame_equal(expected, df)
235+
expected = df
236+
if using_infer_string:
237+
expected.columns = expected.columns.astype("str")
238+
result = store.get("df")
239+
tm.assert_frame_equal(result, expected)
228240

229241

230242
@pytest.mark.parametrize("format", ["table", "fixed"])
@@ -253,7 +265,7 @@ def test_store_index_types(setup_path, format, index):
253265
tm.assert_frame_equal(df, store["df"])
254266

255267

256-
def test_column_multiindex(setup_path):
268+
def test_column_multiindex(setup_path, using_infer_string):
257269
# GH 4710
258270
# recreate multi-indexes properly
259271

@@ -264,6 +276,11 @@ def test_column_multiindex(setup_path):
264276
expected = df.set_axis(df.index.to_numpy())
265277

266278
with ensure_clean_store(setup_path) as store:
279+
if using_infer_string:
280+
msg = "Saving a MultiIndex with an extension dtype is not supported."
281+
with pytest.raises(NotImplementedError, match=msg):
282+
store.put("df", df)
283+
return
267284
store.put("df", df)
268285
tm.assert_frame_equal(
269286
store["df"], expected, check_index_type=True, check_column_type=True

0 commit comments

Comments
 (0)