|
13 | 13 | # limitations under the License. |
14 | 14 | # pylint:disable=cyclic-import |
15 | 15 |
|
| 16 | +from unittest import mock |
| 17 | + |
16 | 18 | import grpc |
17 | 19 |
|
18 | 20 | import opentelemetry.instrumentation.grpc |
|
26 | 28 | ) |
27 | 29 | from opentelemetry.instrumentation.utils import suppress_instrumentation |
28 | 30 | from opentelemetry.propagate import get_global_textmap, set_global_textmap |
| 31 | +from opentelemetry.sdk.trace import Span as SdkSpan |
29 | 32 | from opentelemetry.semconv.trace import SpanAttributes |
30 | 33 | from opentelemetry.test.mock_textmap import MockTextMapPropagator |
31 | 34 | from opentelemetry.test.test_base import TestBase |
@@ -274,41 +277,26 @@ def test_client_interceptor_falsy_response( |
274 | 277 | ): # pylint: disable=no-self-use |
275 | 278 | """ensure that client interceptor closes the span only once even if the response is falsy.""" |
276 | 279 |
|
277 | | - span_end_count = 0 |
278 | | - tracer_provider, _exporter = self.create_tracer_provider() |
279 | | - tracer = tracer_provider.get_tracer(__name__) |
280 | | - original_start_span = tracer.start_span |
281 | | - |
282 | | - def counting_span_end(original_end): |
283 | | - def wrapper(*args, **kwargs): |
284 | | - nonlocal span_end_count |
285 | | - span_end_count += 1 |
286 | | - return original_end(*args, **kwargs) |
287 | | - |
288 | | - return wrapper |
289 | | - |
290 | | - def patched_start_span(*args, **kwargs): |
291 | | - span = original_start_span(*args, **kwargs) |
292 | | - span.end = counting_span_end(span.end) |
293 | | - return span |
294 | | - |
295 | | - tracer.start_span = patched_start_span |
296 | | - interceptor = OpenTelemetryClientInterceptor(tracer) |
297 | | - |
298 | | - def invoker(_request, _metadata): |
299 | | - return {} |
300 | | - |
301 | | - request = Request(client_id=1, request_data="data") |
302 | | - interceptor.intercept_unary( |
303 | | - request, |
304 | | - {}, |
305 | | - _UnaryClientInfo( |
306 | | - full_method="/GRPCTestServer/SimpleMethod", |
307 | | - timeout=None, |
308 | | - ), |
309 | | - invoker=invoker, |
310 | | - ) |
311 | | - assert span_end_count == 1 |
| 280 | + with mock.patch.object(SdkSpan, "end") as span_end_mock: |
| 281 | + tracer_provider, _exporter = self.create_tracer_provider() |
| 282 | + tracer = tracer_provider.get_tracer(__name__) |
| 283 | + |
| 284 | + interceptor = OpenTelemetryClientInterceptor(tracer) |
| 285 | + |
| 286 | + def invoker(_request, _metadata): |
| 287 | + return {} |
| 288 | + |
| 289 | + request = Request(client_id=1, request_data="data") |
| 290 | + interceptor.intercept_unary( |
| 291 | + request, |
| 292 | + {}, |
| 293 | + _UnaryClientInfo( |
| 294 | + full_method="/GRPCTestServer/SimpleMethod", |
| 295 | + timeout=None, |
| 296 | + ), |
| 297 | + invoker=invoker, |
| 298 | + ) |
| 299 | + assert span_end_mock.call_count == 1 |
312 | 300 |
|
313 | 301 | def test_client_interceptor_trace_context_propagation( |
314 | 302 | self, |
|
0 commit comments