diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c844470138..42cc8dc1b7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -248,3 +248,10 @@ you are attempting to validate new spec tests in PyMongo. ## Making a Release Follow the [Python Driver Release Process Wiki](https://wiki.corp.mongodb.com/display/DRIVERS/Python+Driver+Release+Process). + +## Converting a test to async +The `tools/convert_test_to_async.py` script takes in an existing synchronous test file and outputs a +partially-converted asynchronous version of the same name to the `test/asynchronous` directory. +Use this generated file as a starting point for the completed conversion. + +The script is used like so: `python tools/convert_test_to_async.py [test_file.py]` diff --git a/tools/convert_test_to_async.py b/tools/convert_test_to_async.py new file mode 100644 index 0000000000..dbdb217c84 --- /dev/null +++ b/tools/convert_test_to_async.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import asyncio +import sys + +from pymongo import AsyncMongoClient +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.cursor import AsyncCursor +from pymongo.asynchronous.database import AsyncDatabase + +replacements = { + "Collection": "AsyncCollection", + "Database": "AsyncDatabase", + "Cursor": "AsyncCursor", + "MongoClient": "AsyncMongoClient", + "CommandCursor": "AsyncCommandCursor", + "RawBatchCursor": "AsyncRawBatchCursor", + "RawBatchCommandCursor": "AsyncRawBatchCommandCursor", + "ClientSession": "AsyncClientSession", + "ChangeStream": "AsyncChangeStream", + "CollectionChangeStream": "AsyncCollectionChangeStream", + "DatabaseChangeStream": "AsyncDatabaseChangeStream", + "ClusterChangeStream": "AsyncClusterChangeStream", + "_Bulk": "_AsyncBulk", + "_ClientBulk": "_AsyncClientBulk", + "Connection": "AsyncConnection", + "synchronous": "asynchronous", + "Synchronous": "Asynchronous", + "next": "await anext", + "_Lock": "_ALock", + "_Condition": "_ACondition", + "GridFS": "AsyncGridFS", + "GridFSBucket": "AsyncGridFSBucket", + "GridIn": "AsyncGridIn", + "GridOut": "AsyncGridOut", + "GridOutCursor": "AsyncGridOutCursor", + "GridOutIterator": "AsyncGridOutIterator", + "GridOutChunkIterator": "_AsyncGridOutChunkIterator", + "_grid_in_property": "_a_grid_in_property", + "_grid_out_property": "_a_grid_out_property", + "ClientEncryption": "AsyncClientEncryption", + "MongoCryptCallback": "AsyncMongoCryptCallback", + "ExplicitEncrypter": "AsyncExplicitEncrypter", + "AutoEncrypter": "AsyncAutoEncrypter", + "ContextManager": "AsyncContextManager", + "ClientContext": "AsyncClientContext", + "TestCollection": "AsyncTestCollection", + "IntegrationTest": "AsyncIntegrationTest", + "PyMongoTestCase": "AsyncPyMongoTestCase", + "MockClientTest": "AsyncMockClientTest", + "client_context": "async_client_context", + "setUp": "asyncSetUp", + "tearDown": "asyncTearDown", + "wait_until": "await async_wait_until", + "addCleanup": "addAsyncCleanup", + "TestCase": "IsolatedAsyncioTestCase", + "UnitTest": "AsyncUnitTest", + "MockClient": "AsyncMockClient", + "SpecRunner": "AsyncSpecRunner", + "TransactionsBase": "AsyncTransactionsBase", + "get_pool": "await async_get_pool", + "is_mongos": "await async_is_mongos", + "rs_or_single_client": "await async_rs_or_single_client", + "rs_or_single_client_noauth": "await async_rs_or_single_client_noauth", + "rs_client": "await async_rs_client", + "single_client": "await async_single_client", + "from_client": "await async_from_client", + "closing": "aclosing", + "assertRaisesExactly": "asyncAssertRaisesExactly", + "get_mock_client": "await get_async_mock_client", + "close": "await aclose", +} + +async_classes = [AsyncMongoClient, AsyncDatabase, AsyncCollection, AsyncCursor, AsyncCommandCursor] + + +def get_async_methods() -> set[str]: + result: set[str] = set() + for x in async_classes: + methods = { + k + for k, v in vars(x).items() + if callable(v) + and not isinstance(v, classmethod) + and asyncio.iscoroutinefunction(v) + and v.__name__[0] != "_" + } + result = result | methods + return result + + +async_methods = get_async_methods() + + +def apply_replacements(lines: list[str]) -> list[str]: + for i in range(len(lines)): + if "_IS_SYNC = True" in lines[i]: + lines[i] = "_IS_SYNC = False" + if "def test" in lines[i]: + lines[i] = lines[i].replace("def test", "async def test") + for k in replacements: + if k in lines[i]: + lines[i] = lines[i].replace(k, replacements[k]) + for k in async_methods: + if k + "(" in lines[i]: + tokens = lines[i].split(" ") + for j in range(len(tokens)): + if k + "(" in tokens[j]: + if j < 2: + tokens.insert(0, "await") + else: + tokens.insert(j, "await") + break + new_line = " ".join(tokens) + + lines[i] = new_line + + return lines + + +def process_file(input_file: str, output_file: str) -> None: + with open(input_file, "r+") as f: + lines = f.readlines() + lines = apply_replacements(lines) + + with open(output_file, "w+") as f2: + f2.seek(0) + f2.writelines(lines) + f2.truncate() + + +def main() -> None: + args = sys.argv[1:] + sync_file = "./test/" + args[0] + async_file = "./" + args[0] + + process_file(sync_file, async_file) + + +main()