15
15
import unittest
16
16
from test import client_context
17
17
from test .utils import AllowListEventListener
18
+ from unittest import mock
18
19
19
20
import numpy as np
20
- from pyarrow import int32 , int64
21
+ from pyarrow import bool_ , float64 , int32 , int64 , string , timestamp
21
22
from pymongo import DESCENDING , WriteConcern
22
- from pymongoarrow .api import Schema , aggregate_numpy_all , find_numpy_all
23
+ from pymongo .collection import Collection
24
+ from pymongoarrow .api import Schema , aggregate_numpy_all , find_numpy_all , write
25
+ from pymongoarrow .errors import ArrowWriteError
23
26
24
27
25
28
class TestExplicitNumPyApi (unittest .TestCase ):
@@ -47,11 +50,11 @@ def setUp(self):
47
50
48
51
def assert_numpy_equal (self , actual , expected ):
49
52
self .assertIsInstance (actual , dict )
50
- for field in self . schema :
53
+ for field in expected :
51
54
# workaround np.nan == np.nan evaluating to False
52
55
a = np .nan_to_num (actual [field ])
53
56
e = np .nan_to_num (expected [field ])
54
- self . assertTrue ( np .all ( a == e ) )
57
+ np .testing . assert_array_equal ( a , e )
55
58
self .assertEqual (actual [field ].dtype , expected [field ].dtype )
56
59
57
60
def test_find_simple (self ):
@@ -80,7 +83,7 @@ def test_find_simple(self):
80
83
def test_aggregate_simple (self ):
81
84
expected = {
82
85
"_id" : np .array ([1 , 2 , 3 , 4 ], dtype = np .int32 ),
83
- "data" : np .array ([20 , 40 , 60 , np . nan ], dtype = np .float64 ),
86
+ "data" : np .array ([20 , 40 , 60 , None ], dtype = np .float64 ),
84
87
}
85
88
projection = {"_id" : True , "data" : {"$multiply" : [2 , "$data" ]}}
86
89
actual = aggregate_numpy_all (self .coll , [{"$project" : projection }], schema = self .schema )
@@ -91,3 +94,90 @@ def test_aggregate_simple(self):
91
94
assert len (agg_cmd .command ["pipeline" ]) == 2
92
95
self .assertEqual (agg_cmd .command ["pipeline" ][0 ]["$project" ], projection )
93
96
self .assertEqual (agg_cmd .command ["pipeline" ][1 ]["$project" ], {"_id" : True , "data" : True })
97
+
98
+ def round_trip (self , data , schema , coll = None ):
99
+ if coll is None :
100
+ coll = self .coll
101
+ coll .drop ()
102
+ res = write (self .coll , data )
103
+ self .assertEqual (len (list (data .values ())[0 ]), res .raw_result ["insertedCount" ])
104
+ self .assert_numpy_equal (find_numpy_all (coll , {}, schema = schema ), data )
105
+ return res
106
+
107
+ def schemafied_ndarray_dict (self , dict , schema ):
108
+ ret = {}
109
+ for k , v in dict .items ():
110
+ ret [k ] = np .array (v , dtype = schema [k ])
111
+ return ret
112
+
113
+ def test_write_error (self ):
114
+ schema = {"_id" : "int32" , "data" : "int64" }
115
+ length = 10001
116
+ data = {"_id" : [i for i in range (length )] * 2 , "data" : [i * 2 for i in range (length )] * 2 }
117
+ data = self .schemafied_ndarray_dict (data , schema )
118
+ with self .assertRaises (ArrowWriteError ):
119
+ try :
120
+ self .round_trip (data , Schema ({"_id" : int32 (), "data" : int64 ()}))
121
+ except ArrowWriteError as awe :
122
+ self .assertEqual (
123
+ 10001 , awe .details ["writeErrors" ][0 ]["index" ], awe .details ["nInserted" ]
124
+ )
125
+ raise awe
126
+
127
+ def test_write_schema_validation (self ):
128
+ schema = {
129
+ "data" : "int64" ,
130
+ "float" : "float64" ,
131
+ "datetime" : "datetime64[ms]" ,
132
+ "string" : "str" ,
133
+ "bool" : "bool" ,
134
+ }
135
+ data = {
136
+ "data" : [i for i in range (2 )],
137
+ "float" : [i for i in range (2 )],
138
+ "datetime" : [i for i in range (2 )],
139
+ "string" : [str (i ) for i in range (2 )],
140
+ "bool" : [True for _ in range (2 )],
141
+ }
142
+ data = self .schemafied_ndarray_dict (data , schema )
143
+ self .round_trip (
144
+ data ,
145
+ Schema (
146
+ {
147
+ "data" : int64 (),
148
+ "float" : float64 (),
149
+ "datetime" : timestamp ("ms" ),
150
+ "string" : string (),
151
+ "bool" : bool_ (),
152
+ }
153
+ ),
154
+ )
155
+
156
+ schema = {"_id" : "int32" , "data" : np .ubyte ()}
157
+ data = {"_id" : [i for i in range (2 )], "data" : [i for i in range (2 )]}
158
+ data = self .schemafied_ndarray_dict (data , schema )
159
+ with self .assertRaises (ValueError ):
160
+ self .round_trip (data , Schema ({"_id" : int32 (), "data" : np .ubyte ()}))
161
+
162
+ @mock .patch .object (Collection , "insert_many" , side_effect = Collection .insert_many , autospec = True )
163
+ def test_write_batching (self , mock ):
164
+ schema = {"_id" : "int64" }
165
+ data = {"_id" : [i for i in range (100040 )]}
166
+ data = self .schemafied_ndarray_dict (data , schema )
167
+
168
+ self .round_trip (
169
+ data ,
170
+ Schema (
171
+ {
172
+ "_id" : int64 (),
173
+ }
174
+ ),
175
+ coll = self .coll ,
176
+ )
177
+ self .assertEqual (mock .call_count , 2 )
178
+
179
+ def test_write_dictionaries (self ):
180
+ with self .assertRaisesRegex (
181
+ ValueError , "Invalid tabular data object of type <class 'dict'>"
182
+ ):
183
+ write (self .coll , {"foo" : 1 })
0 commit comments