|
22 | 22 | import threading
|
23 | 23 | import time
|
24 | 24 | from io import BytesIO
|
| 25 | +from test.asynchronous.helpers import ConcurrentRunner |
25 | 26 | from unittest.mock import patch
|
26 | 27 |
|
27 | 28 | sys.path[0:0] = [""]
|
28 | 29 |
|
29 | 30 | from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
30 |
| -from test.utils import asyncjoinall, joinall, one |
| 31 | +from test.utils import joinall, one |
31 | 32 |
|
32 | 33 | import gridfs
|
33 | 34 | from bson.binary import Binary
|
|
44 | 45 |
|
45 | 46 | _IS_SYNC = False
|
46 | 47 |
|
47 |
| -if _IS_SYNC: |
48 |
| - |
49 |
| - class JustWrite(threading.Thread): |
50 |
| - def __init__(self, fs, n): |
51 |
| - threading.Thread.__init__(self) |
52 |
| - self.fs = fs |
53 |
| - self.n = n |
54 |
| - self.daemon = True |
55 |
| - |
56 |
| - def run(self): |
57 |
| - for _ in range(self.n): |
58 |
| - file = self.fs.new_file(filename="test") |
59 |
| - file.write(b"hello") |
60 |
| - file.close() |
61 |
| - |
62 |
| - class JustRead(threading.Thread): |
63 |
| - def __init__(self, fs, n, results): |
64 |
| - threading.Thread.__init__(self) |
65 |
| - self.fs = fs |
66 |
| - self.n = n |
67 |
| - self.results = results |
68 |
| - self.daemon = True |
69 |
| - |
70 |
| - def run(self): |
71 |
| - for _ in range(self.n): |
72 |
| - file = self.fs.get("test") |
73 |
| - data = file.read() |
74 |
| - self.results.append(data) |
75 |
| - assert data == b"hello" |
76 |
| -else: |
77 |
| - |
78 |
| - class JustWrite: |
79 |
| - def __init__(self, fs, n): |
80 |
| - self.task = asyncio.create_task(self.run()) |
81 |
| - self.fs = fs |
82 |
| - self.n = n |
83 |
| - self.daemon = True |
84 |
| - |
85 |
| - async def run(self): |
86 |
| - for _ in range(self.n): |
87 |
| - file = self.fs.new_file(filename="test") |
88 |
| - await file.write(b"hello") |
89 |
| - await file.close() |
90 |
| - |
91 |
| - class JustRead: |
92 |
| - def __init__(self, fs, n, results): |
93 |
| - self.task = asyncio.create_task(self.run()) |
94 |
| - self.fs = fs |
95 |
| - self.n = n |
96 |
| - self.results = results |
97 |
| - self.daemon = True |
98 |
| - |
99 |
| - async def run(self): |
100 |
| - for _ in range(self.n): |
101 |
| - file = await self.fs.get("test") |
102 |
| - data = await file.read() |
103 |
| - self.results.append(data) |
104 |
| - assert data == b"hello" |
| 48 | + |
| 49 | +class JustWrite(ConcurrentRunner): |
| 50 | + def __init__(self, fs, n): |
| 51 | + super().__init__() |
| 52 | + self.fs = fs |
| 53 | + self.n = n |
| 54 | + self.daemon = True |
| 55 | + |
| 56 | + async def run(self): |
| 57 | + for _ in range(self.n): |
| 58 | + file = self.fs.new_file(filename="test") |
| 59 | + await file.write(b"hello") |
| 60 | + await file.close() |
| 61 | + |
| 62 | + |
| 63 | +class JustRead(ConcurrentRunner): |
| 64 | + def __init__(self, fs, n, results): |
| 65 | + super().__init__() |
| 66 | + self.fs = fs |
| 67 | + self.n = n |
| 68 | + self.results = results |
| 69 | + self.daemon = True |
| 70 | + |
| 71 | + async def run(self): |
| 72 | + for _ in range(self.n): |
| 73 | + file = await self.fs.get("test") |
| 74 | + data = await file.read() |
| 75 | + self.results.append(data) |
| 76 | + assert data == b"hello" |
105 | 77 |
|
106 | 78 |
|
107 | 79 | class TestGridfsNoConnect(unittest.IsolatedAsyncioTestCase):
|
@@ -252,25 +224,29 @@ async def test_alt_collection(self):
|
252 | 224 | async def test_threaded_reads(self):
|
253 | 225 | await self.fs.put(b"hello", _id="test")
|
254 | 226 |
|
255 |
| - threads = [] |
| 227 | + tasks = [] |
256 | 228 | results: list = []
|
257 | 229 | for i in range(10):
|
258 |
| - threads.append(JustRead(self.fs, 10, results)) |
259 |
| - if _IS_SYNC: |
260 |
| - threads[i].start() |
| 230 | + tasks.append(JustRead(self.fs, 10, results)) |
| 231 | + await tasks[i].start() |
261 | 232 |
|
262 |
| - await asyncjoinall(threads) |
| 233 | + if _IS_SYNC: |
| 234 | + joinall(tasks) |
| 235 | + else: |
| 236 | + await asyncio.wait([t.task for t in tasks]) |
263 | 237 |
|
264 | 238 | self.assertEqual(100 * [b"hello"], results)
|
265 | 239 |
|
266 | 240 | async def test_threaded_writes(self):
|
267 |
| - threads = [] |
| 241 | + tasks = [] |
268 | 242 | for i in range(10):
|
269 |
| - threads.append(JustWrite(self.fs, 10)) |
270 |
| - if _IS_SYNC: |
271 |
| - threads[i].start() |
| 243 | + tasks.append(JustWrite(self.fs, 10)) |
| 244 | + await tasks[i].start() |
272 | 245 |
|
273 |
| - await asyncjoinall(threads) |
| 246 | + if _IS_SYNC: |
| 247 | + joinall(tasks) |
| 248 | + else: |
| 249 | + await asyncio.wait([t.task for t in tasks]) |
274 | 250 |
|
275 | 251 | f = await self.fs.get_last_version("test")
|
276 | 252 | self.assertEqual(await f.read(), b"hello")
|
|
0 commit comments