|
18 | 18 | from test.utils import AllowListEventListener
|
19 | 19 |
|
20 | 20 | import pymongo
|
21 |
| -from pyarrow import Table, bool_, decimal256, float64, int32, int64 |
| 21 | +from bson import Decimal128, ObjectId |
| 22 | +from pyarrow import Table, binary, bool_, decimal256, float64, int32, int64 |
22 | 23 | from pyarrow import schema as ArrowSchema
|
23 | 24 | from pyarrow import string, timestamp
|
24 | 25 | from pyarrow.parquet import read_table, write_table
|
|
27 | 28 | from pymongoarrow.api import Schema, aggregate_arrow_all, find_arrow_all, write
|
28 | 29 | from pymongoarrow.errors import ArrowWriteError
|
29 | 30 | from pymongoarrow.monkey import patch_all
|
| 31 | +from pymongoarrow.types import Decimal128StringType, ObjectIdType |
30 | 32 |
|
31 | 33 |
|
32 | 34 | class TestArrowApiMixin:
|
@@ -292,3 +294,42 @@ def run_find(self, *args, **kwargs):
|
292 | 294 |
|
293 | 295 | def run_aggregate(self, *args, **kwargs):
|
294 | 296 | return self.coll.aggregate_arrow_all(*args, **kwargs)
|
| 297 | + |
| 298 | + |
| 299 | +class TestBSONTypes(unittest.TestCase): |
| 300 | + @classmethod |
| 301 | + def setUpClass(cls): |
| 302 | + if not client_context.connected: |
| 303 | + raise unittest.SkipTest("cannot connect to MongoDB") |
| 304 | + cls.cmd_listener = AllowListEventListener("find", "aggregate") |
| 305 | + cls.getmore_listener = AllowListEventListener("getMore") |
| 306 | + cls.client = client_context.get_client( |
| 307 | + event_listeners=[cls.getmore_listener, cls.cmd_listener] |
| 308 | + ) |
| 309 | + |
| 310 | + def test_find_decimal128(self): |
| 311 | + oids = list(ObjectId() for i in range(4)) |
| 312 | + decs = [Decimal128(i) for i in ["0.1", "1.0", "1e-5"]] |
| 313 | + schema = Schema({"_id": ObjectIdType(), "data": Decimal128StringType()}) |
| 314 | + expected = Table.from_pydict( |
| 315 | + { |
| 316 | + "_id": [i.binary for i in oids], |
| 317 | + "data": [str(decs[0]), str(decs[1]), str(decs[2]), None], |
| 318 | + }, |
| 319 | + ArrowSchema([("_id", binary(12)), ("data", string())]), |
| 320 | + ) |
| 321 | + coll = self.client.pymongoarrow_test.get_collection( |
| 322 | + "test", write_concern=WriteConcern(w="majority") |
| 323 | + ) |
| 324 | + |
| 325 | + coll.drop() |
| 326 | + coll.insert_many( |
| 327 | + [ |
| 328 | + {"_id": oids[0], "data": decs[0]}, |
| 329 | + {"_id": oids[1], "data": decs[1]}, |
| 330 | + {"_id": oids[2], "data": decs[2]}, |
| 331 | + {"_id": oids[3]}, |
| 332 | + ] |
| 333 | + ) |
| 334 | + table = find_arrow_all(coll, {}, schema=schema) |
| 335 | + self.assertEqual(table, expected) |
0 commit comments