diff --git a/CHANGES.rst b/CHANGES.rst index 08ed240ee..de0442dbf 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,6 +5,13 @@ Changelog 0.13.0 (????-??-??) =================== +New features: + +* Allow SASL Connections to Periodically Re-Authenticate (`KIP-368`_) (pr #1105 by @kprzybyla) + +.. _KIP-368: https://cwiki.apache.org/confluence/display/KAFKA/KIP-368%3A+Allow+SASL+Connections+to+Periodically+Re-Authenticate + + Improved Documentation: * Fix incomplete documentation for `AIOKafkaConsumer.offset_for_times`` diff --git a/aiokafka/conn.py b/aiokafka/conn.py index 859f8f245..489dc4a74 100644 --- a/aiokafka/conn.py +++ b/aiokafka/conn.py @@ -44,6 +44,9 @@ READER_LIMIT = 2**16 SASL_QOP_AUTH = 1 +SASL_REQUEST_API_KEYS = frozenset( + request.API_KEY for request in (*SaslHandShakeRequest, *SaslAuthenticateRequest) +) class CloseReason(IntEnum): @@ -116,6 +119,24 @@ async def create_conn( return conn +def calculate_sasl_reauthentication_time(session_lifetime_ms: int) -> int: + """ + Calculates the SASL session re-authentication time following + the Java Kafka implementation from SaslClientAuthenticator.java + ReauthInfo#setAuthenticationEndAndSessionReauthenticationTimes. + + The re-authentication factor is calculated by choosing random value + between 0.85 and 0.95, which accounts for both network latency and clock + drift as well as potential jitter which may cause re-authentication storm + across many channels simultaneously. + """ + + reauthentication_time_factor: float = random.uniform(0.85, 0.95) + expiration_time: float = (session_lifetime_ms * reauthentication_time_factor) / 1000 + + return int(expiration_time) + + class AIOKafkaProtocol(asyncio.StreamReaderProtocol): def __init__(self, closed_fut, *args, loop, **kw): self._closed_fut = closed_fut @@ -184,6 +205,9 @@ def __init__( self._sasl_kerberos_service_name = sasl_kerberos_service_name self._sasl_kerberos_domain_name = sasl_kerberos_domain_name self._sasl_oauth_token_provider = sasl_oauth_token_provider + self._sasl_reauthentication_task = None + self._sasl_reauthentication_done = asyncio.Event() + self._sasl_reauthentication_done.set() # Version hint is the version determined by initial client bootstrap self._version_hint = version_hint @@ -358,6 +382,16 @@ async def _do_sasl_handshake(self): raise exc auth_bytes = resp.sasl_auth_bytes + if ( + hasattr(resp, "session_lifetime_ms") + and resp.session_lifetime_ms != 0 + ): + self._sasl_reauthentication_task = ( + self._create_sasl_reauthentication_task( + resp.session_lifetime_ms + ) + ) + if self._sasl_mechanism == "GSSAPI": log.info("Authenticated as %s via GSSAPI", self.sasl_principal) elif self._sasl_mechanism == "OAUTHBEARER": @@ -369,6 +403,69 @@ async def _do_sasl_handshake(self): self._sasl_mechanism, ) + def _create_sasl_reauthentication_task( + self, session_lifetime_ms: int + ) -> asyncio.Task: + self_ref = weakref.ref(self) + timeout = calculate_sasl_reauthentication_time(session_lifetime_ms) + + log.info( + "SASL re-authentication required after %ds for connection %s:%s", + timeout, + self._host, + self._port, + ) + + sasl_reauthentication_task = create_task( + self._sasl_reauthentication(self_ref, timeout) + ) + sasl_reauthentication_task.add_done_callback( + functools.partial(self._on_sasl_reauthentication_task_error, self_ref) + ) + + return sasl_reauthentication_task + + @staticmethod + async def _sasl_reauthentication( + self_ref: weakref.ReferenceType["AIOKafkaConnection"], + sasl_reauthentication_time: int, + ) -> None: + self = self_ref() + + if self is None: + return + + await asyncio.sleep(sasl_reauthentication_time) + self._sasl_reauthentication_done.clear() + + await self._do_sasl_handshake() + self._sasl_reauthentication_done.set() + + log.info( + "SASL re-authentication complete for connection %s:%s", + self._host, + self._port, + ) + + @staticmethod + def _on_sasl_reauthentication_task_error( + self_ref: weakref.ReferenceType["AIOKafkaConnection"], + sasl_reauthentication_task: asyncio.Task, + ) -> None: + if sasl_reauthentication_task.cancelled(): + return + + try: + sasl_reauthentication_task.result() + except BaseException as exc: + if not isinstance(exc, (OSError, EOFError, ConnectionError)): + log.exception("Unexpected exception in AIOKafkaConnection") + + self = self_ref() + + if self is not None: + self.close(reason=CloseReason.AUTH_FAILURE, exc=exc) + def authenticator_plain(self): return SaslPlainAuthenticator( loop=self._loop, @@ -458,6 +555,18 @@ def send(self, request, expect_response=True): f"No connection to broker at {self._host}:{self._port}" ) + if ( + self._sasl_reauthentication_done.is_set() + or request.API_KEY in SASL_REQUEST_API_KEYS + ): + return self._send(request=request, expect_response=expect_response) + + return self._send_after_sasl_reauthentication( + request=request, + expect_response=expect_response, + ) + + def _send(self, request, expect_response=True): correlation_id = self._next_correlation_id() header = request.build_request_header( correlation_id=correlation_id, client_id=self._client_id @@ -482,6 +591,11 @@ def send(self, request, expect_response=True): ) return wait_for(fut, self._request_timeout) + async def _send_after_sasl_reauthentication(self, request, expect_response): + await self._sasl_reauthentication_done.wait() + + return await self._send(request=request, expect_response=expect_response) + def _send_sasl_token(self, payload, expect_response=True): if self._writer is None: raise Errors.KafkaConnectionError( @@ -530,6 +644,13 @@ def close(self, reason=None, exc=None): if self._idle_handle is not None: self._idle_handle.cancel() + if ( + self._sasl_reauthentication_task is not None + and not self._sasl_reauthentication_task.done() + ): + self._sasl_reauthentication_task.cancel() + self._sasl_reauthentication_task = None + # transport.close() will close socket, but not right ahead. Return # a future in case we need to wait on it. return self._closed_fut diff --git a/docker/scripts/start-kafka.sh b/docker/scripts/start-kafka.sh index ddb86889b..c4dc2572e 100755 --- a/docker/scripts/start-kafka.sh +++ b/docker/scripts/start-kafka.sh @@ -1,12 +1,12 @@ -#!/bin/sh +#!/bin/bash -OPTIONS="" +OPTIONS=() PATH="$HOME/bin:$PATH" # Configure the default number of log partitions per topic if [ ! -z "$NUM_PARTITIONS" ]; then echo "default number of partition: $NUM_PARTITIONS" - OPTIONS="$OPTIONS --override num.partitions=$NUM_PARTITIONS" + OPTIONS+=("--override" "num.partitions=$NUM_PARTITIONS") fi # Set the external host and port @@ -16,38 +16,54 @@ echo "advertised port: $ADVERTISED_PORT" LISTENERS="PLAINTEXT://:$ADVERTISED_PORT" ADVERTISED_LISTENERS="PLAINTEXT://$ADVERTISED_HOST:$ADVERTISED_PORT" -if [ ! -z "$ADVERTISED_SSL_PORT" ]; then +if [[ ! -z "$ADVERTISED_SSL_PORT" ]]; then echo "advertised ssl port: $ADVERTISED_SSL_PORT" # SSL options - OPTIONS="$OPTIONS --override ssl.protocol=TLS" - OPTIONS="$OPTIONS --override ssl.enabled.protocols=TLSv1.2,TLSv1.1,TLSv1" - OPTIONS="$OPTIONS --override ssl.keystore.type=JKS" - OPTIONS="$OPTIONS --override ssl.keystore.location=/ssl_cert/br_server.keystore.jks" - OPTIONS="$OPTIONS --override ssl.keystore.password=abcdefgh" - OPTIONS="$OPTIONS --override ssl.key.password=abcdefgh" - OPTIONS="$OPTIONS --override ssl.truststore.type=JKS" - OPTIONS="$OPTIONS --override ssl.truststore.location=/ssl_cert/br_server.truststore.jks" - OPTIONS="$OPTIONS --override ssl.truststore.password=abcdefgh" - OPTIONS="$OPTIONS --override ssl.client.auth=required" - OPTIONS="$OPTIONS --override security.inter.broker.protocol=SSL" - OPTIONS="$OPTIONS --override ssl.endpoint.identification.algorithm=" + OPTIONS+=("--override" "ssl.protocol=TLS") + OPTIONS+=("--override" "ssl.enabled.protocols=TLSv1.2,TLSv1.1,TLSv1") + OPTIONS+=("--override" "ssl.keystore.type=JKS") + OPTIONS+=("--override" "ssl.keystore.location=/ssl_cert/br_server.keystore.jks") + OPTIONS+=("--override" "ssl.keystore.password=abcdefgh") + OPTIONS+=("--override" "ssl.key.password=abcdefgh") + OPTIONS+=("--override" "ssl.truststore.type=JKS") + OPTIONS+=("--override" "ssl.truststore.location=/ssl_cert/br_server.truststore.jks") + OPTIONS+=("--override" "ssl.truststore.password=abcdefgh") + OPTIONS+=("--override" "ssl.client.auth=required") + OPTIONS+=("--override" "security.inter.broker.protocol=SSL") + OPTIONS+=("--override" "ssl.endpoint.identification.algorithm=") LISTENERS="$LISTENERS,SSL://:$ADVERTISED_SSL_PORT" ADVERTISED_LISTENERS="$ADVERTISED_LISTENERS,SSL://$ADVERTISED_HOST:$ADVERTISED_SSL_PORT" fi -if [ ! -z "$SASL_MECHANISMS" ]; then +if [[ ! -z "$SASL_MECHANISMS" ]]; then echo "sasl mechanisms: $SASL_MECHANISMS" echo "advertised sasl plaintext port: $ADVERTISED_SASL_PLAINTEXT_PORT" echo "advertised sasl ssl port: $ADVERTISED_SASL_SSL_PORT" - OPTIONS="$OPTIONS --override sasl.enabled.mechanisms=$SASL_MECHANISMS" - OPTIONS="$OPTIONS --override sasl.kerberos.service.name=kafka" - OPTIONS="$OPTIONS --override authorizer.class.name=kafka.security.auth.SimpleAclAuthorizer" - OPTIONS="$OPTIONS --override allow.everyone.if.no.acl.found=true" + OPTIONS+=("--override" "sasl.enabled.mechanisms=$SASL_MECHANISMS") + OPTIONS+=("--override" "sasl.kerberos.service.name=kafka") + OPTIONS+=("--override" "authorizer.class.name=kafka.security.auth.SimpleAclAuthorizer") + OPTIONS+=("--override" "allow.everyone.if.no.acl.found=true") export KAFKA_OPTS="-Djava.security.auth.login.config=/etc/kafka/$SASL_JAAS_FILE" + # OAUTHBEARER configuration is incompatible with other SASL configurations present in JAAS file + if [[ "$SASL_MECHANISMS" == "OAUTHBEARER" ]]; then + OPTIONS+=("--override" "listener.name.sasl_plaintext.oauthbearer.sasl.jaas.config= + org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required + unsecuredLoginStringClaim_sub=\"producer\" + unsecuredValidatorAllowableClockSkewMs=\"3000\";" + ) + OPTIONS+=("--override" "listener.name.sasl_ssl.oauthbearer.sasl.jaas.config= + org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required + unsecuredLoginStringClaim_sub=\"consumer\" + unsecuredValidatorAllowableClockSkewMs=\"3000\";" + ) + OPTIONS+=("--override" "listener.name.sasl_plaintext.oauthbearer.connections.max.reauth.ms=3600000") + OPTIONS+=("--override" "listener.name.sasl_ssl.oauthbearer.connections.max.reauth.ms=3600000") + fi + LISTENERS="$LISTENERS,SASL_PLAINTEXT://:$ADVERTISED_SASL_PLAINTEXT_PORT" ADVERTISED_LISTENERS="$ADVERTISED_LISTENERS,SASL_PLAINTEXT://$ADVERTISED_HOST:$ADVERTISED_SASL_PLAINTEXT_PORT" @@ -56,13 +72,13 @@ if [ ! -z "$SASL_MECHANISMS" ]; then fi # Enable auto creation of topics -OPTIONS="$OPTIONS --override auto.create.topics.enable=true" -OPTIONS="$OPTIONS --override listeners=$LISTENERS" -OPTIONS="$OPTIONS --override advertised.listeners=$ADVERTISED_LISTENERS" -OPTIONS="$OPTIONS --override super.users=User:admin" +OPTIONS+=("--override" "auto.create.topics.enable=true") +OPTIONS+=("--override" "listeners=$LISTENERS") +OPTIONS+=("--override" "advertised.listeners=$ADVERTISED_LISTENERS") +OPTIONS+=("--override" "super.users=User:admin") # Run Kafka echo "$KAFKA_HOME/bin/kafka-server-start.sh $KAFKA_HOME/config/server.properties $OPTIONS" -exec $KAFKA_HOME/bin/kafka-server-start.sh $KAFKA_HOME/config/server.properties $OPTIONS +exec $KAFKA_HOME/bin/kafka-server-start.sh $KAFKA_HOME/config/server.properties "${OPTIONS[@]}" diff --git a/pyproject.toml b/pyproject.toml index c279e24dd..ca9bca2b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ asyncio_mode = "auto" addopts = ["--strict-config", "--strict-markers"] markers = [ "ssl: Tests that require SSL certificates to run", + "oauthbearer: Tests that require SASL OAUTHBEARER mechanism to run", ] filterwarnings = [ "error", diff --git a/requirements-ci.txt b/requirements-ci.txt index 615ffb651..f0f268a99 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -11,3 +11,4 @@ Pygments==2.18.0 gssapi==1.9.0 async-timeout==4.0.3 cramjam==2.9.0 +pyjwt==2.10.1 diff --git a/tests/conftest.py b/tests/conftest.py index bd4fda2d6..689595b75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -202,8 +202,15 @@ def hosts(self): if sys.platform != "win32": + @pytest.fixture(scope="class") + def kafka_server(request, kafka_server_basic, kafka_server_oauthbearer): + if request.node.get_closest_marker("oauthbearer"): + return kafka_server_oauthbearer + + return kafka_server_basic + @pytest.fixture(scope="session") - def kafka_server( + def kafka_server_basic( kafka_image, docker, docker_ip_address, unused_port, session_id, ssl_folder ): kafka_host = docker_ip_address @@ -274,6 +281,75 @@ def kafka_server( finally: container.stop() + @pytest.fixture(scope="session") + def kafka_server_oauthbearer( + kafka_image, docker, docker_ip_address, unused_port, session_id, ssl_folder + ): + kafka_host = docker_ip_address + kafka_port = unused_port() + kafka_ssl_port = unused_port() + kafka_sasl_plain_port = unused_port() + kafka_sasl_ssl_port = unused_port() + + environment = { + "ADVERTISED_HOST": kafka_host, + "ADVERTISED_PORT": kafka_port, + "ADVERTISED_SSL_PORT": kafka_ssl_port, + "ADVERTISED_SASL_PLAINTEXT_PORT": kafka_sasl_plain_port, + "ADVERTISED_SASL_SSL_PORT": kafka_sasl_ssl_port, + "NUM_PARTITIONS": 2, + } + + kafka_version = kafka_image.split(":")[-1].split("_")[-1] + kafka_version = tuple(int(x) for x in kafka_version.split(".")) + + if kafka_version < (0, 10, 0): + pytest.skip("SASL OAUTHBEARER requires Kafka version >= 0.10.0") + + environment["SASL_MECHANISMS"] = "OAUTHBEARER" + environment["SASL_JAAS_FILE"] = "kafka_server_jaas.conf" + + container = docker.containers.run( + image=kafka_image, + name="aiokafka-oauthbearer-tests", + ports={ + kafka_port: kafka_port, + kafka_ssl_port: kafka_ssl_port, + kafka_sasl_plain_port: kafka_sasl_plain_port, + kafka_sasl_ssl_port: kafka_sasl_ssl_port, + }, + volumes={str(ssl_folder.resolve()): {"bind": "/ssl_cert", "mode": "ro"}}, + environment=environment, + tty=True, + detach=True, + remove=True, + ) + + try: + if not wait_kafka(kafka_host, kafka_port): + exit_code, output = container.exec_run( + ["supervisorctl", "tail", "-20000", "kafka"] + ) + print("Kafka (SASL OAUTHBEARER) failed to start. \n--- STDOUT:") + print(output.decode(), file=sys.stdout) + exit_code, output = container.exec_run( + ["supervisorctl", "tail", "-20000", "kafka", "stderr"] + ) + print("--- STDERR:") + print(output.decode(), file=sys.stderr) + pytest.exit("Could not start Kafka Server (SASL OAUTHBEARER)") + + yield KafkaServer( + kafka_host, + kafka_port, + kafka_ssl_port, + kafka_sasl_plain_port, + kafka_sasl_ssl_port, + container, + ) + finally: + container.stop() + else: @pytest.fixture(scope="session") diff --git a/tests/test_sasl.py b/tests/test_sasl.py index e9baaf5ff..bd00a4c9c 100644 --- a/tests/test_sasl.py +++ b/tests/test_sasl.py @@ -1,6 +1,13 @@ +import asyncio +from datetime import datetime, timedelta, timezone +from functools import partial + +import jwt import pytest +from aiokafka.abc import AbstractTokenProvider from aiokafka.admin import AIOKafkaAdminClient +from aiokafka.conn import CloseReason from aiokafka.consumer import AIOKafkaConsumer from aiokafka.errors import ( GroupAuthorizationFailedError, @@ -477,3 +484,210 @@ async def test_sasl_deny_txnid_during_transaction(self): ) with self.assertRaises(TransactionalIdAuthorizationFailed): await producer.send_and_wait(self.topic, b"123", partition=1) + + +class TokenProvider(AbstractTokenProvider): + def __init__(self, *, subject: str, expiration_time: timedelta) -> None: + self.subject = subject + self.expiration_time = expiration_time + + async def token(self) -> str: + return await asyncio.get_running_loop().run_in_executor(None, self._token) + + def _token(self) -> str: + sub = self.subject + iat = datetime.now(timezone.utc).timestamp() + exp = iat + self.expiration_time.total_seconds() + + access_token = jwt.encode( + key=None, + headers={"alg": "none"}, + payload={"sub": sub, "iat": iat, "exp": exp}, + ) + + return access_token + + +@pytest.mark.oauthbearer +@pytest.mark.usefixtures("setup_test_class") +class TestKafkaSASLOAuthBearer(KafkaIntegrationTestCase): + TEST_TIMEOUT = 60 + + @property + def group_id(self): + return self.topic + "_group" + + @property + def sasl_hosts(self): + # Produce/consume by SASL_PLAINTEXT + return f"{self.kafka_host}:{self.kafka_sasl_plain_port}" + + @property + def sasl_ssl_hosts(self): + # Produce/consume by SASL_SSL + return f"{self.kafka_host}:{self.kafka_sasl_ssl_port}" + + async def oauthbearer_producer_factory(self, token_expiration_time=None, **kw): + producer = AIOKafkaProducer( + api_version="0.10.2", + bootstrap_servers=[self.sasl_hosts], + security_protocol="SASL_PLAINTEXT", + sasl_mechanism="OAUTHBEARER", + sasl_oauth_token_provider=TokenProvider( + subject="producer", + expiration_time=token_expiration_time or timedelta(hours=1), + ), + **kw, + ) + self.add_cleanup(producer.stop) + await producer.start() + return producer + + async def oauthbearer_consumer_factory(self, token_expiration_time=None, **kw): + kwargs = { + "enable_auto_commit": True, + "auto_offset_reset": "earliest", + "group_id": self.group_id, + } + kwargs.update(kw) + consumer = AIOKafkaConsumer( + self.topic, + api_version="0.10.2", + bootstrap_servers=[self.sasl_ssl_hosts], + security_protocol="SASL_SSL", + ssl_context=self.create_ssl_context(), + sasl_mechanism="OAUTHBEARER", + sasl_oauth_token_provider=TokenProvider( + subject="consumer", + expiration_time=token_expiration_time or timedelta(hours=1), + ), + **kwargs, + ) + self.add_cleanup(consumer.stop) + await consumer.start() + return consumer + + @kafka_versions(">=0.10.2") + @run_until_complete + async def test_sasl_oauthbearer(self): + producer = await self.oauthbearer_producer_factory() + await producer.send_and_wait(topic=self.topic, value=b"Super oauthbearer msg") + + consumer = await self.oauthbearer_consumer_factory() + msg = await consumer.getone() + self.assertEqual(msg.value, b"Super oauthbearer msg") + + @kafka_versions(">=0.10.2") + @run_until_complete + async def test_sasl_oauthbearer_reauthentication(self): + reauthentication_done = asyncio.Event() + + def reauthentication_callback(_task: asyncio.Task) -> None: + reauthentication_done.set() + + token_expiration_time = timedelta(seconds=5) + producer = await self.oauthbearer_producer_factory(token_expiration_time) + await producer.send_and_wait(topic=self.topic, value=b"Before re-auth msg") + + (conn,) = producer.client._conns.values() + conn_reauthentication_task = conn._sasl_reauthentication_task + conn_reauthentication_task.add_done_callback(reauthentication_callback) + + consumer = await self.oauthbearer_consumer_factory() + msg = await consumer.getone() + self.assertEqual(msg.value, b"Before re-auth msg") + + assert not reauthentication_done.is_set() + assert not conn_reauthentication_task.done() + + await reauthentication_done.wait() + + assert conn_reauthentication_task.done() + assert conn_reauthentication_task is not conn._sasl_reauthentication_task + + await producer.send_and_wait(topic=self.topic, value=b"After re-auth msg") + msg = await consumer.getone() + self.assertEqual(msg.value, b"After re-auth msg") + + @kafka_versions(">=0.10.2") + @run_until_complete + async def test_sasl_oauthbearer_reauthentication_cannot_be_interrupted(self): + reauthentication_started = asyncio.Event() + reauthentication_done = asyncio.Event() + + def event_clear_wrapper(original_event_clear_fn) -> None: + original_event_clear_fn() + reauthentication_started.set() + + def reauthentication_callback(_task: asyncio.Task) -> None: + reauthentication_done.set() + + token_expiration_time = timedelta(seconds=5) + producer = await self.oauthbearer_producer_factory(token_expiration_time) + await producer.send_and_wait(topic=self.topic, value=b"Before re-auth msg") + + (conn,) = producer.client._conns.values() + conn_reauthentication_task = conn._sasl_reauthentication_task + conn_reauthentication_task.add_done_callback(reauthentication_callback) + conn_reauthentication_done = conn._sasl_reauthentication_done + conn_reauthentication_done.clear = partial( + event_clear_wrapper, + conn_reauthentication_done.clear, + ) + + consumer = await self.oauthbearer_consumer_factory() + msg = await consumer.getone() + self.assertEqual(msg.value, b"Before re-auth msg") + + assert not reauthentication_done.is_set() + assert not conn_reauthentication_task.done() + + await reauthentication_started.wait() + await producer.send_and_wait(topic=self.topic, value=b"During re-auth msg") + await reauthentication_done.wait() + + assert conn_reauthentication_task.done() + assert conn_reauthentication_task is not conn._sasl_reauthentication_task + + msg = await consumer.getone() + self.assertEqual(msg.value, b"During re-auth msg") + + @kafka_versions(">=0.10.2") + @run_until_complete + async def test_sasl_oauthbearer_reauthentication_handles_failure_gracefully(self): + conn_close_reason = None + reauthentication_done = asyncio.Event() + + def reauthentication_callback(_task: asyncio.Task) -> None: + reauthentication_done.set() + + def do_sasl_handshake() -> None: + raise ConnectionError() + + def on_connection_closed(_conn, reason: CloseReason) -> None: + nonlocal conn_close_reason + conn_close_reason = reason + + token_expiration_time = timedelta(seconds=5) + producer = await self.oauthbearer_producer_factory(token_expiration_time) + await producer.send_and_wait(topic=self.topic, value=b"Before re-auth msg") + + (conn,) = producer.client._conns.values() + conn_reauthentication_task = conn._sasl_reauthentication_task + conn_reauthentication_task.add_done_callback(reauthentication_callback) + conn._do_sasl_handshake = do_sasl_handshake + conn._on_close_cb = on_connection_closed + + consumer = await self.oauthbearer_consumer_factory() + msg = await consumer.getone() + self.assertEqual(msg.value, b"Before re-auth msg") + + assert not reauthentication_done.is_set() + assert not conn_reauthentication_task.done() + + await reauthentication_done.wait() + + assert conn_reauthentication_task.done() + assert conn_reauthentication_task is conn._sasl_reauthentication_task + assert conn.connected() is False + assert conn_close_reason is CloseReason.AUTH_FAILURE