Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `opentelemetry-instrumentation-fastapi`: fix wrapping of middlewares
([#3012](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3012))
- `opentelemetry-instrumentation-grpc`: support non-list interceptors
([#3520](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3520))

### Breaking changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")

def server(*args, **kwargs):
if "interceptors" in kwargs:
if "interceptors" in kwargs and kwargs["interceptors"]:
kwargs["interceptors"] = list(kwargs["interceptors"])
# add our interceptor as the first
kwargs["interceptors"].insert(
0,
Expand All @@ -348,6 +349,7 @@ def server(*args, **kwargs):
tracer_provider=tracer_provider, filter_=self._filter
)
]

return self._original_func(*args, **kwargs)

grpc.server = server
Expand Down Expand Up @@ -386,7 +388,8 @@ def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")

def server(*args, **kwargs):
if "interceptors" in kwargs:
if "interceptors" in kwargs and kwargs["interceptors"]:
kwargs["interceptors"] = list(kwargs["interceptors"])
# add our interceptor as the first
kwargs["interceptors"].insert(
0,
Expand Down Expand Up @@ -516,6 +519,7 @@ def instrumentation_dependencies(self) -> Collection[str]:

def _add_interceptors(self, tracer_provider, kwargs):
if "interceptors" in kwargs and kwargs["interceptors"]:
kwargs["interceptors"] = list(kwargs["interceptors"])
kwargs["interceptors"] = (
aio_client_interceptors(
tracer_provider=tracer_provider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ async def ServerStreamingMethod(self, request, context):


async def run_with_test_server(
runnable, servicer=Servicer(), add_interceptor=True
runnable, servicer=Servicer(), interceptors=None
):
if add_interceptor:
interceptors = [aio_server_interceptor()]
if interceptors is not None:
server = grpc.aio.server(interceptors=interceptors)
else:
server = grpc.aio.server()
Expand Down Expand Up @@ -95,7 +94,7 @@ async def request(channel):
msg = request.SerializeToString()
return await channel.unary_unary(rpc_call)(msg)

await run_with_test_server(request, add_interceptor=False)
await run_with_test_server(request)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
Expand Down Expand Up @@ -140,7 +139,7 @@ async def request(channel):
msg = request.SerializeToString()
return await channel.unary_unary(rpc_call)(msg)

await run_with_test_server(request, add_interceptor=False)
await run_with_test_server(request)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 0)
Expand All @@ -154,7 +153,7 @@ async def request(channel):
msg = request.SerializeToString()
return await channel.unary_unary(rpc_call)(msg)

await run_with_test_server(request)
await run_with_test_server(request, interceptors=[aio_server_interceptor()])

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
Expand Down Expand Up @@ -206,7 +205,7 @@ async def request(channel):
msg = request.SerializeToString()
return await channel.unary_unary(rpc_call)(msg)

await run_with_test_server(request, servicer=TwoSpanServicer())
await run_with_test_server(request, servicer=TwoSpanServicer(), interceptors=[aio_server_interceptor()])

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 2)
Expand Down Expand Up @@ -253,7 +252,7 @@ async def request(channel):
async for response in channel.unary_stream(rpc_call)(msg):
print(response)

await run_with_test_server(request)
await run_with_test_server(request, interceptors=[aio_server_interceptor()])

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
Expand Down Expand Up @@ -307,7 +306,7 @@ async def request(channel):
async for response in channel.unary_stream(rpc_call)(msg):
print(response)

await run_with_test_server(request, servicer=TwoSpanServicer())
await run_with_test_server(request, servicer=TwoSpanServicer(), interceptors=[aio_server_interceptor()])

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 2)
Expand Down Expand Up @@ -367,7 +366,7 @@ async def request(channel):
lifetime_servicer = SpanLifetimeServicer()
active_span_before_call = trace.get_current_span()

await run_with_test_server(request, servicer=lifetime_servicer)
await run_with_test_server(request, servicer=lifetime_servicer, interceptors=[aio_server_interceptor()])

active_span_in_handler = lifetime_servicer.span
active_span_after_call = trace.get_current_span()
Expand All @@ -390,7 +389,7 @@ async def sequential_requests(channel):
await request(channel)
await request(channel)

await run_with_test_server(sequential_requests)
await run_with_test_server(sequential_requests, interceptors=[aio_server_interceptor()])

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 2)
Expand Down Expand Up @@ -450,7 +449,7 @@ async def concurrent_requests(channel):
await asyncio.gather(request(channel), request(channel))

await run_with_test_server(
concurrent_requests, servicer=LatchedServicer()
concurrent_requests, servicer=LatchedServicer(), interceptors=[aio_server_interceptor()]
)

spans_list = self.memory_exporter.get_finished_spans()
Expand Down Expand Up @@ -504,7 +503,7 @@ async def request(channel):
self.assertEqual(cm.exception.code(), grpc.StatusCode.INTERNAL)
self.assertEqual(cm.exception.details(), failure_message)

await run_with_test_server(request, servicer=AbortServicer())
await run_with_test_server(request, servicer=AbortServicer(), interceptors=[aio_server_interceptor()])

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
Expand Down Expand Up @@ -569,7 +568,7 @@ async def request(channel):
)
self.assertEqual(cm.exception.details(), failure_message)

await run_with_test_server(request, servicer=AbortServicer())
await run_with_test_server(request, servicer=AbortServicer(), interceptors=[aio_server_interceptor()])

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
Expand Down Expand Up @@ -602,6 +601,55 @@ async def request(channel):
},
)

async def test_non_list_interceptors(self):
"""Check that we handle non-list interceptors correctly."""

grpc_server_instrumentor = GrpcAioInstrumentorServer()
grpc_server_instrumentor.instrument()

try:
rpc_call = "/GRPCTestServer/SimpleMethod"

async def request(channel):
request = Request(client_id=1, request_data="test")
msg = request.SerializeToString()
return await channel.unary_unary(rpc_call)(msg)

class MockInterceptor(grpc.aio.ServerInterceptor):
async def intercept_service(self, continuation, handler_call_details):
return await continuation(handler_call_details)

await run_with_test_server(request, interceptors=(MockInterceptor(),))

finally:
grpc_server_instrumentor.uninstrument()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

self.assertEqual(span.name, rpc_call)
self.assertIs(span.kind, trace.SpanKind.SERVER)

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationScope(
span, opentelemetry.instrumentation.grpc
)

# Check attributes
self.assertSpanHasAttributes(
span,
{
SpanAttributes.NET_PEER_IP: "[::1]",
SpanAttributes.NET_PEER_NAME: "localhost",
SpanAttributes.RPC_METHOD: "SimpleMethod",
SpanAttributes.RPC_SERVICE: "GRPCTestServer",
SpanAttributes.RPC_SYSTEM: "grpc",
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[
0
],
},
)

def get_latch(num):
"""Get a countdown latch function for use in n threads."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tempfile
import threading
from concurrent import futures
from unittest import mock

import grpc

Expand Down Expand Up @@ -104,41 +105,44 @@ def handler(request, context):

grpc_server_instrumentor = GrpcInstrumentorServer()
grpc_server_instrumentor.instrument()
with self.server(max_workers=1) as (server, channel):
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
rpc_call = "TestServicer/handler"
try:
server.start()
channel.unary_unary(rpc_call)(b"test")
finally:
server.stop(None)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertEqual(span.name, rpc_call)
self.assertIs(span.kind, trace.SpanKind.SERVER)
try:
with self.server(max_workers=1) as (server, channel):
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
rpc_call = "TestServicer/handler"
try:
server.start()
channel.unary_unary(rpc_call)(b"test")
finally:
server.stop(None)

finally:
grpc_server_instrumentor.uninstrument()

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationScope(
span, opentelemetry.instrumentation.grpc
)
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertEqual(span.name, rpc_call)
self.assertIs(span.kind, trace.SpanKind.SERVER)

# Check attributes
self.assertSpanHasAttributes(
span,
{
**self.net_peer_span_attributes,
SpanAttributes.RPC_METHOD: "handler",
SpanAttributes.RPC_SERVICE: "TestServicer",
SpanAttributes.RPC_SYSTEM: "grpc",
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[
0
],
},
)
# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationScope(
span, opentelemetry.instrumentation.grpc
)

grpc_server_instrumentor.uninstrument()
# Check attributes
self.assertSpanHasAttributes(
span,
{
**self.net_peer_span_attributes,
SpanAttributes.RPC_METHOD: "handler",
SpanAttributes.RPC_SERVICE: "TestServicer",
SpanAttributes.RPC_SYSTEM: "grpc",
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[
0
],
},
)

def test_uninstrument(self):
def handler(request, context):
Expand Down Expand Up @@ -647,6 +651,55 @@ def unset_status_handler(request, context):
},
)

def test_non_list_interceptors(self):
"""Check that we handle non-list interceptors correctly."""
grpc_server_instrumentor = GrpcInstrumentorServer()
grpc_server_instrumentor.instrument()

try:
with self.server(
max_workers=1,
interceptors=(mock.MagicMock(),),
) as (server, channel):
add_GRPCTestServerServicer_to_server(Servicer(), server)

rpc_call = "/GRPCTestServer/SimpleMethod"
request = Request(client_id=1, request_data="test")
msg = request.SerializeToString()
try:
server.start()
channel.unary_unary(rpc_call)(msg)
finally:
server.stop(None)
finally:
grpc_server_instrumentor.uninstrument()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

self.assertEqual(span.name, rpc_call)
self.assertIs(span.kind, trace.SpanKind.SERVER)

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationScope(
span, opentelemetry.instrumentation.grpc
)

# Check attributes
self.assertSpanHasAttributes(
span,
{
**self.net_peer_span_attributes,
SpanAttributes.RPC_METHOD: "SimpleMethod",
SpanAttributes.RPC_SERVICE: "GRPCTestServer",
SpanAttributes.RPC_SYSTEM: "grpc",
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[
0
],
},
)


class TestOpenTelemetryServerInterceptorUnix(
TestOpenTelemetryServerInterceptor,
Expand Down
Loading