diff --git a/instrumentation/opentelemetry-instrumentation-kafka-python/src/opentelemetry/instrumentation/kafka/__init__.py b/instrumentation/opentelemetry-instrumentation-kafka-python/src/opentelemetry/instrumentation/kafka/__init__.py index f1ab77e4a2..b019e2b146 100644 --- a/instrumentation/opentelemetry-instrumentation-kafka-python/src/opentelemetry/instrumentation/kafka/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-kafka-python/src/opentelemetry/instrumentation/kafka/__init__.py @@ -95,7 +95,7 @@ def process_msg(message): _instruments_kafka_python, _instruments_kafka_python_ng, ) -from opentelemetry.instrumentation.kafka.utils import _wrap_next, _wrap_send +from opentelemetry.instrumentation.kafka.utils import _wrap_poll, _wrap_send from opentelemetry.instrumentation.kafka.version import __version__ from opentelemetry.instrumentation.utils import unwrap @@ -150,10 +150,10 @@ def _instrument(self, **kwargs): ) wrap_function_wrapper( kafka.KafkaConsumer, - "__next__", - _wrap_next(tracer, consume_hook), + "poll", + _wrap_poll(tracer, consume_hook), ) def _uninstrument(self, **kwargs): unwrap(kafka.KafkaProducer, "send") - unwrap(kafka.KafkaConsumer, "__next__") + unwrap(kafka.KafkaConsumer, "poll") diff --git a/instrumentation/opentelemetry-instrumentation-kafka-python/src/opentelemetry/instrumentation/kafka/utils.py b/instrumentation/opentelemetry-instrumentation-kafka-python/src/opentelemetry/instrumentation/kafka/utils.py index 3f9bd6f39c..aee05323bc 100644 --- a/instrumentation/opentelemetry-instrumentation-kafka-python/src/opentelemetry/instrumentation/kafka/utils.py +++ b/instrumentation/opentelemetry-instrumentation-kafka-python/src/opentelemetry/instrumentation/kafka/utils.py @@ -199,30 +199,30 @@ def _create_consumer_span( context.detach(token) -def _wrap_next( - tracer: Tracer, - consume_hook: ConsumeHookT, -) -> Callable: - def _traced_next(func, instance, args, kwargs): - record = func(*args, **kwargs) - - if record: - bootstrap_servers = ( - KafkaPropertiesExtractor.extract_bootstrap_servers(instance) - ) - - extracted_context = propagate.extract( - record.headers, getter=_kafka_getter - ) - _create_consumer_span( - tracer, - consume_hook, - record, - extracted_context, - bootstrap_servers, - args, - kwargs, - ) - return record - - return _traced_next +def _wrap_poll(tracer: Tracer, consume_hook: ConsumeHookT) -> Callable: + def _traced_poll(func, instance, args, kwargs): + records = func(*args, **kwargs) + + for items in records.values(): + for record in items: + if record: + bootstrap_servers = ( + KafkaPropertiesExtractor.extract_bootstrap_servers( + instance) + ) + + extracted_context = propagate.extract( + record.headers, getter=_kafka_getter + ) + _create_consumer_span( + tracer, + consume_hook, + record, + extracted_context, + bootstrap_servers, + args, + kwargs, + ) + return records + + return _traced_poll diff --git a/instrumentation/opentelemetry-instrumentation-kafka-python/tests/test_instrumentation.py b/instrumentation/opentelemetry-instrumentation-kafka-python/tests/test_instrumentation.py index 13587b0c3c..e9e3fb040e 100644 --- a/instrumentation/opentelemetry-instrumentation-kafka-python/tests/test_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-kafka-python/tests/test_instrumentation.py @@ -32,13 +32,13 @@ def test_instrument_api(self) -> None: instrumentation.instrument() self.assertTrue(isinstance(KafkaProducer.send, BoundFunctionWrapper)) self.assertTrue( - isinstance(KafkaConsumer.__next__, BoundFunctionWrapper) + isinstance(KafkaConsumer.poll, BoundFunctionWrapper) ) instrumentation.uninstrument() self.assertFalse(isinstance(KafkaProducer.send, BoundFunctionWrapper)) self.assertFalse( - isinstance(KafkaConsumer.__next__, BoundFunctionWrapper) + isinstance(KafkaConsumer.poll, BoundFunctionWrapper) ) @patch("opentelemetry.instrumentation.kafka.distribution") diff --git a/instrumentation/opentelemetry-instrumentation-kafka-python/tests/test_utils.py b/instrumentation/opentelemetry-instrumentation-kafka-python/tests/test_utils.py index 85397bcb73..6e7b3ac43e 100644 --- a/instrumentation/opentelemetry-instrumentation-kafka-python/tests/test_utils.py +++ b/instrumentation/opentelemetry-instrumentation-kafka-python/tests/test_utils.py @@ -21,7 +21,7 @@ _get_span_name, _kafka_getter, _kafka_setter, - _wrap_next, + _wrap_poll, _wrap_send, ) from opentelemetry.trace import SpanKind @@ -142,42 +142,46 @@ def wrap_send_helper( @mock.patch( "opentelemetry.instrumentation.kafka.utils.KafkaPropertiesExtractor.extract_bootstrap_servers" ) - def test_wrap_next( - self, - extract_bootstrap_servers: mock.MagicMock, - _create_consumer_span: mock.MagicMock, - extract: mock.MagicMock, + def test_wrap_poll( + self, + extract_bootstrap_servers: mock.MagicMock, + _create_consumer_span: mock.MagicMock, + extract: mock.MagicMock, ) -> None: + self.args = [] + self.kwargs = {"timeout_ms": 1000} tracer = mock.MagicMock() consume_hook = mock.MagicMock() - original_next_callback = mock.MagicMock() + original_poll_callback = mock.MagicMock() kafka_consumer = mock.MagicMock() - wrapped_next = _wrap_next(tracer, consume_hook) - record = wrapped_next( - original_next_callback, kafka_consumer, self.args, self.kwargs + wrapped_poll = _wrap_poll(tracer, consume_hook) + records = wrapped_poll( + original_poll_callback, kafka_consumer, self.args, self.kwargs ) extract_bootstrap_servers.assert_called_once_with(kafka_consumer) bootstrap_servers = extract_bootstrap_servers.return_value - original_next_callback.assert_called_once_with( + original_poll_callback.assert_called_once_with( *self.args, **self.kwargs ) - self.assertEqual(record, original_next_callback.return_value) - - extract.assert_called_once_with(record.headers, getter=_kafka_getter) - context = extract.return_value - - _create_consumer_span.assert_called_once_with( - tracer, - consume_hook, - record, - context, - bootstrap_servers, - self.args, - self.kwargs, - ) + self.assertEqual(records, original_poll_callback.return_value) + + for items in records.values(): + for record in items: + extract.assert_called_once_with(record.headers, getter=_kafka_getter) + context = extract.return_value + + _create_consumer_span.assert_called_once_with( + tracer, + consume_hook, + record, + context, + bootstrap_servers, + self.args, + self.kwargs, + ) @mock.patch("opentelemetry.trace.set_span_in_context") @mock.patch("opentelemetry.context.attach")