Skip to content

Commit 6435bd4

Browse files
author
OlegZv
committed
Add connection instrumentation based on the existing pattern.
add a helper function assert_span_count to simplify tests add unit tests for pipeline hooks
1 parent 4a206aa commit 6435bd4

File tree

2 files changed

+189
-95
lines changed

2 files changed

+189
-95
lines changed

instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py

Lines changed: 94 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def response_hook(span, instance, response):
9191
---
9292
"""
9393

94+
import logging
9495
import typing
9596
from typing import Any, Collection
9697

@@ -121,16 +122,27 @@ def response_hook(span, instance, response):
121122
_ResponseHookT = typing.Optional[
122123
typing.Callable[[Span, redis.connection.Connection, Any], None]
123124
]
125+
_logger = logging.getLogger(__name__)
126+
assert hasattr(redis, "VERSION")
124127

125128
_REDIS_ASYNCIO_VERSION = (4, 2, 0)
126-
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
127-
import redis.asyncio
128-
129129
_REDIS_CLUSTER_VERSION = (4, 1, 0)
130130
_REDIS_ASYNCIO_CLUSTER_VERSION = (4, 3, 2)
131131

132132
_FIELD_TYPES = ["NUMERIC", "TEXT", "GEO", "TAG", "VECTOR"]
133133

134+
_CLIENT_ASYNCIO_SUPPORT = redis.VERSION >= _REDIS_ASYNCIO_VERSION
135+
_CLIENT_ASYNCIO_CLUSTER_SUPPORT = (
136+
redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION
137+
)
138+
_CLIENT_CLUSTER_SUPPORT = redis.VERSION >= _REDIS_CLUSTER_VERSION
139+
_CLIENT_BEFORE_3_0_0 = redis.VERSION < (3, 0, 0)
140+
141+
if _CLIENT_ASYNCIO_SUPPORT:
142+
import redis.asyncio
143+
144+
INSTRUMENTATION_ATTR = "_is_instrumented_by_opentelemetry"
145+
134146

135147
def _set_connection_attributes(span, conn):
136148
if not span.is_recording() or not hasattr(conn, "connection_pool"):
@@ -388,10 +400,8 @@ def _instrument(
388400
_traced_execute_pipeline = _traced_execute_pipeline_factory(
389401
tracer, request_hook, response_hook
390402
)
391-
pipeline_class = (
392-
"BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline"
393-
)
394-
redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis"
403+
pipeline_class = "BasePipeline" if _CLIENT_BEFORE_3_0_0 else "Pipeline"
404+
redis_class = "StrictRedis" if _CLIENT_BEFORE_3_0_0 else "Redis"
395405

396406
wrap_function_wrapper(
397407
"redis", f"{redis_class}.execute_command", _traced_execute_command
@@ -453,68 +463,55 @@ def _instrument(
453463
)
454464

455465

456-
def _instrument_client(
466+
def _instrument_connection(
457467
client,
458468
tracer,
459469
request_hook: _RequestHookT = None,
460470
response_hook: _ResponseHookT = None,
461471
):
462-
# first, handle async clients
463-
_async_traced_execute_command = _async_traced_execute_factory(
472+
# first, handle async clients and cluster clients
473+
_async_traced_execute = _async_traced_execute_factory(
464474
tracer, request_hook, response_hook
465475
)
466476
_async_traced_execute_pipeline = _async_traced_execute_pipeline_factory(
467477
tracer, request_hook, response_hook
468478
)
469479

470-
def _async_pipeline_wrapper(func, instance, args, kwargs):
471-
result = func(*args, **kwargs)
472-
wrap_function_wrapper(
473-
result, "execute", _async_traced_execute_pipeline
474-
)
475-
wrap_function_wrapper(
476-
result, "immediate_execute_command", _async_traced_execute_command
477-
)
478-
return result
480+
if _CLIENT_ASYNCIO_SUPPORT and isinstance(client, redis.asyncio.Redis):
479481

480-
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
481-
client_type = (
482-
redis.asyncio.StrictRedis
483-
if redis.VERSION < (3, 0, 0)
484-
else redis.asyncio.Redis
485-
)
486-
487-
if isinstance(client, client_type):
482+
def _async_pipeline_wrapper(func, instance, args, kwargs):
483+
result = func(*args, **kwargs)
488484
wrap_function_wrapper(
489-
client, "execute_command", _async_traced_execute_command
485+
result, "execute", _async_traced_execute_pipeline
490486
)
491-
wrap_function_wrapper(client, "pipeline", _async_pipeline_wrapper)
492-
return
487+
wrap_function_wrapper(
488+
result, "immediate_execute_command", _async_traced_execute
489+
)
490+
return result
493491

494-
def _async_cluster_pipeline_wrapper(func, instance, args, kwargs):
495-
result = func(*args, **kwargs)
496-
wrap_function_wrapper(
497-
result, "execute", _async_traced_execute_pipeline
498-
)
499-
return result
492+
wrap_function_wrapper(client, "execute_command", _async_traced_execute)
493+
wrap_function_wrapper(client, "pipeline", _async_pipeline_wrapper)
494+
return
500495

501-
# handle
502-
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION and isinstance(
496+
if _CLIENT_ASYNCIO_CLUSTER_SUPPORT and isinstance(
503497
client, redis.asyncio.RedisCluster
504498
):
505-
wrap_function_wrapper(
506-
client, "execute_command", _async_traced_execute_command
507-
)
499+
500+
def _async_cluster_pipeline_wrapper(func, instance, args, kwargs):
501+
result = func(*args, **kwargs)
502+
wrap_function_wrapper(
503+
result, "execute", _async_traced_execute_pipeline
504+
)
505+
return result
506+
507+
wrap_function_wrapper(client, "execute_command", _async_traced_execute)
508508
wrap_function_wrapper(
509509
client, "pipeline", _async_cluster_pipeline_wrapper
510510
)
511511
return
512512
# for redis.client.Redis, redis.Cluster and v3.0.0 redis.client.StrictRedis
513513
# the wrappers are the same
514-
# client_type = (
515-
# redis.client.StrictRedis if redis.VERSION < (3, 0, 0) else redis.client.Redis
516-
# )
517-
_traced_execute_command = _traced_execute_factory(
514+
_traced_execute = _traced_execute_factory(
518515
tracer, request_hook, response_hook
519516
)
520517
_traced_execute_pipeline = _traced_execute_pipeline_factory(
@@ -525,14 +522,14 @@ def _pipeline_wrapper(func, instance, args, kwargs):
525522
result = func(*args, **kwargs)
526523
wrap_function_wrapper(result, "execute", _traced_execute_pipeline)
527524
wrap_function_wrapper(
528-
result, "immediate_execute_command", _traced_execute_command
525+
result, "immediate_execute_command", _traced_execute
529526
)
530527
return result
531528

532529
wrap_function_wrapper(
533530
client,
534531
"execute_command",
535-
_traced_execute_command,
532+
_traced_execute,
536533
)
537534
wrap_function_wrapper(
538535
client,
@@ -546,6 +543,16 @@ class RedisInstrumentor(BaseInstrumentor):
546543
See `BaseInstrumentor`
547544
"""
548545

546+
@staticmethod
547+
def _get_tracer(**kwargs):
548+
tracer_provider = kwargs.get("tracer_provider")
549+
return trace.get_tracer(
550+
__name__,
551+
__version__,
552+
tracer_provider=tracer_provider,
553+
schema_url="https://opentelemetry.io/schemas/1.11.0",
554+
)
555+
549556
def instrumentation_dependencies(self) -> Collection[str]:
550557
return _instruments
551558

@@ -557,30 +564,14 @@ def _instrument(self, **kwargs):
557564
``tracer_provider``: a TracerProvider, defaults to global.
558565
``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
559566
"""
560-
tracer_provider = kwargs.get("tracer_provider")
561-
tracer = trace.get_tracer(
562-
__name__,
563-
__version__,
564-
tracer_provider=tracer_provider,
565-
schema_url="https://opentelemetry.io/schemas/1.11.0",
567+
_instrument(
568+
self._get_tracer(**kwargs),
569+
request_hook=kwargs.get("request_hook"),
570+
response_hook=kwargs.get("response_hook"),
566571
)
567-
redis_client = kwargs.get("client")
568-
if redis_client:
569-
_instrument_client(
570-
redis_client,
571-
tracer,
572-
request_hook=kwargs.get("request_hook"),
573-
response_hook=kwargs.get("response_hook"),
574-
)
575-
else:
576-
_instrument(
577-
tracer,
578-
request_hook=kwargs.get("request_hook"),
579-
response_hook=kwargs.get("response_hook"),
580-
)
581572

582573
def _uninstrument(self, **kwargs):
583-
if redis.VERSION < (3, 0, 0):
574+
if _CLIENT_BEFORE_3_0_0:
584575
unwrap(redis.StrictRedis, "execute_command")
585576
unwrap(redis.StrictRedis, "pipeline")
586577
unwrap(redis.Redis, "pipeline")
@@ -608,3 +599,38 @@ def _uninstrument(self, **kwargs):
608599
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
609600
unwrap(redis.asyncio.cluster.RedisCluster, "execute_command")
610601
unwrap(redis.asyncio.cluster.ClusterPipeline, "execute")
602+
603+
@staticmethod
604+
def instrument_connection(
605+
client, tracer_provider: None, request_hook=None, response_hook=None
606+
):
607+
if not hasattr(client, INSTRUMENTATION_ATTR):
608+
setattr(client, INSTRUMENTATION_ATTR, False)
609+
if not getattr(client, INSTRUMENTATION_ATTR):
610+
_instrument_connection(
611+
client,
612+
RedisInstrumentor._get_tracer(tracer_provider=tracer_provider),
613+
request_hook=request_hook,
614+
response_hook=response_hook,
615+
)
616+
setattr(client, INSTRUMENTATION_ATTR, True)
617+
else:
618+
_logger.warning(
619+
"Attempting to instrument Redis connection while already instrumented"
620+
)
621+
622+
@staticmethod
623+
def uninstrument_connection(client):
624+
if getattr(client, INSTRUMENTATION_ATTR):
625+
# for all clients we need to unwrap execute_command and pipeline functions
626+
unwrap(client, "execute_command")
627+
# pipeline was creating a pipeline and wrapping the functions of the
628+
# created instance. any pipeline created before un-instrumenting will
629+
# remain instrumented (pipelines should usually have a short span)
630+
unwrap(client, "pipeline")
631+
pass
632+
else:
633+
_logger.warning(
634+
"Attempting to un-instrument Redis connection that wasn't instrumented"
635+
)
636+
return

0 commit comments

Comments
 (0)