Skip to content

Commit 99b4721

Browse files
committed
test: refactor - move arrow flight server mocks to util/mocks.py package
1 parent 4d17837 commit 99b4721

File tree

3 files changed

+165
-129
lines changed

3 files changed

+165
-129
lines changed

tests/test_query.py

Lines changed: 16 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,29 @@
11
import asyncio
22
import inspect
33
import unittest
4-
import struct
54
import os
65
import json
76
from unittest.mock import Mock, ANY
87

9-
from pyarrow import (
10-
array,
11-
concat_tables,
12-
Table
13-
)
14-
158
from pyarrow.flight import (
169
FlightClient,
17-
FlightServerBase,
18-
FlightUnauthenticatedError,
19-
GeneratorStream,
20-
RecordBatchStream,
21-
ServerMiddleware,
22-
ServerMiddlewareFactory,
23-
ServerAuthHandler,
2410
Ticket
2511
)
2612

2713
from influxdb_client_3 import InfluxDBClient3
2814
from influxdb_client_3.query.query_api import QueryApiOptionsBuilder, QueryApi
2915
from influxdb_client_3.version import USER_AGENT
3016

17+
from tests.util.mocks import (
18+
ConstantData,
19+
ConstantFlightServer,
20+
HeaderCheckFlightServer,
21+
HeaderCheckServerMiddlewareFactory,
22+
NoopAuthHandler,
23+
get_req_headers,
24+
set_req_headers
25+
)
26+
3127

3228
def asyncio_run(async_func):
3329
def wrapper(*args, **kwargs):
@@ -46,87 +42,24 @@ def case_insensitive_header_lookup(headers, lkey):
4642
return headers.get(key)
4743

4844

49-
class NoopAuthHandler(ServerAuthHandler):
50-
"""A no-op auth handler - as seen in pyarrow tests"""
51-
52-
def authenticate(self, outgoing, incoming):
53-
"""Do nothing"""
54-
55-
def is_valid(self, token):
56-
"""
57-
Return an empty string
58-
N.B. Returning None causes Type error
59-
:param token:
60-
:return:
61-
"""
62-
return ""
63-
64-
65-
_req_headers = {}
66-
67-
68-
class HeaderCheckServerMiddlewareFactory(ServerMiddlewareFactory):
69-
"""Factory to create HeaderCheckServerMiddleware and check header values"""
70-
def start_call(self, info, headers):
71-
auth_header = case_insensitive_header_lookup(headers, "Authorization")
72-
values = auth_header[0].split(' ')
73-
if values[0] != 'Bearer':
74-
raise FlightUnauthenticatedError("Token required")
75-
global _req_headers
76-
_req_headers = headers
77-
return HeaderCheckServerMiddleware(values[1])
78-
79-
80-
class HeaderCheckServerMiddleware(ServerMiddleware):
81-
"""
82-
Middleware needed to catch request headers via factory
83-
N.B. As found in pyarrow tests
84-
"""
85-
def __init__(self, token, *args, **kwargs):
86-
super().__init__(*args, **kwargs)
87-
self.token = token
88-
89-
def sending_headers(self):
90-
return {'authorization': 'Bearer ' + self.token}
91-
92-
93-
class HeaderCheckFlightServer(FlightServerBase):
94-
"""Mock server handle gRPC do_get calls"""
95-
def do_get(self, context, ticket):
96-
"""Return something to avoid needless errors"""
97-
data = [
98-
array([b"Vltava", struct.pack('<i', 105), b"FM"])
99-
]
100-
table = Table.from_arrays(data, names=['a'])
101-
return GeneratorStream(
102-
table.schema,
103-
self.number_batches(table),
104-
options={})
105-
106-
@staticmethod
107-
def number_batches(table):
108-
for idx, batch in enumerate(table.to_batches()):
109-
buf = struct.pack('<i', idx)
110-
yield batch, buf
111-
112-
11345
def test_influx_default_query_headers():
11446
with HeaderCheckFlightServer(
11547
auth_handler=NoopAuthHandler(),
11648
middleware={"check": HeaderCheckServerMiddlewareFactory()}) as server:
11749
global _req_headers
118-
_req_headers = {}
50+
set_req_headers({})
11951
client = InfluxDBClient3(
12052
host=f'http://localhost:{server.port}',
12153
org='test_org',
12254
databse='test_db',
12355
token='TEST_TOKEN'
12456
)
12557
client.query('SELECT * FROM test')
58+
_req_headers = get_req_headers()
12659
assert len(_req_headers) > 0
12760
assert _req_headers['authorization'][0] == "Bearer TEST_TOKEN"
12861
assert _req_headers['user-agent'][0].find(USER_AGENT) > -1
129-
_req_headers = {}
62+
set_req_headers({})
13063

13164

13265
class TestQuery(unittest.TestCase):
@@ -351,7 +284,7 @@ def test_secondary_user_agent_in_options(self):
351284
q_api._flight_client_options['generic_options'])
352285

353286
def test_prepare_query(self):
354-
global _req_headers
287+
set_req_headers({})
355288
token = 'my_token'
356289
q_api = QueryApi(
357290
connection_string="grpc+tls://my-server.org",
@@ -378,8 +311,9 @@ def test_prepare_query(self):
378311
middleware={"check": HeaderCheckServerMiddlewareFactory()}) as server:
379312
with FlightClient(('localhost', server.port)) as client:
380313
client.do_get(ticket, options)
314+
_req_headers = get_req_headers()
381315
assert _req_headers['authorization'] == [f"Bearer {token}"]
382-
_req_headers = {}
316+
set_req_headers({})
383317

384318
@asyncio_run
385319
async def test_query_async_pandas(self):
@@ -436,50 +370,3 @@ async def test_query_async_table(self):
436370
assert {'data': 'database', 'reference': 'my_database', 'value': -1.0} in result_list
437371
assert {'data': 'sql_query', 'reference': 'SELECT * FROM data', 'value': -1.0} in result_list
438372
assert {'data': 'query_type', 'reference': 'sql', 'value': -1.0} in result_list
439-
440-
441-
class ConstantData:
442-
443-
def __init__(self):
444-
self.data = [
445-
array(['temp', 'temp', 'temp']),
446-
array(['kitchen', 'common', 'foyer']),
447-
array([36.9, 25.7, 9.8])
448-
]
449-
self.names = ['data', 'reference', 'value']
450-
451-
def to_tuples(self):
452-
response = []
453-
for n in range(3):
454-
response.append((self.data[0][n].as_py(), self.data[1][n].as_py(), self.data[2][n].as_py()))
455-
return response
456-
457-
def to_list(self):
458-
response = []
459-
for it in range(len(self.data[0])):
460-
item = {}
461-
for o in range(len(self.names)):
462-
item[self.names[o]] = self.data[o][it].as_py()
463-
response.append(item)
464-
return response
465-
466-
467-
class ConstantFlightServer(FlightServerBase):
468-
469-
def __init__(self, location=None, options=None, **kwargs):
470-
super().__init__(location, **kwargs)
471-
self.cd = ConstantData()
472-
self.options = options
473-
474-
# respond with Constant Data plus fields from ticket
475-
def do_get(self, context, ticket):
476-
result_table = Table.from_arrays(self.cd.data, names=self.cd.names)
477-
tkt = json.loads(ticket.ticket.decode('utf-8'))
478-
for key in tkt.keys():
479-
tkt_data = [
480-
array([key]),
481-
array([tkt[key]]),
482-
array([-1.0])
483-
]
484-
result_table = concat_tables([result_table, Table.from_arrays(tkt_data, names=self.cd.names)])
485-
return RecordBatchStream(result_table, options=self.options)

tests/util/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Package for tests/util module."""

tests/util/mocks.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import json
2+
import struct
3+
4+
from pyarrow import (
5+
array,
6+
Table,
7+
concat_tables
8+
)
9+
10+
from pyarrow.flight import (
11+
FlightServerBase,
12+
RecordBatchStream,
13+
ServerMiddlewareFactory,
14+
FlightUnauthenticatedError,
15+
ServerMiddleware,
16+
GeneratorStream,
17+
ServerAuthHandler
18+
)
19+
20+
21+
class NoopAuthHandler(ServerAuthHandler):
22+
"""A no-op auth handler - as seen in pyarrow tests"""
23+
24+
def authenticate(self, outgoing, incoming):
25+
"""Do nothing"""
26+
27+
def is_valid(self, token):
28+
"""
29+
Return an empty string
30+
N.B. Returning None causes Type error
31+
:param token:
32+
:return:
33+
"""
34+
return ""
35+
36+
37+
def case_insensitive_header_lookup(headers, lkey):
38+
"""Lookup the value of a given key in the given headers.
39+
The lkey is case-insensitive.
40+
"""
41+
for key in headers:
42+
if key.lower() == lkey.lower():
43+
return headers.get(key)
44+
45+
46+
req_headers = {}
47+
48+
49+
def set_req_headers(headers):
50+
global req_headers
51+
req_headers = headers
52+
53+
54+
def get_req_headers():
55+
global req_headers
56+
return req_headers
57+
58+
59+
class ConstantData:
60+
61+
def __init__(self):
62+
self.data = [
63+
array(['temp', 'temp', 'temp']),
64+
array(['kitchen', 'common', 'foyer']),
65+
array([36.9, 25.7, 9.8])
66+
]
67+
self.names = ['data', 'reference', 'value']
68+
69+
def to_tuples(self):
70+
response = []
71+
for n in range(3):
72+
response.append((self.data[0][n].as_py(), self.data[1][n].as_py(), self.data[2][n].as_py()))
73+
return response
74+
75+
def to_list(self):
76+
response = []
77+
for it in range(len(self.data[0])):
78+
item = {}
79+
for o in range(len(self.names)):
80+
item[self.names[o]] = self.data[o][it].as_py()
81+
response.append(item)
82+
return response
83+
84+
85+
class ConstantFlightServer(FlightServerBase):
86+
87+
def __init__(self, location=None, options=None, **kwargs):
88+
super().__init__(location, **kwargs)
89+
self.cd = ConstantData()
90+
self.options = options
91+
92+
# respond with Constant Data plus fields from ticket
93+
def do_get(self, context, ticket):
94+
result_table = Table.from_arrays(self.cd.data, names=self.cd.names)
95+
tkt = json.loads(ticket.ticket.decode('utf-8'))
96+
for key in tkt.keys():
97+
tkt_data = [
98+
array([key]),
99+
array([tkt[key]]),
100+
array([-1.0])
101+
]
102+
result_table = concat_tables([result_table, Table.from_arrays(tkt_data, names=self.cd.names)])
103+
return RecordBatchStream(result_table, options=self.options)
104+
105+
106+
class HeaderCheckServerMiddlewareFactory(ServerMiddlewareFactory):
107+
"""Factory to create HeaderCheckServerMiddleware and check header values"""
108+
def start_call(self, info, headers):
109+
auth_header = case_insensitive_header_lookup(headers, "Authorization")
110+
values = auth_header[0].split(' ')
111+
if values[0] != 'Bearer':
112+
raise FlightUnauthenticatedError("Token required")
113+
global req_headers
114+
req_headers = headers
115+
return HeaderCheckServerMiddleware(values[1])
116+
117+
118+
class HeaderCheckServerMiddleware(ServerMiddleware):
119+
"""
120+
Middleware needed to catch request headers via factory
121+
N.B. As found in pyarrow tests
122+
"""
123+
def __init__(self, token, *args, **kwargs):
124+
super().__init__(*args, **kwargs)
125+
self.token = token
126+
127+
def sending_headers(self):
128+
return {'authorization': 'Bearer ' + self.token}
129+
130+
131+
class HeaderCheckFlightServer(FlightServerBase):
132+
"""Mock server handle gRPC do_get calls"""
133+
def do_get(self, context, ticket):
134+
"""Return something to avoid needless errors"""
135+
data = [
136+
array([b"Vltava", struct.pack('<i', 105), b"FM"])
137+
]
138+
table = Table.from_arrays(data, names=['a'])
139+
return GeneratorStream(
140+
table.schema,
141+
self.number_batches(table),
142+
options={})
143+
144+
@staticmethod
145+
def number_batches(table):
146+
for idx, batch in enumerate(table.to_batches()):
147+
buf = struct.pack('<i', idx)
148+
yield batch, buf

0 commit comments

Comments
 (0)