Skip to content

Commit 1480d3b

Browse files
committed
Implementation of automatic batching for async
1 parent 77a3a40 commit 1480d3b

File tree

5 files changed

+437
-77
lines changed

5 files changed

+437
-77
lines changed

gql/client.py

Lines changed: 160 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -829,15 +829,11 @@ async def connect_async(self, reconnecting=False, **kwargs):
829829

830830
if reconnecting:
831831
self.session = ReconnectingAsyncClientSession(client=self, **kwargs)
832-
await self.session.start_connecting_task()
833832
else:
834-
try:
835-
await self.transport.connect()
836-
except Exception as e:
837-
await self.transport.close()
838-
raise e
839833
self.session = AsyncClientSession(client=self)
840834

835+
await self.session.connect()
836+
841837
# Get schema from transport if needed
842838
try:
843839
if self.fetch_schema_from_transport and not self.schema:
@@ -846,18 +842,15 @@ async def connect_async(self, reconnecting=False, **kwargs):
846842
# we don't know what type of exception is thrown here because it
847843
# depends on the underlying transport; we just make sure that the
848844
# transport is closed and re-raise the exception
849-
await self.transport.close()
845+
await self.session.close()
850846
raise
851847

852848
return self.session
853849

854850
async def close_async(self):
855851
"""Close the async transport and stop the optional reconnecting task."""
856852

857-
if isinstance(self.session, ReconnectingAsyncClientSession):
858-
await self.session.stop_connecting_task()
859-
860-
await self.transport.close()
853+
await self.session.close()
861854

862855
async def __aenter__(self):
863856
return await self.connect_async()
@@ -1564,12 +1557,17 @@ async def _execute(
15641557
):
15651558
request = request.serialize_variable_values(self.client.schema)
15661559

1567-
# Execute the query with the transport with a timeout
1568-
with fail_after(self.client.execute_timeout):
1569-
result = await self.transport.execute(
1570-
request,
1571-
**kwargs,
1572-
)
1560+
# Check if batching is enabled
1561+
if self.client.batching_enabled:
1562+
future_result = await self._execute_future(request)
1563+
result = await future_result
1564+
else:
1565+
# Execute the query with the transport with a timeout
1566+
with fail_after(self.client.execute_timeout):
1567+
result = await self.transport.execute(
1568+
request,
1569+
**kwargs,
1570+
)
15731571

15741572
# Unserialize the result if requested
15751573
if self.client.schema:
@@ -1828,6 +1826,134 @@ async def execute_batch(
18281826

18291827
return cast(List[Dict[str, Any]], [result.data for result in results])
18301828

1829+
async def _batch_loop(self) -> None:
1830+
"""Main loop of the task used to wait for requests
1831+
to execute them in a batch"""
1832+
1833+
stop_loop = False
1834+
1835+
while not stop_loop:
1836+
# First wait for a first request in from the batch queue
1837+
requests_and_futures: List[Tuple[GraphQLRequest, asyncio.Future]] = []
1838+
1839+
# Wait for the first request
1840+
request_and_future: Optional[Tuple[GraphQLRequest, asyncio.Future]] = (
1841+
await self.batch_queue.get()
1842+
)
1843+
1844+
if request_and_future is None:
1845+
# None is our sentinel value to stop the loop
1846+
break
1847+
1848+
requests_and_futures.append(request_and_future)
1849+
1850+
# Then wait the requested batch interval except if we already
1851+
# have the maximum number of requests in the queue
1852+
if self.batch_queue.qsize() < self.client.batch_max - 1:
1853+
# Wait for the batch interval
1854+
await asyncio.sleep(self.client.batch_interval)
1855+
1856+
# Then get the requests which had been made during that wait interval
1857+
for _ in range(self.client.batch_max - 1):
1858+
try:
1859+
# Use get_nowait since we don't want to wait here
1860+
request_and_future = self.batch_queue.get_nowait()
1861+
1862+
if request_and_future is None:
1863+
# Sentinel value - stop after processing current batch
1864+
stop_loop = True
1865+
break
1866+
1867+
requests_and_futures.append(request_and_future)
1868+
1869+
except asyncio.QueueEmpty:
1870+
# No more requests in queue, that's fine
1871+
break
1872+
1873+
# Extract requests and futures
1874+
requests = [request for request, _ in requests_and_futures]
1875+
futures = [future for _, future in requests_and_futures]
1876+
1877+
# Execute the batch
1878+
try:
1879+
results: List[ExecutionResult] = await self._execute_batch(
1880+
requests,
1881+
serialize_variables=False, # already done
1882+
parse_result=False, # will be done later
1883+
validate_document=False, # already validated
1884+
)
1885+
1886+
# Set the result for each future
1887+
for result, future in zip(results, futures):
1888+
if not future.cancelled():
1889+
future.set_result(result)
1890+
1891+
except Exception as exc:
1892+
# If batch execution fails, propagate the error to all futures
1893+
for future in futures:
1894+
if not future.cancelled():
1895+
future.set_exception(exc)
1896+
1897+
# Signal that the task has stopped
1898+
self._batch_task_stopped_event.set()
1899+
1900+
async def _execute_future(
1901+
self,
1902+
request: GraphQLRequest,
1903+
) -> asyncio.Future:
1904+
"""If batching is enabled, this method will put a request in the batching queue
1905+
instead of executing it directly so that the requests could be put in a batch.
1906+
"""
1907+
1908+
assert hasattr(self, "batch_queue"), "Batching is not enabled"
1909+
assert not self._batch_task_stop_requested, "Batching task has been stopped"
1910+
1911+
future: asyncio.Future = asyncio.Future()
1912+
await self.batch_queue.put((request, future))
1913+
1914+
return future
1915+
1916+
async def _batch_init(self):
1917+
"""Initialize the batch task loop if batching is enabled."""
1918+
if self.client.batching_enabled:
1919+
self.batch_queue: asyncio.Queue = asyncio.Queue()
1920+
self._batch_task_stop_requested = False
1921+
self._batch_task_stopped_event = asyncio.Event()
1922+
self._batch_task = asyncio.create_task(self._batch_loop())
1923+
1924+
async def _batch_cleanup(self):
1925+
"""Cleanup the batching task if batching is enabled."""
1926+
if hasattr(self, "_batch_task_stopped_event"):
1927+
# Send a None in the queue to indicate that the batching task must stop
1928+
# after having processed the remaining requests in the queue
1929+
self._batch_task_stop_requested = True
1930+
await self.batch_queue.put(None)
1931+
1932+
# Wait for the task to process remaining requests and stop
1933+
await self._batch_task_stopped_event.wait()
1934+
1935+
async def connect(self):
1936+
"""Connect the transport and initialize the batch task loop if batching
1937+
is enabled."""
1938+
1939+
await self._batch_init()
1940+
1941+
try:
1942+
await self.transport.connect()
1943+
except Exception as e:
1944+
await self.transport.close()
1945+
raise e
1946+
1947+
async def close(self):
1948+
"""Close the transport and cleanup the batching task if batching is enabled.
1949+
1950+
Will wait until all the remaining requests in the batch processing queue
1951+
have been executed.
1952+
"""
1953+
await self._batch_cleanup()
1954+
1955+
await self.transport.close()
1956+
18311957
async def fetch_schema(self) -> None:
18321958
"""Fetch the GraphQL schema explicitly using introspection.
18331959
@@ -1954,6 +2080,23 @@ async def stop_connecting_task(self):
19542080
self._connect_task.cancel()
19552081
self._connect_task = None
19562082

2083+
async def connect(self):
2084+
"""Start the connect task and initialize the batch task loop if batching
2085+
is enabled."""
2086+
2087+
await self._batch_init()
2088+
2089+
await self.start_connecting_task()
2090+
2091+
async def close(self):
2092+
"""Stop the connect task and cleanup the batching task
2093+
if batching is enabled."""
2094+
await self._batch_cleanup()
2095+
2096+
await self.stop_connecting_task()
2097+
2098+
await self.transport.close()
2099+
19572100
async def _execute_once(
19582101
self,
19592102
request: GraphQLRequest,

gql/transport/aiohttp.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -274,22 +274,35 @@ def _prepare_file_uploads(
274274

275275
return post_args
276276

277-
async def raise_response_error(
278-
self,
277+
@staticmethod
278+
def _raise_transport_server_error_if_status_more_than_400(
279279
resp: aiohttp.ClientResponse,
280-
reason: str,
281280
) -> None:
282-
# We raise a TransportServerError if status code is 400 or higher
283-
# We raise a TransportProtocolError in the other cases
284-
281+
# If the status is >400,
282+
# then we need to raise a TransportServerError
285283
try:
286284
# Raise ClientResponseError if response status is 400 or higher
287285
resp.raise_for_status()
288286
except ClientResponseError as e:
289287
raise TransportServerError(str(e), e.status) from e
290288

289+
@classmethod
290+
async def _raise_response_error(
291+
cls,
292+
resp: aiohttp.ClientResponse,
293+
reason: str,
294+
) -> None:
295+
# We raise a TransportServerError if status code is 400 or higher
296+
# We raise a TransportProtocolError in the other cases
297+
298+
cls._raise_transport_server_error_if_status_more_than_400(resp)
299+
291300
result_text = await resp.text()
292-
self._raise_invalid_result(result_text, reason)
301+
raise TransportProtocolError(
302+
f"Server did not return a valid GraphQL result: "
303+
f"{reason}: "
304+
f"{result_text}"
305+
)
293306

294307
async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any:
295308

@@ -304,10 +317,10 @@ async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any:
304317
log.debug("<<< %s", result_text)
305318

306319
except Exception:
307-
await self.raise_response_error(response, "Not a JSON answer")
320+
await self._raise_response_error(response, "Not a JSON answer")
308321

309322
if result is None:
310-
await self.raise_response_error(response, "Not a JSON answer")
323+
await self._raise_response_error(response, "Not a JSON answer")
311324

312325
return result
313326

@@ -318,7 +331,7 @@ async def _prepare_result(
318331
result = await self._get_json_result(response)
319332

320333
if "errors" not in result and "data" not in result:
321-
await self.raise_response_error(
334+
await self._raise_response_error(
322335
response, 'No "data" or "errors" keys in answer'
323336
)
324337

@@ -336,14 +349,13 @@ async def _prepare_batch_result(
336349

337350
answers = await self._get_json_result(response)
338351

339-
return get_batch_execution_result_list(reqs, answers)
340-
341-
def _raise_invalid_result(self, result_text: str, reason: str) -> None:
342-
raise TransportProtocolError(
343-
f"Server did not return a valid GraphQL result: "
344-
f"{reason}: "
345-
f"{result_text}"
346-
)
352+
try:
353+
return get_batch_execution_result_list(reqs, answers)
354+
except TransportProtocolError:
355+
# Raise a TransportServerError if status > 400
356+
self._raise_transport_server_error_if_status_more_than_400(response)
357+
# In other cases, raise a TransportProtocolError
358+
raise
347359

348360
async def execute(
349361
self,

gql/transport/httpx.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,18 +195,33 @@ def _prepare_batch_result(
195195

196196
answers = self._get_json_result(response)
197197

198-
return get_batch_execution_result_list(reqs, answers)
199-
200-
def _raise_response_error(self, response: httpx.Response, reason: str) -> NoReturn:
201-
# We raise a TransportServerError if the status code is 400 or higher
202-
# We raise a TransportProtocolError in the other cases
203-
204198
try:
205-
# Raise a HTTPError if response status is 400 or higher
199+
return get_batch_execution_result_list(reqs, answers)
200+
except TransportProtocolError:
201+
# Raise a TransportServerError if status > 400
202+
self._raise_transport_server_error_if_status_more_than_400(response)
203+
# In other cases, raise a TransportProtocolError
204+
raise
205+
206+
@staticmethod
207+
def _raise_transport_server_error_if_status_more_than_400(
208+
response: httpx.Response,
209+
) -> None:
210+
# If the status is >400,
211+
# then we need to raise a TransportServerError
212+
try:
213+
# Raise a HTTPStatusError if response status is 400 or higher
206214
response.raise_for_status()
207215
except httpx.HTTPStatusError as e:
208216
raise TransportServerError(str(e), e.response.status_code) from e
209217

218+
@classmethod
219+
def _raise_response_error(cls, response: httpx.Response, reason: str) -> NoReturn:
220+
# We raise a TransportServerError if the status code is 400 or higher
221+
# We raise a TransportProtocolError in the other cases
222+
223+
cls._raise_transport_server_error_if_status_more_than_400(response)
224+
210225
raise TransportProtocolError(
211226
f"Server did not return a GraphQL result: " f"{reason}: " f"{response.text}"
212227
)

0 commit comments

Comments
 (0)