1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import collections
15+
1616import grpc
1717from unittest import mock
1818import os
1919import pytest
2020
21+ from typing import Sequence , Tuple
22+
2123from google .api_core .client_options import ClientOptions # type: ignore
24+ from google .showcase_v1beta1 .services .echo .transports import EchoRestInterceptor
2225
2326try :
2427 from google .auth .aio import credentials as ga_credentials_async
4245 try :
4346 from google .showcase_v1beta1 .services .echo .transports import (
4447 AsyncEchoRestTransport ,
48+ AsyncEchoRestInterceptor ,
4549 )
4650
4751 HAS_ASYNC_REST_ECHO_TRANSPORT = True
@@ -248,7 +252,51 @@ def messaging(use_mtls, request):
248252 return construct_client (MessagingClient , use_mtls , transport_name = request .param )
249253
250254
251- class MetadataClientInterceptor (
255+ class EchoMetadataClientRestInterceptor (EchoRestInterceptor ):
256+ request_metadata : Sequence [Tuple [str , str ]] = []
257+ response_metadata : Sequence [Tuple [str , str ]] = []
258+
259+ def pre_echo (self , request , metadata ):
260+ self .request_metadata = metadata
261+ return request , metadata
262+
263+ def post_echo_with_metadata (self , request , metadata ):
264+ self .response_metadata = metadata
265+ return request , metadata
266+
267+ def pre_expand (self , request , metadata ):
268+ self .request_metadata = metadata
269+ return request , metadata
270+
271+ def post_expand_with_metadata (self , request , metadata ):
272+ self .response_metadata = metadata
273+ return request , metadata
274+
275+
276+ if HAS_ASYNC_REST_ECHO_TRANSPORT :
277+
278+ class EchoMetadataClientRestAsyncInterceptor (AsyncEchoRestInterceptor ):
279+ request_metadata : Sequence [Tuple [str , str ]] = []
280+ response_metadata : Sequence [Tuple [str , str ]] = []
281+
282+ async def pre_echo (self , request , metadata ):
283+ self .request_metadata = metadata
284+ return request , metadata
285+
286+ async def post_echo_with_metadata (self , request , metadata ):
287+ self .response_metadata = metadata
288+ return request , metadata
289+
290+ async def pre_expand (self , request , metadata ):
291+ self .request_metadata = metadata
292+ return request , metadata
293+
294+ async def post_expand_with_metadata (self , request , metadata ):
295+ self .response_metadata = metadata
296+ return request , metadata
297+
298+
299+ class EchoMetadataClientGrpcInterceptor (
252300 grpc .UnaryUnaryClientInterceptor ,
253301 grpc .UnaryStreamClientInterceptor ,
254302 grpc .StreamUnaryClientInterceptor ,
@@ -257,42 +305,94 @@ class MetadataClientInterceptor(
257305 def __init__ (self , key , value ):
258306 self ._key = key
259307 self ._value = value
308+ self .request_metadata = []
309+ self .response_metadata = []
260310
261- def _add_metadata (self , client_call_details ):
311+ def _add_request_metadata (self , client_call_details ):
262312 if client_call_details .metadata is not None :
263313 client_call_details .metadata .append ((self ._key , self ._value ))
314+ self .request_metadata = client_call_details .metadata
264315
265316 def intercept_unary_unary (self , continuation , client_call_details , request ):
266- self ._add_metadata (client_call_details )
317+ self ._add_request_metadata (client_call_details )
267318 response = continuation (client_call_details , request )
319+ metadata = [(k , str (v )) for k , v in response .trailing_metadata ()]
320+ self .response_metadata = metadata
268321 return response
269322
270323 def intercept_unary_stream (self , continuation , client_call_details , request ):
271- self ._add_metadata (client_call_details )
324+ self ._add_request_metadata (client_call_details )
272325 response_it = continuation (client_call_details , request )
273326 return response_it
274327
275328 def intercept_stream_unary (
276329 self , continuation , client_call_details , request_iterator
277330 ):
278- self ._add_metadata (client_call_details )
331+ self ._add_request_metadata (client_call_details )
279332 response = continuation (client_call_details , request_iterator )
280333 return response
281334
282335 def intercept_stream_stream (
283336 self , continuation , client_call_details , request_iterator
284337 ):
285- self ._add_metadata (client_call_details )
338+ self ._add_request_metadata (client_call_details )
339+ response_it = continuation (client_call_details , request_iterator )
340+ return response_it
341+
342+
343+ class EchoMetadataClientGrpcAsyncInterceptor (
344+ grpc .aio .UnaryUnaryClientInterceptor ,
345+ grpc .aio .UnaryStreamClientInterceptor ,
346+ grpc .aio .StreamUnaryClientInterceptor ,
347+ grpc .aio .StreamStreamClientInterceptor ,
348+ ):
349+ def __init__ (self , key , value ):
350+ self ._key = key
351+ self ._value = value
352+ self .request_metadata = []
353+ self .response_metadata = []
354+
355+ async def _add_request_metadata (self , client_call_details ):
356+ if client_call_details .metadata is not None :
357+ client_call_details .metadata .append ((self ._key , self ._value ))
358+ self .request_metadata = client_call_details .metadata
359+
360+ async def intercept_unary_unary (self , continuation , client_call_details , request ):
361+ await self ._add_request_metadata (client_call_details )
362+ response = await continuation (client_call_details , request )
363+ metadata = [(k , str (v )) for k , v in await response .trailing_metadata ()]
364+ self .response_metadata = metadata
365+ return response
366+
367+ async def intercept_unary_stream (self , continuation , client_call_details , request ):
368+ self ._add_request_metadata (client_call_details )
369+ response_it = continuation (client_call_details , request )
370+ return response_it
371+
372+ async def intercept_stream_unary (
373+ self , continuation , client_call_details , request_iterator
374+ ):
375+ self ._add_request_metadata (client_call_details )
376+ response = continuation (client_call_details , request_iterator )
377+ return response
378+
379+ async def intercept_stream_stream (
380+ self , continuation , client_call_details , request_iterator
381+ ):
382+ self ._add_request_metadata (client_call_details )
286383 response_it = continuation (client_call_details , request_iterator )
287384 return response_it
288385
289386
290387@pytest .fixture
291- def intercepted_echo (use_mtls ):
388+ def intercepted_echo_grpc (use_mtls ):
292389 # The interceptor adds 'showcase-trailer' client metadata. Showcase server
293- # echos any metadata with key 'showcase-trailer', so the same metadata
390+ # echoes any metadata with key 'showcase-trailer', so the same metadata
294391 # should appear as trailing metadata in the response.
295- interceptor = MetadataClientInterceptor ("showcase-trailer" , "intercepted" )
392+ interceptor = EchoMetadataClientGrpcInterceptor (
393+ "showcase-trailer" ,
394+ "intercepted" ,
395+ )
296396 host = "localhost:7469"
297397 channel = (
298398 grpc .secure_channel (host , ssl_credentials )
@@ -304,4 +404,58 @@ def intercepted_echo(use_mtls):
304404 credentials = ga_credentials .AnonymousCredentials (),
305405 channel = intercept_channel ,
306406 )
307- return EchoClient (transport = transport )
407+ return EchoClient (transport = transport ), interceptor
408+
409+
410+ @pytest .fixture
411+ def intercepted_echo_grpc_async ():
412+ # The interceptor adds 'showcase-trailer' client metadata. Showcase server
413+ # echoes any metadata with key 'showcase-trailer', so the same metadata
414+ # should appear as trailing metadata in the response.
415+ interceptor = EchoMetadataClientGrpcAsyncInterceptor (
416+ "showcase-trailer" ,
417+ "intercepted" ,
418+ )
419+ host = "localhost:7469"
420+ channel = grpc .aio .insecure_channel (host , interceptors = [interceptor ])
421+ # intercept_channel = grpc.aio.intercept_channel(channel, interceptor)
422+ transport = EchoAsyncClient .get_transport_class ("grpc_asyncio" )(
423+ credentials = ga_credentials .AnonymousCredentials (),
424+ channel = channel ,
425+ )
426+ return EchoAsyncClient (transport = transport ), interceptor
427+
428+
429+ @pytest .fixture
430+ def intercepted_echo_rest ():
431+ transport_name = "rest"
432+ transport_cls = EchoClient .get_transport_class (transport_name )
433+ interceptor = EchoMetadataClientRestInterceptor ()
434+
435+ # The custom host explicitly bypasses https.
436+ transport = transport_cls (
437+ credentials = ga_credentials .AnonymousCredentials (),
438+ host = "localhost:7469" ,
439+ url_scheme = "http" ,
440+ interceptor = interceptor ,
441+ )
442+ return EchoClient (transport = transport ), interceptor
443+
444+
445+ @pytest .fixture
446+ def intercepted_echo_rest_async ():
447+ if not HAS_ASYNC_REST_ECHO_TRANSPORT :
448+ pytest .skip ("Skipping test with async rest." )
449+
450+ transport_name = "rest_asyncio"
451+ transport_cls = EchoAsyncClient .get_transport_class (transport_name )
452+ interceptor = EchoMetadataClientRestAsyncInterceptor ()
453+
454+ # The custom host explicitly bypasses https.
455+ transport = transport_cls (
456+ credentials = async_anonymous_credentials (),
457+ host = "localhost:7469" ,
458+ url_scheme = "http" ,
459+ interceptor = interceptor ,
460+ )
461+ return EchoAsyncClient (transport = transport ), interceptor
0 commit comments