1515from __future__ import annotations
1616
1717import uuid
18- from typing import Any , List , Sequence , Tuple
18+ from typing import Any , List , Sequence , Tuple , cast
1919from unittest import IsolatedAsyncioTestCase , TestCase , mock
2020
2121import aiokafka
@@ -138,16 +138,22 @@ async def producer_factory() -> AIOKafkaProducer:
138138
139139 return producer
140140
141- async def test_getone (self ) -> None :
142- AIOKafkaInstrumentor ().uninstrument ()
141+ def setUp (self ):
142+ super ().setUp ()
143143 AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
144144
145+ def tearDown (self ):
146+ super ().tearDown ()
147+ AIOKafkaInstrumentor ().uninstrument ()
148+
149+ async def test_getone (self ) -> None :
145150 client_id = str (uuid .uuid4 ())
146151 group_id = str (uuid .uuid4 ())
147152 consumer = await self .consumer_factory (
148153 client_id = client_id , group_id = group_id
149154 )
150- next_record_mock : mock .AsyncMock = consumer ._fetcher .next_record
155+ self .addAsyncCleanup (consumer .stop )
156+ next_record_mock = cast (mock .AsyncMock , consumer ._fetcher .next_record )
151157
152158 expected_spans = [
153159 {
@@ -229,7 +235,8 @@ async def async_consume_hook(span, *_) -> None:
229235 )
230236
231237 consumer = await self .consumer_factory ()
232- next_record_mock : mock .AsyncMock = consumer ._fetcher .next_record
238+ self .addAsyncCleanup (consumer .stop )
239+ next_record_mock = cast (mock .AsyncMock , consumer ._fetcher .next_record )
233240
234241 self .memory_exporter .clear ()
235242
@@ -261,7 +268,8 @@ async def test_getone_consume_hook(self) -> None:
261268 )
262269
263270 consumer = await self .consumer_factory ()
264- next_record_mock : mock .AsyncMock = consumer ._fetcher .next_record
271+ self .addAsyncCleanup (consumer .stop )
272+ next_record_mock = cast (mock .AsyncMock , consumer ._fetcher .next_record )
265273
266274 next_record_mock .side_effect = [
267275 self .consumer_record_factory (1 , headers = ())
@@ -272,16 +280,14 @@ async def test_getone_consume_hook(self) -> None:
272280 async_consume_hook_mock .assert_awaited_once ()
273281
274282 async def test_getmany (self ) -> None :
275- AIOKafkaInstrumentor ().uninstrument ()
276- AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
277-
278283 client_id = str (uuid .uuid4 ())
279284 group_id = str (uuid .uuid4 ())
280285 consumer = await self .consumer_factory (
281286 client_id = client_id , group_id = group_id
282287 )
283- fetched_records_mock : mock .AsyncMock = (
284- consumer ._fetcher .fetched_records
288+ self .addAsyncCleanup (consumer .stop )
289+ fetched_records_mock = cast (
290+ mock .AsyncMock , consumer ._fetcher .fetched_records
285291 )
286292
287293 expected_spans = [
@@ -384,12 +390,10 @@ async def test_getmany(self) -> None:
384390 self ._compare_spans (span_list , expected_spans )
385391
386392 async def test_send (self ) -> None :
387- AIOKafkaInstrumentor ().uninstrument ()
388- AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
389-
390393 producer = await self .producer_factory ()
391- add_message_mock : mock .AsyncMock = (
392- producer ._message_accumulator .add_message
394+ self .addAsyncCleanup (producer .stop )
395+ add_message_mock = cast (
396+ mock .AsyncMock , producer ._message_accumulator .add_message
393397 )
394398
395399 tracer = self .tracer_provider .get_tracer (__name__ )
@@ -419,12 +423,10 @@ async def test_send(self) -> None:
419423 )
420424
421425 async def test_send_baggage (self ) -> None :
422- AIOKafkaInstrumentor ().uninstrument ()
423- AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
424-
425426 producer = await self .producer_factory ()
426- add_message_mock : mock .AsyncMock = (
427- producer ._message_accumulator .add_message
427+ self .addAsyncCleanup (producer .stop )
428+ add_message_mock = cast (
429+ mock .AsyncMock , producer ._message_accumulator .add_message
428430 )
429431
430432 tracer = self .tracer_provider .get_tracer (__name__ )
@@ -453,6 +455,7 @@ async def test_send_produce_hook(self) -> None:
453455 )
454456
455457 producer = await self .producer_factory ()
458+ self .addAsyncCleanup (producer .stop )
456459
457460 await producer .send ("topic_1" , b"value_1" )
458461
0 commit comments