1+ from __future__ import annotations
2+
3+ import asyncio
14import json
25from logging import getLogger
3- from typing import Any , Awaitable , Callable , Dict , List , Optional , Tuple , Union
6+ from typing import (
7+ TYPE_CHECKING ,
8+ Any ,
9+ Awaitable ,
10+ Callable ,
11+ Dict ,
12+ List ,
13+ Optional ,
14+ Protocol ,
15+ Tuple ,
16+ Union ,
17+ )
418
519import aiokafka
6- from aiokafka import ConsumerRecord
720
821from opentelemetry import context , propagate , trace
922from opentelemetry .context import Context
1326from opentelemetry .trace import Tracer
1427from opentelemetry .trace .span import Span
1528
29+ if TYPE_CHECKING :
30+ from aiokafka .structs import RecordMetadata
31+
32+ class AIOKafkaGetOneProto (Protocol ):
33+ async def __call__ (
34+ self , * partitions : aiokafka .TopicPartition
35+ ) -> aiokafka .ConsumerRecord [object , object ]: ...
36+
37+ class AIOKafkaGetManyProto (Protocol ):
38+ async def __call__ (
39+ self ,
40+ * partitions : aiokafka .TopicPartition ,
41+ timeout_ms : int = 0 ,
42+ max_records : int | None = None ,
43+ ) -> dict [
44+ aiokafka .TopicPartition ,
45+ list [aiokafka .ConsumerRecord [object , object ]],
46+ ]: ...
47+
48+ class AIOKafkaSendProto (Protocol ):
49+ async def __call__ (
50+ self ,
51+ topic : str ,
52+ value : Any | None = None ,
53+ key : Any | None = None ,
54+ partition : int | None = None ,
55+ timestamp_ms : int | None = None ,
56+ headers : HeadersT | None = None ,
57+ ) -> asyncio .Future [RecordMetadata ]: ...
58+
59+
60+ ProduceHookT = Optional [
61+ Callable [[Span , Tuple [Any , ...], Dict [str , Any ]], Awaitable [None ]]
62+ ]
63+ ConsumeHookT = Optional [
64+ Callable [
65+ [
66+ Span ,
67+ aiokafka .ConsumerRecord [object , object ],
68+ Tuple [aiokafka .TopicPartition , ...],
69+ Dict [str , Any ],
70+ ],
71+ Awaitable [None ],
72+ ]
73+ ]
74+
75+ HeadersT = List [Tuple [str , Optional [bytes ]]]
76+
1677_LOG = getLogger (__name__ )
1778
1879
@@ -97,14 +158,6 @@ async def _extract_send_partition(
97158 return None
98159
99160
100- ProduceHookT = Optional [Callable [[Span , Tuple , Dict ], Awaitable [None ]]]
101- ConsumeHookT = Optional [
102- Callable [[Span , ConsumerRecord , Tuple , Dict ], Awaitable [None ]]
103- ]
104-
105- HeadersT = List [Tuple [str , Optional [bytes ]]]
106-
107-
108161class AIOKafkaContextGetter (textmap .Getter [HeadersT ]):
109162 def get (self , carrier : HeadersT , key : str ) -> Optional [List [str ]]:
110163 if carrier is None :
@@ -198,7 +251,7 @@ def _enrich_send_span(
198251 )
199252
200253
201- def _enrich_anext_span (
254+ def _enrich_getone_span (
202255 span : Span ,
203256 * ,
204257 bootstrap_servers : Union [str , List [str ]],
@@ -247,19 +300,93 @@ def _enrich_anext_span(
247300 )
248301
249302
303+ def _enrich_getmany_poll_span (
304+ span : Span ,
305+ * ,
306+ bootstrap_servers : Union [str , List [str ]],
307+ client_id : str ,
308+ consumer_group : Optional [str ],
309+ message_count : int ,
310+ ) -> None :
311+ if not span .is_recording ():
312+ return
313+
314+ span .set_attribute (
315+ messaging_attributes .MESSAGING_SYSTEM ,
316+ messaging_attributes .MessagingSystemValues .KAFKA .value ,
317+ )
318+ span .set_attribute (
319+ server_attributes .SERVER_ADDRESS , json .dumps (bootstrap_servers )
320+ )
321+ span .set_attribute (messaging_attributes .MESSAGING_CLIENT_ID , client_id )
322+
323+ if consumer_group is not None :
324+ span .set_attribute (
325+ messaging_attributes .MESSAGING_CONSUMER_GROUP_NAME , consumer_group
326+ )
327+
328+ span .set_attribute (
329+ messaging_attributes .MESSAGING_BATCH_MESSAGE_COUNT , message_count
330+ )
331+
332+ span .set_attribute (messaging_attributes .MESSAGING_OPERATION_NAME , "poll" )
333+ span .set_attribute (
334+ messaging_attributes .MESSAGING_OPERATION_TYPE ,
335+ messaging_attributes .MessagingOperationTypeValues .RECEIVE .value ,
336+ )
337+
338+
339+ def _enrich_getmany_topic_span (
340+ span : Span ,
341+ * ,
342+ bootstrap_servers : Union [str , List [str ]],
343+ client_id : str ,
344+ consumer_group : Optional [str ],
345+ topic : str ,
346+ partition : int ,
347+ message_count : int ,
348+ ) -> None :
349+ if not span .is_recording ():
350+ return
351+
352+ _enrich_base_span (
353+ span ,
354+ bootstrap_servers = bootstrap_servers ,
355+ client_id = client_id ,
356+ topic = topic ,
357+ partition = partition ,
358+ key = None ,
359+ )
360+
361+ if consumer_group is not None :
362+ span .set_attribute (
363+ messaging_attributes .MESSAGING_CONSUMER_GROUP_NAME , consumer_group
364+ )
365+
366+ span .set_attribute (
367+ messaging_attributes .MESSAGING_BATCH_MESSAGE_COUNT , message_count
368+ )
369+
370+ span .set_attribute (messaging_attributes .MESSAGING_OPERATION_NAME , "poll" )
371+ span .set_attribute (
372+ messaging_attributes .MESSAGING_OPERATION_TYPE ,
373+ messaging_attributes .MessagingOperationTypeValues .RECEIVE .value ,
374+ )
375+
376+
250377def _get_span_name (operation : str , topic : str ):
251378 return f"{ topic } { operation } "
252379
253380
254381def _wrap_send (
255382 tracer : Tracer , async_produce_hook : ProduceHookT
256- ) -> Callable [..., Awaitable [None ]]:
383+ ) -> Callable [..., Awaitable [asyncio . Future [ RecordMetadata ] ]]:
257384 async def _traced_send (
258- func : Callable [..., Awaitable [ None ]] ,
385+ func : AIOKafkaSendProto ,
259386 instance : aiokafka .AIOKafkaProducer ,
260387 args : Tuple [Any ],
261388 kwargs : Dict [str , Any ],
262- ) -> None :
389+ ) -> asyncio . Future [ RecordMetadata ] :
263390 headers = _extract_send_headers (args , kwargs )
264391 if headers is None :
265392 headers = []
@@ -301,14 +428,14 @@ async def _traced_send(
301428async def _create_consumer_span (
302429 tracer : Tracer ,
303430 async_consume_hook : ConsumeHookT ,
304- record : ConsumerRecord ,
431+ record : aiokafka . ConsumerRecord [ object , object ] ,
305432 extracted_context : Context ,
306433 bootstrap_servers : Union [str , List [str ]],
307434 client_id : str ,
308435 consumer_group : Optional [str ],
309436 args : Tuple [Any ],
310437 kwargs : Dict [str , Any ],
311- ):
438+ ) -> trace . Span :
312439 span_name = _get_span_name ("receive" , record .topic )
313440 with tracer .start_as_current_span (
314441 span_name ,
@@ -317,7 +444,7 @@ async def _create_consumer_span(
317444 ) as span :
318445 new_context = trace .set_span_in_context (span , extracted_context )
319446 token = context .attach (new_context )
320- _enrich_anext_span (
447+ _enrich_getone_span (
321448 span ,
322449 bootstrap_servers = bootstrap_servers ,
323450 client_id = client_id ,
@@ -334,16 +461,18 @@ async def _create_consumer_span(
334461 _LOG .exception (hook_exception )
335462 context .detach (token )
336463
464+ return span
465+
337466
338467def _wrap_getone (
339468 tracer : Tracer , async_consume_hook : ConsumeHookT
340- ) -> Callable [..., Awaitable [aiokafka .ConsumerRecord ]]:
341- async def _traced_next (
342- func : Callable [..., Awaitable [ aiokafka . ConsumerRecord ]] ,
469+ ) -> Callable [..., Awaitable [aiokafka .ConsumerRecord [ object , object ] ]]:
470+ async def _traced_getone (
471+ func : AIOKafkaGetOneProto ,
343472 instance : aiokafka .AIOKafkaConsumer ,
344473 args : Tuple [Any ],
345474 kwargs : Dict [str , Any ],
346- ) -> aiokafka .ConsumerRecord :
475+ ) -> aiokafka .ConsumerRecord [ object , object ] :
347476 record = await func (* args , ** kwargs )
348477
349478 if record :
@@ -367,4 +496,80 @@ async def _traced_next(
367496 )
368497 return record
369498
370- return _traced_next
499+ return _traced_getone
500+
501+
502+ def _wrap_getmany (
503+ tracer : Tracer , async_consume_hook : ConsumeHookT
504+ ) -> Callable [
505+ ...,
506+ Awaitable [
507+ dict [
508+ aiokafka .TopicPartition ,
509+ list [aiokafka .ConsumerRecord [object , object ]],
510+ ]
511+ ],
512+ ]:
513+ async def _traced_getmany (
514+ func : AIOKafkaGetManyProto ,
515+ instance : aiokafka .AIOKafkaConsumer ,
516+ args : Tuple [Any ],
517+ kwargs : Dict [str , Any ],
518+ ) -> dict [
519+ aiokafka .TopicPartition , list [aiokafka .ConsumerRecord [object , object ]]
520+ ]:
521+ records = await func (* args , ** kwargs )
522+
523+ if records :
524+ bootstrap_servers = _extract_bootstrap_servers (instance ._client )
525+ client_id = _extract_client_id (instance ._client )
526+ consumer_group = _extract_consumer_group (instance )
527+
528+ span_name = _get_span_name (
529+ "poll" , ", " .join ([topic .topic for topic in records .keys ()])
530+ )
531+ with tracer .start_as_current_span (
532+ span_name , kind = trace .SpanKind .CLIENT
533+ ) as poll_span :
534+ _enrich_getmany_poll_span (
535+ poll_span ,
536+ bootstrap_servers = bootstrap_servers ,
537+ client_id = client_id ,
538+ consumer_group = consumer_group ,
539+ message_count = sum (len (r ) for r in records .values ()),
540+ )
541+
542+ for topic , topic_records in records .items ():
543+ span_name = _get_span_name ("poll" , topic .topic )
544+ with tracer .start_as_current_span (
545+ span_name , kind = trace .SpanKind .CLIENT
546+ ) as topic_span :
547+ _enrich_getmany_topic_span (
548+ topic_span ,
549+ bootstrap_servers = bootstrap_servers ,
550+ client_id = client_id ,
551+ consumer_group = consumer_group ,
552+ topic = topic .topic ,
553+ partition = topic .partition ,
554+ message_count = len (topic_records ),
555+ )
556+
557+ for record in topic_records :
558+ extracted_context = propagate .extract (
559+ record .headers , getter = _aiokafka_getter
560+ )
561+ record_span = await _create_consumer_span (
562+ tracer ,
563+ async_consume_hook ,
564+ record ,
565+ extracted_context ,
566+ bootstrap_servers ,
567+ client_id ,
568+ consumer_group ,
569+ args ,
570+ kwargs ,
571+ )
572+ topic_span .add_link (record_span .get_span_context ())
573+ return records
574+
575+ return _traced_getmany
0 commit comments