Skip to content

Commit 7db8deb

Browse files
committed
TST(string dtype): Resolve HDF5 xfails in test_round_trip.py
1 parent 2a292f2 commit 7db8deb

File tree

2 files changed

+99
-16
lines changed

2 files changed

+99
-16
lines changed

pandas/io/pytables.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1950,6 +1950,7 @@ def _write_to_group(
19501950
def _read_group(self, group: Node):
19511951
s = self._create_storer(group)
19521952
s.infer_axes()
1953+
print(type(s), s)
19531954
return s.read()
19541955

19551956
def _identify_group(self, key: str, append: bool) -> Node:
@@ -3297,7 +3298,12 @@ def read(
32973298
index = self.read_index("index", start=start, stop=stop)
32983299
values = self.read_array("values", start=start, stop=stop)
32993300
result = Series(values, index=index, name=self.name, copy=False)
3300-
if using_string_dtype() and is_string_array(values, skipna=True):
3301+
if (
3302+
using_string_dtype()
3303+
and isinstance(values, np.ndarray)
3304+
and len(values) > 0
3305+
and is_string_array(values, skipna=True)
3306+
):
33013307
result = result.astype(StringDtype(na_value=np.nan))
33023308
return result
33033309

@@ -3369,6 +3375,7 @@ def read(
33693375
if (
33703376
using_string_dtype()
33713377
and isinstance(values, np.ndarray)
3378+
and len(df) > 0
33723379
and is_string_array(values, skipna=True)
33733380
):
33743381
df = df.astype(StringDtype(na_value=np.nan))
@@ -4747,6 +4754,7 @@ def read(
47474754
if (
47484755
using_string_dtype()
47494756
and isinstance(values, np.ndarray)
4757+
and len(df) > 0
47504758
and is_string_array(values, skipna=True)
47514759
):
47524760
df = df.astype(StringDtype(na_value=np.nan))

pandas/tests/io/pytables/test_round_trip.py

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
import pytest
66

7+
from pandas._config import using_string_dtype
8+
79
from pandas._libs.tslibs import Timestamp
810
from pandas.compat import is_platform_windows
911

@@ -66,6 +68,7 @@ def roundtrip(key, obj, **kwargs):
6668
tm.assert_frame_equal(result, expected)
6769

6870

71+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
6972
def test_long_strings(setup_path):
7073
# GH6166
7174
data = ["a" * 50] * 10
@@ -206,6 +209,7 @@ def test_put_integer(setup_path):
206209
_check_roundtrip(df, tm.assert_frame_equal, setup_path)
207210

208211

212+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
209213
def test_table_values_dtypes_roundtrip(setup_path):
210214
with ensure_clean_store(setup_path) as store:
211215
df1 = DataFrame({"a": [1, 2, 3]}, dtype="f8")
@@ -375,7 +379,7 @@ def test_timeseries_preepoch(setup_path, request):
375379
@pytest.mark.parametrize(
376380
"compression", [False, pytest.param(True, marks=td.skip_if_windows)]
377381
)
378-
def test_frame(compression, setup_path):
382+
def test_frame(compression, setup_path, using_infer_string):
379383
df = DataFrame(
380384
1.1 * np.arange(120).reshape((30, 4)),
381385
columns=Index(list("ABCD"), dtype=object),
@@ -386,20 +390,40 @@ def test_frame(compression, setup_path):
386390
df.iloc[0, 0] = np.nan
387391
df.iloc[5, 3] = np.nan
388392

393+
expected = df.copy()
394+
if using_infer_string:
395+
expected.index = expected.index.astype("str")
396+
expected.columns = expected.columns.astype("str")
397+
389398
_check_roundtrip_table(
390-
df, tm.assert_frame_equal, path=setup_path, compression=compression
399+
df,
400+
tm.assert_frame_equal,
401+
path=setup_path,
402+
compression=compression,
403+
expected=expected,
391404
)
392405
_check_roundtrip(
393-
df, tm.assert_frame_equal, path=setup_path, compression=compression
406+
df,
407+
tm.assert_frame_equal,
408+
path=setup_path,
409+
compression=compression,
410+
expected=expected,
394411
)
395412

396413
tdf = DataFrame(
397414
np.random.default_rng(2).standard_normal((10, 4)),
398415
columns=Index(list("ABCD"), dtype=object),
399416
index=date_range("2000-01-01", periods=10, freq="B"),
400417
)
418+
expected = tdf.copy()
419+
if using_infer_string:
420+
expected.columns = expected.columns.astype("str")
401421
_check_roundtrip(
402-
tdf, tm.assert_frame_equal, path=setup_path, compression=compression
422+
tdf,
423+
tm.assert_frame_equal,
424+
path=setup_path,
425+
compression=compression,
426+
expected=expected,
403427
)
404428

405429
with ensure_clean_store(setup_path) as store:
@@ -410,7 +434,10 @@ def test_frame(compression, setup_path):
410434
assert recons._mgr.is_consolidated()
411435

412436
# empty
413-
_check_roundtrip(df[:0], tm.assert_frame_equal, path=setup_path)
437+
expected = df[:0]
438+
if using_infer_string:
439+
expected.columns = expected.columns.astype("str")
440+
_check_roundtrip(df[:0], tm.assert_frame_equal, path=setup_path, expected=expected)
414441

415442

416443
def test_empty_series_frame(setup_path):
@@ -442,9 +469,21 @@ def test_can_serialize_dates(setup_path):
442469
_check_roundtrip(frame, tm.assert_frame_equal, path=setup_path)
443470

444471

445-
def test_store_hierarchical(setup_path, multiindex_dataframe_random_data):
472+
def test_store_hierarchical(
473+
setup_path, multiindex_dataframe_random_data, using_infer_string
474+
):
446475
frame = multiindex_dataframe_random_data
447476

477+
if using_infer_string:
478+
msg = "Saving a MultiIndex with an extension dtype is not supported."
479+
with pytest.raises(NotImplementedError, match=msg):
480+
_check_roundtrip(frame, tm.assert_frame_equal, path=setup_path)
481+
with pytest.raises(NotImplementedError, match=msg):
482+
_check_roundtrip(frame.T, tm.assert_frame_equal, path=setup_path)
483+
with pytest.raises(NotImplementedError, match=msg):
484+
_check_roundtrip(frame["A"], tm.assert_series_equal, path=setup_path)
485+
return
486+
448487
_check_roundtrip(frame, tm.assert_frame_equal, path=setup_path)
449488
_check_roundtrip(frame.T, tm.assert_frame_equal, path=setup_path)
450489
_check_roundtrip(frame["A"], tm.assert_series_equal, path=setup_path)
@@ -459,7 +498,7 @@ def test_store_hierarchical(setup_path, multiindex_dataframe_random_data):
459498
@pytest.mark.parametrize(
460499
"compression", [False, pytest.param(True, marks=td.skip_if_windows)]
461500
)
462-
def test_store_mixed(compression, setup_path):
501+
def test_store_mixed(compression, setup_path, using_infer_string):
463502
def _make_one():
464503
df = DataFrame(
465504
1.1 * np.arange(120).reshape((30, 4)),
@@ -477,57 +516,91 @@ def _make_one():
477516
df1 = _make_one()
478517
df2 = _make_one()
479518

480-
_check_roundtrip(df1, tm.assert_frame_equal, path=setup_path)
481-
_check_roundtrip(df2, tm.assert_frame_equal, path=setup_path)
519+
expected = df1.copy()
520+
if using_infer_string:
521+
expected.index = expected.index.astype("str")
522+
expected.columns = expected.columns.astype("str")
523+
_check_roundtrip(df1, tm.assert_frame_equal, path=setup_path, expected=expected)
524+
525+
expected = df2.copy()
526+
if using_infer_string:
527+
expected.index = expected.index.astype("str")
528+
expected.columns = expected.columns.astype("str")
529+
_check_roundtrip(df2, tm.assert_frame_equal, path=setup_path, expected=expected)
482530

483531
with ensure_clean_store(setup_path) as store:
484532
store["obj"] = df1
485-
tm.assert_frame_equal(store["obj"], df1)
533+
expected = df1.copy()
534+
if using_infer_string:
535+
expected.index = expected.index.astype("str")
536+
expected.columns = expected.columns.astype("str")
537+
tm.assert_frame_equal(store["obj"], expected)
538+
486539
store["obj"] = df2
487-
tm.assert_frame_equal(store["obj"], df2)
540+
expected = df2.copy()
541+
if using_infer_string:
542+
expected.index = expected.index.astype("str")
543+
expected.columns = expected.columns.astype("str")
544+
tm.assert_frame_equal(store["obj"], expected)
488545

489546
# check that can store Series of all of these types
547+
expected = df1["obj1"]
548+
if using_infer_string:
549+
expected.index = expected.index.astype("str")
490550
_check_roundtrip(
491551
df1["obj1"],
492552
tm.assert_series_equal,
493553
path=setup_path,
494554
compression=compression,
555+
expected=expected,
495556
)
557+
expected = df1["bool1"]
558+
if using_infer_string:
559+
expected.index = expected.index.astype("str")
496560
_check_roundtrip(
497561
df1["bool1"],
498562
tm.assert_series_equal,
499563
path=setup_path,
500564
compression=compression,
565+
expected=expected,
501566
)
567+
expected = df1["int1"]
568+
if using_infer_string:
569+
expected.index = expected.index.astype("str")
502570
_check_roundtrip(
503571
df1["int1"],
504572
tm.assert_series_equal,
505573
path=setup_path,
506574
compression=compression,
575+
expected=expected,
507576
)
508577

509578

510-
def _check_roundtrip(obj, comparator, path, compression=False, **kwargs):
579+
def _check_roundtrip(obj, comparator, path, compression=False, expected=None, **kwargs):
511580
options = {}
512581
if compression:
513582
options["complib"] = "blosc"
583+
if expected is None:
584+
expected = obj
514585

515586
with ensure_clean_store(path, "w", **options) as store:
516587
store["obj"] = obj
517588
retrieved = store["obj"]
518-
comparator(retrieved, obj, **kwargs)
589+
comparator(retrieved, expected, **kwargs)
519590

520591

521-
def _check_roundtrip_table(obj, comparator, path, compression=False):
592+
def _check_roundtrip_table(obj, comparator, path, compression=False, expected=None):
522593
options = {}
523594
if compression:
524595
options["complib"] = "blosc"
596+
if expected is None:
597+
expected = obj
525598

526599
with ensure_clean_store(path, "w", **options) as store:
527600
store.put("obj", obj, format="table")
528601
retrieved = store["obj"]
529602

530-
comparator(retrieved, obj)
603+
comparator(retrieved, expected)
531604

532605

533606
def test_unicode_index(setup_path):
@@ -540,6 +613,7 @@ def test_unicode_index(setup_path):
540613
_check_roundtrip(s, tm.assert_series_equal, path=setup_path)
541614

542615

616+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
543617
def test_unicode_longer_encoded(setup_path):
544618
# GH 11234
545619
char = "\u0394"
@@ -565,6 +639,7 @@ def test_store_datetime_mixed(setup_path):
565639
_check_roundtrip(df, tm.assert_frame_equal, path=setup_path)
566640

567641

642+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
568643
def test_round_trip_equals(tmp_path, setup_path):
569644
# GH 9330
570645
df = DataFrame({"B": [1, 2], "A": ["x", "y"]})

0 commit comments

Comments
 (0)