@@ -58,10 +58,9 @@ async def ServerStreamingMethod(self, request, context):
5858
5959
6060async def run_with_test_server (
61- runnable , servicer = Servicer (), add_interceptor = True
61+ runnable , servicer = Servicer (), interceptors = None
6262):
63- if add_interceptor :
64- interceptors = [aio_server_interceptor ()]
63+ if interceptors is not None :
6564 server = grpc .aio .server (interceptors = interceptors )
6665 else :
6766 server = grpc .aio .server ()
@@ -95,7 +94,7 @@ async def request(channel):
9594 msg = request .SerializeToString ()
9695 return await channel .unary_unary (rpc_call )(msg )
9796
98- await run_with_test_server (request , add_interceptor = False )
97+ await run_with_test_server (request )
9998
10099 spans_list = self .memory_exporter .get_finished_spans ()
101100 self .assertEqual (len (spans_list ), 1 )
@@ -140,7 +139,7 @@ async def request(channel):
140139 msg = request .SerializeToString ()
141140 return await channel .unary_unary (rpc_call )(msg )
142141
143- await run_with_test_server (request , add_interceptor = False )
142+ await run_with_test_server (request )
144143
145144 spans_list = self .memory_exporter .get_finished_spans ()
146145 self .assertEqual (len (spans_list ), 0 )
@@ -154,7 +153,7 @@ async def request(channel):
154153 msg = request .SerializeToString ()
155154 return await channel .unary_unary (rpc_call )(msg )
156155
157- await run_with_test_server (request )
156+ await run_with_test_server (request , interceptors = [ aio_server_interceptor ()] )
158157
159158 spans_list = self .memory_exporter .get_finished_spans ()
160159 self .assertEqual (len (spans_list ), 1 )
@@ -206,7 +205,7 @@ async def request(channel):
206205 msg = request .SerializeToString ()
207206 return await channel .unary_unary (rpc_call )(msg )
208207
209- await run_with_test_server (request , servicer = TwoSpanServicer ())
208+ await run_with_test_server (request , servicer = TwoSpanServicer (), interceptors = [ aio_server_interceptor ()] )
210209
211210 spans_list = self .memory_exporter .get_finished_spans ()
212211 self .assertEqual (len (spans_list ), 2 )
@@ -253,7 +252,7 @@ async def request(channel):
253252 async for response in channel .unary_stream (rpc_call )(msg ):
254253 print (response )
255254
256- await run_with_test_server (request )
255+ await run_with_test_server (request , interceptors = [ aio_server_interceptor ()] )
257256
258257 spans_list = self .memory_exporter .get_finished_spans ()
259258 self .assertEqual (len (spans_list ), 1 )
@@ -307,7 +306,7 @@ async def request(channel):
307306 async for response in channel .unary_stream (rpc_call )(msg ):
308307 print (response )
309308
310- await run_with_test_server (request , servicer = TwoSpanServicer ())
309+ await run_with_test_server (request , servicer = TwoSpanServicer (), interceptors = [ aio_server_interceptor ()] )
311310
312311 spans_list = self .memory_exporter .get_finished_spans ()
313312 self .assertEqual (len (spans_list ), 2 )
@@ -367,7 +366,7 @@ async def request(channel):
367366 lifetime_servicer = SpanLifetimeServicer ()
368367 active_span_before_call = trace .get_current_span ()
369368
370- await run_with_test_server (request , servicer = lifetime_servicer )
369+ await run_with_test_server (request , servicer = lifetime_servicer , interceptors = [ aio_server_interceptor ()] )
371370
372371 active_span_in_handler = lifetime_servicer .span
373372 active_span_after_call = trace .get_current_span ()
@@ -390,7 +389,7 @@ async def sequential_requests(channel):
390389 await request (channel )
391390 await request (channel )
392391
393- await run_with_test_server (sequential_requests )
392+ await run_with_test_server (sequential_requests , interceptors = [ aio_server_interceptor ()] )
394393
395394 spans_list = self .memory_exporter .get_finished_spans ()
396395 self .assertEqual (len (spans_list ), 2 )
@@ -450,7 +449,7 @@ async def concurrent_requests(channel):
450449 await asyncio .gather (request (channel ), request (channel ))
451450
452451 await run_with_test_server (
453- concurrent_requests , servicer = LatchedServicer ()
452+ concurrent_requests , servicer = LatchedServicer (), interceptors = [ aio_server_interceptor ()]
454453 )
455454
456455 spans_list = self .memory_exporter .get_finished_spans ()
@@ -504,7 +503,7 @@ async def request(channel):
504503 self .assertEqual (cm .exception .code (), grpc .StatusCode .INTERNAL )
505504 self .assertEqual (cm .exception .details (), failure_message )
506505
507- await run_with_test_server (request , servicer = AbortServicer ())
506+ await run_with_test_server (request , servicer = AbortServicer (), interceptors = [ aio_server_interceptor ()] )
508507
509508 spans_list = self .memory_exporter .get_finished_spans ()
510509 self .assertEqual (len (spans_list ), 1 )
@@ -569,7 +568,7 @@ async def request(channel):
569568 )
570569 self .assertEqual (cm .exception .details (), failure_message )
571570
572- await run_with_test_server (request , servicer = AbortServicer ())
571+ await run_with_test_server (request , servicer = AbortServicer (), interceptors = [ aio_server_interceptor ()] )
573572
574573 spans_list = self .memory_exporter .get_finished_spans ()
575574 self .assertEqual (len (spans_list ), 1 )
@@ -602,6 +601,55 @@ async def request(channel):
602601 },
603602 )
604603
604+ async def test_non_list_interceptors (self ):
605+ """Check that we handle non-list interceptors correctly."""
606+
607+ grpc_server_instrumentor = GrpcAioInstrumentorServer ()
608+ grpc_server_instrumentor .instrument ()
609+
610+ try :
611+ rpc_call = "/GRPCTestServer/SimpleMethod"
612+
613+ async def request (channel ):
614+ request = Request (client_id = 1 , request_data = "test" )
615+ msg = request .SerializeToString ()
616+ return await channel .unary_unary (rpc_call )(msg )
617+
618+ class MockInterceptor (grpc .aio .ServerInterceptor ):
619+ async def intercept_service (self , continuation , handler_call_details ):
620+ return await continuation (handler_call_details )
621+
622+ await run_with_test_server (request , interceptors = (MockInterceptor (),))
623+
624+ finally :
625+ grpc_server_instrumentor .uninstrument ()
626+
627+ spans_list = self .memory_exporter .get_finished_spans ()
628+ self .assertEqual (len (spans_list ), 1 )
629+ span = spans_list [0 ]
630+
631+ self .assertEqual (span .name , rpc_call )
632+ self .assertIs (span .kind , trace .SpanKind .SERVER )
633+
634+ # Check version and name in span's instrumentation info
635+ self .assertEqualSpanInstrumentationScope (
636+ span , opentelemetry .instrumentation .grpc
637+ )
638+
639+ # Check attributes
640+ self .assertSpanHasAttributes (
641+ span ,
642+ {
643+ SpanAttributes .NET_PEER_IP : "[::1]" ,
644+ SpanAttributes .NET_PEER_NAME : "localhost" ,
645+ SpanAttributes .RPC_METHOD : "SimpleMethod" ,
646+ SpanAttributes .RPC_SERVICE : "GRPCTestServer" ,
647+ SpanAttributes .RPC_SYSTEM : "grpc" ,
648+ SpanAttributes .RPC_GRPC_STATUS_CODE : grpc .StatusCode .OK .value [
649+ 0
650+ ],
651+ },
652+ )
605653
606654def get_latch (num ):
607655 """Get a countdown latch function for use in n threads."""
0 commit comments