11import asyncio
22import inspect
33import unittest
4- import struct
54import os
65import json
76from unittest .mock import Mock , ANY
87
9- from pyarrow import (
10- array ,
11- concat_tables ,
12- Table
13- )
14-
158from pyarrow .flight import (
169 FlightClient ,
17- FlightServerBase ,
18- FlightUnauthenticatedError ,
19- GeneratorStream ,
20- RecordBatchStream ,
21- ServerMiddleware ,
22- ServerMiddlewareFactory ,
23- ServerAuthHandler ,
2410 Ticket
2511)
2612
2713from influxdb_client_3 import InfluxDBClient3
2814from influxdb_client_3 .query .query_api import QueryApiOptionsBuilder , QueryApi
2915from influxdb_client_3 .version import USER_AGENT
3016
17+ from tests .util .mocks import (
18+ ConstantData ,
19+ ConstantFlightServer ,
20+ HeaderCheckFlightServer ,
21+ HeaderCheckServerMiddlewareFactory ,
22+ NoopAuthHandler ,
23+ get_req_headers ,
24+ set_req_headers
25+ )
26+
3127
3228def asyncio_run (async_func ):
3329 def wrapper (* args , ** kwargs ):
@@ -46,87 +42,24 @@ def case_insensitive_header_lookup(headers, lkey):
4642 return headers .get (key )
4743
4844
49- class NoopAuthHandler (ServerAuthHandler ):
50- """A no-op auth handler - as seen in pyarrow tests"""
51-
52- def authenticate (self , outgoing , incoming ):
53- """Do nothing"""
54-
55- def is_valid (self , token ):
56- """
57- Return an empty string
58- N.B. Returning None causes Type error
59- :param token:
60- :return:
61- """
62- return ""
63-
64-
65- _req_headers = {}
66-
67-
68- class HeaderCheckServerMiddlewareFactory (ServerMiddlewareFactory ):
69- """Factory to create HeaderCheckServerMiddleware and check header values"""
70- def start_call (self , info , headers ):
71- auth_header = case_insensitive_header_lookup (headers , "Authorization" )
72- values = auth_header [0 ].split (' ' )
73- if values [0 ] != 'Bearer' :
74- raise FlightUnauthenticatedError ("Token required" )
75- global _req_headers
76- _req_headers = headers
77- return HeaderCheckServerMiddleware (values [1 ])
78-
79-
80- class HeaderCheckServerMiddleware (ServerMiddleware ):
81- """
82- Middleware needed to catch request headers via factory
83- N.B. As found in pyarrow tests
84- """
85- def __init__ (self , token , * args , ** kwargs ):
86- super ().__init__ (* args , ** kwargs )
87- self .token = token
88-
89- def sending_headers (self ):
90- return {'authorization' : 'Bearer ' + self .token }
91-
92-
93- class HeaderCheckFlightServer (FlightServerBase ):
94- """Mock server handle gRPC do_get calls"""
95- def do_get (self , context , ticket ):
96- """Return something to avoid needless errors"""
97- data = [
98- array ([b"Vltava" , struct .pack ('<i' , 105 ), b"FM" ])
99- ]
100- table = Table .from_arrays (data , names = ['a' ])
101- return GeneratorStream (
102- table .schema ,
103- self .number_batches (table ),
104- options = {})
105-
106- @staticmethod
107- def number_batches (table ):
108- for idx , batch in enumerate (table .to_batches ()):
109- buf = struct .pack ('<i' , idx )
110- yield batch , buf
111-
112-
11345def test_influx_default_query_headers ():
11446 with HeaderCheckFlightServer (
11547 auth_handler = NoopAuthHandler (),
11648 middleware = {"check" : HeaderCheckServerMiddlewareFactory ()}) as server :
11749 global _req_headers
118- _req_headers = {}
50+ set_req_headers ({})
11951 client = InfluxDBClient3 (
12052 host = f'http://localhost:{ server .port } ' ,
12153 org = 'test_org' ,
12254 databse = 'test_db' ,
12355 token = 'TEST_TOKEN'
12456 )
12557 client .query ('SELECT * FROM test' )
58+ _req_headers = get_req_headers ()
12659 assert len (_req_headers ) > 0
12760 assert _req_headers ['authorization' ][0 ] == "Bearer TEST_TOKEN"
12861 assert _req_headers ['user-agent' ][0 ].find (USER_AGENT ) > - 1
129- _req_headers = {}
62+ set_req_headers ({})
13063
13164
13265class TestQuery (unittest .TestCase ):
@@ -351,7 +284,7 @@ def test_secondary_user_agent_in_options(self):
351284 q_api ._flight_client_options ['generic_options' ])
352285
353286 def test_prepare_query (self ):
354- global _req_headers
287+ set_req_headers ({})
355288 token = 'my_token'
356289 q_api = QueryApi (
357290 connection_string = "grpc+tls://my-server.org" ,
@@ -378,8 +311,9 @@ def test_prepare_query(self):
378311 middleware = {"check" : HeaderCheckServerMiddlewareFactory ()}) as server :
379312 with FlightClient (('localhost' , server .port )) as client :
380313 client .do_get (ticket , options )
314+ _req_headers = get_req_headers ()
381315 assert _req_headers ['authorization' ] == [f"Bearer { token } " ]
382- _req_headers = {}
316+ set_req_headers ({})
383317
384318 @asyncio_run
385319 async def test_query_async_pandas (self ):
@@ -436,50 +370,3 @@ async def test_query_async_table(self):
436370 assert {'data' : 'database' , 'reference' : 'my_database' , 'value' : - 1.0 } in result_list
437371 assert {'data' : 'sql_query' , 'reference' : 'SELECT * FROM data' , 'value' : - 1.0 } in result_list
438372 assert {'data' : 'query_type' , 'reference' : 'sql' , 'value' : - 1.0 } in result_list
439-
440-
441- class ConstantData :
442-
443- def __init__ (self ):
444- self .data = [
445- array (['temp' , 'temp' , 'temp' ]),
446- array (['kitchen' , 'common' , 'foyer' ]),
447- array ([36.9 , 25.7 , 9.8 ])
448- ]
449- self .names = ['data' , 'reference' , 'value' ]
450-
451- def to_tuples (self ):
452- response = []
453- for n in range (3 ):
454- response .append ((self .data [0 ][n ].as_py (), self .data [1 ][n ].as_py (), self .data [2 ][n ].as_py ()))
455- return response
456-
457- def to_list (self ):
458- response = []
459- for it in range (len (self .data [0 ])):
460- item = {}
461- for o in range (len (self .names )):
462- item [self .names [o ]] = self .data [o ][it ].as_py ()
463- response .append (item )
464- return response
465-
466-
467- class ConstantFlightServer (FlightServerBase ):
468-
469- def __init__ (self , location = None , options = None , ** kwargs ):
470- super ().__init__ (location , ** kwargs )
471- self .cd = ConstantData ()
472- self .options = options
473-
474- # respond with Constant Data plus fields from ticket
475- def do_get (self , context , ticket ):
476- result_table = Table .from_arrays (self .cd .data , names = self .cd .names )
477- tkt = json .loads (ticket .ticket .decode ('utf-8' ))
478- for key in tkt .keys ():
479- tkt_data = [
480- array ([key ]),
481- array ([tkt [key ]]),
482- array ([- 1.0 ])
483- ]
484- result_table = concat_tables ([result_table , Table .from_arrays (tkt_data , names = self .cd .names )])
485- return RecordBatchStream (result_table , options = self .options )
0 commit comments