Skip to content

Commit e243c67

Browse files
authored
ARROW-71 Add basic support for writing to MongoDB from PyArrow (#64)
1 parent ca97bb9 commit e243c67

File tree

5 files changed

+210
-2
lines changed

5 files changed

+210
-2
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,16 @@
1313
# limitations under the License.
1414
import warnings
1515

16+
from bson import encode
17+
from bson.raw_bson import RawBSONDocument
18+
from pymongo.bulk import BulkWriteError
19+
from pymongo.common import MAX_WRITE_BATCH_SIZE
1620
from pymongoarrow.context import PyMongoArrowContext
21+
from pymongoarrow.errors import ArrowWriteError
1722
from pymongoarrow.lib import process_bson_stream
23+
from pymongoarrow.result import ArrowWriteResult
1824
from pymongoarrow.schema import Schema
25+
from pymongoarrow.types import _validate_schema
1926

2027
__all__ = [
2128
"aggregate_arrow_all",
@@ -37,6 +44,12 @@
3744
"find_numpy_all",
3845
]
3946

47+
# MongoDB 3.6's maxMessageSizeBytes minus some overhead to account
48+
# for the command plus OP_MSG.
49+
_MAX_MESSAGE_SIZE = 48000000 - 16 * 1024
50+
# The maximum number of bulk write operations in one batch.
51+
_MAX_WRITE_BATCH_SIZE = max(100000, MAX_WRITE_BATCH_SIZE)
52+
4053

4154
def find_arrow_all(collection, query, *, schema, **kwargs):
4255
"""Method that returns the results of a find query as a
@@ -233,3 +246,60 @@ def aggregate_numpy_all(collection, pipeline, *, schema, **kwargs):
233246
return _arrow_to_numpy(
234247
aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs), schema
235248
)
249+
250+
251+
def _transform_bwe(bwe, offset):
252+
bwe["nInserted"] += offset
253+
for i in bwe["writeErrors"]:
254+
i["index"] += offset
255+
return bwe
256+
257+
258+
def _tabular_generator(tabular):
259+
for i in tabular.to_batches():
260+
for row in i.to_pylist():
261+
yield row
262+
263+
264+
def write(collection, tabular):
265+
"""Write data from `tabular` into the given MongoDB `collection`.
266+
267+
:Parameters:
268+
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
269+
against which to run the operation.
270+
- `tabular`: A tabular data store to use for the write operation.
271+
272+
:Returns:
273+
An instance of :class:`result.ArrowWriteResult`.
274+
"""
275+
276+
_validate_schema(tabular.schema)
277+
cur_offset = 0
278+
results = {
279+
"insertedCount": 0,
280+
}
281+
tabular_gen = _tabular_generator(tabular)
282+
while cur_offset < len(tabular):
283+
cur_size = 0
284+
cur_batch = []
285+
i = 0
286+
while (
287+
cur_size <= _MAX_MESSAGE_SIZE
288+
and len(cur_batch) <= _MAX_WRITE_BATCH_SIZE
289+
and cur_offset + i < len(tabular)
290+
):
291+
enc_tab = RawBSONDocument(
292+
encode(next(tabular_gen), codec_options=collection.codec_options)
293+
)
294+
cur_batch.append(enc_tab)
295+
cur_size += len(enc_tab)
296+
i += 1
297+
try:
298+
collection.insert_many(cur_batch)
299+
except BulkWriteError as bwe:
300+
raise ArrowWriteError(_transform_bwe(dict(bwe.details), cur_offset)) from bwe
301+
302+
results["insertedCount"] += i
303+
cur_offset += i
304+
305+
return ArrowWriteResult(results)

bindings/python/pymongoarrow/errors.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,30 @@ class PyMongoArrowError(Exception):
2121
"""Base class for all PyMongoArrow exceptions."""
2222

2323
pass
24+
25+
26+
class ArrowWriteError(PyMongoArrowError):
27+
"""Error raised when we encounter an exception writing into MongoDB"""
28+
29+
def __init__(self, details):
30+
self._details = details
31+
32+
@property
33+
def details(self):
34+
"""Details for the error.
35+
36+
It is a dictionary of key-value pairs giving diagnostic information about what went wrong. To see the entire dictionary simply use `print(awe.details)`.
37+
38+
Details will have the following format:
39+
{
40+
'writeErrors': [...],
41+
'writeConcernErrors': [...],
42+
'nInserted': ...,
43+
'nUpserted': ...,
44+
'nMatched': ...,
45+
'nModified': ...,
46+
'nRemoved': ...,
47+
'upserted': [...]
48+
}
49+
"""
50+
return self._details
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2022-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Results returned by PyMongoArrow."""
16+
17+
18+
class ArrowWriteResult:
19+
def __init__(self, result_dict):
20+
self._result = result_dict
21+
22+
def __repr__(self):
23+
return repr(self._result)
24+
25+
@property
26+
def inserted_count(self):
27+
return self._result.get("insertedCount", 0)
28+
29+
@property
30+
def raw_result(self):
31+
return self._result

bindings/python/pymongoarrow/types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,16 @@ def _get_internal_typemap(typemap):
102102
)
103103

104104
return internal_typemap
105+
106+
107+
def _in_type_map(t):
108+
for checker in _TYPE_CHECKER_TO_INTERNAL_TYPE.keys():
109+
if checker(t):
110+
return True
111+
return False
112+
113+
114+
def _validate_schema(schema):
115+
for i in schema.types:
116+
if not _in_type_map(i):
117+
raise ValueError(f'Unsupported data type "{i}" in schema')

bindings/python/test/test_arrow.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717
from test.utils import AllowListEventListener
1818

1919
import pymongo
20-
from pyarrow import Table, int32, int64
20+
from pyarrow import Table, bool_, decimal128, float64, int32, int64
2121
from pyarrow import schema as ArrowSchema
22+
from pyarrow import string, timestamp
2223
from pymongo import DESCENDING, WriteConcern
23-
from pymongoarrow.api import Schema, aggregate_arrow_all, find_arrow_all
24+
from pymongo.collection import Collection
25+
from pymongoarrow.api import Schema, aggregate_arrow_all, find_arrow_all, write
26+
from pymongoarrow.errors import ArrowWriteError
2427
from pymongoarrow.monkey import patch_all
2528

2629

@@ -180,6 +183,70 @@ def test_aggregate_omits_id_if_not_in_schema(self):
180183
self.assertFalse(stage[op_name]["_id"])
181184
break
182185

186+
def round_trip(self, data, schema, coll=None):
187+
if coll is None:
188+
coll = self.coll
189+
self.coll.drop()
190+
res = write(self.coll, data)
191+
self.assertEqual(len(data), res.raw_result["insertedCount"])
192+
self.assertEqual(data, find_arrow_all(coll, {}, schema=schema))
193+
return res
194+
195+
def test_write_error(self):
196+
schema = {"_id": int32(), "data": int64()}
197+
data = Table.from_pydict(
198+
{"_id": [i for i in range(10001)] * 2, "data": [i * 2 for i in range(10001)] * 2},
199+
ArrowSchema(schema),
200+
)
201+
with self.assertRaises(ArrowWriteError):
202+
try:
203+
self.round_trip(data, Schema(schema))
204+
except ArrowWriteError as awe:
205+
self.assertEqual(
206+
10001, awe.details["writeErrors"][0]["index"], awe.details["nInserted"]
207+
)
208+
raise awe
209+
210+
def test_write_schema_validation(self):
211+
schema = {
212+
"data": int64(),
213+
"float": float64(),
214+
"datetime": timestamp("ms"),
215+
"string": string(),
216+
"bool": bool_(),
217+
}
218+
data = Table.from_pydict(
219+
{
220+
"data": [i for i in range(2)],
221+
"float": [i for i in range(2)],
222+
"datetime": [i for i in range(2)],
223+
"string": [str(i) for i in range(2)],
224+
"bool": [True for _ in range(2)],
225+
},
226+
ArrowSchema(schema),
227+
)
228+
self.round_trip(data, Schema(schema))
229+
230+
schema = {"_id": int32(), "data": decimal128(2)}
231+
data = Table.from_pydict(
232+
{"_id": [i for i in range(2)], "data": [i for i in range(2)]},
233+
ArrowSchema(schema),
234+
)
235+
with self.assertRaises(ValueError):
236+
self.round_trip(data, Schema(schema))
237+
238+
@mock.patch.object(Collection, "insert_many", side_effect=Collection.insert_many, autospec=True)
239+
def test_write_batching(self, mock):
240+
schema = {
241+
"_id": int64(),
242+
}
243+
data = Table.from_pydict(
244+
{"_id": [i for i in range(100040)]},
245+
ArrowSchema(schema),
246+
)
247+
self.round_trip(data, Schema(schema), coll=self.coll)
248+
self.assertEqual(mock.call_count, 2)
249+
183250

184251
class TestArrowExplicitApi(TestArrowApiMixin, unittest.TestCase):
185252
def run_find(self, *args, **kwargs):

0 commit comments

Comments
 (0)