| 
13 | 13 | # limitations under the License.  | 
14 | 14 | # pylint:disable=cyclic-import  | 
15 | 15 | 
 
  | 
 | 16 | +from unittest import mock  | 
 | 17 | + | 
16 | 18 | import grpc  | 
17 | 19 | 
 
  | 
18 | 20 | import opentelemetry.instrumentation.grpc  | 
 | 
26 | 28 | )  | 
27 | 29 | from opentelemetry.instrumentation.utils import suppress_instrumentation  | 
28 | 30 | from opentelemetry.propagate import get_global_textmap, set_global_textmap  | 
 | 31 | +from opentelemetry.sdk.trace import Span as SdkSpan  | 
29 | 32 | from opentelemetry.semconv.trace import SpanAttributes  | 
30 | 33 | from opentelemetry.test.mock_textmap import MockTextMapPropagator  | 
31 | 34 | from opentelemetry.test.test_base import TestBase  | 
@@ -274,41 +277,26 @@ def test_client_interceptor_falsy_response(  | 
274 | 277 |     ):  # pylint: disable=no-self-use  | 
275 | 278 |         """ensure that client interceptor closes the span only once even if the response is falsy."""  | 
276 | 279 | 
 
  | 
277 |  | -        span_end_count = 0  | 
278 |  | -        tracer_provider, _exporter = self.create_tracer_provider()  | 
279 |  | -        tracer = tracer_provider.get_tracer(__name__)  | 
280 |  | -        original_start_span = tracer.start_span  | 
281 |  | - | 
282 |  | -        def counting_span_end(original_end):  | 
283 |  | -            def wrapper(*args, **kwargs):  | 
284 |  | -                nonlocal span_end_count  | 
285 |  | -                span_end_count += 1  | 
286 |  | -                return original_end(*args, **kwargs)  | 
287 |  | - | 
288 |  | -            return wrapper  | 
289 |  | - | 
290 |  | -        def patched_start_span(*args, **kwargs):  | 
291 |  | -            span = original_start_span(*args, **kwargs)  | 
292 |  | -            span.end = counting_span_end(span.end)  | 
293 |  | -            return span  | 
294 |  | - | 
295 |  | -        tracer.start_span = patched_start_span  | 
296 |  | -        interceptor = OpenTelemetryClientInterceptor(tracer)  | 
297 |  | - | 
298 |  | -        def invoker(_request, _metadata):  | 
299 |  | -            return {}  | 
300 |  | - | 
301 |  | -        request = Request(client_id=1, request_data="data")  | 
302 |  | -        interceptor.intercept_unary(  | 
303 |  | -            request,  | 
304 |  | -            {},  | 
305 |  | -            _UnaryClientInfo(  | 
306 |  | -                full_method="/GRPCTestServer/SimpleMethod",  | 
307 |  | -                timeout=None,  | 
308 |  | -            ),  | 
309 |  | -            invoker=invoker,  | 
310 |  | -        )  | 
311 |  | -        assert span_end_count == 1  | 
 | 280 | +        with mock.patch.object(SdkSpan, "end") as span_end_mock:  | 
 | 281 | +            tracer_provider, _exporter = self.create_tracer_provider()  | 
 | 282 | +            tracer = tracer_provider.get_tracer(__name__)  | 
 | 283 | + | 
 | 284 | +            interceptor = OpenTelemetryClientInterceptor(tracer)  | 
 | 285 | + | 
 | 286 | +            def invoker(_request, _metadata):  | 
 | 287 | +                return {}  | 
 | 288 | + | 
 | 289 | +            request = Request(client_id=1, request_data="data")  | 
 | 290 | +            interceptor.intercept_unary(  | 
 | 291 | +                request,  | 
 | 292 | +                {},  | 
 | 293 | +                _UnaryClientInfo(  | 
 | 294 | +                    full_method="/GRPCTestServer/SimpleMethod",  | 
 | 295 | +                    timeout=None,  | 
 | 296 | +                ),  | 
 | 297 | +                invoker=invoker,  | 
 | 298 | +            )  | 
 | 299 | +            assert span_end_mock.call_count == 1  | 
312 | 300 | 
 
  | 
313 | 301 |     def test_client_interceptor_trace_context_propagation(  | 
314 | 302 |         self,  | 
 | 
0 commit comments