Skip to content

Commit acc8bdc

Browse files
committed
Add async TestSpecCreator
1 parent 4158326 commit acc8bdc

File tree

9 files changed

+345
-170
lines changed

9 files changed

+345
-170
lines changed

test/asynchronous/test_encryption.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import uuid
3030
import warnings
3131
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context
32-
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
32+
from test.asynchronous.utils_spec_runner import AsyncSpecRunner, AsyncSpecTestCreator
3333
from threading import Thread
3434
from typing import Any, Dict, Mapping
3535

@@ -58,14 +58,12 @@
5858
from test.utils import (
5959
AllowListEventListener,
6060
OvertCommandListener,
61-
SpecTestCreator,
6261
TopologyEventListener,
6362
async_rs_or_single_client,
6463
async_wait_until,
6564
camel_to_snake_args,
6665
is_greenthread_patched,
6766
)
68-
from test.utils_spec_runner import SpecRunner
6967

7068
from bson import DatetimeMS, Decimal128, encode, json_util
7169
from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation
@@ -718,15 +716,15 @@ def allowable_errors(self, op):
718716
return errors
719717

720718

721-
def create_test(scenario_def, test, name):
719+
async def create_test(scenario_def, test, name):
722720
@async_client_context.require_test_commands
723-
def run_scenario(self):
724-
self.run_scenario(scenario_def, test)
721+
async def run_scenario(self):
722+
await self.run_scenario(scenario_def, test)
725723

726724
return run_scenario
727725

728726

729-
test_creator = SpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy"))
727+
test_creator = AsyncSpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy"))
730728
test_creator.create_tests()
731729

732730
if _HAVE_PYMONGOCRYPT:

test/asynchronous/utils_spec_runner.py

Lines changed: 168 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@
1515
"""Utilities for testing driver specs."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import functools
20+
import os
1921
import threading
22+
import unittest
23+
from asyncio import iscoroutinefunction
2024
from collections import abc
2125
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
2226
from test.utils import (
2327
CMAPListener,
2428
CompareType,
2529
EventListener,
2630
OvertCommandListener,
31+
ScenarioDict,
2732
ServerAndTopologyEventListener,
2833
async_rs_client,
2934
camel_to_snake,
@@ -33,11 +38,12 @@
3338
)
3439
from typing import List
3540

36-
from bson import ObjectId, decode, encode
41+
from bson import ObjectId, decode, encode, json_util
3742
from bson.binary import Binary
3843
from bson.int64 import Int64
3944
from bson.son import SON
4045
from gridfs import GridFSBucket
46+
from gridfs.asynchronous.grid_file import AsyncGridFSBucket
4147
from pymongo.asynchronous import client_session
4248
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
4349
from pymongo.asynchronous.cursor import AsyncCursor
@@ -84,6 +90,161 @@ def run(self):
8490
self.stop()
8591

8692

93+
class AsyncSpecTestCreator:
94+
"""Class to create test cases from specifications."""
95+
96+
def __init__(self, create_test, test_class, test_path):
97+
"""Create a TestCreator object.
98+
99+
:Parameters:
100+
- `create_test`: callback that returns a test case. The callback
101+
must accept the following arguments - a dictionary containing the
102+
entire test specification (the `scenario_def`), a dictionary
103+
containing the specification for which the test case will be
104+
generated (the `test_def`).
105+
- `test_class`: the unittest.TestCase class in which to create the
106+
test case.
107+
- `test_path`: path to the directory containing the JSON files with
108+
the test specifications.
109+
"""
110+
self._create_test = create_test
111+
self._test_class = test_class
112+
self.test_path = test_path
113+
114+
def _ensure_min_max_server_version(self, scenario_def, method):
115+
"""Test modifier that enforces a version range for the server on a
116+
test case.
117+
"""
118+
if "minServerVersion" in scenario_def:
119+
min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split("."))
120+
if min_ver is not None:
121+
method = async_client_context.require_version_min(*min_ver)(method)
122+
123+
if "maxServerVersion" in scenario_def:
124+
max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split("."))
125+
if max_ver is not None:
126+
method = async_client_context.require_version_max(*max_ver)(method)
127+
128+
if "serverless" in scenario_def:
129+
serverless = scenario_def["serverless"]
130+
if serverless == "require":
131+
serverless_satisfied = async_client_context.serverless
132+
elif serverless == "forbid":
133+
serverless_satisfied = not async_client_context.serverless
134+
else: # unset or "allow"
135+
serverless_satisfied = True
136+
method = unittest.skipUnless(
137+
serverless_satisfied, "Serverless requirement not satisfied"
138+
)(method)
139+
140+
return method
141+
142+
@staticmethod
143+
async def valid_topology(run_on_req):
144+
return await async_client_context.is_topology_type(
145+
run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"])
146+
)
147+
148+
@staticmethod
149+
def min_server_version(run_on_req):
150+
version = run_on_req.get("minServerVersion")
151+
if version:
152+
min_ver = tuple(int(elt) for elt in version.split("."))
153+
return async_client_context.version >= min_ver
154+
return True
155+
156+
@staticmethod
157+
def max_server_version(run_on_req):
158+
version = run_on_req.get("maxServerVersion")
159+
if version:
160+
max_ver = tuple(int(elt) for elt in version.split("."))
161+
return async_client_context.version <= max_ver
162+
return True
163+
164+
@staticmethod
165+
def valid_auth_enabled(run_on_req):
166+
if "authEnabled" in run_on_req:
167+
if run_on_req["authEnabled"]:
168+
return async_client_context.auth_enabled
169+
return not async_client_context.auth_enabled
170+
return True
171+
172+
@staticmethod
173+
def serverless_ok(run_on_req):
174+
serverless = run_on_req["serverless"]
175+
if serverless == "require":
176+
return async_client_context.serverless
177+
elif serverless == "forbid":
178+
return not async_client_context.serverless
179+
else: # unset or "allow"
180+
return True
181+
182+
async def should_run_on(self, scenario_def):
183+
run_on = scenario_def.get("runOn", [])
184+
if not run_on:
185+
# Always run these tests.
186+
return True
187+
188+
for req in run_on:
189+
if (
190+
await self.valid_topology(req)
191+
and self.min_server_version(req)
192+
and self.max_server_version(req)
193+
and self.valid_auth_enabled(req)
194+
and self.serverless_ok(req)
195+
):
196+
return True
197+
return False
198+
199+
def ensure_run_on(self, scenario_def, method):
200+
"""Test modifier that enforces a 'runOn' on a test case."""
201+
202+
async def predicate():
203+
return await self.should_run_on(scenario_def)
204+
205+
return async_client_context._require(predicate, "runOn not satisfied", method)
206+
207+
def tests(self, scenario_def):
208+
"""Allow CMAP spec test to override the location of test."""
209+
return scenario_def["tests"]
210+
211+
async def _create_tests(self):
212+
for dirpath, _, filenames in os.walk(self.test_path):
213+
dirname = os.path.split(dirpath)[-1]
214+
215+
for filename in filenames:
216+
with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100
217+
# Use tz_aware=False to match how CodecOptions decodes
218+
# dates.
219+
opts = json_util.JSONOptions(tz_aware=False)
220+
scenario_def = ScenarioDict(
221+
json_util.loads(scenario_stream.read(), json_options=opts)
222+
)
223+
224+
test_type = os.path.splitext(filename)[0]
225+
226+
# Construct test from scenario.
227+
for test_def in self.tests(scenario_def):
228+
test_name = "test_{}_{}_{}".format(
229+
dirname,
230+
test_type.replace("-", "_").replace(".", "_"),
231+
str(test_def["description"].replace(" ", "_").replace(".", "_")),
232+
)
233+
234+
new_test = await self._create_test(scenario_def, test_def, test_name)
235+
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
236+
new_test = self.ensure_run_on(scenario_def, new_test)
237+
238+
new_test.__name__ = test_name
239+
setattr(self._test_class, new_test.__name__, new_test)
240+
241+
def create_tests(self):
242+
if _IS_SYNC:
243+
self._create_tests()
244+
else:
245+
asyncio.run(self._create_tests())
246+
247+
87248
class AsyncSpecRunner(AsyncIntegrationTest):
88249
mongos_clients: List
89250
knobs: client_knobs
@@ -283,7 +444,7 @@ async def run_operation(self, sessions, collection, operation):
283444
if object_name == "gridfsbucket":
284445
# Only create the GridFSBucket when we need it (for the gridfs
285446
# retryable reads tests).
286-
obj = GridFSBucket(database, bucket_name=collection.name)
447+
obj = AsyncGridFSBucket(database, bucket_name=collection.name)
287448
else:
288449
objects = {
289450
"client": database.client,
@@ -311,7 +472,10 @@ async def run_operation(self, sessions, collection, operation):
311472
args.update(arguments)
312473
arguments = args
313474

314-
result = cmd(**dict(arguments))
475+
if not _IS_SYNC and iscoroutinefunction(cmd):
476+
result = await cmd(**dict(arguments))
477+
else:
478+
result = cmd(**dict(arguments))
315479
# Cleanup open change stream cursors.
316480
if name == "watch":
317481
self.addAsyncCleanup(result.close)
@@ -587,7 +751,7 @@ async def run_scenario(self, scenario_def, test):
587751
read_preference=ReadPreference.PRIMARY,
588752
read_concern=ReadConcern("local"),
589753
)
590-
actual_data = await (await outcome_coll.find(sort=[("_id", 1)])).to_list()
754+
actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list()
591755

592756
# The expected data needs to be the left hand side here otherwise
593757
# CompareType(Binary) doesn't work.

test/test_connection_monitoring.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from test.pymongo_mocks import DummyMonitor
2626
from test.utils import (
2727
CMAPListener,
28-
SpecTestCreator,
2928
camel_to_snake,
3029
client_context,
3130
get_pool,
@@ -35,7 +34,7 @@
3534
single_client_noauth,
3635
wait_until,
3736
)
38-
from test.utils_spec_runner import SpecRunnerThread
37+
from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator
3938

4039
from bson.objectid import ObjectId
4140
from bson.son import SON

test/test_encryption.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import uuid
3030
import warnings
3131
from test import IntegrationTest, PyMongoTestCase, client_context
32-
from test.utils_spec_runner import SpecRunner
32+
from test.utils_spec_runner import SpecRunner, SpecTestCreator
3333
from threading import Thread
3434
from typing import Any, Dict, Mapping
3535

@@ -58,14 +58,12 @@
5858
from test.utils import (
5959
AllowListEventListener,
6060
OvertCommandListener,
61-
SpecTestCreator,
6261
TopologyEventListener,
6362
camel_to_snake_args,
6463
is_greenthread_patched,
6564
rs_or_single_client,
6665
wait_until,
6766
)
68-
from test.utils_spec_runner import SpecRunner
6967

7068
from bson import DatetimeMS, Decimal128, encode, json_util
7169
from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation

test/test_retryable_reads.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,11 @@
3636
CMAPListener,
3737
EventListener,
3838
OvertCommandListener,
39-
SpecTestCreator,
4039
rs_client,
4140
rs_or_single_client,
4241
set_fail_point,
4342
)
44-
from test.utils_spec_runner import SpecRunner
43+
from test.utils_spec_runner import SpecRunner, SpecTestCreator
4544

4645
from pymongo.monitoring import (
4746
ConnectionCheckedOutEvent,

test/test_retryable_writes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@
2929
DeprecationFilter,
3030
EventListener,
3131
OvertCommandListener,
32-
SpecTestCreator,
3332
rs_or_single_client,
3433
set_fail_point,
3534
)
36-
from test.utils_spec_runner import SpecRunner
35+
from test.utils_spec_runner import SpecRunner, SpecTestCreator
3736
from test.version import Version
3837

3938
from bson.codec_options import DEFAULT_CODEC_OPTIONS

0 commit comments

Comments
 (0)