Skip to content

Commit b28b222

Browse files
author
Kevin Rauwolf
committed
Use mock to count span end calls instead of manual patching
1 parent 8560386 commit b28b222

File tree

1 file changed

+23
-35
lines changed

1 file changed

+23
-35
lines changed

instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
# pylint:disable=cyclic-import
1515

16+
from unittest import mock
17+
1618
import grpc
1719

1820
import opentelemetry.instrumentation.grpc
@@ -26,6 +28,7 @@
2628
)
2729
from opentelemetry.instrumentation.utils import suppress_instrumentation
2830
from opentelemetry.propagate import get_global_textmap, set_global_textmap
31+
from opentelemetry.sdk.trace import Span as SdkSpan
2932
from opentelemetry.semconv.trace import SpanAttributes
3033
from opentelemetry.test.mock_textmap import MockTextMapPropagator
3134
from opentelemetry.test.test_base import TestBase
@@ -274,41 +277,26 @@ def test_client_interceptor_falsy_response(
274277
): # pylint: disable=no-self-use
275278
"""ensure that client interceptor closes the span only once even if the response is falsy."""
276279

277-
span_end_count = 0
278-
tracer_provider, _exporter = self.create_tracer_provider()
279-
tracer = tracer_provider.get_tracer(__name__)
280-
original_start_span = tracer.start_span
281-
282-
def counting_span_end(original_end):
283-
def wrapper(*args, **kwargs):
284-
nonlocal span_end_count
285-
span_end_count += 1
286-
return original_end(*args, **kwargs)
287-
288-
return wrapper
289-
290-
def patched_start_span(*args, **kwargs):
291-
span = original_start_span(*args, **kwargs)
292-
span.end = counting_span_end(span.end)
293-
return span
294-
295-
tracer.start_span = patched_start_span
296-
interceptor = OpenTelemetryClientInterceptor(tracer)
297-
298-
def invoker(_request, _metadata):
299-
return {}
300-
301-
request = Request(client_id=1, request_data="data")
302-
interceptor.intercept_unary(
303-
request,
304-
{},
305-
_UnaryClientInfo(
306-
full_method="/GRPCTestServer/SimpleMethod",
307-
timeout=None,
308-
),
309-
invoker=invoker,
310-
)
311-
assert span_end_count == 1
280+
with mock.patch.object(SdkSpan, "end") as span_end_mock:
281+
tracer_provider, _exporter = self.create_tracer_provider()
282+
tracer = tracer_provider.get_tracer(__name__)
283+
284+
interceptor = OpenTelemetryClientInterceptor(tracer)
285+
286+
def invoker(_request, _metadata):
287+
return {}
288+
289+
request = Request(client_id=1, request_data="data")
290+
interceptor.intercept_unary(
291+
request,
292+
{},
293+
_UnaryClientInfo(
294+
full_method="/GRPCTestServer/SimpleMethod",
295+
timeout=None,
296+
),
297+
invoker=invoker,
298+
)
299+
assert span_end_mock.call_count == 1
312300

313301
def test_client_interceptor_trace_context_propagation(
314302
self,

0 commit comments

Comments
 (0)