Skip to content

Commit df50bb1

Browse files
committed
test: adds test of async behavior of query_api.query_async()
1 parent 99b4721 commit df50bb1

File tree

2 files changed

+83
-1
lines changed

2 files changed

+83
-1
lines changed

tests/test_query.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import asyncio
22
import inspect
3+
import sys
4+
import time
5+
import traceback
36
import unittest
47
import os
58
import json
@@ -17,6 +20,7 @@
1720
from tests.util.mocks import (
1821
ConstantData,
1922
ConstantFlightServer,
23+
ConstantFlightServerDelayed,
2024
HeaderCheckFlightServer,
2125
HeaderCheckServerMiddlewareFactory,
2226
NoopAuthHandler,
@@ -27,7 +31,11 @@
2731

2832
def asyncio_run(async_func):
2933
def wrapper(*args, **kwargs):
30-
return asyncio.run(async_func(*args, **kwargs))
34+
try:
35+
return asyncio.run(async_func(*args, **kwargs))
36+
except Exception as e:
37+
print(traceback.format_exc(), file=sys.stderr)
38+
raise e
3139

3240
wrapper.__signature__ = inspect.signature(async_func)
3341
return wrapper
@@ -370,3 +378,65 @@ async def test_query_async_table(self):
370378
assert {'data': 'database', 'reference': 'my_database', 'value': -1.0} in result_list
371379
assert {'data': 'sql_query', 'reference': 'SELECT * FROM data', 'value': -1.0} in result_list
372380
assert {'data': 'query_type', 'reference': 'sql', 'value': -1.0} in result_list
381+
382+
@asyncio_run
383+
async def test_query_async_delayed(self):
384+
events = dict()
385+
with ConstantFlightServerDelayed(delay=1) as server:
386+
connection_string = f"grpc://localhost:{server.port}"
387+
token = "my_token"
388+
database = "my_database"
389+
q_api = QueryApi(
390+
connection_string=connection_string,
391+
token=token,
392+
flight_client_options={"generic_options": [('Foo', 'Bar')]},
393+
proxy=None,
394+
options=None
395+
)
396+
query = "SELECT * FROM data"
397+
398+
# coroutine to handle query_async
399+
async def local_query(query_api):
400+
events['query_start'] = time.time_ns()
401+
t_result = await query_api.query_async(query, "sql", "", database)
402+
# t_result = query_api.query(query, "sql", "", database)
403+
events['query_result'] = time.time_ns()
404+
return t_result
405+
406+
# second coroutine to run in "parallel"
407+
async def fibo(iters):
408+
events['fibo_start'] = time.time_ns()
409+
await asyncio.sleep(0.5)
410+
n0 = 1
411+
n1 = 1
412+
result = n1 + n0
413+
for _ in range(iters):
414+
n0 = n1
415+
n1 = result
416+
result = n1 + n0
417+
events['fibo_end'] = time.time_ns()
418+
return result
419+
420+
results = await asyncio.gather(local_query(q_api), fibo(50))
421+
422+
table = results[0]
423+
fibo_num = results[1]
424+
425+
# verify fibo calculation
426+
assert fibo_num == 53316291173
427+
428+
# verify constant data
429+
cd = ConstantData()
430+
431+
result_list = table.to_pylist()
432+
for item in cd.to_list():
433+
assert item in result_list
434+
435+
# verify that fibo coroutine was processed while query_async was processing
436+
# i.e. query call does not block event_loop
437+
# fibo started after query_async
438+
assert events['query_start'] < events['fibo_start'], (f"query_start: {events['query_start']} should start "
439+
f"before fibo_start: {events['fibo_start']}")
440+
# fibo ended before query_async
441+
assert events['query_result'] > events['fibo_end'], (f"query_result: {events['query_result']} should occur "
442+
f"after fibo_end: {events['fibo_end']}")

tests/util/mocks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import struct
3+
import time
34

45
from pyarrow import (
56
array,
@@ -103,6 +104,17 @@ def do_get(self, context, ticket):
103104
return RecordBatchStream(result_table, options=self.options)
104105

105106

107+
class ConstantFlightServerDelayed(ConstantFlightServer):
108+
109+
def __init__(self, location=None, options=None, delay=0.5, **kwargs):
110+
super().__init__(location, **kwargs)
111+
self.delay = delay
112+
113+
def do_get(self, context, ticket):
114+
time.sleep(self.delay)
115+
return super().do_get(context, ticket)
116+
117+
106118
class HeaderCheckServerMiddlewareFactory(ServerMiddlewareFactory):
107119
"""Factory to create HeaderCheckServerMiddleware and check header values"""
108120
def start_call(self, info, headers):

0 commit comments

Comments
 (0)