diff --git a/CHANGELOG.md b/CHANGELOG.md index f0fd2243dd..c2ff24d2d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -94,6 +94,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2461](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2461)) - Remove SDK dependency from opentelemetry-instrumentation-grpc ([#2474](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2474)) +- `opentelemetry-instrumentation-grpc` User should be able to cancel grpc stream + ([#2093](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2093)) - `opentelemetry-instrumentation-elasticsearch` Improved support for version 8 ([#2420](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2420)) - `opentelemetry-instrumentation-elasticsearch` Disabling instrumentation with native OTel support enabled diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_client.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_client.py index e27c9e826f..6ee6e95653 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_client.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_client.py @@ -24,6 +24,7 @@ from typing import Callable, MutableMapping import grpc +import wrapt from opentelemetry import trace from opentelemetry.instrumentation.grpc import grpcext @@ -73,6 +74,58 @@ def _safe_invoke(function: Callable, *args): ) +# pylint:disable=abstract-method +class OpenTelemetryStreamWrapper(wrapt.ObjectProxy): + def __init__(self, wrapped, span: trace.Span): + super().__init__(wrapped) + self._self_span = span + self._span_ended = False + + def _end_span_if_not_already_ended(self, status_code=None, status=None): + if self._span_ended: + return + + if status_code is not None: + self._self_span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, status_code + ) + if status is not None: + self._self_span.set_status(status) + self._span_ended = True + self._self_span.end() + + def __del__(self): + self._end_span_if_not_already_ended() + self.__wrapped__.__del__() + + def __iter__(self): + return self + + def cancel(self): + self._end_span_if_not_already_ended( + status_code=grpc.StatusCode.CANCELLED.value[0] + ) + return self.__wrapped__.cancel() + + def __next__(self): + return self._next() + + def next(self): + return self._next() + + def _next(self): + try: + return self.__wrapped__._next() + except StopIteration: + self._end_span_if_not_already_ended() + raise + except grpc.RpcError as err: + self._end_span_if_not_already_ended( + err.code().value[0], Status(StatusCode.ERROR) + ) + raise err + + class OpenTelemetryClientInterceptor( grpcext.UnaryClientInterceptor, grpcext.StreamClientInterceptor ): @@ -195,7 +248,9 @@ def _intercept_server_stream( else: mutable_metadata = OrderedDict(metadata) - with self._start_span(client_info.full_method) as span: + with self._start_span( + client_info.full_method, end_on_exit=False + ) as span: inject(mutable_metadata, setter=_carrier_setter) metadata = tuple(mutable_metadata.items()) rpc_info = RpcInfo( @@ -207,14 +262,9 @@ def _intercept_server_stream( if client_info.is_client_stream: rpc_info.request = request_or_iterator - try: - yield from invoker(request_or_iterator, metadata) - except grpc.RpcError as err: - span.set_status(Status(StatusCode.ERROR)) - span.set_attribute( - SpanAttributes.RPC_GRPC_STATUS_CODE, err.code().value[0] - ) - raise err + stream = invoker(request_or_iterator, metadata) + + return OpenTelemetryStreamWrapper(stream, span) def intercept_stream( self, request_or_iterator, metadata, client_info, invoker diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/_client.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/_client.py index 67e7d0a625..38e40fe2d0 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/tests/_client.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/_client.py @@ -45,17 +45,19 @@ def request_messages(): ) -def server_streaming_method(stub, error=False): +def server_streaming_method(stub, error=False, serialize=True): request = Request( client_id=CLIENT_ID, request_data="error" if error else "data" ) response_iterator = stub.ServerStreamingMethod( request, metadata=(("key", "value"),) ) - list(response_iterator) + if serialize: + list(response_iterator) + return response_iterator -def bidirectional_streaming_method(stub, error=False): +def bidirectional_streaming_method(stub, error=False, serialize=True): def request_messages(): for _ in range(5): request = Request( @@ -66,5 +68,6 @@ def request_messages(): response_iterator = stub.BidirectionalStreamingMethod( request_messages(), metadata=(("key", "value"),) ) - - list(response_iterator) + if serialize: + list(response_iterator) + return response_iterator diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor.py index 38759352b3..98cb472cb1 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor.py @@ -167,6 +167,63 @@ def test_unary_stream(self): }, ) + def test_unary_stream_can_be_cancel(self): + responses = server_streaming_method(self._stub, serialize=False) + for response_num, _ in enumerate(responses): + if response_num == 1: + responses.cancel() + break + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertEqual(span.name, "/GRPCTestServer/ServerStreamingMethod") + self.assertIs(span.kind, trace.SpanKind.CLIENT) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + self.assertSpanHasAttributes( + span, + { + SpanAttributes.RPC_METHOD: "ServerStreamingMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.CANCELLED.value[ + 0 + ], + }, + ) + + def test_finished_stream_cancel_does_not_change_status_of_span(self): + responses = server_streaming_method(self._stub, serialize=True) + responses.cancel() + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertEqual(span.name, "/GRPCTestServer/ServerStreamingMethod") + self.assertIs(span.kind, trace.SpanKind.CLIENT) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + self.assertSpanHasAttributes( + span, + { + SpanAttributes.RPC_METHOD: "ServerStreamingMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + def test_stream_unary(self): client_streaming_method(self._stub) spans = self.memory_exporter.get_finished_spans() @@ -221,6 +278,38 @@ def test_stream_stream(self): }, ) + def test_stream_stream_can_be_cancel(self): + responses = bidirectional_streaming_method(self._stub, serialize=False) + for response_num, _ in enumerate(responses): + if response_num == 1: + responses.cancel() + break + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertEqual( + span.name, "/GRPCTestServer/BidirectionalStreamingMethod" + ) + self.assertIs(span.kind, trace.SpanKind.CLIENT) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + self.assertSpanHasAttributes( + span, + { + SpanAttributes.RPC_METHOD: "BidirectionalStreamingMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.CANCELLED.value[ + 0 + ], + }, + ) + def test_error_simple(self): with self.assertRaises(grpc.RpcError): simple_method(self._stub, error=True)