|
17 | 17 | from test.utils import AllowListEventListener
|
18 | 18 |
|
19 | 19 | import pymongo
|
20 |
| -from pyarrow import Table, int32, int64 |
| 20 | +from pyarrow import Table, bool_, decimal128, float64, int32, int64 |
21 | 21 | from pyarrow import schema as ArrowSchema
|
| 22 | +from pyarrow import string, timestamp |
22 | 23 | from pymongo import DESCENDING, WriteConcern
|
23 |
| -from pymongoarrow.api import Schema, aggregate_arrow_all, find_arrow_all |
| 24 | +from pymongo.collection import Collection |
| 25 | +from pymongoarrow.api import Schema, aggregate_arrow_all, find_arrow_all, write |
| 26 | +from pymongoarrow.errors import ArrowWriteError |
24 | 27 | from pymongoarrow.monkey import patch_all
|
25 | 28 |
|
26 | 29 |
|
@@ -180,6 +183,70 @@ def test_aggregate_omits_id_if_not_in_schema(self):
|
180 | 183 | self.assertFalse(stage[op_name]["_id"])
|
181 | 184 | break
|
182 | 185 |
|
| 186 | + def round_trip(self, data, schema, coll=None): |
| 187 | + if coll is None: |
| 188 | + coll = self.coll |
| 189 | + self.coll.drop() |
| 190 | + res = write(self.coll, data) |
| 191 | + self.assertEqual(len(data), res.raw_result["insertedCount"]) |
| 192 | + self.assertEqual(data, find_arrow_all(coll, {}, schema=schema)) |
| 193 | + return res |
| 194 | + |
| 195 | + def test_write_error(self): |
| 196 | + schema = {"_id": int32(), "data": int64()} |
| 197 | + data = Table.from_pydict( |
| 198 | + {"_id": [i for i in range(10001)] * 2, "data": [i * 2 for i in range(10001)] * 2}, |
| 199 | + ArrowSchema(schema), |
| 200 | + ) |
| 201 | + with self.assertRaises(ArrowWriteError): |
| 202 | + try: |
| 203 | + self.round_trip(data, Schema(schema)) |
| 204 | + except ArrowWriteError as awe: |
| 205 | + self.assertEqual( |
| 206 | + 10001, awe.details["writeErrors"][0]["index"], awe.details["nInserted"] |
| 207 | + ) |
| 208 | + raise awe |
| 209 | + |
| 210 | + def test_write_schema_validation(self): |
| 211 | + schema = { |
| 212 | + "data": int64(), |
| 213 | + "float": float64(), |
| 214 | + "datetime": timestamp("ms"), |
| 215 | + "string": string(), |
| 216 | + "bool": bool_(), |
| 217 | + } |
| 218 | + data = Table.from_pydict( |
| 219 | + { |
| 220 | + "data": [i for i in range(2)], |
| 221 | + "float": [i for i in range(2)], |
| 222 | + "datetime": [i for i in range(2)], |
| 223 | + "string": [str(i) for i in range(2)], |
| 224 | + "bool": [True for _ in range(2)], |
| 225 | + }, |
| 226 | + ArrowSchema(schema), |
| 227 | + ) |
| 228 | + self.round_trip(data, Schema(schema)) |
| 229 | + |
| 230 | + schema = {"_id": int32(), "data": decimal128(2)} |
| 231 | + data = Table.from_pydict( |
| 232 | + {"_id": [i for i in range(2)], "data": [i for i in range(2)]}, |
| 233 | + ArrowSchema(schema), |
| 234 | + ) |
| 235 | + with self.assertRaises(ValueError): |
| 236 | + self.round_trip(data, Schema(schema)) |
| 237 | + |
| 238 | + @mock.patch.object(Collection, "insert_many", side_effect=Collection.insert_many, autospec=True) |
| 239 | + def test_write_batching(self, mock): |
| 240 | + schema = { |
| 241 | + "_id": int64(), |
| 242 | + } |
| 243 | + data = Table.from_pydict( |
| 244 | + {"_id": [i for i in range(100040)]}, |
| 245 | + ArrowSchema(schema), |
| 246 | + ) |
| 247 | + self.round_trip(data, Schema(schema), coll=self.coll) |
| 248 | + self.assertEqual(mock.call_count, 2) |
| 249 | + |
183 | 250 |
|
184 | 251 | class TestArrowExplicitApi(TestArrowApiMixin, unittest.TestCase):
|
185 | 252 | def run_find(self, *args, **kwargs):
|
|
0 commit comments