Skip to content

Commit ddc172f

Browse files
authored
ARROW-73 Add support and testing for Pandas (#69)
1 parent 0213b2e commit ddc172f

File tree

4 files changed

+108
-7
lines changed

4 files changed

+108
-7
lines changed

bindings/python/benchmark.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
dtypes = {}
3232
schemas = {}
3333
raw_bsons = {}
34+
3435
arrow_tables = {}
36+
pandas_tables = {}
3537

3638

3739
def _setup():
@@ -91,6 +93,8 @@ def _setup():
9193
raw_bsons[LARGE] = raw_bson_large
9294
arrow_tables[SMALL] = find_arrow_all(db[collection_names[SMALL]], {}, schema=schemas[SMALL])
9395
arrow_tables[LARGE] = find_arrow_all(db[collection_names[LARGE]], {}, schema=schemas[LARGE])
96+
pandas_tables[SMALL] = find_pandas_all(db[collection_names[SMALL]], {}, schema=schemas[SMALL])
97+
pandas_tables[LARGE] = find_pandas_all(db[collection_names[LARGE]], {}, schema=schemas[LARGE])
9498

9599

96100
def _teardown():
@@ -163,6 +167,11 @@ def insert_conventional(use_large):
163167
db[collection_names[use_large]].insert_many(tab)
164168

165169

170+
@bench("insert_pandas")
171+
def insert_pandas(use_large):
172+
write(db[collection_names[use_large]], pandas_tables[use_large])
173+
174+
166175
parser = argparse.ArgumentParser(
167176
formatter_class=argparse.RawTextHelpFormatter,
168177
epilog="""

bindings/python/pymongoarrow/api.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
from bson import encode
1717
from bson.raw_bson import RawBSONDocument
18+
from pandas import DataFrame
19+
from pyarrow import Schema as ArrowSchema
20+
from pyarrow import Table
1821
from pymongo.bulk import BulkWriteError
1922
from pymongo.common import MAX_WRITE_BATCH_SIZE
2023
from pymongoarrow.context import PyMongoArrowContext
@@ -256,9 +259,13 @@ def _transform_bwe(bwe, offset):
256259

257260

258261
def _tabular_generator(tabular):
259-
for i in tabular.to_batches():
260-
for row in i.to_pylist():
261-
yield row
262+
if isinstance(tabular, Table):
263+
for i in tabular.to_batches():
264+
for row in i.to_pylist():
265+
yield row
266+
elif isinstance(tabular, DataFrame):
267+
for i in tabular.to_dict("records"):
268+
yield i
262269

263270

264271
def write(collection, tabular):
@@ -273,7 +280,10 @@ def write(collection, tabular):
273280
An instance of :class:`result.ArrowWriteResult`.
274281
"""
275282

276-
_validate_schema(tabular.schema)
283+
if isinstance(tabular, Table):
284+
_validate_schema(tabular.schema.types)
285+
elif isinstance(tabular, DataFrame):
286+
_validate_schema(ArrowSchema.from_pandas(tabular).types)
277287
cur_offset = 0
278288
results = {
279289
"insertedCount": 0,

bindings/python/pymongoarrow/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,6 @@ def _in_type_map(t):
112112

113113

114114
def _validate_schema(schema):
115-
for i in schema.types:
115+
for i in schema:
116116
if not _in_type_map(i):
117117
raise ValueError(f'Unsupported data type "{i}" in schema')

bindings/python/test/test_pandas.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313
# limitations under the License.
1414
# from datetime import datetime, timedelta
1515
import unittest
16+
import unittest.mock as mock
1617
from test import client_context
1718
from test.utils import AllowListEventListener
1819

20+
import numpy as np
1921
import pandas as pd
20-
from pyarrow import int32, int64
22+
from pyarrow import bool_, decimal128, float64, int32, int64, string, timestamp
2123
from pymongo import DESCENDING, WriteConcern
22-
from pymongoarrow.api import Schema, aggregate_pandas_all, find_pandas_all
24+
from pymongo.collection import Collection
25+
from pymongoarrow.api import Schema, aggregate_pandas_all, find_pandas_all, write
26+
from pymongoarrow.errors import ArrowWriteError
2327

2428

2529
class TestExplicitPandasApi(unittest.TestCase):
@@ -76,3 +80,81 @@ def test_aggregate_simple(self):
7680
assert len(agg_cmd.command["pipeline"]) == 2
7781
self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection)
7882
self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True})
83+
84+
def round_trip(self, data, schema, coll=None):
85+
if coll is None:
86+
coll = self.coll
87+
coll.drop()
88+
res = write(self.coll, data)
89+
self.assertEqual(len(data), res.raw_result["insertedCount"])
90+
pd.testing.assert_frame_equal(data, find_pandas_all(coll, {}, schema=schema))
91+
return res
92+
93+
def test_write_error(self):
94+
schema = {"_id": "int32", "data": "int64"}
95+
96+
data = pd.DataFrame(
97+
data={"_id": [i for i in range(10001)] * 2, "data": [i * 2 for i in range(10001)] * 2}
98+
).astype(schema)
99+
with self.assertRaises(ArrowWriteError):
100+
try:
101+
self.round_trip(data, Schema({"_id": int32(), "data": int64()}))
102+
except ArrowWriteError as awe:
103+
self.assertEqual(
104+
10001, awe.details["writeErrors"][0]["index"], awe.details["nInserted"]
105+
)
106+
raise awe
107+
108+
def test_write_schema_validation(self):
109+
schema = {
110+
"data": "int64",
111+
"float": "float64",
112+
"datetime": "datetime64[ms]",
113+
"string": "object",
114+
"bool": "bool",
115+
}
116+
data = pd.DataFrame(
117+
data={
118+
"data": [i for i in range(2)],
119+
"float": [i for i in range(2)],
120+
"datetime": [i for i in range(2)],
121+
"string": [str(i) for i in range(2)],
122+
"bool": [True for _ in range(2)],
123+
}
124+
).astype(schema)
125+
self.round_trip(
126+
data,
127+
Schema(
128+
{
129+
"data": int64(),
130+
"float": float64(),
131+
"datetime": timestamp("ms"),
132+
"string": string(),
133+
"bool": bool_(),
134+
}
135+
),
136+
)
137+
138+
schema = {"_id": "int32", "data": np.ubyte()}
139+
data = pd.DataFrame(
140+
data={"_id": [i for i in range(2)], "data": [i for i in range(2)]}
141+
).astype(schema)
142+
with self.assertRaises(ValueError):
143+
self.round_trip(data, Schema({"_id": int32(), "data": decimal128(2)}))
144+
145+
@mock.patch.object(Collection, "insert_many", side_effect=Collection.insert_many, autospec=True)
146+
def test_write_batching(self, mock):
147+
schema = {"_id": "int64"}
148+
data = pd.DataFrame(
149+
data={"_id": [i for i in range(100040)]},
150+
).astype(schema)
151+
self.round_trip(
152+
data,
153+
Schema(
154+
{
155+
"_id": int64(),
156+
}
157+
),
158+
coll=self.coll,
159+
)
160+
self.assertEqual(mock.call_count, 2)

0 commit comments

Comments
 (0)