|
13 | 13 | # limitations under the License.
|
14 | 14 | # from datetime import datetime, timedelta
|
15 | 15 | import unittest
|
| 16 | +import unittest.mock as mock |
16 | 17 | from test import client_context
|
17 | 18 | from test.utils import AllowListEventListener
|
18 | 19 |
|
| 20 | +import numpy as np |
19 | 21 | import pandas as pd
|
20 |
| -from pyarrow import int32, int64 |
| 22 | +from pyarrow import bool_, decimal128, float64, int32, int64, string, timestamp |
21 | 23 | from pymongo import DESCENDING, WriteConcern
|
22 |
| -from pymongoarrow.api import Schema, aggregate_pandas_all, find_pandas_all |
| 24 | +from pymongo.collection import Collection |
| 25 | +from pymongoarrow.api import Schema, aggregate_pandas_all, find_pandas_all, write |
| 26 | +from pymongoarrow.errors import ArrowWriteError |
23 | 27 |
|
24 | 28 |
|
25 | 29 | class TestExplicitPandasApi(unittest.TestCase):
|
@@ -76,3 +80,81 @@ def test_aggregate_simple(self):
|
76 | 80 | assert len(agg_cmd.command["pipeline"]) == 2
|
77 | 81 | self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection)
|
78 | 82 | self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True})
|
| 83 | + |
| 84 | + def round_trip(self, data, schema, coll=None): |
| 85 | + if coll is None: |
| 86 | + coll = self.coll |
| 87 | + coll.drop() |
| 88 | + res = write(self.coll, data) |
| 89 | + self.assertEqual(len(data), res.raw_result["insertedCount"]) |
| 90 | + pd.testing.assert_frame_equal(data, find_pandas_all(coll, {}, schema=schema)) |
| 91 | + return res |
| 92 | + |
| 93 | + def test_write_error(self): |
| 94 | + schema = {"_id": "int32", "data": "int64"} |
| 95 | + |
| 96 | + data = pd.DataFrame( |
| 97 | + data={"_id": [i for i in range(10001)] * 2, "data": [i * 2 for i in range(10001)] * 2} |
| 98 | + ).astype(schema) |
| 99 | + with self.assertRaises(ArrowWriteError): |
| 100 | + try: |
| 101 | + self.round_trip(data, Schema({"_id": int32(), "data": int64()})) |
| 102 | + except ArrowWriteError as awe: |
| 103 | + self.assertEqual( |
| 104 | + 10001, awe.details["writeErrors"][0]["index"], awe.details["nInserted"] |
| 105 | + ) |
| 106 | + raise awe |
| 107 | + |
| 108 | + def test_write_schema_validation(self): |
| 109 | + schema = { |
| 110 | + "data": "int64", |
| 111 | + "float": "float64", |
| 112 | + "datetime": "datetime64[ms]", |
| 113 | + "string": "object", |
| 114 | + "bool": "bool", |
| 115 | + } |
| 116 | + data = pd.DataFrame( |
| 117 | + data={ |
| 118 | + "data": [i for i in range(2)], |
| 119 | + "float": [i for i in range(2)], |
| 120 | + "datetime": [i for i in range(2)], |
| 121 | + "string": [str(i) for i in range(2)], |
| 122 | + "bool": [True for _ in range(2)], |
| 123 | + } |
| 124 | + ).astype(schema) |
| 125 | + self.round_trip( |
| 126 | + data, |
| 127 | + Schema( |
| 128 | + { |
| 129 | + "data": int64(), |
| 130 | + "float": float64(), |
| 131 | + "datetime": timestamp("ms"), |
| 132 | + "string": string(), |
| 133 | + "bool": bool_(), |
| 134 | + } |
| 135 | + ), |
| 136 | + ) |
| 137 | + |
| 138 | + schema = {"_id": "int32", "data": np.ubyte()} |
| 139 | + data = pd.DataFrame( |
| 140 | + data={"_id": [i for i in range(2)], "data": [i for i in range(2)]} |
| 141 | + ).astype(schema) |
| 142 | + with self.assertRaises(ValueError): |
| 143 | + self.round_trip(data, Schema({"_id": int32(), "data": decimal128(2)})) |
| 144 | + |
| 145 | + @mock.patch.object(Collection, "insert_many", side_effect=Collection.insert_many, autospec=True) |
| 146 | + def test_write_batching(self, mock): |
| 147 | + schema = {"_id": "int64"} |
| 148 | + data = pd.DataFrame( |
| 149 | + data={"_id": [i for i in range(100040)]}, |
| 150 | + ).astype(schema) |
| 151 | + self.round_trip( |
| 152 | + data, |
| 153 | + Schema( |
| 154 | + { |
| 155 | + "_id": int64(), |
| 156 | + } |
| 157 | + ), |
| 158 | + coll=self.coll, |
| 159 | + ) |
| 160 | + self.assertEqual(mock.call_count, 2) |
0 commit comments