@@ -91,6 +91,7 @@ def response_hook(span, instance, response):
9191---
9292"""
9393
94+ import logging
9495import typing
9596from typing import Any , Collection
9697
@@ -121,16 +122,27 @@ def response_hook(span, instance, response):
121122_ResponseHookT = typing .Optional [
122123 typing .Callable [[Span , redis .connection .Connection , Any ], None ]
123124]
125+ _logger = logging .getLogger (__name__ )
126+ assert hasattr (redis , "VERSION" )
124127
125128_REDIS_ASYNCIO_VERSION = (4 , 2 , 0 )
126- if redis .VERSION >= _REDIS_ASYNCIO_VERSION :
127- import redis .asyncio
128-
129129_REDIS_CLUSTER_VERSION = (4 , 1 , 0 )
130130_REDIS_ASYNCIO_CLUSTER_VERSION = (4 , 3 , 2 )
131131
132132_FIELD_TYPES = ["NUMERIC" , "TEXT" , "GEO" , "TAG" , "VECTOR" ]
133133
134+ _CLIENT_ASYNCIO_SUPPORT = redis .VERSION >= _REDIS_ASYNCIO_VERSION
135+ _CLIENT_ASYNCIO_CLUSTER_SUPPORT = (
136+ redis .VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION
137+ )
138+ _CLIENT_CLUSTER_SUPPORT = redis .VERSION >= _REDIS_CLUSTER_VERSION
139+ _CLIENT_BEFORE_3_0_0 = redis .VERSION < (3 , 0 , 0 )
140+
141+ if _CLIENT_ASYNCIO_SUPPORT :
142+ import redis .asyncio
143+
144+ INSTRUMENTATION_ATTR = "_is_instrumented_by_opentelemetry"
145+
134146
135147def _set_connection_attributes (span , conn ):
136148 if not span .is_recording () or not hasattr (conn , "connection_pool" ):
@@ -388,10 +400,8 @@ def _instrument(
388400 _traced_execute_pipeline = _traced_execute_pipeline_factory (
389401 tracer , request_hook , response_hook
390402 )
391- pipeline_class = (
392- "BasePipeline" if redis .VERSION < (3 , 0 , 0 ) else "Pipeline"
393- )
394- redis_class = "StrictRedis" if redis .VERSION < (3 , 0 , 0 ) else "Redis"
403+ pipeline_class = "BasePipeline" if _CLIENT_BEFORE_3_0_0 else "Pipeline"
404+ redis_class = "StrictRedis" if _CLIENT_BEFORE_3_0_0 else "Redis"
395405
396406 wrap_function_wrapper (
397407 "redis" , f"{ redis_class } .execute_command" , _traced_execute_command
@@ -453,68 +463,55 @@ def _instrument(
453463 )
454464
455465
456- def _instrument_client (
466+ def _instrument_connection (
457467 client ,
458468 tracer ,
459469 request_hook : _RequestHookT = None ,
460470 response_hook : _ResponseHookT = None ,
461471):
462- # first, handle async clients
463- _async_traced_execute_command = _async_traced_execute_factory (
472+ # first, handle async clients and cluster clients
473+ _async_traced_execute = _async_traced_execute_factory (
464474 tracer , request_hook , response_hook
465475 )
466476 _async_traced_execute_pipeline = _async_traced_execute_pipeline_factory (
467477 tracer , request_hook , response_hook
468478 )
469479
470- def _async_pipeline_wrapper (func , instance , args , kwargs ):
471- result = func (* args , ** kwargs )
472- wrap_function_wrapper (
473- result , "execute" , _async_traced_execute_pipeline
474- )
475- wrap_function_wrapper (
476- result , "immediate_execute_command" , _async_traced_execute_command
477- )
478- return result
480+ if _CLIENT_ASYNCIO_SUPPORT and isinstance (client , redis .asyncio .Redis ):
479481
480- if redis .VERSION >= _REDIS_ASYNCIO_VERSION :
481- client_type = (
482- redis .asyncio .StrictRedis
483- if redis .VERSION < (3 , 0 , 0 )
484- else redis .asyncio .Redis
485- )
486-
487- if isinstance (client , client_type ):
482+ def _async_pipeline_wrapper (func , instance , args , kwargs ):
483+ result = func (* args , ** kwargs )
488484 wrap_function_wrapper (
489- client , "execute_command " , _async_traced_execute_command
485+ result , "execute " , _async_traced_execute_pipeline
490486 )
491- wrap_function_wrapper (client , "pipeline" , _async_pipeline_wrapper )
492- return
487+ wrap_function_wrapper (
488+ result , "immediate_execute_command" , _async_traced_execute
489+ )
490+ return result
493491
494- def _async_cluster_pipeline_wrapper (func , instance , args , kwargs ):
495- result = func (* args , ** kwargs )
496- wrap_function_wrapper (
497- result , "execute" , _async_traced_execute_pipeline
498- )
499- return result
492+ wrap_function_wrapper (client , "execute_command" , _async_traced_execute )
493+ wrap_function_wrapper (client , "pipeline" , _async_pipeline_wrapper )
494+ return
500495
501- # handle
502- if redis .VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION and isinstance (
496+ if _CLIENT_ASYNCIO_CLUSTER_SUPPORT and isinstance (
503497 client , redis .asyncio .RedisCluster
504498 ):
505- wrap_function_wrapper (
506- client , "execute_command" , _async_traced_execute_command
507- )
499+
500+ def _async_cluster_pipeline_wrapper (func , instance , args , kwargs ):
501+ result = func (* args , ** kwargs )
502+ wrap_function_wrapper (
503+ result , "execute" , _async_traced_execute_pipeline
504+ )
505+ return result
506+
507+ wrap_function_wrapper (client , "execute_command" , _async_traced_execute )
508508 wrap_function_wrapper (
509509 client , "pipeline" , _async_cluster_pipeline_wrapper
510510 )
511511 return
512512 # for redis.client.Redis, redis.Cluster and v3.0.0 redis.client.StrictRedis
513513 # the wrappers are the same
514- # client_type = (
515- # redis.client.StrictRedis if redis.VERSION < (3, 0, 0) else redis.client.Redis
516- # )
517- _traced_execute_command = _traced_execute_factory (
514+ _traced_execute = _traced_execute_factory (
518515 tracer , request_hook , response_hook
519516 )
520517 _traced_execute_pipeline = _traced_execute_pipeline_factory (
@@ -525,14 +522,14 @@ def _pipeline_wrapper(func, instance, args, kwargs):
525522 result = func (* args , ** kwargs )
526523 wrap_function_wrapper (result , "execute" , _traced_execute_pipeline )
527524 wrap_function_wrapper (
528- result , "immediate_execute_command" , _traced_execute_command
525+ result , "immediate_execute_command" , _traced_execute
529526 )
530527 return result
531528
532529 wrap_function_wrapper (
533530 client ,
534531 "execute_command" ,
535- _traced_execute_command ,
532+ _traced_execute ,
536533 )
537534 wrap_function_wrapper (
538535 client ,
@@ -546,6 +543,16 @@ class RedisInstrumentor(BaseInstrumentor):
546543 See `BaseInstrumentor`
547544 """
548545
546+ @staticmethod
547+ def _get_tracer (** kwargs ):
548+ tracer_provider = kwargs .get ("tracer_provider" )
549+ return trace .get_tracer (
550+ __name__ ,
551+ __version__ ,
552+ tracer_provider = tracer_provider ,
553+ schema_url = "https://opentelemetry.io/schemas/1.11.0" ,
554+ )
555+
549556 def instrumentation_dependencies (self ) -> Collection [str ]:
550557 return _instruments
551558
@@ -557,30 +564,14 @@ def _instrument(self, **kwargs):
557564 ``tracer_provider``: a TracerProvider, defaults to global.
558565 ``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
559566 """
560- tracer_provider = kwargs .get ("tracer_provider" )
561- tracer = trace .get_tracer (
562- __name__ ,
563- __version__ ,
564- tracer_provider = tracer_provider ,
565- schema_url = "https://opentelemetry.io/schemas/1.11.0" ,
567+ _instrument (
568+ self ._get_tracer (** kwargs ),
569+ request_hook = kwargs .get ("request_hook" ),
570+ response_hook = kwargs .get ("response_hook" ),
566571 )
567- redis_client = kwargs .get ("client" )
568- if redis_client :
569- _instrument_client (
570- redis_client ,
571- tracer ,
572- request_hook = kwargs .get ("request_hook" ),
573- response_hook = kwargs .get ("response_hook" ),
574- )
575- else :
576- _instrument (
577- tracer ,
578- request_hook = kwargs .get ("request_hook" ),
579- response_hook = kwargs .get ("response_hook" ),
580- )
581572
582573 def _uninstrument (self , ** kwargs ):
583- if redis . VERSION < ( 3 , 0 , 0 ) :
574+ if _CLIENT_BEFORE_3_0_0 :
584575 unwrap (redis .StrictRedis , "execute_command" )
585576 unwrap (redis .StrictRedis , "pipeline" )
586577 unwrap (redis .Redis , "pipeline" )
@@ -608,3 +599,38 @@ def _uninstrument(self, **kwargs):
608599 if redis .VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION :
609600 unwrap (redis .asyncio .cluster .RedisCluster , "execute_command" )
610601 unwrap (redis .asyncio .cluster .ClusterPipeline , "execute" )
602+
603+ @staticmethod
604+ def instrument_connection (
605+ client , tracer_provider : None , request_hook = None , response_hook = None
606+ ):
607+ if not hasattr (client , INSTRUMENTATION_ATTR ):
608+ setattr (client , INSTRUMENTATION_ATTR , False )
609+ if not getattr (client , INSTRUMENTATION_ATTR ):
610+ _instrument_connection (
611+ client ,
612+ RedisInstrumentor ._get_tracer (tracer_provider = tracer_provider ),
613+ request_hook = request_hook ,
614+ response_hook = response_hook ,
615+ )
616+ setattr (client , INSTRUMENTATION_ATTR , True )
617+ else :
618+ _logger .warning (
619+ "Attempting to instrument Redis connection while already instrumented"
620+ )
621+
622+ @staticmethod
623+ def uninstrument_connection (client ):
624+ if getattr (client , INSTRUMENTATION_ATTR ):
625+ # for all clients we need to unwrap execute_command and pipeline functions
626+ unwrap (client , "execute_command" )
627+ # pipeline was creating a pipeline and wrapping the functions of the
628+ # created instance. any pipeline created before un-instrumenting will
629+ # remain instrumented (pipelines should usually have a short span)
630+ unwrap (client , "pipeline" )
631+ pass
632+ else :
633+ _logger .warning (
634+ "Attempting to un-instrument Redis connection that wasn't instrumented"
635+ )
636+ return
0 commit comments