Skip to content

Commit 225a9e4

Browse files
authored
ARROW-94 Test for nulls (#83)
1 parent 96477bc commit 225a9e4

File tree

4 files changed

+318
-4
lines changed

4 files changed

+318
-4
lines changed

bindings/python/test/test_arrow.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
import unittest
1616
import unittest.mock as mock
1717
from test import client_context
18-
from test.utils import AllowListEventListener
18+
from test.utils import AllowListEventListener, TestNullsBase
1919

20+
import pyarrow
2021
import pymongo
2122
from bson import Decimal128, ObjectId
2223
from pyarrow import Table, binary, bool_, decimal256, float64, int32, int64
@@ -384,3 +385,20 @@ def test_find_decimal128(self):
384385
)
385386
table = find_arrow_all(coll, {}, schema=schema)
386387
self.assertEqual(table, expected)
388+
389+
390+
class TestNulls(TestNullsBase):
391+
def find_fn(self, coll, query, schema):
392+
return find_arrow_all(coll, query, schema=schema)
393+
394+
def equal_fn(self, left, right):
395+
self.assertEqual(left, right)
396+
397+
def table_from_dict(self, dict, schema=None):
398+
return pyarrow.Table.from_pydict(dict, schema)
399+
400+
def assert_in_idx(self, table, col_name):
401+
self.assertTrue(col_name in table.column_names)
402+
403+
def na_safe(self, atype):
404+
return True

bindings/python/test/test_numpy.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# from datetime import datetime, timedelta
15+
import datetime
1516
import unittest
1617
from test import client_context
17-
from test.utils import AllowListEventListener
18+
from test.utils import AllowListEventListener, TestNullsBase
1819
from unittest import mock
1920

2021
import numpy as np
@@ -242,3 +243,50 @@ def test_find_decimal128(self):
242243
}
243244
actual = find_numpy_all(self.coll, {}, schema=self.schema)
244245
self.assert_numpy_equal(actual, expected)
246+
247+
248+
# The spec for pyarrow says to_numpy is experimental, so we should expect
249+
# this to change in the future.
250+
class TestNulls(TestNullsBase, NumpyTestBase):
251+
def table_from_dict(self, d, schema=None):
252+
out = {}
253+
for k, v in d.items():
254+
if any(isinstance(x, int) for x in v) and None in v:
255+
out[k] = np.array(v, dtype=np.float_)
256+
else:
257+
out[k] = np.array(v, dtype=np.dtype(type(v[0]))) # Infer
258+
return out
259+
260+
def equal_fn(self, left, right):
261+
left = np.nan_to_num(left)
262+
right = np.nan_to_num(left)
263+
self.assertTrue(np.all(np.equal(left, right)))
264+
265+
def find_fn(self, coll, query, schema=None):
266+
return find_numpy_all(coll, query, schema=schema)
267+
268+
def assert_in_idx(self, table, col_name):
269+
self.assertTrue(col_name in table)
270+
271+
pytype_tab_map = {
272+
str: "str128",
273+
int: ["int64", "float64"],
274+
float: "float64",
275+
datetime.datetime: "datetime64[ms]",
276+
ObjectId: "object",
277+
Decimal128: "object",
278+
bool: "object",
279+
}
280+
281+
pytype_writeback_exc_map = {
282+
str: None,
283+
int: None,
284+
float: None,
285+
datetime.datetime: ValueError, # TypeError,
286+
ObjectId: ValueError, # TypeError,
287+
Decimal128: ValueError, # TypeError,
288+
bool: None,
289+
}
290+
291+
def na_safe(self, atype):
292+
return atype != TestNulls.pytype_tab_map[str]

bindings/python/test/test_pandas.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# from datetime import datetime, timedelta
15+
import datetime
1516
import unittest
1617
import unittest.mock as mock
1718
from test import client_context
18-
from test.utils import AllowListEventListener
19+
from test.utils import AllowListEventListener, TestNullsBase
1920

2021
import numpy as np
2122
import pandas as pd
23+
import pandas.testing
24+
import pyarrow
2225
from bson import Decimal128, ObjectId
2326
from pyarrow import decimal256, int32, int64
2427
from pymongo import DESCENDING, WriteConcern
@@ -219,3 +222,45 @@ def test_find_decimal128(self):
219222

220223
table = find_pandas_all(self.coll, {}, schema=self.schema)
221224
pd.testing.assert_frame_equal(expected, table)
225+
226+
227+
class TestNulls(TestNullsBase):
228+
def find_fn(self, coll, query, schema):
229+
return find_pandas_all(coll, query, schema=schema)
230+
231+
def equal_fn(self, left, right):
232+
left = left.fillna(0).replace(-0b1 << 63, 0) # NaN is sometimes this
233+
right = right.fillna(0).replace(-0b1 << 63, 0)
234+
if type(left) == pandas.DataFrame:
235+
pandas.testing.assert_frame_equal(left, right, check_dtype=False)
236+
else:
237+
pandas.testing.assert_series_equal(left, right, check_dtype=False)
238+
239+
def table_from_dict(self, dict, schema=None):
240+
return pandas.DataFrame(dict)
241+
242+
def assert_in_idx(self, table, col_name):
243+
self.assertTrue(col_name in table.columns)
244+
245+
pytype_tab_map = {
246+
str: "object",
247+
int: ["int64", "float64"],
248+
float: "float64",
249+
datetime.datetime: "datetime64[ns]",
250+
ObjectId: "object",
251+
Decimal128: "object",
252+
bool: "object",
253+
}
254+
255+
pytype_writeback_exc_map = {
256+
str: None,
257+
int: None,
258+
float: None,
259+
datetime.datetime: ValueError,
260+
ObjectId: ValueError,
261+
Decimal128: pyarrow.lib.ArrowInvalid,
262+
bool: None,
263+
}
264+
265+
def na_safe(self, atype):
266+
return True

bindings/python/test/utils.py

Lines changed: 204 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,25 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import datetime
15+
import unittest
1416
from collections import defaultdict
17+
from test import client_context
1518

16-
from pymongo import monitoring
19+
import numpy as np
20+
import pyarrow
21+
from bson import Decimal128, ObjectId
22+
from pandas import isna
23+
from pyarrow import bool_, float64, int64, string, timestamp
24+
from pymongo import WriteConcern, monitoring
25+
from pymongoarrow.api import write
26+
from pymongoarrow.schema import Schema
27+
from pymongoarrow.types import (
28+
_TYPE_NORMALIZER_FACTORY,
29+
Decimal128StringType,
30+
ObjectIdType,
31+
_in_type_map,
32+
)
1733

1834

1935
class EventListener(monitoring.CommandListener):
@@ -54,3 +70,190 @@ def succeeded(self, event):
5470
def failed(self, event):
5571
if event.command_name in self.commands:
5672
super(AllowListEventListener, self).failed(event)
73+
74+
75+
class TestNullsBase(unittest.TestCase):
76+
def find_fn(self, coll, query, schema):
77+
raise NotImplementedError
78+
79+
def equal_fn(self, left, right):
80+
raise NotImplementedError
81+
82+
def table_from_dict(self, dict, schema=None):
83+
raise NotImplementedError
84+
85+
def assert_in_idx(self, table, col_name):
86+
raise NotImplementedError
87+
88+
def na_safe(self, atype):
89+
raise NotImplementedError
90+
91+
# Map Python types to constructors
92+
pytype_cons_map = {
93+
str: str,
94+
int: int,
95+
float: float,
96+
datetime.datetime: lambda x: datetime.datetime(x + 1970, 1, 1),
97+
ObjectId: lambda _: ObjectId(),
98+
Decimal128: lambda x: Decimal128(str(x)),
99+
}
100+
101+
# Map Python types to types for table we are comparing.
102+
pytype_tab_map = {
103+
str: string(),
104+
int: int64(),
105+
float: float64(),
106+
datetime.datetime: timestamp("ms"),
107+
ObjectId: ObjectIdType(),
108+
Decimal128: Decimal128StringType(),
109+
bool: bool_(),
110+
}
111+
112+
pytype_writeback_exc_map = {
113+
str: None,
114+
int: None,
115+
float: None,
116+
datetime.datetime: None,
117+
ObjectId: pyarrow.lib.ArrowInvalid,
118+
Decimal128: pyarrow.lib.ArrowInvalid,
119+
bool: None,
120+
}
121+
122+
@classmethod
123+
def setUpClass(cls):
124+
if cls is TestNullsBase:
125+
raise unittest.SkipTest("Base class")
126+
127+
if not client_context.connected:
128+
raise unittest.SkipTest("cannot connect to MongoDB")
129+
cls.cmd_listener = AllowListEventListener("find", "aggregate")
130+
cls.getmore_listener = AllowListEventListener("getMore")
131+
cls.client = client_context.get_client(
132+
event_listeners=[cls.getmore_listener, cls.cmd_listener]
133+
)
134+
135+
cls.oids = [ObjectId() for _ in range(4)]
136+
cls.coll = cls.client.pymongoarrow_test.get_collection(
137+
"test", write_concern=WriteConcern(w="majority")
138+
)
139+
140+
def setUp(self):
141+
self.coll.drop()
142+
143+
self.cmd_listener.reset()
144+
self.getmore_listener.reset()
145+
146+
def assertType(self, obj1, arrow_type):
147+
if isinstance(obj1, pyarrow.ChunkedArray):
148+
if "storage_type" in dir(arrow_type):
149+
self.assertEqual(obj1.type, arrow_type.storage_type)
150+
else:
151+
self.assertEqual(obj1.type, arrow_type)
152+
else:
153+
if isinstance(arrow_type, list):
154+
self.assertTrue(obj1.dtype.name in arrow_type)
155+
else:
156+
self.assertEqual(obj1.dtype.name, arrow_type)
157+
158+
def test_int_handling(self):
159+
# Default integral types
160+
int_schema = Schema({"_id": ObjectIdType(), "int64": int64()})
161+
int64_arr = [(i if (i % 2 == 0) else None) for i in range(len(self.oids))]
162+
self.coll.insert_many(
163+
[{"_id": self.oids[i], "int64": int64_arr[i]} for i in range(len(self.oids))]
164+
)
165+
166+
table = self.find_fn(self.coll, {}, schema=int_schema)
167+
168+
# Resulting datatype should be float64 according to the spec for numpy
169+
# and pandas.
170+
atype = self.pytype_tab_map[int]
171+
self.assertType(table["int64"], atype)
172+
173+
# Does it contain NAs where we expect?
174+
self.assertTrue(np.all(np.equal(isna(int64_arr), isna(table["int64"]))))
175+
176+
# Write
177+
self.coll.drop()
178+
table_write = self.table_from_dict({"int64": int64_arr})
179+
180+
write(self.coll, table_write)
181+
res_table = self.find_fn(self.coll, {}, schema=int_schema)
182+
183+
self.equal_fn(res_table["int64"], table_write["int64"])
184+
self.assert_in_idx(res_table, "_id")
185+
self.assertType(res_table["int64"], atype)
186+
187+
def test_all_types(self):
188+
for t in self.pytype_tab_map.keys():
189+
self.assertTrue(_in_type_map(_TYPE_NORMALIZER_FACTORY[t](0)))
190+
191+
def test_other_handling(self):
192+
# Tests other types, which are treated differently in
193+
# arrow/pandas/numpy.
194+
for gen in [str, float, datetime.datetime, ObjectId, Decimal128]:
195+
con_type = self.pytype_tab_map[gen] # Arrow/Pandas/Numpy
196+
pytype = TestNullsBase.pytype_tab_map[gen] # Arrow type specifically
197+
198+
other_schema = Schema({"_id": ObjectIdType(), "other": pytype})
199+
others = [
200+
self.pytype_cons_map[gen](i) if (i % 2 == 0) else None
201+
for i in range(len(self.oids))
202+
]
203+
204+
self.setUp()
205+
self.coll.insert_many(
206+
[{"_id": self.oids[i], "other": others[i]} for i in range(len(self.oids))]
207+
)
208+
table = self.find_fn(self.coll, {}, schema=other_schema)
209+
210+
# Resulting datatype should be str in this case
211+
212+
self.assertType(table["other"], con_type)
213+
self.assertEqual(
214+
self.na_safe(con_type), np.all(np.equal(isna(others), isna(table["other"])))
215+
)
216+
217+
def writeback():
218+
# Write
219+
self.coll.drop()
220+
table_write_schema = Schema({"other": pytype})
221+
table_write_schema_arrow = (
222+
pyarrow.schema([("other", pytype)])
223+
if (gen in [str, float, datetime.datetime])
224+
else None
225+
)
226+
227+
table_write = self.table_from_dict(
228+
{"other": others}, schema=table_write_schema_arrow
229+
)
230+
231+
write(self.coll, table_write)
232+
res_table = self.find_fn(self.coll, {}, schema=table_write_schema)
233+
self.equal_fn(res_table, table_write)
234+
self.assertType(res_table["other"], con_type)
235+
236+
# Do we expect an exception to be raised?
237+
if self.pytype_writeback_exc_map[gen] is not None:
238+
expected_exc = self.pytype_writeback_exc_map[gen]
239+
with self.assertRaises(expected_exc):
240+
writeback()
241+
else:
242+
writeback()
243+
244+
def test_bool_handling(self):
245+
atype = self.pytype_tab_map[bool]
246+
bool_schema = Schema({"_id": ObjectIdType(), "bool_": bool_()})
247+
bools = [True if (i % 2 == 0) else None for i in range(len(self.oids))]
248+
249+
self.coll.insert_many(
250+
[{"_id": self.oids[i], "bool_": bools[i]} for i in range(len(self.oids))]
251+
)
252+
253+
table = self.find_fn(self.coll, {}, schema=bool_schema)
254+
255+
# Resulting datatype should be object
256+
self.assertType(table["bool_"], atype)
257+
258+
# Does it contain Nones where expected?
259+
self.assertTrue(np.all(np.equal(isna(bools), isna(table["bool_"]))))

0 commit comments

Comments
 (0)