Skip to content

Commit 3d39c2c

Browse files
authored
ARROW-74 Add support and testing for NumPy (#71)
1 parent 5bf829a commit 3d39c2c

File tree

6 files changed

+168
-19
lines changed

6 files changed

+168
-19
lines changed

bindings/python/benchmark.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
arrow_tables = {}
3636
pandas_tables = {}
37+
numpy_arrays = {}
3738

3839

3940
def _setup():
@@ -95,6 +96,8 @@ def _setup():
9596
arrow_tables[LARGE] = find_arrow_all(db[collection_names[LARGE]], {}, schema=schemas[LARGE])
9697
pandas_tables[SMALL] = find_pandas_all(db[collection_names[SMALL]], {}, schema=schemas[SMALL])
9798
pandas_tables[LARGE] = find_pandas_all(db[collection_names[LARGE]], {}, schema=schemas[LARGE])
99+
numpy_arrays[SMALL] = find_numpy_all(db[collection_names[SMALL]], {}, schema=schemas[SMALL])
100+
numpy_arrays[LARGE] = find_numpy_all(db[collection_names[LARGE]], {}, schema=schemas[LARGE])
98101

99102

100103
def _teardown():
@@ -172,6 +175,11 @@ def insert_pandas(use_large):
172175
write(db[collection_names[use_large]], pandas_tables[use_large])
173176

174177

178+
@bench("insert_numpy")
179+
def insert_numpy(use_large):
180+
write(db[collection_names[use_large]], numpy_arrays[use_large])
181+
182+
175183
parser = argparse.ArgumentParser(
176184
formatter_class=argparse.RawTextHelpFormatter,
177185
epilog="""

bindings/python/pymongoarrow/api.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414
import warnings
1515

16+
import numpy as np
1617
from bson import encode
1718
from bson.raw_bson import RawBSONDocument
19+
from numpy import ndarray
1820
from pandas import DataFrame
1921
from pyarrow import Schema as ArrowSchema
2022
from pyarrow import Table
@@ -25,7 +27,7 @@
2527
from pymongoarrow.lib import process_bson_stream
2628
from pymongoarrow.result import ArrowWriteResult
2729
from pymongoarrow.schema import Schema
28-
from pymongoarrow.types import _validate_schema
30+
from pymongoarrow.types import _validate_schema, get_numpy_type
2931

3032
__all__ = [
3133
"aggregate_arrow_all",
@@ -187,7 +189,11 @@ def _arrow_to_numpy(arrow_table, schema):
187189
"""
188190
container = {}
189191
for fname in schema:
190-
container[fname] = arrow_table[fname].to_numpy()
192+
dtype = get_numpy_type(schema.typemap[fname])
193+
if dtype == np.str_:
194+
container[fname] = arrow_table[fname].to_pandas().to_numpy(dtype=dtype)
195+
else:
196+
container[fname] = arrow_table[fname].to_numpy()
191197
return container
192198

193199

@@ -264,8 +270,15 @@ def _tabular_generator(tabular):
264270
for row in i.to_pylist():
265271
yield row
266272
elif isinstance(tabular, DataFrame):
267-
for i in tabular.to_dict("records"):
268-
yield i
273+
for row in tabular.to_dict("records"):
274+
yield row
275+
elif isinstance(tabular, dict):
276+
iter_dict = {k: np.nditer(v) for k, v in tabular.items()}
277+
try:
278+
while True:
279+
yield {k: next(i).item() for k, i in iter_dict.items()}
280+
except StopIteration:
281+
return
269282

270283

271284
def write(collection, tabular):
@@ -279,17 +292,30 @@ def write(collection, tabular):
279292
:Returns:
280293
An instance of :class:`result.ArrowWriteResult`.
281294
"""
282-
283-
if isinstance(tabular, Table):
284-
_validate_schema(tabular.schema.types)
285-
elif isinstance(tabular, DataFrame):
286-
_validate_schema(ArrowSchema.from_pandas(tabular).types)
287295
cur_offset = 0
288296
results = {
289297
"insertedCount": 0,
290298
}
291-
tabular_gen = _tabular_generator(tabular)
292299
tab_size = len(tabular)
300+
if isinstance(tabular, Table):
301+
_validate_schema(tabular.schema.types)
302+
elif isinstance(tabular, DataFrame):
303+
_validate_schema(ArrowSchema.from_pandas(tabular).types)
304+
elif (
305+
isinstance(tabular, dict)
306+
and len(tabular.values()) >= 1
307+
and all([isinstance(i, ndarray) for i in tabular.values()])
308+
):
309+
_validate_schema([i.dtype for i in tabular.values()])
310+
tab_size = len(next(iter(tabular.values())))
311+
else:
312+
raise ValueError(
313+
f"Invalid tabular data object of type {type(tabular)} \n"
314+
"Please ensure that it is one of the supported types: "
315+
"DataFrame, Table, or a dictionary containing NumPy arrays."
316+
)
317+
318+
tabular_gen = _tabular_generator(tabular)
293319
while cur_offset < tab_size:
294320
cur_size = 0
295321
cur_batch = []

bindings/python/pymongoarrow/types.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import enum
1515
from datetime import datetime
1616

17+
import numpy as np
18+
import pyarrow as pa
1719
import pyarrow.types as _atypes
1820
from bson import Int64, ObjectId
1921
from pyarrow import DataType as _ArrowDataType
@@ -64,6 +66,24 @@ def _is_objectid(obj):
6466
}
6567

6668

69+
_TYPE_CHECKER_TO_NUMPY = {
70+
_atypes.is_int32: np.int32,
71+
_atypes.is_int64: np.int64,
72+
_atypes.is_float64: np.float64,
73+
_atypes.is_timestamp: "datetime64[ms]",
74+
_is_objectid: np.object,
75+
_atypes.is_string: np.str_,
76+
_atypes.is_boolean: np.bool_,
77+
}
78+
79+
80+
def get_numpy_type(type):
81+
for checker, comp_type in _TYPE_CHECKER_TO_NUMPY.items():
82+
if checker(type):
83+
return comp_type
84+
return None
85+
86+
6787
_TYPE_CHECKER_TO_INTERNAL_TYPE = {
6888
_atypes.is_int32: _BsonArrowTypes.int32,
6989
_atypes.is_int64: _BsonArrowTypes.int64,
@@ -105,6 +125,11 @@ def _get_internal_typemap(typemap):
105125

106126

107127
def _in_type_map(t):
128+
if isinstance(t, np.dtype):
129+
try:
130+
t = pa.from_numpy_dtype(t)
131+
except pa.lib.ArrowNotImplementedError:
132+
return False
108133
for checker in _TYPE_CHECKER_TO_INTERNAL_TYPE.keys():
109134
if checker(t):
110135
return True

bindings/python/test/test_arrow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from test.utils import AllowListEventListener
1919

2020
import pymongo
21-
from pyarrow import Table, bool_, decimal128, float64, int32, int64
21+
from pyarrow import Table, bool_, decimal256, float64, int32, int64
2222
from pyarrow import schema as ArrowSchema
2323
from pyarrow import string, timestamp
2424
from pyarrow.parquet import read_table, write_table
@@ -229,7 +229,7 @@ def test_write_schema_validation(self):
229229
)
230230
self.round_trip(data, Schema(schema))
231231

232-
schema = {"_id": int32(), "data": decimal128(2)}
232+
schema = {"_id": int32(), "data": decimal256(2)}
233233
data = Table.from_pydict(
234234
{"_id": [i for i in range(2)], "data": [i for i in range(2)]},
235235
ArrowSchema(schema),

bindings/python/test/test_numpy.py

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
import unittest
1616
from test import client_context
1717
from test.utils import AllowListEventListener
18+
from unittest import mock
1819

1920
import numpy as np
20-
from pyarrow import int32, int64
21+
from pyarrow import bool_, float64, int32, int64, string, timestamp
2122
from pymongo import DESCENDING, WriteConcern
22-
from pymongoarrow.api import Schema, aggregate_numpy_all, find_numpy_all
23+
from pymongo.collection import Collection
24+
from pymongoarrow.api import Schema, aggregate_numpy_all, find_numpy_all, write
25+
from pymongoarrow.errors import ArrowWriteError
2326

2427

2528
class TestExplicitNumPyApi(unittest.TestCase):
@@ -47,11 +50,11 @@ def setUp(self):
4750

4851
def assert_numpy_equal(self, actual, expected):
4952
self.assertIsInstance(actual, dict)
50-
for field in self.schema:
53+
for field in expected:
5154
# workaround np.nan == np.nan evaluating to False
5255
a = np.nan_to_num(actual[field])
5356
e = np.nan_to_num(expected[field])
54-
self.assertTrue(np.all(a == e))
57+
np.testing.assert_array_equal(a, e)
5558
self.assertEqual(actual[field].dtype, expected[field].dtype)
5659

5760
def test_find_simple(self):
@@ -80,7 +83,7 @@ def test_find_simple(self):
8083
def test_aggregate_simple(self):
8184
expected = {
8285
"_id": np.array([1, 2, 3, 4], dtype=np.int32),
83-
"data": np.array([20, 40, 60, np.nan], dtype=np.float64),
86+
"data": np.array([20, 40, 60, None], dtype=np.float64),
8487
}
8588
projection = {"_id": True, "data": {"$multiply": [2, "$data"]}}
8689
actual = aggregate_numpy_all(self.coll, [{"$project": projection}], schema=self.schema)
@@ -91,3 +94,90 @@ def test_aggregate_simple(self):
9194
assert len(agg_cmd.command["pipeline"]) == 2
9295
self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection)
9396
self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True})
97+
98+
def round_trip(self, data, schema, coll=None):
99+
if coll is None:
100+
coll = self.coll
101+
coll.drop()
102+
res = write(self.coll, data)
103+
self.assertEqual(len(list(data.values())[0]), res.raw_result["insertedCount"])
104+
self.assert_numpy_equal(find_numpy_all(coll, {}, schema=schema), data)
105+
return res
106+
107+
def schemafied_ndarray_dict(self, dict, schema):
108+
ret = {}
109+
for k, v in dict.items():
110+
ret[k] = np.array(v, dtype=schema[k])
111+
return ret
112+
113+
def test_write_error(self):
114+
schema = {"_id": "int32", "data": "int64"}
115+
length = 10001
116+
data = {"_id": [i for i in range(length)] * 2, "data": [i * 2 for i in range(length)] * 2}
117+
data = self.schemafied_ndarray_dict(data, schema)
118+
with self.assertRaises(ArrowWriteError):
119+
try:
120+
self.round_trip(data, Schema({"_id": int32(), "data": int64()}))
121+
except ArrowWriteError as awe:
122+
self.assertEqual(
123+
10001, awe.details["writeErrors"][0]["index"], awe.details["nInserted"]
124+
)
125+
raise awe
126+
127+
def test_write_schema_validation(self):
128+
schema = {
129+
"data": "int64",
130+
"float": "float64",
131+
"datetime": "datetime64[ms]",
132+
"string": "str",
133+
"bool": "bool",
134+
}
135+
data = {
136+
"data": [i for i in range(2)],
137+
"float": [i for i in range(2)],
138+
"datetime": [i for i in range(2)],
139+
"string": [str(i) for i in range(2)],
140+
"bool": [True for _ in range(2)],
141+
}
142+
data = self.schemafied_ndarray_dict(data, schema)
143+
self.round_trip(
144+
data,
145+
Schema(
146+
{
147+
"data": int64(),
148+
"float": float64(),
149+
"datetime": timestamp("ms"),
150+
"string": string(),
151+
"bool": bool_(),
152+
}
153+
),
154+
)
155+
156+
schema = {"_id": "int32", "data": np.ubyte()}
157+
data = {"_id": [i for i in range(2)], "data": [i for i in range(2)]}
158+
data = self.schemafied_ndarray_dict(data, schema)
159+
with self.assertRaises(ValueError):
160+
self.round_trip(data, Schema({"_id": int32(), "data": np.ubyte()}))
161+
162+
@mock.patch.object(Collection, "insert_many", side_effect=Collection.insert_many, autospec=True)
163+
def test_write_batching(self, mock):
164+
schema = {"_id": "int64"}
165+
data = {"_id": [i for i in range(100040)]}
166+
data = self.schemafied_ndarray_dict(data, schema)
167+
168+
self.round_trip(
169+
data,
170+
Schema(
171+
{
172+
"_id": int64(),
173+
}
174+
),
175+
coll=self.coll,
176+
)
177+
self.assertEqual(mock.call_count, 2)
178+
179+
def test_write_dictionaries(self):
180+
with self.assertRaisesRegex(
181+
ValueError, "Invalid tabular data object of type <class 'dict'>"
182+
):
183+
write(self.coll, {"foo": 1})

bindings/python/test/test_pandas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy as np
2121
import pandas as pd
22-
from pyarrow import bool_, decimal128, float64, int32, int64, string, timestamp
22+
from pyarrow import bool_, decimal256, float64, int32, int64, string, timestamp
2323
from pymongo import DESCENDING, WriteConcern
2424
from pymongo.collection import Collection
2525
from pymongoarrow.api import Schema, aggregate_pandas_all, find_pandas_all, write
@@ -140,7 +140,7 @@ def test_write_schema_validation(self):
140140
data={"_id": [i for i in range(2)], "data": [i for i in range(2)]}
141141
).astype(schema)
142142
with self.assertRaises(ValueError):
143-
self.round_trip(data, Schema({"_id": int32(), "data": decimal128(2)}))
143+
self.round_trip(data, Schema({"_id": int32(), "data": decimal256(2)}))
144144

145145
@mock.patch.object(Collection, "insert_many", side_effect=Collection.insert_many, autospec=True)
146146
def test_write_batching(self, mock):

0 commit comments

Comments
 (0)