Skip to content

Commit 0acea14

Browse files
authored
Fix running execute and subscribe of client in a Thread (#135)
1 parent 38d6c87 commit 0acea14

File tree

6 files changed

+141
-25
lines changed

6 files changed

+141
-25
lines changed

gql/client.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,13 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
110110

111111
if isinstance(self.transport, AsyncTransport):
112112

113-
loop = asyncio.get_event_loop()
113+
# Get the current asyncio event loop
114+
# Or create a new event loop if there isn't one (in a new Thread)
115+
try:
116+
loop = asyncio.get_event_loop()
117+
except RuntimeError:
118+
loop = asyncio.new_event_loop()
119+
asyncio.set_event_loop(loop)
114120

115121
assert not loop.is_running(), (
116122
"Cannot run client.execute(query) if an asyncio loop is running."
@@ -146,9 +152,15 @@ def subscribe(
146152
We need an async transport for this functionality.
147153
"""
148154

149-
async_generator = self.subscribe_async(document, *args, **kwargs)
155+
# Get the current asyncio event loop
156+
# Or create a new event loop if there isn't one (in a new Thread)
157+
try:
158+
loop = asyncio.get_event_loop()
159+
except RuntimeError:
160+
loop = asyncio.new_event_loop()
161+
asyncio.set_event_loop(loop)
150162

151-
loop = asyncio.get_event_loop()
163+
async_generator = self.subscribe_async(document, *args, **kwargs)
152164

153165
assert not loop.is_running(), (
154166
"Cannot run client.subscribe(query) if an asyncio loop is running."

gql/transport/websockets.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ def __init__(
128128
self.receive_data_task: Optional[asyncio.Future] = None
129129
self.close_task: Optional[asyncio.Future] = None
130130

131+
# We need to set an event loop here if there is none
132+
# Or else we will not be able to create an asyncio.Event()
133+
try:
134+
self._loop = asyncio.get_event_loop()
135+
except RuntimeError:
136+
self._loop = asyncio.new_event_loop()
137+
asyncio.set_event_loop(self._loop)
138+
131139
self._wait_closed: asyncio.Event = asyncio.Event()
132140
self._wait_closed.set()
133141

tests/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pathlib
66
import ssl
77
import types
8+
from concurrent.futures import ThreadPoolExecutor
89

910
import pytest
1011
import websockets
@@ -266,3 +267,21 @@ async def client_and_server(server):
266267

267268
# Yield both client session and server
268269
yield session, server
270+
271+
272+
@pytest.fixture
273+
async def run_sync_test():
274+
async def run_sync_test_inner(event_loop, server, test_function):
275+
"""This function will run the test in a different Thread.
276+
277+
This allows us to run sync code while aiohttp server can still run.
278+
"""
279+
executor = ThreadPoolExecutor(max_workers=2)
280+
test_task = event_loop.run_in_executor(executor, test_function)
281+
282+
await test_task
283+
284+
if hasattr(server, "close"):
285+
await server.close()
286+
287+
return run_sync_test_inner

tests/test_aiohttp.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,62 @@ async def handler(request):
262262
continent = result["continent"]
263263

264264
assert continent["name"] == "Europe"
265+
266+
267+
@pytest.mark.asyncio
268+
async def test_aiohttp_execute_running_in_thread(
269+
event_loop, aiohttp_server, run_sync_test
270+
):
271+
async def handler(request):
272+
return web.Response(text=query1_server_answer, content_type="application/json")
273+
274+
app = web.Application()
275+
app.router.add_route("POST", "/", handler)
276+
server = await aiohttp_server(app)
277+
278+
url = server.make_url("/")
279+
280+
def test_code():
281+
sample_transport = AIOHTTPTransport(url=url)
282+
283+
client = Client(transport=sample_transport)
284+
285+
query = gql(query1_str)
286+
287+
client.execute(query)
288+
289+
await run_sync_test(event_loop, server, test_code)
290+
291+
292+
@pytest.mark.asyncio
293+
async def test_aiohttp_subscribe_running_in_thread(
294+
event_loop, aiohttp_server, run_sync_test
295+
):
296+
async def handler(request):
297+
return web.Response(text=query1_server_answer, content_type="application/json")
298+
299+
app = web.Application()
300+
app.router.add_route("POST", "/", handler)
301+
server = await aiohttp_server(app)
302+
303+
url = server.make_url("/")
304+
305+
def test_code():
306+
sample_transport = AIOHTTPTransport(url=url)
307+
308+
client = Client(transport=sample_transport)
309+
310+
query = gql(query1_str)
311+
312+
# Note: subscriptions are not supported on the aiohttp transport
313+
# But we add this test in order to have 100% code coverage
314+
# It is to check that we will correctly set an event loop
315+
# in the subscribe function if there is none (in a Thread for example)
316+
# We cannot test this with the websockets transport because
317+
# the websockets transport will set an event loop in its init
318+
319+
with pytest.raises(NotImplementedError):
320+
for result in client.subscribe(query):
321+
pass
322+
323+
await run_sync_test(event_loop, server, test_code)

tests/test_requests.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
1-
from concurrent.futures import ThreadPoolExecutor
2-
31
import pytest
42
from aiohttp import web
53

6-
from gql import Client, gql
4+
from gql import Client, RequestsHTTPTransport, gql
75
from gql.transport.exceptions import (
86
TransportAlreadyConnected,
97
TransportClosed,
108
TransportProtocolError,
119
TransportQueryError,
1210
TransportServerError,
1311
)
14-
from gql.transport.requests import RequestsHTTPTransport
1512

1613
query1_str = """
1714
query getContinents {
@@ -31,20 +28,8 @@
3128
)
3229

3330

34-
async def run_sync_test(event_loop, server, test_function):
35-
"""This function will run the test in a different Thread.
36-
37-
This allows us to run sync code while aiohttp server can still run.
38-
"""
39-
executor = ThreadPoolExecutor(max_workers=2)
40-
test_task = event_loop.run_in_executor(executor, test_function)
41-
42-
await test_task
43-
await server.close()
44-
45-
4631
@pytest.mark.asyncio
47-
async def test_requests_query(event_loop, aiohttp_server):
32+
async def test_requests_query(event_loop, aiohttp_server, run_sync_test):
4833
async def handler(request):
4934
return web.Response(text=query1_server_answer, content_type="application/json")
5035

@@ -74,7 +59,7 @@ def test_code():
7459

7560

7661
@pytest.mark.asyncio
77-
async def test_requests_error_code_500(event_loop, aiohttp_server):
62+
async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test):
7863
async def handler(request):
7964
# Will generate http error code 500
8065
raise Exception("Server error")
@@ -102,7 +87,7 @@ def test_code():
10287

10388

10489
@pytest.mark.asyncio
105-
async def test_requests_error_code(event_loop, aiohttp_server):
90+
async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test):
10691
async def handler(request):
10792
return web.Response(
10893
text=query1_server_error_answer, content_type="application/json"
@@ -136,7 +121,9 @@ def test_code():
136121

137122
@pytest.mark.asyncio
138123
@pytest.mark.parametrize("response", invalid_protocol_responses)
139-
async def test_requests_invalid_protocol(event_loop, aiohttp_server, response):
124+
async def test_requests_invalid_protocol(
125+
event_loop, aiohttp_server, response, run_sync_test
126+
):
140127
async def handler(request):
141128
return web.Response(text=response, content_type="application/json")
142129

@@ -160,7 +147,7 @@ def test_code():
160147

161148

162149
@pytest.mark.asyncio
163-
async def test_requests_cannot_connect_twice(event_loop, aiohttp_server):
150+
async def test_requests_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test):
164151
async def handler(request):
165152
return web.Response(text=query1_server_answer, content_type="application/json")
166153

@@ -182,7 +169,9 @@ def test_code():
182169

183170

184171
@pytest.mark.asyncio
185-
async def test_requests_cannot_execute_if_not_connected(event_loop, aiohttp_server):
172+
async def test_requests_cannot_execute_if_not_connected(
173+
event_loop, aiohttp_server, run_sync_test
174+
):
186175
async def handler(request):
187176
return web.Response(text=query1_server_answer, content_type="application/json")
188177

tests/test_websocket_subscription.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,32 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str)
446446

447447
# Check that the server received a connection_terminate message last
448448
assert logged_messages.pop() == '{"type": "connection_terminate"}'
449+
450+
451+
@pytest.mark.asyncio
452+
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
453+
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
454+
async def test_websocket_subscription_running_in_thread(
455+
event_loop, server, subscription_str, run_sync_test
456+
):
457+
def test_code():
458+
path = "/graphql"
459+
url = f"ws://{server.hostname}:{server.port}{path}"
460+
sample_transport = WebsocketsTransport(url=url)
461+
462+
client = Client(transport=sample_transport)
463+
464+
count = 10
465+
subscription = gql(subscription_str.format(count=count))
466+
467+
for result in client.subscribe(subscription):
468+
469+
number = result["number"]
470+
print(f"Number received: {number}")
471+
472+
assert number == count
473+
count -= 1
474+
475+
assert count == -1
476+
477+
await run_sync_test(event_loop, server, test_code)

0 commit comments

Comments
 (0)