|
| 1 | +import asyncio |
| 2 | +import inspect |
1 | 3 | import unittest |
2 | 4 | import struct |
3 | 5 | import os |
|
6 | 8 |
|
7 | 9 | from pyarrow import ( |
8 | 10 | array, |
| 11 | + concat_tables, |
9 | 12 | Table |
10 | 13 | ) |
11 | 14 |
|
|
14 | 17 | FlightServerBase, |
15 | 18 | FlightUnauthenticatedError, |
16 | 19 | GeneratorStream, |
| 20 | + RecordBatchStream, |
17 | 21 | ServerMiddleware, |
18 | 22 | ServerMiddlewareFactory, |
19 | 23 | ServerAuthHandler, |
|
25 | 29 | from influxdb_client_3.version import USER_AGENT |
26 | 30 |
|
27 | 31 |
|
| 32 | +def asyncio_run(async_func): |
| 33 | + def wrapper(*args, **kwargs): |
| 34 | + return asyncio.run(async_func(*args, **kwargs)) |
| 35 | + |
| 36 | + wrapper.__signature__ = inspect.signature(async_func) |
| 37 | + return wrapper |
| 38 | + |
| 39 | + |
28 | 40 | def case_insensitive_header_lookup(headers, lkey): |
29 | 41 | """Lookup the value of a given key in the given headers. |
30 | 42 | The lkey is case-insensitive. |
@@ -368,3 +380,106 @@ def test_prepare_query(self): |
368 | 380 | client.do_get(ticket, options) |
369 | 381 | assert _req_headers['authorization'] == [f"Bearer {token}"] |
370 | 382 | _req_headers = {} |
| 383 | + |
| 384 | + @asyncio_run |
| 385 | + async def test_query_async_pandas(self): |
| 386 | + with ConstantFlightServer() as server: |
| 387 | + connection_string = f"grpc://localhost:{server.port}" |
| 388 | + token = "my_token" |
| 389 | + database = "my_database" |
| 390 | + q_api = QueryApi( |
| 391 | + connection_string=connection_string, |
| 392 | + token=token, |
| 393 | + flight_client_options={"generic_options": [('Foo', 'Bar')]}, |
| 394 | + proxy=None, |
| 395 | + options=None |
| 396 | + ) |
| 397 | + |
| 398 | + query = "SELECT * FROM data" |
| 399 | + pndf = await q_api.query_async(query, "sql", "pandas", database) |
| 400 | + |
| 401 | + cd = ConstantData() |
| 402 | + numpy_array = pndf.T.to_numpy() |
| 403 | + tuples = [] |
| 404 | + for n in range(len(numpy_array[0])): |
| 405 | + tuples.append((numpy_array[0][n], numpy_array[1][n], numpy_array[2][n])) |
| 406 | + |
| 407 | + for constant in cd.to_tuples(): |
| 408 | + assert constant in tuples |
| 409 | + |
| 410 | + assert ('sql_query', query, -1.0) in tuples |
| 411 | + assert ('database', database, -1.0) in tuples |
| 412 | + assert ('query_type', 'sql', -1.0) in tuples |
| 413 | + |
| 414 | + @asyncio_run |
| 415 | + async def test_query_async_table(self): |
| 416 | + with ConstantFlightServer() as server: |
| 417 | + connection_string = f"grpc://localhost:{server.port}" |
| 418 | + token = "my_token" |
| 419 | + database = "my_database" |
| 420 | + q_api = QueryApi( |
| 421 | + connection_string=connection_string, |
| 422 | + token=token, |
| 423 | + flight_client_options={"generic_options": [('Foo', 'Bar')]}, |
| 424 | + proxy=None, |
| 425 | + options=None |
| 426 | + ) |
| 427 | + query = "SELECT * FROM data" |
| 428 | + table = await q_api.query_async(query, "sql", "", database) |
| 429 | + |
| 430 | + cd = ConstantData() |
| 431 | + |
| 432 | + result_list = table.to_pylist() |
| 433 | + for item in cd.to_list(): |
| 434 | + assert item in result_list |
| 435 | + |
| 436 | + assert {'data': 'database', 'reference': 'my_database', 'value': -1.0} in result_list |
| 437 | + assert {'data': 'sql_query', 'reference': 'SELECT * FROM data', 'value': -1.0} in result_list |
| 438 | + assert {'data': 'query_type', 'reference': 'sql', 'value': -1.0} in result_list |
| 439 | + |
| 440 | + |
| 441 | +class ConstantData: |
| 442 | + |
| 443 | + def __init__(self): |
| 444 | + self.data = [ |
| 445 | + array(['temp', 'temp', 'temp']), |
| 446 | + array(['kitchen', 'common', 'foyer']), |
| 447 | + array([36.9, 25.7, 9.8]) |
| 448 | + ] |
| 449 | + self.names = ['data', 'reference', 'value'] |
| 450 | + |
| 451 | + def to_tuples(self): |
| 452 | + response = [] |
| 453 | + for n in range(3): |
| 454 | + response.append((self.data[0][n].as_py(), self.data[1][n].as_py(), self.data[2][n].as_py())) |
| 455 | + return response |
| 456 | + |
| 457 | + def to_list(self): |
| 458 | + response = [] |
| 459 | + for it in range(len(self.data[0])): |
| 460 | + item = {} |
| 461 | + for o in range(len(self.names)): |
| 462 | + item[self.names[o]] = self.data[o][it].as_py() |
| 463 | + response.append(item) |
| 464 | + return response |
| 465 | + |
| 466 | + |
| 467 | +class ConstantFlightServer(FlightServerBase): |
| 468 | + |
| 469 | + def __init__(self, location=None, options=None, **kwargs): |
| 470 | + super().__init__(location, **kwargs) |
| 471 | + self.cd = ConstantData() |
| 472 | + self.options = options |
| 473 | + |
| 474 | + # respond with Constant Data plus fields from ticket |
| 475 | + def do_get(self, context, ticket): |
| 476 | + result_table = Table.from_arrays(self.cd.data, names=self.cd.names) |
| 477 | + tkt = json.loads(ticket.ticket.decode('utf-8')) |
| 478 | + for key in tkt.keys(): |
| 479 | + tkt_data = [ |
| 480 | + array([key]), |
| 481 | + array([tkt[key]]), |
| 482 | + array([-1.0]) |
| 483 | + ] |
| 484 | + result_table = concat_tables([result_table, Table.from_arrays(tkt_data, names=self.cd.names)]) |
| 485 | + return RecordBatchStream(result_table, options=self.options) |
0 commit comments