|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import datetime |
| 15 | +import unittest |
14 | 16 | from collections import defaultdict
|
| 17 | +from test import client_context |
15 | 18 |
|
16 |
| -from pymongo import monitoring |
| 19 | +import numpy as np |
| 20 | +import pyarrow |
| 21 | +from bson import Decimal128, ObjectId |
| 22 | +from pandas import isna |
| 23 | +from pyarrow import bool_, float64, int64, string, timestamp |
| 24 | +from pymongo import WriteConcern, monitoring |
| 25 | +from pymongoarrow.api import write |
| 26 | +from pymongoarrow.schema import Schema |
| 27 | +from pymongoarrow.types import ( |
| 28 | + _TYPE_NORMALIZER_FACTORY, |
| 29 | + Decimal128StringType, |
| 30 | + ObjectIdType, |
| 31 | + _in_type_map, |
| 32 | +) |
17 | 33 |
|
18 | 34 |
|
19 | 35 | class EventListener(monitoring.CommandListener):
|
@@ -54,3 +70,190 @@ def succeeded(self, event):
|
54 | 70 | def failed(self, event):
|
55 | 71 | if event.command_name in self.commands:
|
56 | 72 | super(AllowListEventListener, self).failed(event)
|
| 73 | + |
| 74 | + |
| 75 | +class TestNullsBase(unittest.TestCase): |
| 76 | + def find_fn(self, coll, query, schema): |
| 77 | + raise NotImplementedError |
| 78 | + |
| 79 | + def equal_fn(self, left, right): |
| 80 | + raise NotImplementedError |
| 81 | + |
| 82 | + def table_from_dict(self, dict, schema=None): |
| 83 | + raise NotImplementedError |
| 84 | + |
| 85 | + def assert_in_idx(self, table, col_name): |
| 86 | + raise NotImplementedError |
| 87 | + |
| 88 | + def na_safe(self, atype): |
| 89 | + raise NotImplementedError |
| 90 | + |
| 91 | + # Map Python types to constructors |
| 92 | + pytype_cons_map = { |
| 93 | + str: str, |
| 94 | + int: int, |
| 95 | + float: float, |
| 96 | + datetime.datetime: lambda x: datetime.datetime(x + 1970, 1, 1), |
| 97 | + ObjectId: lambda _: ObjectId(), |
| 98 | + Decimal128: lambda x: Decimal128(str(x)), |
| 99 | + } |
| 100 | + |
| 101 | + # Map Python types to types for table we are comparing. |
| 102 | + pytype_tab_map = { |
| 103 | + str: string(), |
| 104 | + int: int64(), |
| 105 | + float: float64(), |
| 106 | + datetime.datetime: timestamp("ms"), |
| 107 | + ObjectId: ObjectIdType(), |
| 108 | + Decimal128: Decimal128StringType(), |
| 109 | + bool: bool_(), |
| 110 | + } |
| 111 | + |
| 112 | + pytype_writeback_exc_map = { |
| 113 | + str: None, |
| 114 | + int: None, |
| 115 | + float: None, |
| 116 | + datetime.datetime: None, |
| 117 | + ObjectId: pyarrow.lib.ArrowInvalid, |
| 118 | + Decimal128: pyarrow.lib.ArrowInvalid, |
| 119 | + bool: None, |
| 120 | + } |
| 121 | + |
| 122 | + @classmethod |
| 123 | + def setUpClass(cls): |
| 124 | + if cls is TestNullsBase: |
| 125 | + raise unittest.SkipTest("Base class") |
| 126 | + |
| 127 | + if not client_context.connected: |
| 128 | + raise unittest.SkipTest("cannot connect to MongoDB") |
| 129 | + cls.cmd_listener = AllowListEventListener("find", "aggregate") |
| 130 | + cls.getmore_listener = AllowListEventListener("getMore") |
| 131 | + cls.client = client_context.get_client( |
| 132 | + event_listeners=[cls.getmore_listener, cls.cmd_listener] |
| 133 | + ) |
| 134 | + |
| 135 | + cls.oids = [ObjectId() for _ in range(4)] |
| 136 | + cls.coll = cls.client.pymongoarrow_test.get_collection( |
| 137 | + "test", write_concern=WriteConcern(w="majority") |
| 138 | + ) |
| 139 | + |
| 140 | + def setUp(self): |
| 141 | + self.coll.drop() |
| 142 | + |
| 143 | + self.cmd_listener.reset() |
| 144 | + self.getmore_listener.reset() |
| 145 | + |
| 146 | + def assertType(self, obj1, arrow_type): |
| 147 | + if isinstance(obj1, pyarrow.ChunkedArray): |
| 148 | + if "storage_type" in dir(arrow_type): |
| 149 | + self.assertEqual(obj1.type, arrow_type.storage_type) |
| 150 | + else: |
| 151 | + self.assertEqual(obj1.type, arrow_type) |
| 152 | + else: |
| 153 | + if isinstance(arrow_type, list): |
| 154 | + self.assertTrue(obj1.dtype.name in arrow_type) |
| 155 | + else: |
| 156 | + self.assertEqual(obj1.dtype.name, arrow_type) |
| 157 | + |
| 158 | + def test_int_handling(self): |
| 159 | + # Default integral types |
| 160 | + int_schema = Schema({"_id": ObjectIdType(), "int64": int64()}) |
| 161 | + int64_arr = [(i if (i % 2 == 0) else None) for i in range(len(self.oids))] |
| 162 | + self.coll.insert_many( |
| 163 | + [{"_id": self.oids[i], "int64": int64_arr[i]} for i in range(len(self.oids))] |
| 164 | + ) |
| 165 | + |
| 166 | + table = self.find_fn(self.coll, {}, schema=int_schema) |
| 167 | + |
| 168 | + # Resulting datatype should be float64 according to the spec for numpy |
| 169 | + # and pandas. |
| 170 | + atype = self.pytype_tab_map[int] |
| 171 | + self.assertType(table["int64"], atype) |
| 172 | + |
| 173 | + # Does it contain NAs where we expect? |
| 174 | + self.assertTrue(np.all(np.equal(isna(int64_arr), isna(table["int64"])))) |
| 175 | + |
| 176 | + # Write |
| 177 | + self.coll.drop() |
| 178 | + table_write = self.table_from_dict({"int64": int64_arr}) |
| 179 | + |
| 180 | + write(self.coll, table_write) |
| 181 | + res_table = self.find_fn(self.coll, {}, schema=int_schema) |
| 182 | + |
| 183 | + self.equal_fn(res_table["int64"], table_write["int64"]) |
| 184 | + self.assert_in_idx(res_table, "_id") |
| 185 | + self.assertType(res_table["int64"], atype) |
| 186 | + |
| 187 | + def test_all_types(self): |
| 188 | + for t in self.pytype_tab_map.keys(): |
| 189 | + self.assertTrue(_in_type_map(_TYPE_NORMALIZER_FACTORY[t](0))) |
| 190 | + |
| 191 | + def test_other_handling(self): |
| 192 | + # Tests other types, which are treated differently in |
| 193 | + # arrow/pandas/numpy. |
| 194 | + for gen in [str, float, datetime.datetime, ObjectId, Decimal128]: |
| 195 | + con_type = self.pytype_tab_map[gen] # Arrow/Pandas/Numpy |
| 196 | + pytype = TestNullsBase.pytype_tab_map[gen] # Arrow type specifically |
| 197 | + |
| 198 | + other_schema = Schema({"_id": ObjectIdType(), "other": pytype}) |
| 199 | + others = [ |
| 200 | + self.pytype_cons_map[gen](i) if (i % 2 == 0) else None |
| 201 | + for i in range(len(self.oids)) |
| 202 | + ] |
| 203 | + |
| 204 | + self.setUp() |
| 205 | + self.coll.insert_many( |
| 206 | + [{"_id": self.oids[i], "other": others[i]} for i in range(len(self.oids))] |
| 207 | + ) |
| 208 | + table = self.find_fn(self.coll, {}, schema=other_schema) |
| 209 | + |
| 210 | + # Resulting datatype should be str in this case |
| 211 | + |
| 212 | + self.assertType(table["other"], con_type) |
| 213 | + self.assertEqual( |
| 214 | + self.na_safe(con_type), np.all(np.equal(isna(others), isna(table["other"]))) |
| 215 | + ) |
| 216 | + |
| 217 | + def writeback(): |
| 218 | + # Write |
| 219 | + self.coll.drop() |
| 220 | + table_write_schema = Schema({"other": pytype}) |
| 221 | + table_write_schema_arrow = ( |
| 222 | + pyarrow.schema([("other", pytype)]) |
| 223 | + if (gen in [str, float, datetime.datetime]) |
| 224 | + else None |
| 225 | + ) |
| 226 | + |
| 227 | + table_write = self.table_from_dict( |
| 228 | + {"other": others}, schema=table_write_schema_arrow |
| 229 | + ) |
| 230 | + |
| 231 | + write(self.coll, table_write) |
| 232 | + res_table = self.find_fn(self.coll, {}, schema=table_write_schema) |
| 233 | + self.equal_fn(res_table, table_write) |
| 234 | + self.assertType(res_table["other"], con_type) |
| 235 | + |
| 236 | + # Do we expect an exception to be raised? |
| 237 | + if self.pytype_writeback_exc_map[gen] is not None: |
| 238 | + expected_exc = self.pytype_writeback_exc_map[gen] |
| 239 | + with self.assertRaises(expected_exc): |
| 240 | + writeback() |
| 241 | + else: |
| 242 | + writeback() |
| 243 | + |
| 244 | + def test_bool_handling(self): |
| 245 | + atype = self.pytype_tab_map[bool] |
| 246 | + bool_schema = Schema({"_id": ObjectIdType(), "bool_": bool_()}) |
| 247 | + bools = [True if (i % 2 == 0) else None for i in range(len(self.oids))] |
| 248 | + |
| 249 | + self.coll.insert_many( |
| 250 | + [{"_id": self.oids[i], "bool_": bools[i]} for i in range(len(self.oids))] |
| 251 | + ) |
| 252 | + |
| 253 | + table = self.find_fn(self.coll, {}, schema=bool_schema) |
| 254 | + |
| 255 | + # Resulting datatype should be object |
| 256 | + self.assertType(table["bool_"], atype) |
| 257 | + |
| 258 | + # Does it contain Nones where expected? |
| 259 | + self.assertTrue(np.all(np.equal(isna(bools), isna(table["bool_"])))) |
0 commit comments