Skip to content

Commit 4d17837

Browse files
committed
feat: (WIP) adds client.query_async() plus first tests
1 parent 91004b1 commit 4d17837

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed

influxdb_client_3/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,22 @@ def query(self, query: str, language: str = "sql", mode: str = "all", database:
278278
except InfluxDBError as e:
279279
raise e
280280

281+
async def query_async(self, query: str, language: str = "sql", mode: str = "all", database: str = None, **kwargs):
282+
if mode == "polars" and polars is False:
283+
raise ImportError("Polars is not installed. Please install it with `pip install polars`.")
284+
285+
if database is None:
286+
database = self._database
287+
288+
try:
289+
return await self._query_api.query_async(query=query,
290+
language=language,
291+
mode=mode,
292+
database=database,
293+
**kwargs)
294+
except InfluxDBError as e:
295+
raise e
296+
281297
def close(self):
282298
"""Close the client and clean up resources."""
283299
self._write_api.close()

tests/test_query.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import inspect
13
import unittest
24
import struct
35
import os
@@ -6,6 +8,7 @@
68

79
from pyarrow import (
810
array,
11+
concat_tables,
912
Table
1013
)
1114

@@ -14,6 +17,7 @@
1417
FlightServerBase,
1518
FlightUnauthenticatedError,
1619
GeneratorStream,
20+
RecordBatchStream,
1721
ServerMiddleware,
1822
ServerMiddlewareFactory,
1923
ServerAuthHandler,
@@ -25,6 +29,14 @@
2529
from influxdb_client_3.version import USER_AGENT
2630

2731

32+
def asyncio_run(async_func):
33+
def wrapper(*args, **kwargs):
34+
return asyncio.run(async_func(*args, **kwargs))
35+
36+
wrapper.__signature__ = inspect.signature(async_func)
37+
return wrapper
38+
39+
2840
def case_insensitive_header_lookup(headers, lkey):
2941
"""Lookup the value of a given key in the given headers.
3042
The lkey is case-insensitive.
@@ -368,3 +380,106 @@ def test_prepare_query(self):
368380
client.do_get(ticket, options)
369381
assert _req_headers['authorization'] == [f"Bearer {token}"]
370382
_req_headers = {}
383+
384+
@asyncio_run
385+
async def test_query_async_pandas(self):
386+
with ConstantFlightServer() as server:
387+
connection_string = f"grpc://localhost:{server.port}"
388+
token = "my_token"
389+
database = "my_database"
390+
q_api = QueryApi(
391+
connection_string=connection_string,
392+
token=token,
393+
flight_client_options={"generic_options": [('Foo', 'Bar')]},
394+
proxy=None,
395+
options=None
396+
)
397+
398+
query = "SELECT * FROM data"
399+
pndf = await q_api.query_async(query, "sql", "pandas", database)
400+
401+
cd = ConstantData()
402+
numpy_array = pndf.T.to_numpy()
403+
tuples = []
404+
for n in range(len(numpy_array[0])):
405+
tuples.append((numpy_array[0][n], numpy_array[1][n], numpy_array[2][n]))
406+
407+
for constant in cd.to_tuples():
408+
assert constant in tuples
409+
410+
assert ('sql_query', query, -1.0) in tuples
411+
assert ('database', database, -1.0) in tuples
412+
assert ('query_type', 'sql', -1.0) in tuples
413+
414+
@asyncio_run
415+
async def test_query_async_table(self):
416+
with ConstantFlightServer() as server:
417+
connection_string = f"grpc://localhost:{server.port}"
418+
token = "my_token"
419+
database = "my_database"
420+
q_api = QueryApi(
421+
connection_string=connection_string,
422+
token=token,
423+
flight_client_options={"generic_options": [('Foo', 'Bar')]},
424+
proxy=None,
425+
options=None
426+
)
427+
query = "SELECT * FROM data"
428+
table = await q_api.query_async(query, "sql", "", database)
429+
430+
cd = ConstantData()
431+
432+
result_list = table.to_pylist()
433+
for item in cd.to_list():
434+
assert item in result_list
435+
436+
assert {'data': 'database', 'reference': 'my_database', 'value': -1.0} in result_list
437+
assert {'data': 'sql_query', 'reference': 'SELECT * FROM data', 'value': -1.0} in result_list
438+
assert {'data': 'query_type', 'reference': 'sql', 'value': -1.0} in result_list
439+
440+
441+
class ConstantData:
442+
443+
def __init__(self):
444+
self.data = [
445+
array(['temp', 'temp', 'temp']),
446+
array(['kitchen', 'common', 'foyer']),
447+
array([36.9, 25.7, 9.8])
448+
]
449+
self.names = ['data', 'reference', 'value']
450+
451+
def to_tuples(self):
452+
response = []
453+
for n in range(3):
454+
response.append((self.data[0][n].as_py(), self.data[1][n].as_py(), self.data[2][n].as_py()))
455+
return response
456+
457+
def to_list(self):
458+
response = []
459+
for it in range(len(self.data[0])):
460+
item = {}
461+
for o in range(len(self.names)):
462+
item[self.names[o]] = self.data[o][it].as_py()
463+
response.append(item)
464+
return response
465+
466+
467+
class ConstantFlightServer(FlightServerBase):
468+
469+
def __init__(self, location=None, options=None, **kwargs):
470+
super().__init__(location, **kwargs)
471+
self.cd = ConstantData()
472+
self.options = options
473+
474+
# respond with Constant Data plus fields from ticket
475+
def do_get(self, context, ticket):
476+
result_table = Table.from_arrays(self.cd.data, names=self.cd.names)
477+
tkt = json.loads(ticket.ticket.decode('utf-8'))
478+
for key in tkt.keys():
479+
tkt_data = [
480+
array([key]),
481+
array([tkt[key]]),
482+
array([-1.0])
483+
]
484+
result_table = concat_tables([result_table, Table.from_arrays(tkt_data, names=self.cd.names)])
485+
return RecordBatchStream(result_table, options=self.options)

0 commit comments

Comments
 (0)