Skip to content

Commit 7247b8a

Browse files
committed
add threading tests
1 parent 05cf96d commit 7247b8a

File tree

5 files changed

+907
-817
lines changed

5 files changed

+907
-817
lines changed

bindings/python/pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ testpaths = ["test"]
101101
log_cli_level = "INFO"
102102
faulthandler_timeout = 1500
103103
xfail_strict = true
104-
filterwarnings = [
105-
"error",
106-
]
104+
# filterwarnings = [
105+
# "error",
106+
# ]
107107

108108
[tool.ruff]
109109
line-length = 100

bindings/python/test/test_arrow.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import concurrent.futures
1516
import io
1617
import json
1718
import tempfile
19+
import threading
1820
import unittest
1921
import unittest.mock as mock
2022
from datetime import date, datetime
@@ -244,8 +246,8 @@ def test_aggregate_omits_id_if_not_in_schema(self):
244246
def round_trip(self, data, schema, coll=None):
245247
if coll is None:
246248
coll = self.coll
247-
self.coll.drop()
248-
res = write(self.coll, data)
249+
coll.drop()
250+
res = write(coll, data)
249251
self.assertEqual(len(data), res.raw_result["insertedCount"])
250252
self.assertEqual(data, find_arrow_all(coll, {}, schema=schema))
251253
return res
@@ -1052,6 +1054,29 @@ def test_empty_embedded_array(self):
10521054
assert pmapatable2.to_pylist()[0] == doc2
10531055
write_table(pmapatable2, io.BytesIO())
10541056

1057+
def test_threading(self):
1058+
schema, data = self._create_data()
1059+
1060+
def run_test():
1061+
client = client_context.get_client(
1062+
event_listeners=[self.getmore_listener, self.cmd_listener]
1063+
)
1064+
name = f"test-{threading.current_thread().name}"
1065+
coll = client.pymongoarrow_test.get_collection(
1066+
name, write_concern=WriteConcern(w="majority")
1067+
)
1068+
coll.drop()
1069+
self.round_trip(data, Schema(schema), coll=coll)
1070+
client.close()
1071+
1072+
with concurrent.futures.ThreadPoolExecutor() as executor:
1073+
futures = []
1074+
for i in range(5):
1075+
futures.append(executor.submit(run_test))
1076+
concurrent.futures.wait(futures)
1077+
for future in futures:
1078+
future.result()
1079+
10551080

10561081
class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
10571082
def run_find(self, *args, **kwargs):

bindings/python/test/test_pandas.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# from datetime import datetime, timedelta
15+
import concurrent.futures
1516
import datetime
1617
import tempfile
18+
import threading
1719
import unittest
1820
import unittest.mock as mock
1921
import warnings
@@ -266,18 +268,22 @@ def inner(i):
266268
raw_data["nested"] = [inner(i) for i in range(3)]
267269
return pd.DataFrame(data=raw_data).astype(schema)
268270

269-
def test_auto_schema(self):
271+
def _check_auto_schema(self, coll):
272+
coll.drop()
270273
data = self._create_nested_data()
271-
self.coll.drop()
272-
res = write(self.coll, data)
274+
res = write(coll, data)
273275
self.assertEqual(len(data), res.raw_result["insertedCount"])
274276
for func in [find_pandas_all, aggregate_pandas_all]:
275-
out = func(self.coll, {} if func == find_pandas_all else []).drop(columns=["_id"])
277+
out = func(coll, {} if func == find_pandas_all else []).drop(columns=["_id"])
276278
for name in data.columns:
277279
val = out[name]
278280
if str(val.dtype) == "object":
279281
val = val.astype(data[name].dtype)
280282
pd.testing.assert_series_equal(data[name], val)
283+
coll.drop()
284+
285+
def test_auto_schema(self):
286+
self._check_auto_schema(self.coll)
281287

282288
def test_auto_schema_heterogeneous(self):
283289
vals = [1, "2", True, 4]
@@ -345,6 +351,26 @@ def test_exclude_none(self):
345351
col_data = list(self.coll.find({}))
346352
assert "b" not in col_data[3]
347353

354+
def test_threading(self):
355+
def run_test():
356+
client = client_context.get_client(
357+
event_listeners=[self.getmore_listener, self.cmd_listener]
358+
)
359+
name = f"test-{threading.current_thread().name}"
360+
coll = client.pymongoarrow_test.get_collection(
361+
name, write_concern=WriteConcern(w="majority")
362+
)
363+
self._check_auto_schema(coll)
364+
client.close()
365+
366+
with concurrent.futures.ThreadPoolExecutor() as executor:
367+
futures = []
368+
for i in range(5):
369+
futures.append(executor.submit(run_test))
370+
concurrent.futures.wait(futures)
371+
for future in futures:
372+
future.result()
373+
348374

349375
class TestBSONTypes(PandasTestBase):
350376
@classmethod

bindings/python/test/test_polars.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import concurrent.futures
16+
import threading
1517
import unittest
1618
import unittest.mock as mock
1719
import uuid
@@ -65,14 +67,13 @@ def setUpClass(cls):
6567
def setUp(self):
6668
"""Insert simple use case data."""
6769
self.coll.drop()
68-
self.coll.insert_many(
69-
[
70-
{"_id": 1, "data": 10},
71-
{"_id": 2, "data": 20},
72-
{"_id": 3, "data": 30},
73-
{"_id": 4},
74-
]
75-
)
70+
self.data = [
71+
{"_id": 1, "data": 10},
72+
{"_id": 2, "data": 20},
73+
{"_id": 3, "data": 30},
74+
{"_id": 4},
75+
]
76+
self.coll.insert_many(self.data)
7677
self.cmd_listener.reset()
7778
self.getmore_listener.reset()
7879

@@ -115,15 +116,17 @@ def test_find_simple(self):
115116
self.assertEqual(find_cmd.command_name, "find")
116117
self.assertEqual(find_cmd.command["projection"], {"_id": True, "data": True})
117118

118-
def test_aggregate_simple(self):
119+
def _check_aggregation_simple(self, coll):
119120
expected = pl.DataFrame(
120121
data={
121122
"_id": pl.Series(values=[1, 2, 3, 4], dtype=pl.Int32),
122123
"data": pl.Series(values=[20, 40, 60, None], dtype=pl.Int64),
123124
}
124125
)
126+
coll.drop()
127+
coll.insert_many(self.data)
125128
projection = {"_id": True, "data": {"$multiply": [2, "$data"]}}
126-
table = aggregate_polars_all(self.coll, [{"$project": projection}], schema=self.schema)
129+
table = aggregate_polars_all(coll, [{"$project": projection}], schema=self.schema)
127130
self.assertTrue(table.equals(expected))
128131

129132
agg_cmd = self.cmd_listener.results["started"][-1]
@@ -132,6 +135,9 @@ def test_aggregate_simple(self):
132135
self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection)
133136
self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True})
134137

138+
def test_aggregate_simple(self):
139+
self._check_aggregation_simple(self.coll)
140+
135141
@mock.patch.object(Collection, "insert_many", side_effect=Collection.insert_many, autospec=True)
136142
def test_write_batching(self, mock):
137143
data = pl.DataFrame(data={"_id": pl.Series(values=range(100040), dtype=pl.Int64)})
@@ -395,3 +401,23 @@ def test_bson_types(self):
395401
assert dfpl["value"].dtype == data_type["ptype"]
396402
except pl.exceptions.ComputeError:
397403
assert isinstance(table["value"].type, pa.ExtensionType)
404+
405+
def test_threading(self):
406+
def run_test():
407+
client = client_context.get_client(
408+
event_listeners=[self.getmore_listener, self.cmd_listener]
409+
)
410+
name = f"test-{threading.current_thread().name}"
411+
coll = client.pymongoarrow_test.get_collection(
412+
name, write_concern=WriteConcern(w="majority")
413+
)
414+
self._check_aggregation_simple(coll)
415+
client.close()
416+
417+
with concurrent.futures.ThreadPoolExecutor() as executor:
418+
futures = []
419+
for i in range(5):
420+
futures.append(executor.submit(run_test))
421+
concurrent.futures.wait(futures)
422+
for future in futures:
423+
future.result()

0 commit comments

Comments
 (0)