@@ -93,6 +93,7 @@ def response_hook(span, instance, response):
9393
9494from __future__ import annotations
9595
96+ import logging
9697from typing import TYPE_CHECKING , Any , Callable , Collection
9798
9899import redis
@@ -146,17 +147,26 @@ def response_hook(span, instance, response):
146147
147148
148149_DEFAULT_SERVICE = "redis"
149-
150+ _logger = logging . getLogger ( __name__ )
150151
151152_REDIS_ASYNCIO_VERSION = (4 , 2 , 0 )
152- if redis .VERSION >= _REDIS_ASYNCIO_VERSION :
153- import redis .asyncio
154-
155153_REDIS_CLUSTER_VERSION = (4 , 1 , 0 )
156154_REDIS_ASYNCIO_CLUSTER_VERSION = (4 , 3 , 2 )
157155
158156_FIELD_TYPES = ["NUMERIC" , "TEXT" , "GEO" , "TAG" , "VECTOR" ]
159157
158+ _CLIENT_ASYNCIO_SUPPORT = redis .VERSION >= _REDIS_ASYNCIO_VERSION
159+ _CLIENT_ASYNCIO_CLUSTER_SUPPORT = (
160+ redis .VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION
161+ )
162+ _CLIENT_CLUSTER_SUPPORT = redis .VERSION >= _REDIS_CLUSTER_VERSION
163+ _CLIENT_BEFORE_3_0_0 = redis .VERSION < (3 , 0 , 0 )
164+
165+ if _CLIENT_ASYNCIO_SUPPORT :
166+ import redis .asyncio
167+
168+ INSTRUMENTATION_ATTR = "_is_instrumented_by_opentelemetry"
169+
160170
161171def _set_connection_attributes (
162172 span : Span , conn : RedisInstance | AsyncRedisInstance
@@ -440,10 +450,8 @@ def _instrument(
440450 _traced_execute_pipeline = _traced_execute_pipeline_factory (
441451 tracer , request_hook , response_hook
442452 )
443- pipeline_class = (
444- "BasePipeline" if redis .VERSION < (3 , 0 , 0 ) else "Pipeline"
445- )
446- redis_class = "StrictRedis" if redis .VERSION < (3 , 0 , 0 ) else "Redis"
453+ pipeline_class = "BasePipeline" if _CLIENT_BEFORE_3_0_0 else "Pipeline"
454+ redis_class = "StrictRedis" if _CLIENT_BEFORE_3_0_0 else "Redis"
447455
448456 wrap_function_wrapper (
449457 "redis" , f"{ redis_class } .execute_command" , _traced_execute_command
@@ -505,68 +513,55 @@ def _instrument(
505513 )
506514
507515
508- def _instrument_client (
516+ def _instrument_connection (
509517 client ,
510518 tracer ,
511519 request_hook : _RequestHookT = None ,
512520 response_hook : _ResponseHookT = None ,
513521):
514- # first, handle async clients
515- _async_traced_execute_command = _async_traced_execute_factory (
522+ # first, handle async clients and cluster clients
523+ _async_traced_execute = _async_traced_execute_factory (
516524 tracer , request_hook , response_hook
517525 )
518526 _async_traced_execute_pipeline = _async_traced_execute_pipeline_factory (
519527 tracer , request_hook , response_hook
520528 )
521529
522- def _async_pipeline_wrapper (func , instance , args , kwargs ):
523- result = func (* args , ** kwargs )
524- wrap_function_wrapper (
525- result , "execute" , _async_traced_execute_pipeline
526- )
527- wrap_function_wrapper (
528- result , "immediate_execute_command" , _async_traced_execute_command
529- )
530- return result
531-
532- if redis .VERSION >= _REDIS_ASYNCIO_VERSION :
533- client_type = (
534- redis .asyncio .StrictRedis
535- if redis .VERSION < (3 , 0 , 0 )
536- else redis .asyncio .Redis
537- )
530+ if _CLIENT_ASYNCIO_SUPPORT and isinstance (client , redis .asyncio .Redis ):
538531
539- if isinstance (client , client_type ):
532+ def _async_pipeline_wrapper (func , instance , args , kwargs ):
533+ result = func (* args , ** kwargs )
540534 wrap_function_wrapper (
541- client , "execute_command " , _async_traced_execute_command
535+ result , "execute " , _async_traced_execute_pipeline
542536 )
543- wrap_function_wrapper (client , "pipeline" , _async_pipeline_wrapper )
544- return
537+ wrap_function_wrapper (
538+ result , "immediate_execute_command" , _async_traced_execute
539+ )
540+ return result
545541
546- def _async_cluster_pipeline_wrapper (func , instance , args , kwargs ):
547- result = func (* args , ** kwargs )
548- wrap_function_wrapper (
549- result , "execute" , _async_traced_execute_pipeline
550- )
551- return result
542+ wrap_function_wrapper (client , "execute_command" , _async_traced_execute )
543+ wrap_function_wrapper (client , "pipeline" , _async_pipeline_wrapper )
544+ return
552545
553- # handle
554- if redis .VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION and isinstance (
546+ if _CLIENT_ASYNCIO_CLUSTER_SUPPORT and isinstance (
555547 client , redis .asyncio .RedisCluster
556548 ):
557- wrap_function_wrapper (
558- client , "execute_command" , _async_traced_execute_command
559- )
549+
550+ def _async_cluster_pipeline_wrapper (func , instance , args , kwargs ):
551+ result = func (* args , ** kwargs )
552+ wrap_function_wrapper (
553+ result , "execute" , _async_traced_execute_pipeline
554+ )
555+ return result
556+
557+ wrap_function_wrapper (client , "execute_command" , _async_traced_execute )
560558 wrap_function_wrapper (
561559 client , "pipeline" , _async_cluster_pipeline_wrapper
562560 )
563561 return
564562 # for redis.client.Redis, redis.Cluster and v3.0.0 redis.client.StrictRedis
565563 # the wrappers are the same
566- # client_type = (
567- # redis.client.StrictRedis if redis.VERSION < (3, 0, 0) else redis.client.Redis
568- # )
569- _traced_execute_command = _traced_execute_factory (
564+ _traced_execute = _traced_execute_factory (
570565 tracer , request_hook , response_hook
571566 )
572567 _traced_execute_pipeline = _traced_execute_pipeline_factory (
@@ -577,14 +572,14 @@ def _pipeline_wrapper(func, instance, args, kwargs):
577572 result = func (* args , ** kwargs )
578573 wrap_function_wrapper (result , "execute" , _traced_execute_pipeline )
579574 wrap_function_wrapper (
580- result , "immediate_execute_command" , _traced_execute_command
575+ result , "immediate_execute_command" , _traced_execute
581576 )
582577 return result
583578
584579 wrap_function_wrapper (
585580 client ,
586581 "execute_command" ,
587- _traced_execute_command ,
582+ _traced_execute ,
588583 )
589584 wrap_function_wrapper (
590585 client ,
@@ -599,6 +594,16 @@ class RedisInstrumentor(BaseInstrumentor):
599594 See `BaseInstrumentor`
600595 """
601596
597+ @staticmethod
598+ def _get_tracer (** kwargs ):
599+ tracer_provider = kwargs .get ("tracer_provider" )
600+ return trace .get_tracer (
601+ __name__ ,
602+ __version__ ,
603+ tracer_provider = tracer_provider ,
604+ schema_url = "https://opentelemetry.io/schemas/1.11.0" ,
605+ )
606+
602607 def instrumentation_dependencies (self ) -> Collection [str ]:
603608 return _instruments
604609
@@ -610,30 +615,14 @@ def _instrument(self, **kwargs: Any):
610615 ``tracer_provider``: a TracerProvider, defaults to global.
611616 ``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
612617 """
613- tracer_provider = kwargs .get ("tracer_provider" )
614- tracer = trace .get_tracer (
615- __name__ ,
616- __version__ ,
617- tracer_provider = tracer_provider ,
618- schema_url = "https://opentelemetry.io/schemas/1.11.0" ,
618+ _instrument (
619+ self ._get_tracer (** kwargs ),
620+ request_hook = kwargs .get ("request_hook" ),
621+ response_hook = kwargs .get ("response_hook" ),
619622 )
620- redis_client = kwargs .get ("client" )
621- if redis_client :
622- _instrument_client (
623- redis_client ,
624- tracer ,
625- request_hook = kwargs .get ("request_hook" ),
626- response_hook = kwargs .get ("response_hook" ),
627- )
628- else :
629- _instrument (
630- tracer ,
631- request_hook = kwargs .get ("request_hook" ),
632- response_hook = kwargs .get ("response_hook" ),
633- )
634623
635624 def _uninstrument (self , ** kwargs : Any ):
636- if redis . VERSION < ( 3 , 0 , 0 ) :
625+ if _CLIENT_BEFORE_3_0_0 :
637626 unwrap (redis .StrictRedis , "execute_command" )
638627 unwrap (redis .StrictRedis , "pipeline" )
639628 unwrap (redis .Redis , "pipeline" )
@@ -661,3 +650,38 @@ def _uninstrument(self, **kwargs: Any):
661650 if redis .VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION :
662651 unwrap (redis .asyncio .cluster .RedisCluster , "execute_command" )
663652 unwrap (redis .asyncio .cluster .ClusterPipeline , "execute" )
653+
654+ @staticmethod
655+ def instrument_connection (
656+ client , tracer_provider : None , request_hook = None , response_hook = None
657+ ):
658+ if not hasattr (client , INSTRUMENTATION_ATTR ):
659+ setattr (client , INSTRUMENTATION_ATTR , False )
660+ if not getattr (client , INSTRUMENTATION_ATTR ):
661+ _instrument_connection (
662+ client ,
663+ RedisInstrumentor ._get_tracer (tracer_provider = tracer_provider ),
664+ request_hook = request_hook ,
665+ response_hook = response_hook ,
666+ )
667+ setattr (client , INSTRUMENTATION_ATTR , True )
668+ else :
669+ _logger .warning (
670+ "Attempting to instrument Redis connection while already instrumented"
671+ )
672+
673+ @staticmethod
674+ def uninstrument_connection (client ):
675+ if getattr (client , INSTRUMENTATION_ATTR ):
676+ # for all clients we need to unwrap execute_command and pipeline functions
677+ unwrap (client , "execute_command" )
678+ # pipeline was creating a pipeline and wrapping the functions of the
679+ # created instance. any pipeline created before un-instrumenting will
680+ # remain instrumented (pipelines should usually have a short span)
681+ unwrap (client , "pipeline" )
682+ pass
683+ else :
684+ _logger .warning (
685+ "Attempting to un-instrument Redis connection that wasn't instrumented"
686+ )
687+ return
0 commit comments