Skip to content

Commit c33f2b5

Browse files
committed
Add connection instrumentation based on the existing pattern.
add a helper function assert_span_count to simplify tests add unit tests for pipeline hooks
1 parent 2c291bf commit c33f2b5

File tree

2 files changed

+188
-96
lines changed

2 files changed

+188
-96
lines changed

instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py

Lines changed: 93 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def response_hook(span, instance, response):
9393

9494
from __future__ import annotations
9595

96+
import logging
9697
from typing import TYPE_CHECKING, Any, Callable, Collection
9798

9899
import redis
@@ -146,17 +147,26 @@ def response_hook(span, instance, response):
146147

147148

148149
_DEFAULT_SERVICE = "redis"
149-
150+
_logger = logging.getLogger(__name__)
150151

151152
_REDIS_ASYNCIO_VERSION = (4, 2, 0)
152-
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
153-
import redis.asyncio
154-
155153
_REDIS_CLUSTER_VERSION = (4, 1, 0)
156154
_REDIS_ASYNCIO_CLUSTER_VERSION = (4, 3, 2)
157155

158156
_FIELD_TYPES = ["NUMERIC", "TEXT", "GEO", "TAG", "VECTOR"]
159157

158+
_CLIENT_ASYNCIO_SUPPORT = redis.VERSION >= _REDIS_ASYNCIO_VERSION
159+
_CLIENT_ASYNCIO_CLUSTER_SUPPORT = (
160+
redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION
161+
)
162+
_CLIENT_CLUSTER_SUPPORT = redis.VERSION >= _REDIS_CLUSTER_VERSION
163+
_CLIENT_BEFORE_3_0_0 = redis.VERSION < (3, 0, 0)
164+
165+
if _CLIENT_ASYNCIO_SUPPORT:
166+
import redis.asyncio
167+
168+
INSTRUMENTATION_ATTR = "_is_instrumented_by_opentelemetry"
169+
160170

161171
def _set_connection_attributes(
162172
span: Span, conn: RedisInstance | AsyncRedisInstance
@@ -440,10 +450,8 @@ def _instrument(
440450
_traced_execute_pipeline = _traced_execute_pipeline_factory(
441451
tracer, request_hook, response_hook
442452
)
443-
pipeline_class = (
444-
"BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline"
445-
)
446-
redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis"
453+
pipeline_class = "BasePipeline" if _CLIENT_BEFORE_3_0_0 else "Pipeline"
454+
redis_class = "StrictRedis" if _CLIENT_BEFORE_3_0_0 else "Redis"
447455

448456
wrap_function_wrapper(
449457
"redis", f"{redis_class}.execute_command", _traced_execute_command
@@ -505,68 +513,55 @@ def _instrument(
505513
)
506514

507515

508-
def _instrument_client(
516+
def _instrument_connection(
509517
client,
510518
tracer,
511519
request_hook: _RequestHookT = None,
512520
response_hook: _ResponseHookT = None,
513521
):
514-
# first, handle async clients
515-
_async_traced_execute_command = _async_traced_execute_factory(
522+
# first, handle async clients and cluster clients
523+
_async_traced_execute = _async_traced_execute_factory(
516524
tracer, request_hook, response_hook
517525
)
518526
_async_traced_execute_pipeline = _async_traced_execute_pipeline_factory(
519527
tracer, request_hook, response_hook
520528
)
521529

522-
def _async_pipeline_wrapper(func, instance, args, kwargs):
523-
result = func(*args, **kwargs)
524-
wrap_function_wrapper(
525-
result, "execute", _async_traced_execute_pipeline
526-
)
527-
wrap_function_wrapper(
528-
result, "immediate_execute_command", _async_traced_execute_command
529-
)
530-
return result
531-
532-
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
533-
client_type = (
534-
redis.asyncio.StrictRedis
535-
if redis.VERSION < (3, 0, 0)
536-
else redis.asyncio.Redis
537-
)
530+
if _CLIENT_ASYNCIO_SUPPORT and isinstance(client, redis.asyncio.Redis):
538531

539-
if isinstance(client, client_type):
532+
def _async_pipeline_wrapper(func, instance, args, kwargs):
533+
result = func(*args, **kwargs)
540534
wrap_function_wrapper(
541-
client, "execute_command", _async_traced_execute_command
535+
result, "execute", _async_traced_execute_pipeline
542536
)
543-
wrap_function_wrapper(client, "pipeline", _async_pipeline_wrapper)
544-
return
537+
wrap_function_wrapper(
538+
result, "immediate_execute_command", _async_traced_execute
539+
)
540+
return result
545541

546-
def _async_cluster_pipeline_wrapper(func, instance, args, kwargs):
547-
result = func(*args, **kwargs)
548-
wrap_function_wrapper(
549-
result, "execute", _async_traced_execute_pipeline
550-
)
551-
return result
542+
wrap_function_wrapper(client, "execute_command", _async_traced_execute)
543+
wrap_function_wrapper(client, "pipeline", _async_pipeline_wrapper)
544+
return
552545

553-
# handle
554-
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION and isinstance(
546+
if _CLIENT_ASYNCIO_CLUSTER_SUPPORT and isinstance(
555547
client, redis.asyncio.RedisCluster
556548
):
557-
wrap_function_wrapper(
558-
client, "execute_command", _async_traced_execute_command
559-
)
549+
550+
def _async_cluster_pipeline_wrapper(func, instance, args, kwargs):
551+
result = func(*args, **kwargs)
552+
wrap_function_wrapper(
553+
result, "execute", _async_traced_execute_pipeline
554+
)
555+
return result
556+
557+
wrap_function_wrapper(client, "execute_command", _async_traced_execute)
560558
wrap_function_wrapper(
561559
client, "pipeline", _async_cluster_pipeline_wrapper
562560
)
563561
return
564562
# for redis.client.Redis, redis.Cluster and v3.0.0 redis.client.StrictRedis
565563
# the wrappers are the same
566-
# client_type = (
567-
# redis.client.StrictRedis if redis.VERSION < (3, 0, 0) else redis.client.Redis
568-
# )
569-
_traced_execute_command = _traced_execute_factory(
564+
_traced_execute = _traced_execute_factory(
570565
tracer, request_hook, response_hook
571566
)
572567
_traced_execute_pipeline = _traced_execute_pipeline_factory(
@@ -577,14 +572,14 @@ def _pipeline_wrapper(func, instance, args, kwargs):
577572
result = func(*args, **kwargs)
578573
wrap_function_wrapper(result, "execute", _traced_execute_pipeline)
579574
wrap_function_wrapper(
580-
result, "immediate_execute_command", _traced_execute_command
575+
result, "immediate_execute_command", _traced_execute
581576
)
582577
return result
583578

584579
wrap_function_wrapper(
585580
client,
586581
"execute_command",
587-
_traced_execute_command,
582+
_traced_execute,
588583
)
589584
wrap_function_wrapper(
590585
client,
@@ -599,6 +594,16 @@ class RedisInstrumentor(BaseInstrumentor):
599594
See `BaseInstrumentor`
600595
"""
601596

597+
@staticmethod
598+
def _get_tracer(**kwargs):
599+
tracer_provider = kwargs.get("tracer_provider")
600+
return trace.get_tracer(
601+
__name__,
602+
__version__,
603+
tracer_provider=tracer_provider,
604+
schema_url="https://opentelemetry.io/schemas/1.11.0",
605+
)
606+
602607
def instrumentation_dependencies(self) -> Collection[str]:
603608
return _instruments
604609

@@ -610,30 +615,14 @@ def _instrument(self, **kwargs: Any):
610615
``tracer_provider``: a TracerProvider, defaults to global.
611616
``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
612617
"""
613-
tracer_provider = kwargs.get("tracer_provider")
614-
tracer = trace.get_tracer(
615-
__name__,
616-
__version__,
617-
tracer_provider=tracer_provider,
618-
schema_url="https://opentelemetry.io/schemas/1.11.0",
618+
_instrument(
619+
self._get_tracer(**kwargs),
620+
request_hook=kwargs.get("request_hook"),
621+
response_hook=kwargs.get("response_hook"),
619622
)
620-
redis_client = kwargs.get("client")
621-
if redis_client:
622-
_instrument_client(
623-
redis_client,
624-
tracer,
625-
request_hook=kwargs.get("request_hook"),
626-
response_hook=kwargs.get("response_hook"),
627-
)
628-
else:
629-
_instrument(
630-
tracer,
631-
request_hook=kwargs.get("request_hook"),
632-
response_hook=kwargs.get("response_hook"),
633-
)
634623

635624
def _uninstrument(self, **kwargs: Any):
636-
if redis.VERSION < (3, 0, 0):
625+
if _CLIENT_BEFORE_3_0_0:
637626
unwrap(redis.StrictRedis, "execute_command")
638627
unwrap(redis.StrictRedis, "pipeline")
639628
unwrap(redis.Redis, "pipeline")
@@ -661,3 +650,38 @@ def _uninstrument(self, **kwargs: Any):
661650
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
662651
unwrap(redis.asyncio.cluster.RedisCluster, "execute_command")
663652
unwrap(redis.asyncio.cluster.ClusterPipeline, "execute")
653+
654+
@staticmethod
655+
def instrument_connection(
656+
client, tracer_provider: None, request_hook=None, response_hook=None
657+
):
658+
if not hasattr(client, INSTRUMENTATION_ATTR):
659+
setattr(client, INSTRUMENTATION_ATTR, False)
660+
if not getattr(client, INSTRUMENTATION_ATTR):
661+
_instrument_connection(
662+
client,
663+
RedisInstrumentor._get_tracer(tracer_provider=tracer_provider),
664+
request_hook=request_hook,
665+
response_hook=response_hook,
666+
)
667+
setattr(client, INSTRUMENTATION_ATTR, True)
668+
else:
669+
_logger.warning(
670+
"Attempting to instrument Redis connection while already instrumented"
671+
)
672+
673+
@staticmethod
674+
def uninstrument_connection(client):
675+
if getattr(client, INSTRUMENTATION_ATTR):
676+
# for all clients we need to unwrap execute_command and pipeline functions
677+
unwrap(client, "execute_command")
678+
# pipeline was creating a pipeline and wrapping the functions of the
679+
# created instance. any pipeline created before un-instrumenting will
680+
# remain instrumented (pipelines should usually have a short span)
681+
unwrap(client, "pipeline")
682+
pass
683+
else:
684+
_logger.warning(
685+
"Attempting to un-instrument Redis connection that wasn't instrumented"
686+
)
687+
return

0 commit comments

Comments
 (0)