Skip to content

Commit 25e31db

Browse files
committed
Move ConcurrentRunner into helpers.py
1 parent 9fddf51 commit 25e31db

File tree

4 files changed

+76
-66
lines changed

4 files changed

+76
-66
lines changed

test/asynchronous/helpers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import base64
1920
import gc
2021
import multiprocessing
@@ -30,6 +31,8 @@
3031
import warnings
3132
from asyncio import iscoroutinefunction
3233

34+
from pymongo._asyncio_task import create_task
35+
3336
try:
3437
import ipaddress
3538

@@ -369,3 +372,37 @@ def disable(self):
369372
os.environ.pop("SSL_CERT_FILE")
370373
else:
371374
os.environ["SSL_CERT_FILE"] = self.original_certs
375+
376+
377+
if _IS_SYNC:
378+
PARENT = threading.Thread
379+
else:
380+
PARENT = object
381+
382+
383+
class ConcurrentRunner(PARENT):
384+
def __init__(self, name, *args, **kwargs):
385+
if _IS_SYNC:
386+
super().__init__(*args, **kwargs)
387+
self.name = name
388+
self.stopped = False
389+
self.task = None
390+
if "target" in kwargs:
391+
self.target = kwargs["target"]
392+
393+
if not _IS_SYNC:
394+
395+
async def start(self):
396+
self.task = create_task(self.run(), name=self.name)
397+
398+
async def join(self, timeout: float | None = 0): # type: ignore[override]
399+
if self.task is not None:
400+
await asyncio.wait([self.task], timeout=timeout)
401+
402+
def is_alive(self):
403+
return not self.stopped
404+
405+
async def run(self):
406+
if self.target:
407+
await self.target()
408+
self.stopped = True

test/asynchronous/utils_spec_runner.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import asyncio
1919
import functools
2020
import os
21-
import threading
2221
import unittest
2322
from asyncio import iscoroutinefunction
2423
from collections import abc
2524
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
25+
from test.asynchronous.helpers import ConcurrentRunner
2626
from test.utils import (
2727
CMAPListener,
2828
CompareType,
@@ -55,38 +55,6 @@
5555

5656
_IS_SYNC = False
5757

58-
if _IS_SYNC:
59-
PARENT = threading.Thread
60-
else:
61-
PARENT = object
62-
63-
64-
class ConcurrentRunner(PARENT):
65-
def __init__(self, name, *args, **kwargs):
66-
if _IS_SYNC:
67-
super().__init__(*args, **kwargs)
68-
self.name = name
69-
self.stopped = False
70-
self.task = None
71-
if "target" in kwargs:
72-
self.target = kwargs["target"]
73-
74-
if not _IS_SYNC:
75-
76-
async def start(self):
77-
self.task = asyncio.create_task(self.run(), name=self.name)
78-
79-
async def join(self, timeout: float | None = 0): # type: ignore[override]
80-
if self.task is not None:
81-
await asyncio.wait([self.task], timeout=timeout)
82-
83-
def is_alive(self):
84-
return not self.stopped
85-
86-
async def run(self):
87-
if self.target:
88-
await self.target()
89-
9058

9159
class SpecRunnerTask(ConcurrentRunner):
9260
def __init__(self, name):

test/helpers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import base64
1920
import gc
2021
import multiprocessing
@@ -30,6 +31,8 @@
3031
import warnings
3132
from asyncio import iscoroutinefunction
3233

34+
from pymongo._asyncio_task import create_task
35+
3336
try:
3437
import ipaddress
3538

@@ -369,3 +372,37 @@ def disable(self):
369372
os.environ.pop("SSL_CERT_FILE")
370373
else:
371374
os.environ["SSL_CERT_FILE"] = self.original_certs
375+
376+
377+
if _IS_SYNC:
378+
PARENT = threading.Thread
379+
else:
380+
PARENT = object
381+
382+
383+
class ConcurrentRunner(PARENT):
384+
def __init__(self, name, *args, **kwargs):
385+
if _IS_SYNC:
386+
super().__init__(*args, **kwargs)
387+
self.name = name
388+
self.stopped = False
389+
self.task = None
390+
if "target" in kwargs:
391+
self.target = kwargs["target"]
392+
393+
if not _IS_SYNC:
394+
395+
def start(self):
396+
self.task = create_task(self.run(), name=self.name)
397+
398+
def join(self, timeout: float | None = 0): # type: ignore[override]
399+
if self.task is not None:
400+
asyncio.wait([self.task], timeout=timeout)
401+
402+
def is_alive(self):
403+
return not self.stopped
404+
405+
def run(self):
406+
if self.target:
407+
self.target()
408+
self.stopped = True

test/utils_spec_runner.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import asyncio
1919
import functools
2020
import os
21-
import threading
2221
import unittest
2322
from asyncio import iscoroutinefunction
2423
from collections import abc
2524
from test import IntegrationTest, client_context, client_knobs
25+
from test.helpers import ConcurrentRunner
2626
from test.utils import (
2727
CMAPListener,
2828
CompareType,
@@ -55,38 +55,6 @@
5555

5656
_IS_SYNC = True
5757

58-
if _IS_SYNC:
59-
PARENT = threading.Thread
60-
else:
61-
PARENT = object
62-
63-
64-
class ConcurrentRunner(PARENT):
65-
def __init__(self, name, *args, **kwargs):
66-
if _IS_SYNC:
67-
super().__init__(*args, **kwargs)
68-
self.name = name
69-
self.stopped = False
70-
self.task = None
71-
if "target" in kwargs:
72-
self.target = kwargs["target"]
73-
74-
if not _IS_SYNC:
75-
76-
def start(self):
77-
self.task = asyncio.create_task(self.run(), name=self.name)
78-
79-
def join(self, timeout: float | None = 0): # type: ignore[override]
80-
if self.task is not None:
81-
asyncio.wait([self.task], timeout=timeout)
82-
83-
def is_alive(self):
84-
return not self.stopped
85-
86-
def run(self):
87-
if self.target:
88-
self.target()
89-
9058

9159
class SpecRunnerThread(ConcurrentRunner):
9260
def __init__(self, name):

0 commit comments

Comments
 (0)