diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ecdc6d396..beac03f24b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `opentelemetry-instrumentation-fastapi`: fix wrapping of middlewares ([#3012](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3012)) +- `opentelemetry-instrumentation-grpc`: support non-list interceptors + ([#3520](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3520)) ### Breaking changes diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py index ff0fa93902..a89a575d64 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py @@ -334,7 +334,8 @@ def _instrument(self, **kwargs): tracer_provider = kwargs.get("tracer_provider") def server(*args, **kwargs): - if "interceptors" in kwargs: + if "interceptors" in kwargs and kwargs["interceptors"]: + kwargs["interceptors"] = list(kwargs["interceptors"]) # add our interceptor as the first kwargs["interceptors"].insert( 0, @@ -348,6 +349,7 @@ def server(*args, **kwargs): tracer_provider=tracer_provider, filter_=self._filter ) ] + return self._original_func(*args, **kwargs) grpc.server = server @@ -386,7 +388,8 @@ def _instrument(self, **kwargs): tracer_provider = kwargs.get("tracer_provider") def server(*args, **kwargs): - if "interceptors" in kwargs: + if "interceptors" in kwargs and kwargs["interceptors"]: + kwargs["interceptors"] = list(kwargs["interceptors"]) # add our interceptor as the first kwargs["interceptors"].insert( 0, @@ -516,6 +519,7 @@ def instrumentation_dependencies(self) -> Collection[str]: def _add_interceptors(self, tracer_provider, kwargs): if "interceptors" in kwargs and kwargs["interceptors"]: + kwargs["interceptors"] = list(kwargs["interceptors"]) kwargs["interceptors"] = ( aio_client_interceptors( tracer_provider=tracer_provider, diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py index ee917ca26c..04a2c02476 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py @@ -58,13 +58,9 @@ async def ServerStreamingMethod(self, request, context): async def run_with_test_server( - runnable, servicer=Servicer(), add_interceptor=True + runnable, servicer=Servicer(), interceptors=None ): - if add_interceptor: - interceptors = [aio_server_interceptor()] - server = grpc.aio.server(interceptors=interceptors) - else: - server = grpc.aio.server() + server = grpc.aio.server(interceptors=interceptors) add_GRPCTestServerServicer_to_server(servicer, server) @@ -95,7 +91,7 @@ async def request(channel): msg = request.SerializeToString() return await channel.unary_unary(rpc_call)(msg) - await run_with_test_server(request, add_interceptor=False) + await run_with_test_server(request) spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) @@ -140,7 +136,7 @@ async def request(channel): msg = request.SerializeToString() return await channel.unary_unary(rpc_call)(msg) - await run_with_test_server(request, add_interceptor=False) + await run_with_test_server(request) spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 0) @@ -154,7 +150,9 @@ async def request(channel): msg = request.SerializeToString() return await channel.unary_unary(rpc_call)(msg) - await run_with_test_server(request) + await run_with_test_server( + request, interceptors=[aio_server_interceptor()] + ) spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) @@ -206,7 +204,11 @@ async def request(channel): msg = request.SerializeToString() return await channel.unary_unary(rpc_call)(msg) - await run_with_test_server(request, servicer=TwoSpanServicer()) + await run_with_test_server( + request, + servicer=TwoSpanServicer(), + interceptors=[aio_server_interceptor()], + ) spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 2) @@ -253,7 +255,9 @@ async def request(channel): async for response in channel.unary_stream(rpc_call)(msg): print(response) - await run_with_test_server(request) + await run_with_test_server( + request, interceptors=[aio_server_interceptor()] + ) spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) @@ -307,7 +311,11 @@ async def request(channel): async for response in channel.unary_stream(rpc_call)(msg): print(response) - await run_with_test_server(request, servicer=TwoSpanServicer()) + await run_with_test_server( + request, + servicer=TwoSpanServicer(), + interceptors=[aio_server_interceptor()], + ) spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 2) @@ -367,7 +375,11 @@ async def request(channel): lifetime_servicer = SpanLifetimeServicer() active_span_before_call = trace.get_current_span() - await run_with_test_server(request, servicer=lifetime_servicer) + await run_with_test_server( + request, + servicer=lifetime_servicer, + interceptors=[aio_server_interceptor()], + ) active_span_in_handler = lifetime_servicer.span active_span_after_call = trace.get_current_span() @@ -390,7 +402,9 @@ async def sequential_requests(channel): await request(channel) await request(channel) - await run_with_test_server(sequential_requests) + await run_with_test_server( + sequential_requests, interceptors=[aio_server_interceptor()] + ) spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 2) @@ -450,7 +464,9 @@ async def concurrent_requests(channel): await asyncio.gather(request(channel), request(channel)) await run_with_test_server( - concurrent_requests, servicer=LatchedServicer() + concurrent_requests, + servicer=LatchedServicer(), + interceptors=[aio_server_interceptor()], ) spans_list = self.memory_exporter.get_finished_spans() @@ -504,7 +520,11 @@ async def request(channel): self.assertEqual(cm.exception.code(), grpc.StatusCode.INTERNAL) self.assertEqual(cm.exception.details(), failure_message) - await run_with_test_server(request, servicer=AbortServicer()) + await run_with_test_server( + request, + servicer=AbortServicer(), + interceptors=[aio_server_interceptor()], + ) spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) @@ -569,7 +589,11 @@ async def request(channel): ) self.assertEqual(cm.exception.details(), failure_message) - await run_with_test_server(request, servicer=AbortServicer()) + await run_with_test_server( + request, + servicer=AbortServicer(), + interceptors=[aio_server_interceptor()], + ) spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) @@ -602,6 +626,60 @@ async def request(channel): }, ) + async def test_non_list_interceptors(self): + """Check that we handle non-list interceptors correctly.""" + + grpc_server_instrumentor = GrpcAioInstrumentorServer() + grpc_server_instrumentor.instrument() + + try: + rpc_call = "/GRPCTestServer/SimpleMethod" + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + class MockInterceptor(grpc.aio.ServerInterceptor): + async def intercept_service( + self, continuation, handler_call_details + ): + return await continuation(handler_call_details) + + await run_with_test_server( + request, interceptors=(MockInterceptor(),) + ) + + finally: + grpc_server_instrumentor.uninstrument() + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationScope( + span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + def get_latch(num): """Get a countdown latch function for use in n threads.""" diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_server_interceptor.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_server_interceptor.py index 08aa16187a..3fdfb8e815 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_server_interceptor.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_server_interceptor.py @@ -19,6 +19,7 @@ import tempfile import threading from concurrent import futures +from unittest import mock import grpc @@ -104,41 +105,46 @@ def handler(request, context): grpc_server_instrumentor = GrpcInstrumentorServer() grpc_server_instrumentor.instrument() - with self.server(max_workers=1) as (server, channel): - server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),)) - rpc_call = "TestServicer/handler" - try: - server.start() - channel.unary_unary(rpc_call)(b"test") - finally: - server.stop(None) - spans_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(spans_list), 1) - span = spans_list[0] - self.assertEqual(span.name, rpc_call) - self.assertIs(span.kind, trace.SpanKind.SERVER) + try: + with self.server(max_workers=1) as (server, channel): + server.add_generic_rpc_handlers( + (UnaryUnaryRpcHandler(handler),) + ) + rpc_call = "TestServicer/handler" + try: + server.start() + channel.unary_unary(rpc_call)(b"test") + finally: + server.stop(None) + + finally: + grpc_server_instrumentor.uninstrument() - # Check version and name in span's instrumentation info - self.assertEqualSpanInstrumentationScope( - span, opentelemetry.instrumentation.grpc - ) + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) - # Check attributes - self.assertSpanHasAttributes( - span, - { - **self.net_peer_span_attributes, - SpanAttributes.RPC_METHOD: "handler", - SpanAttributes.RPC_SERVICE: "TestServicer", - SpanAttributes.RPC_SYSTEM: "grpc", - SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ - 0 - ], - }, - ) + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationScope( + span, opentelemetry.instrumentation.grpc + ) - grpc_server_instrumentor.uninstrument() + # Check attributes + self.assertSpanHasAttributes( + span, + { + **self.net_peer_span_attributes, + SpanAttributes.RPC_METHOD: "handler", + SpanAttributes.RPC_SERVICE: "TestServicer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) def test_uninstrument(self): def handler(request, context): @@ -647,6 +653,55 @@ def unset_status_handler(request, context): }, ) + def test_non_list_interceptors(self): + """Check that we handle non-list interceptors correctly.""" + grpc_server_instrumentor = GrpcInstrumentorServer() + grpc_server_instrumentor.instrument() + + try: + with self.server( + max_workers=1, + interceptors=(mock.MagicMock(),), + ) as (server, channel): + add_GRPCTestServerServicer_to_server(Servicer(), server) + + rpc_call = "/GRPCTestServer/SimpleMethod" + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + try: + server.start() + channel.unary_unary(rpc_call)(msg) + finally: + server.stop(None) + finally: + grpc_server_instrumentor.uninstrument() + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationScope( + span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + **self.net_peer_span_attributes, + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + class TestOpenTelemetryServerInterceptorUnix( TestOpenTelemetryServerInterceptor,