3
3
import numpy as np
4
4
import pytest
5
5
6
- from pandas ._config import using_string_dtype
7
-
8
6
from pandas ._libs .tslibs import Timestamp
9
7
10
8
import pandas as pd
26
24
27
25
pytestmark = [
28
26
pytest .mark .single_cpu ,
29
- pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" , strict = False ),
30
27
]
31
28
32
29
@@ -99,7 +96,7 @@ def test_api_default_format(tmp_path, setup_path):
99
96
assert store .get_storer ("df4" ).is_table
100
97
101
98
102
- def test_put (setup_path ):
99
+ def test_put (setup_path , using_infer_string ):
103
100
with ensure_clean_store (setup_path ) as store :
104
101
ts = Series (
105
102
np .arange (10 , dtype = np .float64 ), index = date_range ("2020-01-01" , periods = 10 )
@@ -133,7 +130,11 @@ def test_put(setup_path):
133
130
134
131
# overwrite table
135
132
store .put ("c" , df [:10 ], format = "table" , append = False )
136
- tm .assert_frame_equal (df [:10 ], store ["c" ])
133
+ expected = df [:10 ]
134
+ if using_infer_string :
135
+ expected .columns = expected .columns .astype ("str" )
136
+ result = store ["c" ]
137
+ tm .assert_frame_equal (result , expected )
137
138
138
139
139
140
def test_put_string_index (setup_path ):
@@ -162,7 +163,7 @@ def test_put_string_index(setup_path):
162
163
tm .assert_frame_equal (store ["b" ], df )
163
164
164
165
165
- def test_put_compression (setup_path ):
166
+ def test_put_compression (setup_path , using_infer_string ):
166
167
with ensure_clean_store (setup_path ) as store :
167
168
df = DataFrame (
168
169
np .random .default_rng (2 ).standard_normal ((10 , 4 )),
@@ -171,7 +172,11 @@ def test_put_compression(setup_path):
171
172
)
172
173
173
174
store .put ("c" , df , format = "table" , complib = "zlib" )
174
- tm .assert_frame_equal (store ["c" ], df )
175
+ expected = df
176
+ if using_infer_string :
177
+ expected .columns = expected .columns .astype ("str" )
178
+ result = store ["c" ]
179
+ tm .assert_frame_equal (result , expected )
175
180
176
181
# can't compress if format='fixed'
177
182
msg = "Compression not supported on Fixed format stores"
@@ -180,7 +185,7 @@ def test_put_compression(setup_path):
180
185
181
186
182
187
@td .skip_if_windows
183
- def test_put_compression_blosc (setup_path ):
188
+ def test_put_compression_blosc (setup_path , using_infer_string ):
184
189
df = DataFrame (
185
190
np .random .default_rng (2 ).standard_normal ((10 , 4 )),
186
191
columns = Index (list ("ABCD" ), dtype = object ),
@@ -194,10 +199,14 @@ def test_put_compression_blosc(setup_path):
194
199
store .put ("b" , df , format = "fixed" , complib = "blosc" )
195
200
196
201
store .put ("c" , df , format = "table" , complib = "blosc" )
197
- tm .assert_frame_equal (store ["c" ], df )
202
+ expected = df
203
+ if using_infer_string :
204
+ expected .columns = expected .columns .astype ("str" )
205
+ result = store ["c" ]
206
+ tm .assert_frame_equal (result , expected )
198
207
199
208
200
- def test_put_mixed_type (setup_path , performance_warning ):
209
+ def test_put_mixed_type (setup_path , performance_warning , using_infer_string ):
201
210
df = DataFrame (
202
211
np .random .default_rng (2 ).standard_normal ((10 , 4 )),
203
212
columns = Index (list ("ABCD" ), dtype = object ),
@@ -223,8 +232,11 @@ def test_put_mixed_type(setup_path, performance_warning):
223
232
with tm .assert_produces_warning (performance_warning ):
224
233
store .put ("df" , df )
225
234
226
- expected = store .get ("df" )
227
- tm .assert_frame_equal (expected , df )
235
+ expected = df
236
+ if using_infer_string :
237
+ expected .columns = expected .columns .astype ("str" )
238
+ result = store .get ("df" )
239
+ tm .assert_frame_equal (result , expected )
228
240
229
241
230
242
@pytest .mark .parametrize ("format" , ["table" , "fixed" ])
@@ -253,7 +265,7 @@ def test_store_index_types(setup_path, format, index):
253
265
tm .assert_frame_equal (df , store ["df" ])
254
266
255
267
256
- def test_column_multiindex (setup_path ):
268
+ def test_column_multiindex (setup_path , using_infer_string ):
257
269
# GH 4710
258
270
# recreate multi-indexes properly
259
271
@@ -264,6 +276,11 @@ def test_column_multiindex(setup_path):
264
276
expected = df .set_axis (df .index .to_numpy ())
265
277
266
278
with ensure_clean_store (setup_path ) as store :
279
+ if using_infer_string :
280
+ msg = "Saving a MultiIndex with an extension dtype is not supported."
281
+ with pytest .raises (NotImplementedError , match = msg ):
282
+ store .put ("df" , df )
283
+ return
267
284
store .put ("df" , df )
268
285
tm .assert_frame_equal (
269
286
store ["df" ], expected , check_index_type = True , check_column_type = True
0 commit comments