|
1 | 1 | import asyncio |
2 | 2 | import inspect |
| 3 | +import sys |
| 4 | +import time |
| 5 | +import traceback |
3 | 6 | import unittest |
4 | 7 | import os |
5 | 8 | import json |
|
17 | 20 | from tests.util.mocks import ( |
18 | 21 | ConstantData, |
19 | 22 | ConstantFlightServer, |
| 23 | + ConstantFlightServerDelayed, |
20 | 24 | HeaderCheckFlightServer, |
21 | 25 | HeaderCheckServerMiddlewareFactory, |
22 | 26 | NoopAuthHandler, |
|
27 | 31 |
|
28 | 32 | def asyncio_run(async_func): |
29 | 33 | def wrapper(*args, **kwargs): |
30 | | - return asyncio.run(async_func(*args, **kwargs)) |
| 34 | + try: |
| 35 | + return asyncio.run(async_func(*args, **kwargs)) |
| 36 | + except Exception as e: |
| 37 | + print(traceback.format_exc(), file=sys.stderr) |
| 38 | + raise e |
31 | 39 |
|
32 | 40 | wrapper.__signature__ = inspect.signature(async_func) |
33 | 41 | return wrapper |
@@ -370,3 +378,65 @@ async def test_query_async_table(self): |
370 | 378 | assert {'data': 'database', 'reference': 'my_database', 'value': -1.0} in result_list |
371 | 379 | assert {'data': 'sql_query', 'reference': 'SELECT * FROM data', 'value': -1.0} in result_list |
372 | 380 | assert {'data': 'query_type', 'reference': 'sql', 'value': -1.0} in result_list |
| 381 | + |
| 382 | + @asyncio_run |
| 383 | + async def test_query_async_delayed(self): |
| 384 | + events = dict() |
| 385 | + with ConstantFlightServerDelayed(delay=1) as server: |
| 386 | + connection_string = f"grpc://localhost:{server.port}" |
| 387 | + token = "my_token" |
| 388 | + database = "my_database" |
| 389 | + q_api = QueryApi( |
| 390 | + connection_string=connection_string, |
| 391 | + token=token, |
| 392 | + flight_client_options={"generic_options": [('Foo', 'Bar')]}, |
| 393 | + proxy=None, |
| 394 | + options=None |
| 395 | + ) |
| 396 | + query = "SELECT * FROM data" |
| 397 | + |
| 398 | + # coroutine to handle query_async |
| 399 | + async def local_query(query_api): |
| 400 | + events['query_start'] = time.time_ns() |
| 401 | + t_result = await query_api.query_async(query, "sql", "", database) |
| 402 | + # t_result = query_api.query(query, "sql", "", database) |
| 403 | + events['query_result'] = time.time_ns() |
| 404 | + return t_result |
| 405 | + |
| 406 | + # second coroutine to run in "parallel" |
| 407 | + async def fibo(iters): |
| 408 | + events['fibo_start'] = time.time_ns() |
| 409 | + await asyncio.sleep(0.5) |
| 410 | + n0 = 1 |
| 411 | + n1 = 1 |
| 412 | + result = n1 + n0 |
| 413 | + for _ in range(iters): |
| 414 | + n0 = n1 |
| 415 | + n1 = result |
| 416 | + result = n1 + n0 |
| 417 | + events['fibo_end'] = time.time_ns() |
| 418 | + return result |
| 419 | + |
| 420 | + results = await asyncio.gather(local_query(q_api), fibo(50)) |
| 421 | + |
| 422 | + table = results[0] |
| 423 | + fibo_num = results[1] |
| 424 | + |
| 425 | + # verify fibo calculation |
| 426 | + assert fibo_num == 53316291173 |
| 427 | + |
| 428 | + # verify constant data |
| 429 | + cd = ConstantData() |
| 430 | + |
| 431 | + result_list = table.to_pylist() |
| 432 | + for item in cd.to_list(): |
| 433 | + assert item in result_list |
| 434 | + |
| 435 | + # verify that fibo coroutine was processed while query_async was processing |
| 436 | + # i.e. query call does not block event_loop |
| 437 | + # fibo started after query_async |
| 438 | + assert events['query_start'] < events['fibo_start'], (f"query_start: {events['query_start']} should start " |
| 439 | + f"before fibo_start: {events['fibo_start']}") |
| 440 | + # fibo ended before query_async |
| 441 | + assert events['query_result'] > events['fibo_end'], (f"query_result: {events['query_result']} should occur " |
| 442 | + f"after fibo_end: {events['fibo_end']}") |
0 commit comments