1414# pylint: disable=unnecessary-dunder-call
1515
1616from logging import getLogger
17- from typing import Any , Collection , Dict , Optional
17+ from typing import Any , Collection , Dict , Optional , Union
1818
1919import pika
2020import wrapt
2424 BlockingChannel ,
2525 _QueueConsumerGeneratorInfo ,
2626)
27+ from pika .channel import Channel
28+ from pika .connection import Connection
2729
2830from opentelemetry import trace
2931from opentelemetry .instrumentation .instrumentor import BaseInstrumentor
@@ -53,12 +55,16 @@ class PikaInstrumentor(BaseInstrumentor): # type: ignore
5355
5456 # pylint: disable=attribute-defined-outside-init
5557 @staticmethod
56- def _instrument_blocking_channel_consumers (
57- channel : BlockingChannel ,
58+ def _instrument_channel_consumers (
59+ channel : Union [ BlockingChannel , Channel ] ,
5860 tracer : Tracer ,
5961 consume_hook : utils .HookT = utils .dummy_callback ,
6062 ) -> Any :
61- for consumer_tag , consumer_info in channel ._consumer_infos .items ():
63+ if isinstance (channel , BlockingChannel ):
64+ consumer_infos = channel ._consumer_infos
65+ elif isinstance (channel , Channel ):
66+ consumer_infos = channel ._consumers
67+ for consumer_tag , consumer_info in consumer_infos .items ():
6268 callback_attr = PikaInstrumentor .CONSUMER_CALLBACK_ATTR
6369 consumer_callback = getattr (consumer_info , callback_attr , None )
6470 if consumer_callback is None :
@@ -79,7 +85,7 @@ def _instrument_blocking_channel_consumers(
7985
8086 @staticmethod
8187 def _instrument_basic_publish (
82- channel : BlockingChannel ,
88+ channel : Union [ BlockingChannel , Channel ] ,
8389 tracer : Tracer ,
8490 publish_hook : utils .HookT = utils .dummy_callback ,
8591 ) -> None :
@@ -93,7 +99,7 @@ def _instrument_basic_publish(
9399
94100 @staticmethod
95101 def _instrument_channel_functions (
96- channel : BlockingChannel ,
102+ channel : Union [ BlockingChannel , Channel ] ,
97103 tracer : Tracer ,
98104 publish_hook : utils .HookT = utils .dummy_callback ,
99105 ) -> None :
@@ -103,7 +109,9 @@ def _instrument_channel_functions(
103109 )
104110
105111 @staticmethod
106- def _uninstrument_channel_functions (channel : BlockingChannel ) -> None :
112+ def _uninstrument_channel_functions (
113+ channel : Union [BlockingChannel , Channel ],
114+ ) -> None :
107115 for function_name in _FUNCTIONS_TO_UNINSTRUMENT :
108116 if not hasattr (channel , function_name ):
109117 continue
@@ -115,7 +123,7 @@ def _uninstrument_channel_functions(channel: BlockingChannel) -> None:
115123 @staticmethod
116124 # Make sure that the spans are created inside hash them set as parent and not as brothers
117125 def instrument_channel (
118- channel : BlockingChannel ,
126+ channel : Union [ BlockingChannel , Channel ] ,
119127 tracer_provider : Optional [TracerProvider ] = None ,
120128 publish_hook : utils .HookT = utils .dummy_callback ,
121129 consume_hook : utils .HookT = utils .dummy_callback ,
@@ -133,7 +141,7 @@ def instrument_channel(
133141 tracer_provider ,
134142 schema_url = "https://opentelemetry.io/schemas/1.11.0" ,
135143 )
136- PikaInstrumentor ._instrument_blocking_channel_consumers (
144+ PikaInstrumentor ._instrument_channel_consumers (
137145 channel , tracer , consume_hook
138146 )
139147 PikaInstrumentor ._decorate_basic_consume (channel , tracer , consume_hook )
@@ -178,16 +186,17 @@ def wrapper(wrapped, instance, args, kwargs):
178186 return channel
179187
180188 wrapt .wrap_function_wrapper (BlockingConnection , "channel" , wrapper )
189+ wrapt .wrap_function_wrapper (Connection , "channel" , wrapper )
181190
182191 @staticmethod
183192 def _decorate_basic_consume (
184- channel : BlockingChannel ,
193+ channel : Union [ BlockingChannel , Channel ] ,
185194 tracer : Optional [Tracer ],
186195 consume_hook : utils .HookT = utils .dummy_callback ,
187196 ) -> None :
188197 def wrapper (wrapped , instance , args , kwargs ):
189198 return_value = wrapped (* args , ** kwargs )
190- PikaInstrumentor ._instrument_blocking_channel_consumers (
199+ PikaInstrumentor ._instrument_channel_consumers (
191200 channel , tracer , consume_hook
192201 )
193202 return return_value
@@ -236,6 +245,7 @@ def _uninstrument(self, **kwargs: Dict[str, Any]) -> None:
236245 if hasattr (self , "__opentelemetry_tracer_provider" ):
237246 delattr (self , "__opentelemetry_tracer_provider" )
238247 unwrap (BlockingConnection , "channel" )
248+ unwrap (Connection , "channel" )
239249 unwrap (_QueueConsumerGeneratorInfo , "__init__" )
240250
241251 def instrumentation_dependencies (self ) -> Collection [str ]:
0 commit comments