Skip to content

Commit 174b8b5

Browse files
committed
fix(grpc): Support non-list interceptors
The GRPC integration assumes that if interceptors are provided, they will be a list, when GRPC itself types them as Sequence. With this change, we're making the codepaths using interceptors more robust by explicitly turning them into lists before manipulating them.
1 parent 4d6893e commit 174b8b5

File tree

3 files changed

+152
-47
lines changed

3 files changed

+152
-47
lines changed

instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ def _instrument(self, **kwargs):
334334
tracer_provider = kwargs.get("tracer_provider")
335335

336336
def server(*args, **kwargs):
337-
if "interceptors" in kwargs:
337+
if "interceptors" in kwargs and kwargs["interceptors"]:
338+
kwargs["interceptors"] = list(kwargs["interceptors"])
338339
# add our interceptor as the first
339340
kwargs["interceptors"].insert(
340341
0,
@@ -348,6 +349,7 @@ def server(*args, **kwargs):
348349
tracer_provider=tracer_provider, filter_=self._filter
349350
)
350351
]
352+
351353
return self._original_func(*args, **kwargs)
352354

353355
grpc.server = server
@@ -386,7 +388,8 @@ def _instrument(self, **kwargs):
386388
tracer_provider = kwargs.get("tracer_provider")
387389

388390
def server(*args, **kwargs):
389-
if "interceptors" in kwargs:
391+
if "interceptors" in kwargs and kwargs["interceptors"]:
392+
kwargs["interceptors"] = list(kwargs["interceptors"])
390393
# add our interceptor as the first
391394
kwargs["interceptors"].insert(
392395
0,
@@ -516,6 +519,7 @@ def instrumentation_dependencies(self) -> Collection[str]:
516519

517520
def _add_interceptors(self, tracer_provider, kwargs):
518521
if "interceptors" in kwargs and kwargs["interceptors"]:
522+
kwargs["interceptors"] = list(kwargs["interceptors"])
519523
kwargs["interceptors"] = (
520524
aio_client_interceptors(
521525
tracer_provider=tracer_provider,

instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,9 @@ async def ServerStreamingMethod(self, request, context):
5858

5959

6060
async def run_with_test_server(
61-
runnable, servicer=Servicer(), add_interceptor=True
61+
runnable, servicer=Servicer(), interceptors=None
6262
):
63-
if add_interceptor:
64-
interceptors = [aio_server_interceptor()]
63+
if interceptors is not None:
6564
server = grpc.aio.server(interceptors=interceptors)
6665
else:
6766
server = grpc.aio.server()
@@ -95,7 +94,7 @@ async def request(channel):
9594
msg = request.SerializeToString()
9695
return await channel.unary_unary(rpc_call)(msg)
9796

98-
await run_with_test_server(request, add_interceptor=False)
97+
await run_with_test_server(request)
9998

10099
spans_list = self.memory_exporter.get_finished_spans()
101100
self.assertEqual(len(spans_list), 1)
@@ -140,7 +139,7 @@ async def request(channel):
140139
msg = request.SerializeToString()
141140
return await channel.unary_unary(rpc_call)(msg)
142141

143-
await run_with_test_server(request, add_interceptor=False)
142+
await run_with_test_server(request)
144143

145144
spans_list = self.memory_exporter.get_finished_spans()
146145
self.assertEqual(len(spans_list), 0)
@@ -154,7 +153,7 @@ async def request(channel):
154153
msg = request.SerializeToString()
155154
return await channel.unary_unary(rpc_call)(msg)
156155

157-
await run_with_test_server(request)
156+
await run_with_test_server(request, interceptors=[aio_server_interceptor()])
158157

159158
spans_list = self.memory_exporter.get_finished_spans()
160159
self.assertEqual(len(spans_list), 1)
@@ -206,7 +205,7 @@ async def request(channel):
206205
msg = request.SerializeToString()
207206
return await channel.unary_unary(rpc_call)(msg)
208207

209-
await run_with_test_server(request, servicer=TwoSpanServicer())
208+
await run_with_test_server(request, servicer=TwoSpanServicer(), interceptors=[aio_server_interceptor()])
210209

211210
spans_list = self.memory_exporter.get_finished_spans()
212211
self.assertEqual(len(spans_list), 2)
@@ -253,7 +252,7 @@ async def request(channel):
253252
async for response in channel.unary_stream(rpc_call)(msg):
254253
print(response)
255254

256-
await run_with_test_server(request)
255+
await run_with_test_server(request, interceptors=[aio_server_interceptor()])
257256

258257
spans_list = self.memory_exporter.get_finished_spans()
259258
self.assertEqual(len(spans_list), 1)
@@ -307,7 +306,7 @@ async def request(channel):
307306
async for response in channel.unary_stream(rpc_call)(msg):
308307
print(response)
309308

310-
await run_with_test_server(request, servicer=TwoSpanServicer())
309+
await run_with_test_server(request, servicer=TwoSpanServicer(), interceptors=[aio_server_interceptor()])
311310

312311
spans_list = self.memory_exporter.get_finished_spans()
313312
self.assertEqual(len(spans_list), 2)
@@ -367,7 +366,7 @@ async def request(channel):
367366
lifetime_servicer = SpanLifetimeServicer()
368367
active_span_before_call = trace.get_current_span()
369368

370-
await run_with_test_server(request, servicer=lifetime_servicer)
369+
await run_with_test_server(request, servicer=lifetime_servicer, interceptors=[aio_server_interceptor()])
371370

372371
active_span_in_handler = lifetime_servicer.span
373372
active_span_after_call = trace.get_current_span()
@@ -390,7 +389,7 @@ async def sequential_requests(channel):
390389
await request(channel)
391390
await request(channel)
392391

393-
await run_with_test_server(sequential_requests)
392+
await run_with_test_server(sequential_requests, interceptors=[aio_server_interceptor()])
394393

395394
spans_list = self.memory_exporter.get_finished_spans()
396395
self.assertEqual(len(spans_list), 2)
@@ -450,7 +449,7 @@ async def concurrent_requests(channel):
450449
await asyncio.gather(request(channel), request(channel))
451450

452451
await run_with_test_server(
453-
concurrent_requests, servicer=LatchedServicer()
452+
concurrent_requests, servicer=LatchedServicer(), interceptors=[aio_server_interceptor()]
454453
)
455454

456455
spans_list = self.memory_exporter.get_finished_spans()
@@ -504,7 +503,7 @@ async def request(channel):
504503
self.assertEqual(cm.exception.code(), grpc.StatusCode.INTERNAL)
505504
self.assertEqual(cm.exception.details(), failure_message)
506505

507-
await run_with_test_server(request, servicer=AbortServicer())
506+
await run_with_test_server(request, servicer=AbortServicer(), interceptors=[aio_server_interceptor()])
508507

509508
spans_list = self.memory_exporter.get_finished_spans()
510509
self.assertEqual(len(spans_list), 1)
@@ -569,7 +568,7 @@ async def request(channel):
569568
)
570569
self.assertEqual(cm.exception.details(), failure_message)
571570

572-
await run_with_test_server(request, servicer=AbortServicer())
571+
await run_with_test_server(request, servicer=AbortServicer(), interceptors=[aio_server_interceptor()])
573572

574573
spans_list = self.memory_exporter.get_finished_spans()
575574
self.assertEqual(len(spans_list), 1)
@@ -602,6 +601,55 @@ async def request(channel):
602601
},
603602
)
604603

604+
async def test_non_list_interceptors(self):
605+
"""Check that we handle non-list interceptors correctly."""
606+
607+
grpc_server_instrumentor = GrpcAioInstrumentorServer()
608+
grpc_server_instrumentor.instrument()
609+
610+
try:
611+
rpc_call = "/GRPCTestServer/SimpleMethod"
612+
613+
async def request(channel):
614+
request = Request(client_id=1, request_data="test")
615+
msg = request.SerializeToString()
616+
return await channel.unary_unary(rpc_call)(msg)
617+
618+
class MockInterceptor(grpc.aio.ServerInterceptor):
619+
async def intercept_service(self, continuation, handler_call_details):
620+
return await continuation(handler_call_details)
621+
622+
await run_with_test_server(request, interceptors=(MockInterceptor(),))
623+
624+
finally:
625+
grpc_server_instrumentor.uninstrument()
626+
627+
spans_list = self.memory_exporter.get_finished_spans()
628+
self.assertEqual(len(spans_list), 1)
629+
span = spans_list[0]
630+
631+
self.assertEqual(span.name, rpc_call)
632+
self.assertIs(span.kind, trace.SpanKind.SERVER)
633+
634+
# Check version and name in span's instrumentation info
635+
self.assertEqualSpanInstrumentationScope(
636+
span, opentelemetry.instrumentation.grpc
637+
)
638+
639+
# Check attributes
640+
self.assertSpanHasAttributes(
641+
span,
642+
{
643+
SpanAttributes.NET_PEER_IP: "[::1]",
644+
SpanAttributes.NET_PEER_NAME: "localhost",
645+
SpanAttributes.RPC_METHOD: "SimpleMethod",
646+
SpanAttributes.RPC_SERVICE: "GRPCTestServer",
647+
SpanAttributes.RPC_SYSTEM: "grpc",
648+
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[
649+
0
650+
],
651+
},
652+
)
605653

606654
def get_latch(num):
607655
"""Get a countdown latch function for use in n threads."""

instrumentation/opentelemetry-instrumentation-grpc/tests/test_server_interceptor.py

Lines changed: 84 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import tempfile
2020
import threading
2121
from concurrent import futures
22+
from unittest import mock
2223

2324
import grpc
2425

@@ -104,41 +105,44 @@ def handler(request, context):
104105

105106
grpc_server_instrumentor = GrpcInstrumentorServer()
106107
grpc_server_instrumentor.instrument()
107-
with self.server(max_workers=1) as (server, channel):
108-
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
109-
rpc_call = "TestServicer/handler"
110-
try:
111-
server.start()
112-
channel.unary_unary(rpc_call)(b"test")
113-
finally:
114-
server.stop(None)
115108

116-
spans_list = self.memory_exporter.get_finished_spans()
117-
self.assertEqual(len(spans_list), 1)
118-
span = spans_list[0]
119-
self.assertEqual(span.name, rpc_call)
120-
self.assertIs(span.kind, trace.SpanKind.SERVER)
109+
try:
110+
with self.server(max_workers=1) as (server, channel):
111+
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
112+
rpc_call = "TestServicer/handler"
113+
try:
114+
server.start()
115+
channel.unary_unary(rpc_call)(b"test")
116+
finally:
117+
server.stop(None)
118+
119+
finally:
120+
grpc_server_instrumentor.uninstrument()
121121

122-
# Check version and name in span's instrumentation info
123-
self.assertEqualSpanInstrumentationScope(
124-
span, opentelemetry.instrumentation.grpc
125-
)
122+
spans_list = self.memory_exporter.get_finished_spans()
123+
self.assertEqual(len(spans_list), 1)
124+
span = spans_list[0]
125+
self.assertEqual(span.name, rpc_call)
126+
self.assertIs(span.kind, trace.SpanKind.SERVER)
126127

127-
# Check attributes
128-
self.assertSpanHasAttributes(
129-
span,
130-
{
131-
**self.net_peer_span_attributes,
132-
SpanAttributes.RPC_METHOD: "handler",
133-
SpanAttributes.RPC_SERVICE: "TestServicer",
134-
SpanAttributes.RPC_SYSTEM: "grpc",
135-
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[
136-
0
137-
],
138-
},
139-
)
128+
# Check version and name in span's instrumentation info
129+
self.assertEqualSpanInstrumentationScope(
130+
span, opentelemetry.instrumentation.grpc
131+
)
140132

141-
grpc_server_instrumentor.uninstrument()
133+
# Check attributes
134+
self.assertSpanHasAttributes(
135+
span,
136+
{
137+
**self.net_peer_span_attributes,
138+
SpanAttributes.RPC_METHOD: "handler",
139+
SpanAttributes.RPC_SERVICE: "TestServicer",
140+
SpanAttributes.RPC_SYSTEM: "grpc",
141+
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[
142+
0
143+
],
144+
},
145+
)
142146

143147
def test_uninstrument(self):
144148
def handler(request, context):
@@ -647,6 +651,55 @@ def unset_status_handler(request, context):
647651
},
648652
)
649653

654+
def test_non_list_interceptors(self):
655+
"""Check that we handle non-list interceptors correctly."""
656+
grpc_server_instrumentor = GrpcInstrumentorServer()
657+
grpc_server_instrumentor.instrument()
658+
659+
try:
660+
with self.server(
661+
max_workers=1,
662+
interceptors=(mock.MagicMock(),),
663+
) as (server, channel):
664+
add_GRPCTestServerServicer_to_server(Servicer(), server)
665+
666+
rpc_call = "/GRPCTestServer/SimpleMethod"
667+
request = Request(client_id=1, request_data="test")
668+
msg = request.SerializeToString()
669+
try:
670+
server.start()
671+
channel.unary_unary(rpc_call)(msg)
672+
finally:
673+
server.stop(None)
674+
finally:
675+
grpc_server_instrumentor.uninstrument()
676+
677+
spans_list = self.memory_exporter.get_finished_spans()
678+
self.assertEqual(len(spans_list), 1)
679+
span = spans_list[0]
680+
681+
self.assertEqual(span.name, rpc_call)
682+
self.assertIs(span.kind, trace.SpanKind.SERVER)
683+
684+
# Check version and name in span's instrumentation info
685+
self.assertEqualSpanInstrumentationScope(
686+
span, opentelemetry.instrumentation.grpc
687+
)
688+
689+
# Check attributes
690+
self.assertSpanHasAttributes(
691+
span,
692+
{
693+
**self.net_peer_span_attributes,
694+
SpanAttributes.RPC_METHOD: "SimpleMethod",
695+
SpanAttributes.RPC_SERVICE: "GRPCTestServer",
696+
SpanAttributes.RPC_SYSTEM: "grpc",
697+
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[
698+
0
699+
],
700+
},
701+
)
702+
650703

651704
class TestOpenTelemetryServerInterceptorUnix(
652705
TestOpenTelemetryServerInterceptor,

0 commit comments

Comments
 (0)