Skip to content

Commit 8e3140d

Browse files
authored
ARROW-87 Test find_arrow/pandas/numpy with all supported types (#80)
1 parent 58890e9 commit 8e3140d

File tree

3 files changed

+112
-49
lines changed

3 files changed

+112
-49
lines changed

bindings/python/test/test_arrow.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
from pymongoarrow.api import Schema, aggregate_arrow_all, find_arrow_all, write
2929
from pymongoarrow.errors import ArrowWriteError
3030
from pymongoarrow.monkey import patch_all
31-
from pymongoarrow.types import Decimal128StringType, ObjectIdType
31+
from pymongoarrow.types import (
32+
_TYPE_NORMALIZER_FACTORY,
33+
Decimal128StringType,
34+
ObjectIdType,
35+
)
3236

3337

3438
class TestArrowApiMixin:
@@ -42,6 +46,7 @@ def setUpClass(cls):
4246
event_listeners=[cls.getmore_listener, cls.cmd_listener]
4347
)
4448
cls.schema = Schema({"_id": int32(), "data": int64()})
49+
4550
cls.coll = cls.client.pymongoarrow_test.get_collection(
4651
"test", write_concern=WriteConcern(w="majority")
4752
)
@@ -237,19 +242,18 @@ def test_pymongo_error(self):
237242

238243
def test_write_schema_validation(self):
239244
schema = {
240-
"data": int64(),
241-
"float": float64(),
242-
"datetime": timestamp("ms"),
243-
"string": string(),
244-
"bool": bool_(),
245+
k.__name__: v(True)
246+
for k, v in _TYPE_NORMALIZER_FACTORY.items()
247+
if k.__name__ not in ("ObjectId", "Decimal128")
245248
}
246249
data = Table.from_pydict(
247250
{
248-
"data": [i for i in range(2)],
251+
"Int64": [i for i in range(2)],
249252
"float": [i for i in range(2)],
250253
"datetime": [i for i in range(2)],
251-
"string": [str(i) for i in range(2)],
252-
"bool": [True for _ in range(2)],
254+
"str": [str(i) for i in range(2)],
255+
"int": [i for i in range(2)],
256+
"bool": [True, False],
253257
},
254258
ArrowSchema(schema),
255259
)
@@ -298,6 +302,29 @@ def test_parquet(self):
298302
self.round_trip(data, Schema(schema))
299303
os.remove("test.parquet")
300304

305+
def test_string_bool(self):
306+
data = Table.from_pydict(
307+
{
308+
"string": [str(i) for i in range(2)],
309+
"bool": [True for _ in range(2)],
310+
},
311+
ArrowSchema(
312+
{
313+
"string": string(),
314+
"bool": bool_(),
315+
}
316+
),
317+
)
318+
self.round_trip(
319+
data,
320+
Schema(
321+
{
322+
"string": str,
323+
"bool": bool,
324+
}
325+
),
326+
)
327+
301328

302329
class TestArrowExplicitApi(TestArrowApiMixin, unittest.TestCase):
303330
def run_find(self, *args, **kwargs):

bindings/python/test/test_numpy.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919

2020
import numpy as np
2121
from bson import Decimal128, ObjectId
22-
from pyarrow import bool_, float64, int32, int64, string, timestamp
22+
from pyarrow import int32, int64
2323
from pymongo import DESCENDING, WriteConcern
2424
from pymongo.collection import Collection
2525
from pymongoarrow.api import Schema, aggregate_numpy_all, find_numpy_all, write
2626
from pymongoarrow.errors import ArrowWriteError
27-
from pymongoarrow.types import Decimal128StringType, ObjectIdType
27+
from pymongoarrow.types import (
28+
_TYPE_NORMALIZER_FACTORY,
29+
Decimal128StringType,
30+
ObjectIdType,
31+
)
2832

2933

3034
class NumpyTestBase(unittest.TestCase):
@@ -134,32 +138,26 @@ def test_write_error(self):
134138
raise awe
135139

136140
def test_write_schema_validation(self):
137-
schema = {
138-
"data": "int64",
139-
"float": "float64",
140-
"datetime": "datetime64[ms]",
141-
"string": "str",
142-
"bool": "bool",
141+
arrow_schema = {
142+
k.__name__: v(True)
143+
for k, v in _TYPE_NORMALIZER_FACTORY.items()
144+
if k.__name__ not in ("ObjectId", "Decimal128")
143145
}
146+
schema = {k: v.to_pandas_dtype() for k, v in arrow_schema.items()}
147+
schema["str"] = "str"
148+
schema["datetime"] = "datetime64[ms]"
144149
data = {
145-
"data": [i for i in range(2)],
150+
"Int64": [i for i in range(2)],
146151
"float": [i for i in range(2)],
147152
"datetime": [i for i in range(2)],
148-
"string": [str(i) for i in range(2)],
149-
"bool": [True for _ in range(2)],
153+
"str": [str(i) for i in range(2)],
154+
"int": [i for i in range(2)],
155+
"bool": [True, False],
150156
}
151157
data = self.schemafied_ndarray_dict(data, schema)
152158
self.round_trip(
153159
data,
154-
Schema(
155-
{
156-
"data": int64(),
157-
"float": float64(),
158-
"datetime": timestamp("ms"),
159-
"string": string(),
160-
"bool": bool_(),
161-
}
162-
),
160+
Schema(arrow_schema),
163161
)
164162

165163
schema = {"_id": "int32", "data": np.ubyte()}
@@ -191,6 +189,26 @@ def test_write_dictionaries(self):
191189
):
192190
write(self.coll, {"foo": 1})
193191

192+
def test_string_bool(self):
193+
schema = {
194+
"string": "str",
195+
"bool": "bool",
196+
}
197+
data = {
198+
"string": [str(i) for i in range(2)],
199+
"bool": [True for _ in range(2)],
200+
}
201+
data = self.schemafied_ndarray_dict(data, schema)
202+
self.round_trip(
203+
data,
204+
Schema(
205+
{
206+
"string": str,
207+
"bool": bool,
208+
}
209+
),
210+
)
211+
194212

195213
class TestBSONTypes(NumpyTestBase):
196214
@classmethod

bindings/python/test/test_pandas.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,16 @@
2020
import numpy as np
2121
import pandas as pd
2222
from bson import Decimal128, ObjectId
23-
from pyarrow import bool_, decimal256, float64, int32, int64, string, timestamp
23+
from pyarrow import decimal256, int32, int64
2424
from pymongo import DESCENDING, WriteConcern
2525
from pymongo.collection import Collection
2626
from pymongoarrow.api import Schema, aggregate_pandas_all, find_pandas_all, write
2727
from pymongoarrow.errors import ArrowWriteError
28-
from pymongoarrow.types import Decimal128StringType, ObjectIdType
28+
from pymongoarrow.types import (
29+
_TYPE_NORMALIZER_FACTORY,
30+
Decimal128StringType,
31+
ObjectIdType,
32+
)
2933

3034

3135
class PandasTestBase(unittest.TestCase):
@@ -114,33 +118,28 @@ def test_write_error(self):
114118
raise awe
115119

116120
def test_write_schema_validation(self):
117-
schema = {
118-
"data": "int64",
119-
"float": "float64",
120-
"datetime": "datetime64[ms]",
121-
"string": "object",
122-
"bool": "bool",
121+
arrow_schema = {
122+
k.__name__: v(True)
123+
for k, v in _TYPE_NORMALIZER_FACTORY.items()
124+
if k.__name__ not in ("ObjectId", "Decimal128")
123125
}
126+
schema = {k: v.to_pandas_dtype() for k, v in arrow_schema.items()}
127+
schema["str"] = "str"
128+
schema["datetime"] = "datetime64[ms]"
129+
124130
data = pd.DataFrame(
125131
data={
126-
"data": [i for i in range(2)],
132+
"Int64": [i for i in range(2)],
127133
"float": [i for i in range(2)],
134+
"int": [i for i in range(2)],
128135
"datetime": [i for i in range(2)],
129-
"string": [str(i) for i in range(2)],
130-
"bool": [True for _ in range(2)],
136+
"str": [str(i) for i in range(2)],
137+
"bool": [True, False],
131138
}
132139
).astype(schema)
133140
self.round_trip(
134141
data,
135-
Schema(
136-
{
137-
"data": int64(),
138-
"float": float64(),
139-
"datetime": timestamp("ms"),
140-
"string": string(),
141-
"bool": bool_(),
142-
}
143-
),
142+
Schema(arrow_schema),
144143
)
145144

146145
schema = {"_id": "int32", "data": np.ubyte()}
@@ -167,6 +166,25 @@ def test_write_batching(self, mock):
167166
)
168167
self.assertEqual(mock.call_count, 2)
169168

169+
def test_string_bool(self):
170+
schema = {
171+
"string": "str",
172+
"bool": "bool",
173+
}
174+
data = pd.DataFrame(
175+
data=[{"string": [str(i) for i in range(2)], "bool": [True for _ in range(2)]}],
176+
).astype(schema)
177+
178+
self.round_trip(
179+
data,
180+
Schema(
181+
{
182+
"string": str,
183+
"bool": bool,
184+
}
185+
),
186+
)
187+
170188

171189
class TestBSONTypes(PandasTestBase):
172190
@classmethod

0 commit comments

Comments
 (0)