33import numpy as np
44import pytest
55
6- from pandas ._config import using_string_dtype
7-
86from pandas ._libs .tslibs import Timestamp
97
108import pandas as pd
2624
2725pytestmark = [
2826 pytest .mark .single_cpu ,
29- pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" , strict = False ),
3027]
3128
3229
@@ -99,7 +96,7 @@ def test_api_default_format(tmp_path, setup_path):
9996 assert store .get_storer ("df4" ).is_table
10097
10198
102- def test_put (setup_path ):
99+ def test_put (setup_path , using_infer_string ):
103100 with ensure_clean_store (setup_path ) as store :
104101 ts = Series (
105102 np .arange (10 , dtype = np .float64 ), index = date_range ("2020-01-01" , periods = 10 )
@@ -133,7 +130,11 @@ def test_put(setup_path):
133130
134131 # overwrite table
135132 store .put ("c" , df [:10 ], format = "table" , append = False )
136- tm .assert_frame_equal (df [:10 ], store ["c" ])
133+ expected = df [:10 ]
134+ if using_infer_string :
135+ expected .columns = expected .columns .astype ("str" )
136+ result = store ["c" ]
137+ tm .assert_frame_equal (result , expected )
137138
138139
139140def test_put_string_index (setup_path ):
@@ -162,7 +163,7 @@ def test_put_string_index(setup_path):
162163 tm .assert_frame_equal (store ["b" ], df )
163164
164165
165- def test_put_compression (setup_path ):
166+ def test_put_compression (setup_path , using_infer_string ):
166167 with ensure_clean_store (setup_path ) as store :
167168 df = DataFrame (
168169 np .random .default_rng (2 ).standard_normal ((10 , 4 )),
@@ -171,7 +172,11 @@ def test_put_compression(setup_path):
171172 )
172173
173174 store .put ("c" , df , format = "table" , complib = "zlib" )
174- tm .assert_frame_equal (store ["c" ], df )
175+ expected = df
176+ if using_infer_string :
177+ expected .columns = expected .columns .astype ("str" )
178+ result = store ["c" ]
179+ tm .assert_frame_equal (result , expected )
175180
176181 # can't compress if format='fixed'
177182 msg = "Compression not supported on Fixed format stores"
@@ -180,7 +185,7 @@ def test_put_compression(setup_path):
180185
181186
182187@td .skip_if_windows
183- def test_put_compression_blosc (setup_path ):
188+ def test_put_compression_blosc (setup_path , using_infer_string ):
184189 df = DataFrame (
185190 np .random .default_rng (2 ).standard_normal ((10 , 4 )),
186191 columns = Index (list ("ABCD" ), dtype = object ),
@@ -194,10 +199,14 @@ def test_put_compression_blosc(setup_path):
194199 store .put ("b" , df , format = "fixed" , complib = "blosc" )
195200
196201 store .put ("c" , df , format = "table" , complib = "blosc" )
197- tm .assert_frame_equal (store ["c" ], df )
202+ expected = df
203+ if using_infer_string :
204+ expected .columns = expected .columns .astype ("str" )
205+ result = store ["c" ]
206+ tm .assert_frame_equal (result , expected )
198207
199208
200- def test_put_mixed_type (setup_path , performance_warning ):
209+ def test_put_mixed_type (setup_path , performance_warning , using_infer_string ):
201210 df = DataFrame (
202211 np .random .default_rng (2 ).standard_normal ((10 , 4 )),
203212 columns = Index (list ("ABCD" ), dtype = object ),
@@ -223,8 +232,11 @@ def test_put_mixed_type(setup_path, performance_warning):
223232 with tm .assert_produces_warning (performance_warning ):
224233 store .put ("df" , df )
225234
226- expected = store .get ("df" )
227- tm .assert_frame_equal (expected , df )
235+ expected = df
236+ if using_infer_string :
237+ expected .columns = expected .columns .astype ("str" )
238+ result = store .get ("df" )
239+ tm .assert_frame_equal (result , expected )
228240
229241
230242@pytest .mark .parametrize ("format" , ["table" , "fixed" ])
@@ -253,7 +265,7 @@ def test_store_index_types(setup_path, format, index):
253265 tm .assert_frame_equal (df , store ["df" ])
254266
255267
256- def test_column_multiindex (setup_path ):
268+ def test_column_multiindex (setup_path , using_infer_string ):
257269 # GH 4710
258270 # recreate multi-indexes properly
259271
@@ -264,6 +276,11 @@ def test_column_multiindex(setup_path):
264276 expected = df .set_axis (df .index .to_numpy ())
265277
266278 with ensure_clean_store (setup_path ) as store :
279+ if using_infer_string :
280+ msg = "Saving a MultiIndex with an extension dtype is not supported."
281+ with pytest .raises (NotImplementedError , match = msg ):
282+ store .put ("df" , df )
283+ return
267284 store .put ("df" , df )
268285 tm .assert_frame_equal (
269286 store ["df" ], expected , check_index_type = True , check_column_type = True
0 commit comments