From c50656f37327ca5d2c1b9e119e261e99c81f1dd1 Mon Sep 17 00:00:00 2001 From: Naxin Date: Thu, 4 Sep 2025 10:20:18 -0400 Subject: [PATCH 01/31] update --- src/confluent_kafka/cimpl.pyi | 77 +++++++++++++++++------------------ 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/src/confluent_kafka/cimpl.pyi b/src/confluent_kafka/cimpl.pyi index 99a888acc..75c43f832 100644 --- a/src/confluent_kafka/cimpl.pyi +++ b/src/confluent_kafka/cimpl.pyi @@ -43,6 +43,9 @@ from ._types import HeadersType if TYPE_CHECKING: from confluent_kafka.admin._metadata import ClusterMetadata, GroupMetadata +# Type aliases for common patterns +ConfigDict = Dict[str, Union[str, int, float, bool]] + # Callback types with proper class references (defined locally to avoid circular imports) DeliveryCallback = Callable[[Optional['KafkaError'], 'Message'], None] RebalanceCallback = Callable[['Consumer', List['TopicPartition']], None] @@ -73,8 +76,9 @@ class KafkaError: def __ge__(self, other: Union['KafkaError', int]) -> bool: ... class KafkaException(Exception): - def __init__(self, *args: Any, **kwargs: Any) -> None: ... - args: Tuple[Any, ...] + def __init__(self, kafka_error: KafkaError) -> None: ... + @property + def args(self) -> Tuple[KafkaError, ...]: ... class Message: def topic(self) -> str: ... @@ -88,8 +92,8 @@ class Message: def latency(self) -> Optional[float]: ... def leader_epoch(self) -> Optional[int]: ... def set_headers(self, headers: HeadersType) -> None: ... - def set_key(self, key: Any) -> None: ... - def set_value(self, value: Any) -> None: ... + def set_key(self, key: bytes) -> None: ... + def set_value(self, value: bytes) -> None: ... def __len__(self) -> int: ... class TopicPartition: @@ -115,7 +119,7 @@ class Uuid: def __eq__(self, other: object) -> bool: ... class Producer: - def __init__(self, config: Dict[str, Union[str, int, float, bool]]) -> None: ... + def __init__(self, config: ConfigDict) -> None: ... def produce( self, topic: str, @@ -124,7 +128,7 @@ class Producer: partition: int = -1, callback: Optional[DeliveryCallback] = None, on_delivery: Optional[DeliveryCallback] = None, - timestamp: int = 0, + timestamp: Optional[int] = None, headers: Optional[HeadersType] = None ) -> None: ... def produce_batch( @@ -142,7 +146,7 @@ class Producer: in_queue: bool = True, in_flight: bool = True, blocking: bool = True - ) -> None: ... + ) -> int: ... def abort_transaction(self, timeout: float = -1) -> None: ... def begin_transaction(self) -> None: ... def commit_transaction(self, timeout: float = -1) -> None: ... @@ -159,7 +163,7 @@ class Producer: def __bool__(self) -> bool: ... class Consumer: - def __init__(self, config: Dict[str, Union[str, int, float, bool, None]]) -> None: ... + def __init__(self, config: ConfigDict) -> None: ... def subscribe( self, topics: List[str], @@ -187,6 +191,11 @@ class Consumer: offsets: Optional[List[TopicPartition]] = None, asynchronous: Literal[False] = False ) -> List[TopicPartition]: ... + def committed( + self, + partitions: List[TopicPartition], + timeout: float = -1 + ) -> List[TopicPartition]: ... def get_watermark_offsets( self, partition: TopicPartition, @@ -202,11 +211,6 @@ class Consumer: message: Optional['Message'] = None, offsets: Optional[List[TopicPartition]] = None ) -> None: ... - def committed( - self, - partitions: List[TopicPartition], - timeout: float = -1 - ) -> List[TopicPartition]: ... def close(self) -> None: ... def list_topics(self, topic: Optional[str] = None, timeout: float = -1) -> Any: ... def offsets_for_times( @@ -222,7 +226,7 @@ class Consumer: def __bool__(self) -> bool: ... class _AdminClientImpl: - def __init__(self, config: Dict[str, Union[str, int, float, bool]]) -> None: ... + def __init__(self, config: ConfigDict) -> None: ... def create_topics( self, topics: List['NewTopic'], @@ -248,8 +252,8 @@ class _AdminClientImpl: ) -> None: ... def describe_topics( self, + topics: List[str], future: Any, - topic_names: List[str], request_timeout: float = -1, include_authorized_operations: bool = False ) -> None: ... @@ -271,7 +275,7 @@ class _AdminClientImpl: ) -> List[GroupMetadata]: ... def describe_consumer_groups( self, - group_ids: List[str], + groups: List[str], future: Any, request_timeout: float = -1, include_authorized_operations: bool = False @@ -279,15 +283,13 @@ class _AdminClientImpl: def list_consumer_groups( self, future: Any, - states_int: Optional[List[int]] = None, - types_int: Optional[List[int]] = None, - request_timeout: float = -1 + request_timeout: float = -1, + states: Optional[List[str]] = None ) -> None: ... def list_consumer_group_offsets( self, request: Any, # ConsumerGroupTopicPartitions future: Any, - require_stable: bool = False, request_timeout: float = -1 ) -> None: ... def alter_consumer_group_offsets( @@ -298,7 +300,7 @@ class _AdminClientImpl: ) -> None: ... def delete_consumer_groups( self, - group_ids: List[str], + groups: List[str], future: Any, request_timeout: float = -1 ) -> None: ... @@ -310,13 +312,13 @@ class _AdminClientImpl: ) -> None: ... def describe_acls( self, - acl_binding_filter: Any, # AclBindingFilter + acl_filter: Any, # AclBindingFilter future: Any, request_timeout: float = -1 ) -> None: ... def delete_acls( self, - acls: List[Any], # List[AclBindingFilter] + acl_filters: List[Any], # List[AclBindingFilter] future: Any, request_timeout: float = -1 ) -> None: ... @@ -325,23 +327,22 @@ class _AdminClientImpl: resources: List[Any], # List[ConfigResource] future: Any, request_timeout: float = -1, - broker: int = -1 + include_synonyms: bool = False, + include_documentation: bool = False ) -> None: ... def alter_configs( self, - resources: List[Any], # List[ConfigResource] + resources: Dict[Any, Dict[str, str]], # Dict[ConfigResource, Dict[str, str]] future: Any, - validate_only: bool = False, request_timeout: float = -1, - broker: int = -1 + validate_only: bool = False ) -> None: ... def incremental_alter_configs( self, - resources: List[Any], # List[ConfigResource] + resources: Dict[Any, Dict[str, Any]], # Dict[ConfigResource, Dict[str, ConfigEntry]] future: Any, - validate_only: bool = False, request_timeout: float = -1, - broker: int = -1 + validate_only: bool = False ) -> None: ... def describe_user_scram_credentials( self, @@ -359,25 +360,23 @@ class _AdminClientImpl: self, topic_partitions: List[TopicPartition], future: Any, - isolation_level_value: Optional[int] = None, - request_timeout: float = -1 + request_timeout: float = -1, + isolation_level: Optional[int] = None ) -> None: ... def delete_records( self, - topic_partition_offsets: List[TopicPartition], + topic_partitions: List[TopicPartition], future: Any, - request_timeout: float = -1, - operation_timeout: float = -1 + request_timeout: float = -1 ) -> None: ... def elect_leaders( self, - election_type: int, - partitions: Optional[List[TopicPartition]], + topic_partitions: Optional[List[TopicPartition]], future: Any, request_timeout: float = -1, - operation_timeout: float = -1 + election_type: int = 0 ) -> None: ... - def poll(self, timeout: float = -1) -> int: ... + def poll(self, timeout: float = -1) -> Any: ... def set_sasl_credentials(self, username: str, password: str) -> None: ... class NewTopic: From 343dbfe40390516e1c71dd17b3858cf018775b50 Mon Sep 17 00:00:00 2001 From: Naxin Date: Thu, 4 Sep 2025 15:14:18 -0400 Subject: [PATCH 02/31] remove py.typed for now --- src/confluent_kafka/py.typed | 1 - 1 file changed, 1 deletion(-) delete mode 100644 src/confluent_kafka/py.typed diff --git a/src/confluent_kafka/py.typed b/src/confluent_kafka/py.typed deleted file mode 100644 index 0519ecba6..000000000 --- a/src/confluent_kafka/py.typed +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file From 1e7f2ac018f23058aef1219c36b0103bf1cdc9e3 Mon Sep 17 00:00:00 2001 From: Naxin Date: Fri, 5 Sep 2025 11:53:27 -0400 Subject: [PATCH 03/31] update --- src/confluent_kafka/cimpl.pyi | 54 +++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/src/confluent_kafka/cimpl.pyi b/src/confluent_kafka/cimpl.pyi index 75c43f832..aee974daa 100644 --- a/src/confluent_kafka/cimpl.pyi +++ b/src/confluent_kafka/cimpl.pyi @@ -252,8 +252,8 @@ class _AdminClientImpl: ) -> None: ... def describe_topics( self, - topics: List[str], future: Any, + topic_names: List[str], request_timeout: float = -1, include_authorized_operations: bool = False ) -> None: ... @@ -265,42 +265,45 @@ class _AdminClientImpl: ) -> None: ... def list_topics( self, - topic: Optional[str] = None, - timeout: float = -1 - ) -> ClusterMetadata: ... + future: Any, + request_timeout: float = -1 + ) -> None: ... def list_groups( self, - group: Optional[str] = None, - timeout: float = -1 - ) -> List[GroupMetadata]: ... + future: Any, + request_timeout: float = -1, + states: Optional[List[str]] = None + ) -> None: ... def describe_consumer_groups( self, - groups: List[str], future: Any, + group_ids: List[str], request_timeout: float = -1, include_authorized_operations: bool = False ) -> None: ... def list_consumer_groups( self, future: Any, - request_timeout: float = -1, - states: Optional[List[str]] = None + states_int: Optional[List[int]] = None, + types_int: Optional[List[int]] = None, + request_timeout: float = -1 ) -> None: ... def list_consumer_group_offsets( self, request: Any, # ConsumerGroupTopicPartitions future: Any, + require_stable: bool = False, request_timeout: float = -1 ) -> None: ... def alter_consumer_group_offsets( self, - requests: Any, # List[ConsumerGroupTopicPartitions] + requests: Any, # List[ConsumerGroupTopicPartitions] - exactly 1 item required future: Any, request_timeout: float = -1 ) -> None: ... def delete_consumer_groups( self, - groups: List[str], + group_ids: List[str], future: Any, request_timeout: float = -1 ) -> None: ... @@ -312,13 +315,13 @@ class _AdminClientImpl: ) -> None: ... def describe_acls( self, - acl_filter: Any, # AclBindingFilter + acl_binding_filter: Any, # AclBindingFilter future: Any, request_timeout: float = -1 ) -> None: ... def delete_acls( self, - acl_filters: List[Any], # List[AclBindingFilter] + acls: List[Any], # List[AclBindingFilter] future: Any, request_timeout: float = -1 ) -> None: ... @@ -327,22 +330,23 @@ class _AdminClientImpl: resources: List[Any], # List[ConfigResource] future: Any, request_timeout: float = -1, - include_synonyms: bool = False, - include_documentation: bool = False + broker: int = -1 ) -> None: ... def alter_configs( self, resources: Dict[Any, Dict[str, str]], # Dict[ConfigResource, Dict[str, str]] future: Any, + validate_only: bool = False, request_timeout: float = -1, - validate_only: bool = False + broker: int = -1 ) -> None: ... def incremental_alter_configs( self, resources: Dict[Any, Dict[str, Any]], # Dict[ConfigResource, Dict[str, ConfigEntry]] future: Any, + validate_only: bool = False, request_timeout: float = -1, - validate_only: bool = False + broker: int = -1 ) -> None: ... def describe_user_scram_credentials( self, @@ -360,21 +364,23 @@ class _AdminClientImpl: self, topic_partitions: List[TopicPartition], future: Any, - request_timeout: float = -1, - isolation_level: Optional[int] = None + isolation_level_value: Optional[int] = None, + request_timeout: float = -1 ) -> None: ... def delete_records( self, - topic_partitions: List[TopicPartition], + topic_partition_offsets: List[TopicPartition], future: Any, - request_timeout: float = -1 + request_timeout: float = -1, + operation_timeout: float = -1 ) -> None: ... def elect_leaders( self, - topic_partitions: Optional[List[TopicPartition]], + election_type: int, + partitions: Optional[List[TopicPartition]], future: Any, request_timeout: float = -1, - election_type: int = 0 + operation_timeout: float = -1 ) -> None: ... def poll(self, timeout: float = -1) -> Any: ... def set_sasl_credentials(self, username: str, password: str) -> None: ... From 69ca6b8b6281222d015867dd2ba994dd3b1b0506 Mon Sep 17 00:00:00 2001 From: Naxin Date: Fri, 5 Sep 2025 14:40:32 -0400 Subject: [PATCH 04/31] fix cimply and add types to serde producer/consumer --- src/confluent_kafka/_types.py | 2 ++ src/confluent_kafka/cimpl.pyi | 21 ++++++++----------- src/confluent_kafka/deserializing_consumer.py | 11 +++++----- src/confluent_kafka/error.py | 6 ++---- src/confluent_kafka/py.typed | 1 + src/confluent_kafka/serializing_producer.py | 18 +++++++--------- 6 files changed, 28 insertions(+), 31 deletions(-) create mode 100644 src/confluent_kafka/py.typed diff --git a/src/confluent_kafka/_types.py b/src/confluent_kafka/_types.py index 2cf8ed8ea..b0e6d20b5 100644 --- a/src/confluent_kafka/_types.py +++ b/src/confluent_kafka/_types.py @@ -25,6 +25,8 @@ from typing import Any, Optional, Dict, Union, Callable, List, Tuple +# Configuration dictionary type +ConfigDict = Dict[str, Union[str, int, float, bool]] # Headers can be either dict format or list of tuples format HeadersType = Union[ Dict[str, Union[str, bytes, None]], diff --git a/src/confluent_kafka/cimpl.pyi b/src/confluent_kafka/cimpl.pyi index aee974daa..ea8b3846d 100644 --- a/src/confluent_kafka/cimpl.pyi +++ b/src/confluent_kafka/cimpl.pyi @@ -38,14 +38,11 @@ from typing import Any, Optional, Callable, List, Tuple, Dict, Union, overload, from typing_extensions import Self, Literal import builtins -from ._types import HeadersType +from ._types import ConfigDict, HeadersType if TYPE_CHECKING: from confluent_kafka.admin._metadata import ClusterMetadata, GroupMetadata -# Type aliases for common patterns -ConfigDict = Dict[str, Union[str, int, float, bool]] - # Callback types with proper class references (defined locally to avoid circular imports) DeliveryCallback = Callable[[Optional['KafkaError'], 'Message'], None] RebalanceCallback = Callable[['Consumer', List['TopicPartition']], None] @@ -68,8 +65,8 @@ class KafkaError: def __str__(self) -> builtins.str: ... def __bool__(self) -> bool: ... def __hash__(self) -> int: ... - def __eq__(self, other: object) -> bool: ... - def __ne__(self, other: object) -> bool: ... + def __eq__(self, other: Union['KafkaError', int]) -> bool: ... + def __ne__(self, other: Union['KafkaError', int]) -> bool: ... def __lt__(self, other: Union['KafkaError', int]) -> bool: ... def __le__(self, other: Union['KafkaError', int]) -> bool: ... def __gt__(self, other: Union['KafkaError', int]) -> bool: ... @@ -128,7 +125,7 @@ class Producer: partition: int = -1, callback: Optional[DeliveryCallback] = None, on_delivery: Optional[DeliveryCallback] = None, - timestamp: Optional[int] = None, + timestamp: int = 0, headers: Optional[HeadersType] = None ) -> None: ... def produce_batch( @@ -297,7 +294,7 @@ class _AdminClientImpl: ) -> None: ... def alter_consumer_group_offsets( self, - requests: Any, # List[ConsumerGroupTopicPartitions] - exactly 1 item required + requests: Any, # List[ConsumerGroupTopicPartitions] future: Any, request_timeout: float = -1 ) -> None: ... @@ -401,8 +398,8 @@ class NewTopic: config: Optional[Dict[str, str]] def __str__(self) -> str: ... def __hash__(self) -> int: ... - def __eq__(self, other: object) -> bool: ... - def __ne__(self, other: object) -> bool: ... + def __eq__(self, other: 'NewTopic') -> bool: ... + def __ne__(self, other: 'NewTopic') -> bool: ... def __lt__(self, other: 'NewTopic') -> bool: ... def __le__(self, other: 'NewTopic') -> bool: ... def __gt__(self, other: 'NewTopic') -> bool: ... @@ -420,8 +417,8 @@ class NewPartitions: replica_assignment: Optional[List[List[int]]] def __str__(self) -> str: ... def __hash__(self) -> int: ... - def __eq__(self, other: object) -> bool: ... - def __ne__(self, other: object) -> bool: ... + def __eq__(self, other: 'NewPartitions') -> bool: ... + def __ne__(self, other: 'NewPartitions') -> bool: ... def __lt__(self, other: 'NewPartitions') -> bool: ... def __le__(self, other: 'NewPartitions') -> bool: ... def __gt__(self, other: 'NewPartitions') -> bool: ... diff --git a/src/confluent_kafka/deserializing_consumer.py b/src/confluent_kafka/deserializing_consumer.py index 324239f25..51bc8e941 100644 --- a/src/confluent_kafka/deserializing_consumer.py +++ b/src/confluent_kafka/deserializing_consumer.py @@ -16,7 +16,7 @@ # limitations under the License. # -from typing import Any, Dict, List, Optional +from typing import Any, Optional, Callable, List from confluent_kafka.cimpl import Consumer as _ConsumerImpl, Message from .error import (ConsumeError, @@ -24,6 +24,7 @@ ValueDeserializationError) from .serialization import (SerializationContext, MessageField) +from ._types import ConfigDict, Deserializer class DeserializingConsumer(_ConsumerImpl): @@ -72,14 +73,14 @@ class DeserializingConsumer(_ConsumerImpl): ValueError: if configuration validation fails """ # noqa: E501 - def __init__(self, conf: Dict[str, Any]) -> None: + def __init__(self, conf: ConfigDict) -> None: conf_copy = conf.copy() - self._key_deserializer = conf_copy.pop('key.deserializer', None) - self._value_deserializer = conf_copy.pop('value.deserializer', None) + self._key_deserializer: Optional[Deserializer] = conf_copy.pop('key.deserializer', None) + self._value_deserializer: Optional[Deserializer] = conf_copy.pop('value.deserializer', None) super(DeserializingConsumer, self).__init__(conf_copy) - def poll(self, timeout: float = -1) -> Optional[Message]: + def poll(self, timeout: float = -1) -> Optional['Message']: """ Consume messages and calls callbacks. diff --git a/src/confluent_kafka/error.py b/src/confluent_kafka/error.py index 2df3557cb..6f14fbf1c 100644 --- a/src/confluent_kafka/error.py +++ b/src/confluent_kafka/error.py @@ -35,8 +35,7 @@ class _KafkaClientError(KafkaException): by the broker. """ - def __init__(self, kafka_error: KafkaError, exception: Optional[Exception] = None, - kafka_message: Optional[Message] = None) -> None: + def __init__(self, kafka_error: KafkaError, exception: Optional[Exception] = None, kafka_message: Optional[Message] = None) -> None: super(_KafkaClientError, self).__init__(kafka_error) self.exception = exception self.kafka_message = kafka_message @@ -68,8 +67,7 @@ class ConsumeError(_KafkaClientError): """ - def __init__(self, kafka_error: KafkaError, exception: Optional[Exception] = None, - kafka_message: Optional[Message] = None) -> None: + def __init__(self, kafka_error: KafkaError, exception: Optional[Exception] = None, kafka_message: Optional[Message] = None) -> None: super(ConsumeError, self).__init__(kafka_error, exception, kafka_message) diff --git a/src/confluent_kafka/py.typed b/src/confluent_kafka/py.typed new file mode 100644 index 000000000..0519ecba6 --- /dev/null +++ b/src/confluent_kafka/py.typed @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/confluent_kafka/serializing_producer.py b/src/confluent_kafka/serializing_producer.py index c4547d9d4..2ca74a990 100644 --- a/src/confluent_kafka/serializing_producer.py +++ b/src/confluent_kafka/serializing_producer.py @@ -16,14 +16,14 @@ # limitations under the License. # -from typing import Any, Dict, Optional +from typing import Any, Optional, Callable from confluent_kafka.cimpl import Producer as _ProducerImpl from .serialization import (MessageField, SerializationContext) from .error import (KeySerializationError, ValueSerializationError) -from ._types import HeadersType, DeliveryCallback +from ._types import ConfigDict, HeadersType, DeliveryCallback, Serializer class SerializingProducer(_ProducerImpl): @@ -69,19 +69,17 @@ class SerializingProducer(_ProducerImpl): conf (producer): SerializingProducer configuration. """ # noqa E501 - def __init__(self, conf: Dict[str, Any]) -> None: + def __init__(self, conf: ConfigDict) -> None: conf_copy = conf.copy() - self._key_serializer = conf_copy.pop('key.serializer', None) - self._value_serializer = conf_copy.pop('value.serializer', None) + self._key_serializer: Optional[Serializer] = conf_copy.pop('key.serializer', None) + self._value_serializer: Optional[Serializer] = conf_copy.pop('value.serializer', None) super(SerializingProducer, self).__init__(conf_copy) - def produce( # type: ignore[override] - self, topic: str, key: Any = None, value: Any = None, partition: int = -1, - on_delivery: Optional[DeliveryCallback] = None, timestamp: int = 0, - headers: Optional[HeadersType] = None - ) -> None: + def produce(self, topic: str, key: Any = None, value: Any = None, partition: int = -1, + on_delivery: Optional[DeliveryCallback] = None, timestamp: int = 0, + headers: Optional[HeadersType] = None) -> None: """ Produce a message. From d789244b4b13125fcec816ee9f6494fecbd05314 Mon Sep 17 00:00:00 2001 From: Naxin Date: Fri, 5 Sep 2025 19:13:08 -0400 Subject: [PATCH 05/31] admin --- src/confluent_kafka/_model/__init__.py | 20 +++++++++---------- src/confluent_kafka/_util/conversion_util.py | 7 +++++-- src/confluent_kafka/admin/_cluster.py | 2 +- src/confluent_kafka/admin/_resource.py | 8 ++++---- src/confluent_kafka/admin/_scram.py | 6 +++--- src/confluent_kafka/admin/_topic.py | 2 +- src/confluent_kafka/deserializing_consumer.py | 4 ++-- src/confluent_kafka/error.py | 6 ++++-- src/confluent_kafka/serializing_producer.py | 2 +- 9 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/confluent_kafka/_model/__init__.py b/src/confluent_kafka/_model/__init__.py index f3bce031f..bd05bf556 100644 --- a/src/confluent_kafka/_model/__init__.py +++ b/src/confluent_kafka/_model/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Any from enum import Enum from .. import cimpl from ..cimpl import TopicPartition @@ -62,7 +62,7 @@ class ConsumerGroupTopicPartitions: List of topic partitions information. """ - def __init__(self, group_id: str, topic_partitions: Optional[List[TopicPartition]] = None) -> None: + def __init__(self, group_id: str, topic_partitions: Optional[List['cimpl.TopicPartition']] = None) -> None: self.group_id = group_id self.topic_partitions = topic_partitions @@ -91,8 +91,8 @@ class ConsumerGroupState(Enum): #: Consumer Group is empty. EMPTY = cimpl.CONSUMER_GROUP_STATE_EMPTY - def __lt__(self, other: object) -> bool: - if not isinstance(other, ConsumerGroupState): + def __lt__(self, other: 'ConsumerGroupState') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self.value < other.value @@ -111,8 +111,8 @@ class ConsumerGroupType(Enum): #: Classic Type CLASSIC = cimpl.CONSUMER_GROUP_TYPE_CLASSIC - def __lt__(self, other: object) -> bool: - if not isinstance(other, ConsumerGroupType): + def __lt__(self, other: 'ConsumerGroupType') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self.value < other.value @@ -167,8 +167,8 @@ class IsolationLevel(Enum): READ_UNCOMMITTED = cimpl.ISOLATION_LEVEL_READ_UNCOMMITTED #: Receive all the offsets. READ_COMMITTED = cimpl.ISOLATION_LEVEL_READ_COMMITTED #: Skip offsets belonging to an aborted transaction. - def __lt__(self, other: object) -> bool: - if not isinstance(other, IsolationLevel): + def __lt__(self, other: 'IsolationLevel') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self.value < other.value @@ -186,7 +186,7 @@ class ElectionType(Enum): #: Unclean election UNCLEAN = cimpl.ELECTION_TYPE_UNCLEAN - def __lt__(self, other: object) -> bool: - if not isinstance(other, ElectionType): + def __lt__(self, other: 'ElectionType') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self.value < other.value diff --git a/src/confluent_kafka/_util/conversion_util.py b/src/confluent_kafka/_util/conversion_util.py index 4bbbc6c38..ef513406b 100644 --- a/src/confluent_kafka/_util/conversion_util.py +++ b/src/confluent_kafka/_util/conversion_util.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Type +from typing import Union, Type, TypeVar from enum import Enum +# Generic type for enum conversion +E = TypeVar('E', bound=Enum) + class ConversionUtil: @staticmethod - def convert_to_enum(val: Union[str, int, Enum], enum_clazz: Type[Enum]) -> Enum: + def convert_to_enum(val: Union[str, int, E], enum_clazz: Type[E]) -> E: if type(enum_clazz) is not type(Enum): raise TypeError("'enum_clazz' must be of type Enum") diff --git a/src/confluent_kafka/admin/_cluster.py b/src/confluent_kafka/admin/_cluster.py index 0e83c4cc3..610a28b76 100644 --- a/src/confluent_kafka/admin/_cluster.py +++ b/src/confluent_kafka/admin/_cluster.py @@ -41,7 +41,7 @@ def __init__(self, controller: Node, nodes: List[Node], cluster_id: Optional[str self.cluster_id = cluster_id self.controller = controller self.nodes = nodes - self.authorized_operations = None + self.authorized_operations: Optional[List[AclOperation]] = None if authorized_operations: self.authorized_operations = [] for op in authorized_operations: diff --git a/src/confluent_kafka/admin/_resource.py b/src/confluent_kafka/admin/_resource.py index 55c6d783d..8fa6dd19b 100644 --- a/src/confluent_kafka/admin/_resource.py +++ b/src/confluent_kafka/admin/_resource.py @@ -28,8 +28,8 @@ class ResourceType(Enum): BROKER = _cimpl.RESOURCE_BROKER #: Broker resource. Resource name is broker id. TRANSACTIONAL_ID = _cimpl.RESOURCE_TRANSACTIONAL_ID #: Transactional ID resource. - def __lt__(self, other: object) -> Any: - if not isinstance(other, ResourceType): + def __lt__(self, other: 'ResourceType') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self.value < other.value @@ -44,7 +44,7 @@ class ResourcePatternType(Enum): LITERAL = _cimpl.RESOURCE_PATTERN_LITERAL #: Literal: A literal resource name PREFIXED = _cimpl.RESOURCE_PATTERN_PREFIXED #: Prefixed: A prefixed resource name - def __lt__(self, other: object) -> Any: - if not isinstance(other, ResourcePatternType): + def __lt__(self, other: 'ResourcePatternType') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self.value < other.value diff --git a/src/confluent_kafka/admin/_scram.py b/src/confluent_kafka/admin/_scram.py index 2bb19a414..ff14a20fb 100644 --- a/src/confluent_kafka/admin/_scram.py +++ b/src/confluent_kafka/admin/_scram.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Any from enum import Enum from .. import cimpl @@ -26,8 +26,8 @@ class ScramMechanism(Enum): SCRAM_SHA_256 = cimpl.SCRAM_MECHANISM_SHA_256 #: SCRAM-SHA-256 mechanism SCRAM_SHA_512 = cimpl.SCRAM_MECHANISM_SHA_512 #: SCRAM-SHA-512 mechanism - def __lt__(self, other: object) -> bool: - if not isinstance(other, ScramMechanism): + def __lt__(self, other: 'ScramMechanism') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self.value < other.value diff --git a/src/confluent_kafka/admin/_topic.py b/src/confluent_kafka/admin/_topic.py index a804631be..fecca4ed3 100644 --- a/src/confluent_kafka/admin/_topic.py +++ b/src/confluent_kafka/admin/_topic.py @@ -46,7 +46,7 @@ def __init__(self, name: str, topic_id: Uuid, is_internal: bool, self.topic_id = topic_id self.is_internal = is_internal self.partitions = partitions - self.authorized_operations = None + self.authorized_operations: Optional[List[AclOperation]] = None if authorized_operations: self.authorized_operations = [] for op in authorized_operations: diff --git a/src/confluent_kafka/deserializing_consumer.py b/src/confluent_kafka/deserializing_consumer.py index 51bc8e941..f32c9a166 100644 --- a/src/confluent_kafka/deserializing_consumer.py +++ b/src/confluent_kafka/deserializing_consumer.py @@ -16,7 +16,7 @@ # limitations under the License. # -from typing import Any, Optional, Callable, List +from typing import Optional, List from confluent_kafka.cimpl import Consumer as _ConsumerImpl, Message from .error import (ConsumeError, @@ -80,7 +80,7 @@ def __init__(self, conf: ConfigDict) -> None: super(DeserializingConsumer, self).__init__(conf_copy) - def poll(self, timeout: float = -1) -> Optional['Message']: + def poll(self, timeout: float = -1) -> Optional[Message]: """ Consume messages and calls callbacks. diff --git a/src/confluent_kafka/error.py b/src/confluent_kafka/error.py index 6f14fbf1c..2df3557cb 100644 --- a/src/confluent_kafka/error.py +++ b/src/confluent_kafka/error.py @@ -35,7 +35,8 @@ class _KafkaClientError(KafkaException): by the broker. """ - def __init__(self, kafka_error: KafkaError, exception: Optional[Exception] = None, kafka_message: Optional[Message] = None) -> None: + def __init__(self, kafka_error: KafkaError, exception: Optional[Exception] = None, + kafka_message: Optional[Message] = None) -> None: super(_KafkaClientError, self).__init__(kafka_error) self.exception = exception self.kafka_message = kafka_message @@ -67,7 +68,8 @@ class ConsumeError(_KafkaClientError): """ - def __init__(self, kafka_error: KafkaError, exception: Optional[Exception] = None, kafka_message: Optional[Message] = None) -> None: + def __init__(self, kafka_error: KafkaError, exception: Optional[Exception] = None, + kafka_message: Optional[Message] = None) -> None: super(ConsumeError, self).__init__(kafka_error, exception, kafka_message) diff --git a/src/confluent_kafka/serializing_producer.py b/src/confluent_kafka/serializing_producer.py index 2ca74a990..61ad488d1 100644 --- a/src/confluent_kafka/serializing_producer.py +++ b/src/confluent_kafka/serializing_producer.py @@ -16,7 +16,7 @@ # limitations under the License. # -from typing import Any, Optional, Callable +from typing import Any, Optional from confluent_kafka.cimpl import Producer as _ProducerImpl from .serialization import (MessageField, From 0653fe03850f3534522b7b55e7ede7302f978334 Mon Sep 17 00:00:00 2001 From: Naxin Date: Mon, 8 Sep 2025 23:18:02 -0400 Subject: [PATCH 06/31] address feedback --- src/confluent_kafka/_model/__init__.py | 2 +- src/confluent_kafka/_util/conversion_util.py | 7 ++----- src/confluent_kafka/admin/_cluster.py | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/confluent_kafka/_model/__init__.py b/src/confluent_kafka/_model/__init__.py index bd05bf556..bcffdcf41 100644 --- a/src/confluent_kafka/_model/__init__.py +++ b/src/confluent_kafka/_model/__init__.py @@ -62,7 +62,7 @@ class ConsumerGroupTopicPartitions: List of topic partitions information. """ - def __init__(self, group_id: str, topic_partitions: Optional[List['cimpl.TopicPartition']] = None) -> None: + def __init__(self, group_id: str, topic_partitions: Optional[List[TopicPartition]] = None) -> None: self.group_id = group_id self.topic_partitions = topic_partitions diff --git a/src/confluent_kafka/_util/conversion_util.py b/src/confluent_kafka/_util/conversion_util.py index ef513406b..4bbbc6c38 100644 --- a/src/confluent_kafka/_util/conversion_util.py +++ b/src/confluent_kafka/_util/conversion_util.py @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Type, TypeVar +from typing import Union, Type from enum import Enum -# Generic type for enum conversion -E = TypeVar('E', bound=Enum) - class ConversionUtil: @staticmethod - def convert_to_enum(val: Union[str, int, E], enum_clazz: Type[E]) -> E: + def convert_to_enum(val: Union[str, int, Enum], enum_clazz: Type[Enum]) -> Enum: if type(enum_clazz) is not type(Enum): raise TypeError("'enum_clazz' must be of type Enum") diff --git a/src/confluent_kafka/admin/_cluster.py b/src/confluent_kafka/admin/_cluster.py index 610a28b76..0e83c4cc3 100644 --- a/src/confluent_kafka/admin/_cluster.py +++ b/src/confluent_kafka/admin/_cluster.py @@ -41,7 +41,7 @@ def __init__(self, controller: Node, nodes: List[Node], cluster_id: Optional[str self.cluster_id = cluster_id self.controller = controller self.nodes = nodes - self.authorized_operations: Optional[List[AclOperation]] = None + self.authorized_operations = None if authorized_operations: self.authorized_operations = [] for op in authorized_operations: From 77c496575bd85b60de82f6a2dc5b7b0548a616b5 Mon Sep 17 00:00:00 2001 From: Naxin Date: Tue, 9 Sep 2025 15:48:08 -0400 Subject: [PATCH 07/31] add warning to stub and c files; admin typing more --- CONTRIBUTOR.md | 219 ------------------ INSTALL.md | 121 ---------- src/confluent_kafka/_model/__init__.py | 8 +- src/confluent_kafka/admin/_acl.py | 20 +- src/confluent_kafka/admin/_config.py | 8 +- src/confluent_kafka/admin/_group.py | 2 +- src/confluent_kafka/admin/_listoffsets.py | 7 +- src/confluent_kafka/admin/_metadata.py | 6 +- src/confluent_kafka/admin/_scram.py | 2 +- src/confluent_kafka/admin/_topic.py | 2 +- src/confluent_kafka/deserializing_consumer.py | 4 +- src/confluent_kafka/serializing_producer.py | 4 +- 12 files changed, 28 insertions(+), 375 deletions(-) delete mode 100644 CONTRIBUTOR.md delete mode 100644 INSTALL.md diff --git a/CONTRIBUTOR.md b/CONTRIBUTOR.md deleted file mode 100644 index e9a2f925a..000000000 --- a/CONTRIBUTOR.md +++ /dev/null @@ -1,219 +0,0 @@ -# Contributing to confluent-kafka-python - -Thank you for your interest in contributing to confluent-kafka-python! This document provides guidelines and best practices for contributing to this project. - -## Table of Contents - -- [Getting Started](#getting-started) -- [Development Environment Setup](#development-environment-setup) -- [Code Style and Standards](#code-style-and-standards) -- [Testing](#testing) -- [Submitting Changes](#submitting-changes) -- [Reporting Issues](#reporting-issues) -- [Community Guidelines](#community-guidelines) - -## Getting Started - -### Ways to Contribute - -- **Bug Reports**: Report bugs and issues you encounter -- **Feature Requests**: Suggest new features or improvements -- **Code Contributions**: Fix bugs, implement features, or improve documentation -- **Documentation**: Improve existing docs or add new documentation -- **Testing**: Help improve test coverage and quality - -### Before You Start - -1. Check existing [issues](../../issues) to see if your bug/feature has already been reported -2. For major changes, open an issue first to discuss the proposed changes -3. Fork the repository and create a feature branch for your work - -## Development Environment Setup - -For complete development environment setup instructions, including prerequisites, virtual environment creation, and dependency installation, see the [Development Environment Setup section in DEVELOPER.md](DEVELOPER.md#development-environment-setup). - -## Code Style and Standards - -### Python Code Style - -- **PEP 8**: Follow [PEP 8](https://pep8.org/) style guidelines as a default, with exceptions captured in the `tox.ini` flake8 rules for modern updates to the recommendations -- **Docstrings**: Use Google-style docstrings for all public functions and classes - -### Code Formatting - -We use automated tools to maintain consistent code style: - -```bash -# Install formatting tools -pip install flake8 - -# Check style -flake8 src/ tests/ -``` - -### Naming Conventions - -- **Functions and Variables**: `snake_case` -- **Classes**: `PascalCase` -- **Constants**: `UPPER_SNAKE_CASE` -- **Private Methods/Objects**: Prefix with single underscore `_private_method` - -### Documentation - -- All public APIs must have docstrings -- Include examples in docstrings when helpful -- Keep docstrings concise but complete -- Update relevant documentation files when making changes - -## Testing - -### Running Tests - -See [tests/README.md](tests/README.md) for comprehensive testing instructions. - -### Test Requirements - -- **Unit Tests**: All new functionality must include unit tests -- **Integration Tests**: Add integration tests for complex features -- **Test Coverage**: Maintain or improve existing test coverage -- **Test Naming**: Use descriptive test names that explain what is being tested - -### Test Structure - -```python -def test_feature_should_behave_correctly_when_condition(): - # Arrange - setup_data = create_test_data() - - # Act - result = function_under_test(setup_data) - - # Assert - assert result.expected_property == expected_value -``` - -## Submitting Changes - -### Pull Request Process - -1. **Create Feature Branch** - ```bash - git checkout -b feature/your-feature-name - # or - git checkout -b fix/issue-number-description - ``` - -2. **Make Your Changes** - - Write clean, well-documented code - - Add appropriate tests - - Update documentation if needed - - Add an entry to the CHANGELOG.md file for the proposed change - -3. **Test Your Changes** - Refer to [tests/README.md](tests/README.md) - -4. **Commit Your Changes** - ```bash - git add . - git commit -m "Clear, descriptive commit message" - ``` - - **Commit Message Guidelines:** - - Use present tense ("Add feature" not "Added feature") - - Keep first line under 50 characters - - Reference issue numbers when applicable (#123) - - Include breaking change notes if applicable - -5. **Push and Create Pull Request** - ```bash - git push origin feature/your-feature-name - ``` - - Then create a pull request through GitHub's interface. - -### Pull Request Guidelines - -- **Title**: Clear and descriptive -- **Description**: Explain what changes you made and why -- **Linked Issues**: Reference related issues using "Fixes #123" or "Relates to #123" -- **Labels**: Review available issue/PR labels and apply relevant ones to help with categorization and triage -- **Documentation**: Update documentation for user-facing changes -- **Tests**: Include appropriate tests -- **Breaking Changes**: Clearly mark any breaking changes - -### Code Review Process - -- All pull requests require review before merging -- Address reviewer feedback promptly -- Keep discussions respectful and constructive -- Be open to suggestions and alternative approaches - -## Reporting Issues - -### Using Labels - -When creating issues or pull requests, please review the available labels and apply those that are relevant to your submission. This helps maintainers categorize and prioritize work effectively. Common label categories include (look at available labels / other issues for options): - -- **Type**: bug, enhancement, documentation, question -- **Priority**: high, medium, low -- **Component**: producer, consumer, admin, schema-registry, etc -- **Status**: needs-investigation, help-wanted, good-first-issue, etc - -### Bug Reports - -When reporting bugs, please include: - -- **Clear Title**: Describe the issue concisely -- **Environment**: Python version, OS, library versions -- **Steps to Reproduce**: Detailed steps to reproduce the issue -- **Expected Behavior**: What you expected to happen -- **Actual Behavior**: What actually happened -- **Code Sample**: Minimal code that demonstrates the issue -- **Error Messages**: Full error messages and stack traces -- **Client Configuration**: Specify how the client was configured and setup -- **Logs**: Client logs when possible -- **Labels**: Apply relevant labels such as "bug" and component-specific labels - -### Feature Requests - -For feature requests, please include: - -- **Use Case**: Describe the problem you're trying to solve -- **Proposed Solution**: Your idea for how to address it -- **Alternatives**: Other solutions you've considered -- **Additional Context**: Any other relevant information -- **Labels**: Apply relevant labels such as "enhancement" and component-specific labels - -## Community Guidelines - -### Code of Conduct - -This project follows the [Contributor Covenant Code of Conduct](https://www.contributor-covenant.org/). By participating, you agree to uphold this code. - -### Communication - -- **Be Respectful**: Treat all community members with respect -- **Be Constructive**: Provide helpful feedback and suggestions -- **Be Patient**: Remember that maintainers and contributors volunteer their time -- **Be Clear**: Communicate clearly and provide sufficient context - -### Getting Help - -- **Issues**: Use GitHub issues for bug reports and feature requests -- **Discussions**: Use GitHub Discussions for questions and general discussion -- **Documentation**: Check existing documentation before asking questions - -## Recognition - -Contributors are recognized in the following ways: - -- Contributors are listed in the project's contributor history -- Significant contributions may be mentioned in release notes - -## License - -By contributing to this project, you agree that your contributions will be licensed under the same license as the project (see LICENSE file). - ---- - -Thank you for contributing to confluent-kafka-python! Your contributions help make this project better for everyone. \ No newline at end of file diff --git a/INSTALL.md b/INSTALL.md deleted file mode 100644 index 6b642463b..000000000 --- a/INSTALL.md +++ /dev/null @@ -1,121 +0,0 @@ -# confluent-kafka-python installation instructions - -## Install pre-built wheels (recommended) - -Confluent provides pre-built Python wheels of confluent-kafka-python with -all dependencies included. - -To install, simply do: - -```bash -python3 -m pip install confluent-kafka -``` - -If you get a build error or require Kerberos/GSSAPI support please read the next section: *Install from source* - - -## Install from source - -It is sometimes necessary to install confluent-kafka from source, rather -than from prebuilt binary wheels, such as when: - - You need GSSAPI/Kerberos authentication. - - You're on a Python version we do not provide prebuilt wheels for. - - You're on an architecture or platform we do not provide prebuilt wheels for. - - You want to build confluent-kafka-python from the master branch. - - -### Install from source on RedHat, CentOS, Fedora, etc - -```bash -# -# Perform these steps as the root user (e.g., in a 'sudo bash' shell) -# - -# Install build tools and Kerberos support. - -yum install -y python3 python3-pip python3-devel gcc make cyrus-sasl-gssapi krb5-workstation - -# Install the latest version of librdkafka: - -rpm --import https://packages.confluent.io/rpm/7.0/archive.key - -echo ' -[Confluent-Clients] -name=Confluent Clients repository -baseurl=https://packages.confluent.io/clients/rpm/centos/$releasever/$basearch -gpgcheck=1 -gpgkey=https://packages.confluent.io/clients/rpm/archive.key -enabled=1' > /etc/yum.repos.d/confluent.repo - -yum install -y librdkafka-devel - - -# -# Now build and install confluent-kafka-python as your standard user -# (e.g., exit the root shell first). -# - -python3 -m pip install --no-binary confluent-kafka confluent-kafka - - -# Verify that confluent_kafka is installed: - -python3 -c 'import confluent_kafka; print(confluent_kafka.version())' -``` - -### Install from source on Debian or Ubuntu - -```bash -# -# Perform these steps as the root user (e.g., in a 'sudo bash' shell) -# - -# Install build tools and Kerberos support. - -apt install -y wget software-properties-common lsb-release gcc make python3 python3-pip python3-dev libsasl2-modules-gssapi-mit krb5-user - - -# Install the latest version of librdkafka: - -wget -qO - https://packages.confluent.io/deb/7.0/archive.key | apt-key add - - -add-apt-repository "deb https://packages.confluent.io/clients/deb $(lsb_release -cs) main" - -apt update - -apt install -y librdkafka-dev - - -# -# Now build and install confluent-kafka-python as your standard user -# (e.g., exit the root shell first). -# - -python3 -m pip install --no-binary confluent-kafka confluent-kafka - - -# Verify that confluent_kafka is installed: - -python3 -c 'import confluent_kafka; print(confluent_kafka.version())' -``` - - -### Install from source on Mac OS X - -```bash - -# Install librdkafka from homebrew - -brew install librdkafka - - -# Build and install confluent-kafka-python - -python3 -m pip install --no-binary confluent-kafka confluent-kafka - - -# Verify that confluent_kafka is installed: - -python3 -c 'import confluent_kafka; print(confluent_kafka.version())' - -``` diff --git a/src/confluent_kafka/_model/__init__.py b/src/confluent_kafka/_model/__init__.py index bcffdcf41..8c775dbd9 100644 --- a/src/confluent_kafka/_model/__init__.py +++ b/src/confluent_kafka/_model/__init__.py @@ -91,7 +91,7 @@ class ConsumerGroupState(Enum): #: Consumer Group is empty. EMPTY = cimpl.CONSUMER_GROUP_STATE_EMPTY - def __lt__(self, other: 'ConsumerGroupState') -> Any: + def __lt__(self, other) -> Any: if self.__class__ != other.__class__: return NotImplemented return self.value < other.value @@ -111,7 +111,7 @@ class ConsumerGroupType(Enum): #: Classic Type CLASSIC = cimpl.CONSUMER_GROUP_TYPE_CLASSIC - def __lt__(self, other: 'ConsumerGroupType') -> Any: + def __lt__(self, other) -> Any: if self.__class__ != other.__class__: return NotImplemented return self.value < other.value @@ -167,7 +167,7 @@ class IsolationLevel(Enum): READ_UNCOMMITTED = cimpl.ISOLATION_LEVEL_READ_UNCOMMITTED #: Receive all the offsets. READ_COMMITTED = cimpl.ISOLATION_LEVEL_READ_COMMITTED #: Skip offsets belonging to an aborted transaction. - def __lt__(self, other: 'IsolationLevel') -> Any: + def __lt__(self, other) -> Any: if self.__class__ != other.__class__: return NotImplemented return self.value < other.value @@ -186,7 +186,7 @@ class ElectionType(Enum): #: Unclean election UNCLEAN = cimpl.ELECTION_TYPE_UNCLEAN - def __lt__(self, other: 'ElectionType') -> Any: + def __lt__(self, other) -> Any: if self.__class__ != other.__class__: return NotImplemented return self.value < other.value diff --git a/src/confluent_kafka/admin/_acl.py b/src/confluent_kafka/admin/_acl.py index 75adc0c8f..940d25913 100644 --- a/src/confluent_kafka/admin/_acl.py +++ b/src/confluent_kafka/admin/_acl.py @@ -44,8 +44,8 @@ class AclOperation(Enum): ALTER_CONFIGS = _cimpl.ACL_OPERATION_ALTER_CONFIGS #: ALTER_CONFIGS operation IDEMPOTENT_WRITE = _cimpl.ACL_OPERATION_IDEMPOTENT_WRITE #: IDEMPOTENT_WRITE operation - def __lt__(self, other: object) -> bool: - if not isinstance(other, AclOperation): + def __lt__(self, other: 'AclOperation') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self.value < other.value @@ -59,8 +59,8 @@ class AclPermissionType(Enum): DENY = _cimpl.ACL_PERMISSION_TYPE_DENY #: Disallows access ALLOW = _cimpl.ACL_PERMISSION_TYPE_ALLOW #: Grants access - def __lt__(self, other: object) -> bool: - if not isinstance(other, AclPermissionType): + def __lt__(self, other: 'AclPermissionType') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self.value < other.value @@ -110,7 +110,7 @@ def __init__(self, restype: Union[ResourceType, str, int], name: str, self.permission_type_int = int(self.permission_type.value) # type: ignore[union-attr] def _convert_enums(self) -> None: - self.restype = ConversionUtil.convert_to_enum(self.restype, ResourceType) # type: ignore[assignment] + self.restype = ConversionUtil.convert_to_enum(self.restype, ResourceType) self.resource_pattern_type = ConversionUtil.convert_to_enum( self.resource_pattern_type, ResourcePatternType) # type: ignore[assignment] self.operation = ConversionUtil.convert_to_enum( @@ -154,20 +154,20 @@ def __repr__(self) -> str: return "%s(%s,%s,%s,%s,%s,%s,%s)" % ((type_name,) + self._to_tuple()) def _to_tuple(self) -> Tuple[ResourceType, str, ResourcePatternType, str, str, AclOperation, AclPermissionType]: - return (self.restype, self.name, self.resource_pattern_type, # type: ignore[return-value] + return (self.restype, self.name, self.resource_pattern_type, self.principal, self.host, self.operation, self.permission_type) def __hash__(self) -> int: return hash(self._to_tuple()) - def __lt__(self, other: object) -> bool: - if not isinstance(other, AclBinding): + def __lt__(self, other: 'AclBinding') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self._to_tuple() < other._to_tuple() - def __eq__(self, other: object) -> Any: - if not isinstance(other, AclBinding): + def __eq__(self, other: 'AclBinding') -> Any: + if self.__class__ != other.__class__: return NotImplemented return self._to_tuple() == other._to_tuple() diff --git a/src/confluent_kafka/admin/_config.py b/src/confluent_kafka/admin/_config.py index 70b19c6aa..abccffc66 100644 --- a/src/confluent_kafka/admin/_config.py +++ b/src/confluent_kafka/admin/_config.py @@ -185,16 +185,12 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash((self.restype, self.name)) - def __lt__(self, other: object) -> bool: - if not isinstance(other, ConfigResource): - return NotImplemented + def __lt__(self, other: 'ConfigResource') -> bool: if self.restype < other.restype: return True return self.name.__lt__(other.name) - def __eq__(self, other: object) -> bool: - if not isinstance(other, ConfigResource): - return NotImplemented + def __eq__(self, other: 'ConfigResource') -> bool: return self.restype == other.restype and self.name == other.name def __len__(self) -> int: diff --git a/src/confluent_kafka/admin/_group.py b/src/confluent_kafka/admin/_group.py index 1db7923bb..3ff3fa8be 100644 --- a/src/confluent_kafka/admin/_group.py +++ b/src/confluent_kafka/admin/_group.py @@ -80,7 +80,7 @@ class MemberAssignment: The topic partitions assigned to a group member. """ - def __init__(self, topic_partitions: Optional[List[TopicPartition]] = None) -> None: + def __init__(self, topic_partitions: List[TopicPartition] = []) -> None: self.topic_partitions = topic_partitions or [] diff --git a/src/confluent_kafka/admin/_listoffsets.py b/src/confluent_kafka/admin/_listoffsets.py index 6b567088e..0d815266a 100644 --- a/src/confluent_kafka/admin/_listoffsets.py +++ b/src/confluent_kafka/admin/_listoffsets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from typing import Dict, Any from abc import ABC, abstractmethod from .. import cimpl @@ -24,9 +24,6 @@ class OffsetSpec(ABC): of the partition being queried. """ _values: Dict[int, 'OffsetSpec'] = {} - _max_timestamp: Optional['MaxTimestampSpec'] = None - _earliest: Optional['EarliestSpec'] = None - _latest: Optional['LatestSpec'] = None @property @abstractmethod @@ -68,7 +65,7 @@ def __new__(cls, index: int): else: return cls.for_timestamp(index) - def __lt__(self, other: object) -> bool: + def __lt__(self, other) -> Any: if not isinstance(other, OffsetSpec): return NotImplemented return self._value < other._value diff --git a/src/confluent_kafka/admin/_metadata.py b/src/confluent_kafka/admin/_metadata.py index 90dc061c0..51e9f7b37 100644 --- a/src/confluent_kafka/admin/_metadata.py +++ b/src/confluent_kafka/admin/_metadata.py @@ -79,7 +79,7 @@ class TopicMetadata(object): # on other classes which raises a warning/error. def __init__(self) -> None: - self.topic: Optional[str] = None + self.topic = None """Topic name""" self.partitions: Dict[int, 'PartitionMetadata'] = {} """Map of partitions indexed by partition id. Value is a PartitionMetadata object.""" @@ -93,7 +93,7 @@ def __repr__(self) -> str: return "TopicMetadata({}, {} partitions)".format(self.topic, len(self.partitions)) def __str__(self) -> str: - return str(self.topic) + return self.topic class PartitionMetadata(object): @@ -181,4 +181,4 @@ def __repr__(self) -> str: return "GroupMetadata({})".format(self.id) def __str__(self) -> str: - return str(self.id) + return self.id diff --git a/src/confluent_kafka/admin/_scram.py b/src/confluent_kafka/admin/_scram.py index ff14a20fb..76c999dbc 100644 --- a/src/confluent_kafka/admin/_scram.py +++ b/src/confluent_kafka/admin/_scram.py @@ -26,7 +26,7 @@ class ScramMechanism(Enum): SCRAM_SHA_256 = cimpl.SCRAM_MECHANISM_SHA_256 #: SCRAM-SHA-256 mechanism SCRAM_SHA_512 = cimpl.SCRAM_MECHANISM_SHA_512 #: SCRAM-SHA-512 mechanism - def __lt__(self, other: 'ScramMechanism') -> Any: + def __lt__(self, other) -> Any: if self.__class__ != other.__class__: return NotImplemented return self.value < other.value diff --git a/src/confluent_kafka/admin/_topic.py b/src/confluent_kafka/admin/_topic.py index fecca4ed3..a804631be 100644 --- a/src/confluent_kafka/admin/_topic.py +++ b/src/confluent_kafka/admin/_topic.py @@ -46,7 +46,7 @@ def __init__(self, name: str, topic_id: Uuid, is_internal: bool, self.topic_id = topic_id self.is_internal = is_internal self.partitions = partitions - self.authorized_operations: Optional[List[AclOperation]] = None + self.authorized_operations = None if authorized_operations: self.authorized_operations = [] for op in authorized_operations: diff --git a/src/confluent_kafka/deserializing_consumer.py b/src/confluent_kafka/deserializing_consumer.py index f32c9a166..5abd8147b 100644 --- a/src/confluent_kafka/deserializing_consumer.py +++ b/src/confluent_kafka/deserializing_consumer.py @@ -75,8 +75,8 @@ class DeserializingConsumer(_ConsumerImpl): def __init__(self, conf: ConfigDict) -> None: conf_copy = conf.copy() - self._key_deserializer: Optional[Deserializer] = conf_copy.pop('key.deserializer', None) - self._value_deserializer: Optional[Deserializer] = conf_copy.pop('value.deserializer', None) + self._key_deserializer = conf_copy.pop('key.deserializer', None) + self._value_deserializer = conf_copy.pop('value.deserializer', None) super(DeserializingConsumer, self).__init__(conf_copy) diff --git a/src/confluent_kafka/serializing_producer.py b/src/confluent_kafka/serializing_producer.py index 61ad488d1..88b2defd6 100644 --- a/src/confluent_kafka/serializing_producer.py +++ b/src/confluent_kafka/serializing_producer.py @@ -72,8 +72,8 @@ class SerializingProducer(_ProducerImpl): def __init__(self, conf: ConfigDict) -> None: conf_copy = conf.copy() - self._key_serializer: Optional[Serializer] = conf_copy.pop('key.serializer', None) - self._value_serializer: Optional[Serializer] = conf_copy.pop('value.serializer', None) + self._key_serializer = conf_copy.pop('key.serializer', None) + self._value_serializer = conf_copy.pop('value.serializer', None) super(SerializingProducer, self).__init__(conf_copy) From cfa723f3044b1135343dd1c285ccde69f83103ca Mon Sep 17 00:00:00 2001 From: Naxin Date: Tue, 9 Sep 2025 15:59:34 -0400 Subject: [PATCH 08/31] add accidentally removed md files --- CHANGELOG.md | 58 +++++++++++++ CONTRIBUTOR.md | 219 +++++++++++++++++++++++++++++++++++++++++++++++++ DEVELOPER.md | 38 +++++++++ INSTALL.md | 121 +++++++++++++++++++++++++++ README.md | 134 ++++++++++++++++++++++++++++++ 5 files changed, 570 insertions(+) create mode 100644 CONTRIBUTOR.md create mode 100644 INSTALL.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 77f0dcf85..73bf093c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +<<<<<<< HEAD # Confluent Python Client for Apache Kafka - CHANGELOG ## v2.12.1 - 2025-10-21 @@ -66,6 +67,11 @@ for a complete list of changes, enhancements, fixes and upgrade considerations. ## v2.11.1 - 2025-08-18 +======= +# Confluent's Python client for Apache Kafka + +## v2.11.1 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.11.1 is a maintenance release with the following fixes: @@ -74,7 +80,11 @@ confluent-kafka-python v2.11.1 is based on librdkafka v2.11.1, see the for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.11.0 - 2025-07-03 +======= +## v2.11.0 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.11.0 is a feature release with the following enhancements: @@ -83,7 +93,11 @@ confluent-kafka-python v2.11.0 is based on librdkafka v2.11.0, see the for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.10.1 - 2025-06-11 +======= +## v2.10.1 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.10.1 is a maintenance release with the following fixes: @@ -99,7 +113,11 @@ confluent-kafka-python v2.10.1 is based on librdkafka v2.10.1, see the [librdkafka release notes](https://github.com/confluentinc/librdkafka/releases/tag/v2.10.1) for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.10.0 - 2025-04-18 +======= +## v2.10.0 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.10.0 is a feature release with the following fixes and enhancements: @@ -110,7 +128,11 @@ confluent-kafka-python v2.10.0 is based on librdkafka v2.10.0, see the [librdkafka release notes](https://github.com/confluentinc/librdkafka/releases/tag/v2.10.0) for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.9.0 - 2025-03-28 +======= +## v2.9.0 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.9.0 is a feature release with the following fixes and enhancements: @@ -123,7 +145,11 @@ confluent-kafka-python v2.9.0 is based on librdkafka v2.8.0, see the [librdkafka release notes](https://github.com/confluentinc/librdkafka/releases/tag/v2.8.0) for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.8.2 - 2025-02-28 +======= +## v2.8.2 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.8.2 is a maintenance release with the following fixes and enhancements: @@ -138,7 +164,11 @@ Note: Versioning is skipped due to breaking change in v2.8.1. Do not run software with v2.8.1 installed. +<<<<<<< HEAD ## v2.8.0 - 2025-01-07 +======= +## v2.8.0 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.8.0 is a feature release with the features, fixes and enhancements: @@ -147,7 +177,11 @@ confluent-kafka-python v2.8.0 is based on librdkafka v2.8.0, see the for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.7.0 - 2024-12-21 +======= +## v2.7.0 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.7.0 is a feature release with the features, fixes and enhancements present in v2.6.2 including the following fix: @@ -158,7 +192,11 @@ confluent-kafka-python v2.7.0 is based on librdkafka v2.6.1, see the for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.6.2 - 2024-12-18 +======= +## v2.6.2 +>>>>>>> 7b378e7 (add accidentally removed md files) > [!WARNING] > Due to an error in which we included dependency changes to a recent patch release, Confluent recommends users to **refrain from upgrading to 2.6.2** of Confluent Kafka. Confluent will release a new minor version, 2.7.0, where the dependency changes will be appropriately included. Users who have already upgraded to 2.6.2 and made the required dependency changes are free to remain on that version and are recommended to upgrade to 2.7.0 when that version is available. Upon the release of 2.7.0, the 2.6.2 version will be marked deprecated. @@ -201,7 +239,11 @@ confluent-kafka-python is based on librdkafka v2.6.1, see the for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.6.1 - 2024-11-18 +======= +## v2.6.1 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.6.1 is a maintenance release with the following fixes and enhancements: @@ -214,7 +256,11 @@ confluent-kafka-python is based on librdkafka v2.6.1, see the for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.6.0 - 2024-10-11 +======= +## v2.6.0 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.6.0 is a feature release with the following features, fixes and enhancements: @@ -228,7 +274,11 @@ confluent-kafka-python is based on librdkafka v2.6.0, see the for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.5.3 - 2024-09-02 +======= +## v2.5.3 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.5.3 is a maintenance release with the following fixes and enhancements: @@ -243,7 +293,11 @@ for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.5.0 - 2024-07-10 +======= +## v2.5.0 +>>>>>>> 7b378e7 (add accidentally removed md files) > [!WARNING] This version has introduced a regression in which an assert is triggered during **PushTelemetry** call. This happens when no metric is matched on the client side among those requested by broker subscription. @@ -276,7 +330,11 @@ confluent-kafka-python is based on librdkafka v2.5.0, see the for a complete list of changes, enhancements, fixes and upgrade considerations. +<<<<<<< HEAD ## v2.4.0 - 2024-05-07 +======= +## v2.4.0 +>>>>>>> 7b378e7 (add accidentally removed md files) v2.4.0 is a feature release with the following features, fixes and enhancements: diff --git a/CONTRIBUTOR.md b/CONTRIBUTOR.md new file mode 100644 index 000000000..e9a2f925a --- /dev/null +++ b/CONTRIBUTOR.md @@ -0,0 +1,219 @@ +# Contributing to confluent-kafka-python + +Thank you for your interest in contributing to confluent-kafka-python! This document provides guidelines and best practices for contributing to this project. + +## Table of Contents + +- [Getting Started](#getting-started) +- [Development Environment Setup](#development-environment-setup) +- [Code Style and Standards](#code-style-and-standards) +- [Testing](#testing) +- [Submitting Changes](#submitting-changes) +- [Reporting Issues](#reporting-issues) +- [Community Guidelines](#community-guidelines) + +## Getting Started + +### Ways to Contribute + +- **Bug Reports**: Report bugs and issues you encounter +- **Feature Requests**: Suggest new features or improvements +- **Code Contributions**: Fix bugs, implement features, or improve documentation +- **Documentation**: Improve existing docs or add new documentation +- **Testing**: Help improve test coverage and quality + +### Before You Start + +1. Check existing [issues](../../issues) to see if your bug/feature has already been reported +2. For major changes, open an issue first to discuss the proposed changes +3. Fork the repository and create a feature branch for your work + +## Development Environment Setup + +For complete development environment setup instructions, including prerequisites, virtual environment creation, and dependency installation, see the [Development Environment Setup section in DEVELOPER.md](DEVELOPER.md#development-environment-setup). + +## Code Style and Standards + +### Python Code Style + +- **PEP 8**: Follow [PEP 8](https://pep8.org/) style guidelines as a default, with exceptions captured in the `tox.ini` flake8 rules for modern updates to the recommendations +- **Docstrings**: Use Google-style docstrings for all public functions and classes + +### Code Formatting + +We use automated tools to maintain consistent code style: + +```bash +# Install formatting tools +pip install flake8 + +# Check style +flake8 src/ tests/ +``` + +### Naming Conventions + +- **Functions and Variables**: `snake_case` +- **Classes**: `PascalCase` +- **Constants**: `UPPER_SNAKE_CASE` +- **Private Methods/Objects**: Prefix with single underscore `_private_method` + +### Documentation + +- All public APIs must have docstrings +- Include examples in docstrings when helpful +- Keep docstrings concise but complete +- Update relevant documentation files when making changes + +## Testing + +### Running Tests + +See [tests/README.md](tests/README.md) for comprehensive testing instructions. + +### Test Requirements + +- **Unit Tests**: All new functionality must include unit tests +- **Integration Tests**: Add integration tests for complex features +- **Test Coverage**: Maintain or improve existing test coverage +- **Test Naming**: Use descriptive test names that explain what is being tested + +### Test Structure + +```python +def test_feature_should_behave_correctly_when_condition(): + # Arrange + setup_data = create_test_data() + + # Act + result = function_under_test(setup_data) + + # Assert + assert result.expected_property == expected_value +``` + +## Submitting Changes + +### Pull Request Process + +1. **Create Feature Branch** + ```bash + git checkout -b feature/your-feature-name + # or + git checkout -b fix/issue-number-description + ``` + +2. **Make Your Changes** + - Write clean, well-documented code + - Add appropriate tests + - Update documentation if needed + - Add an entry to the CHANGELOG.md file for the proposed change + +3. **Test Your Changes** + Refer to [tests/README.md](tests/README.md) + +4. **Commit Your Changes** + ```bash + git add . + git commit -m "Clear, descriptive commit message" + ``` + + **Commit Message Guidelines:** + - Use present tense ("Add feature" not "Added feature") + - Keep first line under 50 characters + - Reference issue numbers when applicable (#123) + - Include breaking change notes if applicable + +5. **Push and Create Pull Request** + ```bash + git push origin feature/your-feature-name + ``` + + Then create a pull request through GitHub's interface. + +### Pull Request Guidelines + +- **Title**: Clear and descriptive +- **Description**: Explain what changes you made and why +- **Linked Issues**: Reference related issues using "Fixes #123" or "Relates to #123" +- **Labels**: Review available issue/PR labels and apply relevant ones to help with categorization and triage +- **Documentation**: Update documentation for user-facing changes +- **Tests**: Include appropriate tests +- **Breaking Changes**: Clearly mark any breaking changes + +### Code Review Process + +- All pull requests require review before merging +- Address reviewer feedback promptly +- Keep discussions respectful and constructive +- Be open to suggestions and alternative approaches + +## Reporting Issues + +### Using Labels + +When creating issues or pull requests, please review the available labels and apply those that are relevant to your submission. This helps maintainers categorize and prioritize work effectively. Common label categories include (look at available labels / other issues for options): + +- **Type**: bug, enhancement, documentation, question +- **Priority**: high, medium, low +- **Component**: producer, consumer, admin, schema-registry, etc +- **Status**: needs-investigation, help-wanted, good-first-issue, etc + +### Bug Reports + +When reporting bugs, please include: + +- **Clear Title**: Describe the issue concisely +- **Environment**: Python version, OS, library versions +- **Steps to Reproduce**: Detailed steps to reproduce the issue +- **Expected Behavior**: What you expected to happen +- **Actual Behavior**: What actually happened +- **Code Sample**: Minimal code that demonstrates the issue +- **Error Messages**: Full error messages and stack traces +- **Client Configuration**: Specify how the client was configured and setup +- **Logs**: Client logs when possible +- **Labels**: Apply relevant labels such as "bug" and component-specific labels + +### Feature Requests + +For feature requests, please include: + +- **Use Case**: Describe the problem you're trying to solve +- **Proposed Solution**: Your idea for how to address it +- **Alternatives**: Other solutions you've considered +- **Additional Context**: Any other relevant information +- **Labels**: Apply relevant labels such as "enhancement" and component-specific labels + +## Community Guidelines + +### Code of Conduct + +This project follows the [Contributor Covenant Code of Conduct](https://www.contributor-covenant.org/). By participating, you agree to uphold this code. + +### Communication + +- **Be Respectful**: Treat all community members with respect +- **Be Constructive**: Provide helpful feedback and suggestions +- **Be Patient**: Remember that maintainers and contributors volunteer their time +- **Be Clear**: Communicate clearly and provide sufficient context + +### Getting Help + +- **Issues**: Use GitHub issues for bug reports and feature requests +- **Discussions**: Use GitHub Discussions for questions and general discussion +- **Documentation**: Check existing documentation before asking questions + +## Recognition + +Contributors are recognized in the following ways: + +- Contributors are listed in the project's contributor history +- Significant contributions may be mentioned in release notes + +## License + +By contributing to this project, you agree that your contributions will be licensed under the same license as the project (see LICENSE file). + +--- + +Thank you for contributing to confluent-kafka-python! Your contributions help make this project better for everyone. \ No newline at end of file diff --git a/DEVELOPER.md b/DEVELOPER.md index 201816c73..a6a0c5838 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -10,12 +10,18 @@ This document provides information useful to developers working on confluent-kaf - Git - librdkafka (for Kafka functionality) +<<<<<<< HEAD ### Quick start (editable install) 1. **Fork and Clone** +======= +### Setup Steps + +1. **Fork and Clone** +>>>>>>> 7b378e7 (add accidentally removed md files) ```bash git clone https://github.com/your-username/confluent-kafka-python.git cd confluent-kafka-python @@ -30,8 +36,12 @@ This document provides information useful to developers working on confluent-kaf **Note**: On Windows the variables for Visual Studio are named INCLUDE and LIB 3. **Install librdkafka** (if not already installed) +<<<<<<< HEAD See the main README.md for platform-specific installation instructions. +======= +See the main README.md for platform-specific installation instructions +>>>>>>> 7b378e7 (add accidentally removed md files) If librdkafka is installed in a non-standard location provide the include and library directories with: @@ -39,24 +49,36 @@ If librdkafka is installed in a non-standard location provide the include and li C_INCLUDE_PATH=/path/to/include LIBRARY_PATH=/path/to/lib python -m build ``` +<<<<<<< HEAD 4. **Install confluent-kafka-python (editable) with dev/test/docs extras** +======= +4. **Install confluent-kafka-python with optional dependencies** +>>>>>>> 7b378e7 (add accidentally removed md files) ```bash pip3 install -e .[dev,tests,docs] ``` +<<<<<<< HEAD Alternatively you can build the bundle independently with: +======= + This will also build the wheel be default. Alternatively you can build the bundle independently with: +>>>>>>> 7b378e7 (add accidentally removed md files) ```bash python3 -m build ``` 5. **Verify Setup** +<<<<<<< HEAD +======= +>>>>>>> 7b378e7 (add accidentally removed md files) ```bash python3 -c "import confluent_kafka; print('Setup successful!')" ``` +<<<<<<< HEAD ## Project layout @@ -68,6 +90,8 @@ C_INCLUDE_PATH=/path/to/include LIBRARY_PATH=/path/to/lib python -m build - `examples/` — runnable samples (includes asyncio example) - `tools/unasync.py` — SR-only sync code generation from async sources +======= +>>>>>>> 7b378e7 (add accidentally removed md files) ## Generate Documentation Install docs dependencies: @@ -92,7 +116,11 @@ python3 setup.py build_sphinx Documentation will be generated in `build/sphinx/html`. +<<<<<<< HEAD ## Unasync — maintaining sync versions of async code (Schema Registry only) +======= +## Unasync -- maintaining sync versions of async code +>>>>>>> 7b378e7 (add accidentally removed md files) ```bash python3 tools/unasync.py @@ -101,6 +129,7 @@ python3 tools/unasync.py python3 tools/unasync.py --check ``` +<<<<<<< HEAD If you make any changes to the async code (in `src/confluent_kafka/schema_registry/_async` and `tests/integration/schema_registry/_async`), you **must** run this script to generate the sync counterparts (in `src/confluent_kafka/schema_registry/_sync` and `tests/integration/schema_registry/_sync`). Otherwise, this script will be run in CI with the `--check` flag and fail the build. Note: The AsyncIO Producer/Consumer under `src/confluent_kafka/experimental/aio/` are first-class asyncio implementations and are not generated using `unasync`. @@ -183,3 +212,12 @@ See “Generate Documentation” above; ensure examples and code blocks compile - Build errors related to librdkafka: ensure headers and libraries are discoverable; see “Install librdkafka” above for `C_INCLUDE_PATH` and `LIBRARY_PATH`. - Async tests hanging: check event loop usage and that `await producer.close()` is called to stop background tasks. +======= +If you make any changes to the async code (in `src/confluent_kafka/schema_registry/_async` and `tests/integration/schema_registry/_async`), you **must** run this script to generate the sync counter parts (in `src/confluent_kafka/schema_registry/_sync` and `tests/integration/schema_registry/_sync`). Otherwise, this script will be run in CI with the --check flag and fail the build. + + +## Tests + + +See [tests/README.md](tests/README.md) for instructions on how to run tests. +>>>>>>> 7b378e7 (add accidentally removed md files) diff --git a/INSTALL.md b/INSTALL.md new file mode 100644 index 000000000..6b642463b --- /dev/null +++ b/INSTALL.md @@ -0,0 +1,121 @@ +# confluent-kafka-python installation instructions + +## Install pre-built wheels (recommended) + +Confluent provides pre-built Python wheels of confluent-kafka-python with +all dependencies included. + +To install, simply do: + +```bash +python3 -m pip install confluent-kafka +``` + +If you get a build error or require Kerberos/GSSAPI support please read the next section: *Install from source* + + +## Install from source + +It is sometimes necessary to install confluent-kafka from source, rather +than from prebuilt binary wheels, such as when: + - You need GSSAPI/Kerberos authentication. + - You're on a Python version we do not provide prebuilt wheels for. + - You're on an architecture or platform we do not provide prebuilt wheels for. + - You want to build confluent-kafka-python from the master branch. + + +### Install from source on RedHat, CentOS, Fedora, etc + +```bash +# +# Perform these steps as the root user (e.g., in a 'sudo bash' shell) +# + +# Install build tools and Kerberos support. + +yum install -y python3 python3-pip python3-devel gcc make cyrus-sasl-gssapi krb5-workstation + +# Install the latest version of librdkafka: + +rpm --import https://packages.confluent.io/rpm/7.0/archive.key + +echo ' +[Confluent-Clients] +name=Confluent Clients repository +baseurl=https://packages.confluent.io/clients/rpm/centos/$releasever/$basearch +gpgcheck=1 +gpgkey=https://packages.confluent.io/clients/rpm/archive.key +enabled=1' > /etc/yum.repos.d/confluent.repo + +yum install -y librdkafka-devel + + +# +# Now build and install confluent-kafka-python as your standard user +# (e.g., exit the root shell first). +# + +python3 -m pip install --no-binary confluent-kafka confluent-kafka + + +# Verify that confluent_kafka is installed: + +python3 -c 'import confluent_kafka; print(confluent_kafka.version())' +``` + +### Install from source on Debian or Ubuntu + +```bash +# +# Perform these steps as the root user (e.g., in a 'sudo bash' shell) +# + +# Install build tools and Kerberos support. + +apt install -y wget software-properties-common lsb-release gcc make python3 python3-pip python3-dev libsasl2-modules-gssapi-mit krb5-user + + +# Install the latest version of librdkafka: + +wget -qO - https://packages.confluent.io/deb/7.0/archive.key | apt-key add - + +add-apt-repository "deb https://packages.confluent.io/clients/deb $(lsb_release -cs) main" + +apt update + +apt install -y librdkafka-dev + + +# +# Now build and install confluent-kafka-python as your standard user +# (e.g., exit the root shell first). +# + +python3 -m pip install --no-binary confluent-kafka confluent-kafka + + +# Verify that confluent_kafka is installed: + +python3 -c 'import confluent_kafka; print(confluent_kafka.version())' +``` + + +### Install from source on Mac OS X + +```bash + +# Install librdkafka from homebrew + +brew install librdkafka + + +# Build and install confluent-kafka-python + +python3 -m pip install --no-binary confluent-kafka confluent-kafka + + +# Verify that confluent_kafka is installed: + +python3 -c 'import confluent_kafka; print(confluent_kafka.version())' + +``` diff --git a/README.md b/README.md index 126e9ac0b..ce92155a9 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,17 @@ +<<<<<<< HEAD # Confluent Python Client for Apache Kafka [![Try Confluent Cloud - The Data Streaming Platform](https://images.ctfassets.net/8vofjvai1hpv/10bgcSfn5MzmvS4nNqr94J/af43dd2336e3f9e0c0ca4feef4398f6f/confluent-banner-v2.svg)](https://confluent.cloud/signup?utm_source=github&utm_medium=banner&utm_campaign=tm.plg.cflt-oss-repos&utm_term=confluent-kafka-python) +======= +> [!WARNING] +> Due to an error in which we included dependency changes to a recent patch release, Confluent recommends users to **refrain from upgrading to 2.6.2** of Confluent Kafka. Confluent will release a new minor version, 2.7.0, where the dependency changes will be appropriately included. Users who have already upgraded to 2.6.2 and made the required dependency changes are free to remain on that version and are recommended to upgrade to 2.7.0 when that version is available. Upon the release of 2.7.0, the 2.6.2 version will be marked deprecated. +We apologize for the inconvenience and appreciate the feedback that we have gotten from the community. +>>>>>>> 7b378e7 (add accidentally removed md files) Confluent's Python Client for Apache KafkaTM ======================================================= +<<<<<<< HEAD **confluent-kafka-python** provides a high-level `Producer`, `Consumer` and `AdminClient` compatible with all [Apache Kafka™](http://kafka.apache.org/) brokers >= v0.8, [Confluent Cloud](https://www.confluent.io/confluent-cloud/) and [Confluent Platform](https://www.confluent.io/product/compare/). **Recommended for Production:** While this client works with any Kafka deployment, it's optimized for and fully supported with [Confluent Cloud](https://www.confluent.io/confluent-cloud/) (fully managed) and [Confluent Platform](https://www.confluent.io/product/compare/) (self-managed), which provide enterprise-grade security, monitoring, and support. @@ -46,11 +53,33 @@ For a step-by-step guide on using the client, see [Getting Started with Apache K Additional examples can be found in the [examples](examples) directory or the [confluentinc/examples](https://github.com/confluentinc/examples/tree/master/clients/cloud/python) GitHub repo, which include demonstrations of: +======= +**confluent-kafka-python** provides a high-level Producer, Consumer and AdminClient compatible with all +[Apache KafkaTM](http://kafka.apache.org/) brokers >= v0.8, [Confluent Cloud](https://www.confluent.io/confluent-cloud/) +and [Confluent Platform](https://www.confluent.io/product/compare/). The client is: + +- **Reliable** - It's a wrapper around [librdkafka](https://github.com/edenhill/librdkafka) (provided automatically via binary wheels) which is widely deployed in a diverse set of production scenarios. It's tested using [the same set of system tests](https://github.com/confluentinc/confluent-kafka-python/tree/master/src/confluent_kafka/kafkatest) as the Java client [and more](https://github.com/confluentinc/confluent-kafka-python/tree/master/tests). It's supported by [Confluent](https://confluent.io). + +- **Performant** - Performance is a key design consideration. Maximum throughput is on par with the Java client for larger message sizes (where the overhead of the Python interpreter has less impact). Latency is on par with the Java client. + +- **Future proof** - Confluent, founded by the +creators of Kafka, is building a [streaming platform](https://www.confluent.io/product/compare/) +with Apache Kafka at its core. It's high priority for us that client features keep +pace with core Apache Kafka and components of the [Confluent Platform](https://www.confluent.io/product/compare/). + + +## Usage + +For a step-by-step guide on using the client see [Getting Started with Apache Kafka and Python](https://developer.confluent.io/get-started/python/). + +Aditional examples can be found in the [examples](examples) directory or the [confluentinc/examples](https://github.com/confluentinc/examples/tree/master/clients/cloud/python) github repo, which include demonstration of: +>>>>>>> 7b378e7 (add accidentally removed md files) - Exactly once data processing using the transactional API. - Integration with asyncio. - (De)serializing Protobuf, JSON, and Avro data with Confluent Schema Registry integration. - [Confluent Cloud](https://www.confluent.io/confluent-cloud/) configuration. +<<<<<<< HEAD Also see the [Python client docs](https://docs.confluent.io/kafka-clients/python/current/overview.html) and the [API reference](https://docs.confluent.io/kafka-clients/python/current/). Finally, the [tests](tests) are useful as a reference for example usage. @@ -98,6 +127,13 @@ For a more detailed example that includes both an async producer and consumer, s The AsyncIO producer and consumer integrate seamlessly with async Schema Registry serializers. See the [Schema Registry Integration](#schema-registry-integration) section below for full details. ### Basic Producer example +======= +Also refer to the [API documentation](http://docs.confluent.io/current/clients/confluent-kafka-python/index.html). + +Finally, the [tests](tests) are useful as a reference for example usage. + +### Basic Producer Example +>>>>>>> 7b378e7 (add accidentally removed md files) ```python from confluent_kafka import Producer @@ -129,6 +165,7 @@ p.flush() For a discussion on the poll based producer API, refer to the [Integrating Apache Kafka With Python Asyncio Web Applications](https://www.confluent.io/blog/kafka-python-asyncio-integration/) blog post. +<<<<<<< HEAD ### Schema Registry Integration This client provides full integration with Schema Registry for schema management and message serialization, and is compatible with both [Confluent Platform](https://docs.confluent.io/platform/current/schema-registry/index.html) and [Confluent Cloud](https://docs.confluent.io/cloud/current/sr/index.html). Both synchronous and asynchronous clients are available. @@ -216,6 +253,11 @@ from confluent_kafka.schema_registry._async.protobuf import AsyncProtobufSeriali - **401/403 Unauthorized when using Confluent Cloud:** Verify your `basic.auth.user.info` (SR API key/secret) is correct and that the Schema Registry URL is for your specific cluster. Ensure you are using an SR API key, not a Kafka API key. - **Schema not found:** Check that your `subject.name.strategy` configuration matches how your schemas are registered in Schema Registry, and that the topic and message field (key/value) pairing is correct. ### Basic Consumer example +======= + + +### Basic Consumer Example +>>>>>>> 7b378e7 (add accidentally removed md files) ```python from confluent_kafka import Consumer @@ -241,7 +283,13 @@ while True: c.close() ``` +<<<<<<< HEAD ### Basic AdminClient example +======= + + +### Basic AdminClient Example +>>>>>>> 7b378e7 (add accidentally removed md files) Create topics: @@ -265,6 +313,7 @@ for topic, f in fs.items(): except Exception as e: print("Failed to create topic {}: {}".format(topic, e)) ``` +<<<<<<< HEAD ## Thread safety The `Producer`, `Consumer`, and `AdminClient` are all thread safe. @@ -284,6 +333,29 @@ pip install "confluent-kafka[avro,schemaregistry,rules]" ``` **Note:** Pre-built Linux wheels do not include SASL Kerberos/GSSAPI support. For Kerberos, see the source installation instructions in [INSTALL.md](INSTALL.md). +======= + + +## Thread Safety + +The `Producer`, `Consumer` and `AdminClient` are all thread safe. + + +## Install + +**Install self-contained binary wheels** + +```bash +pip install confluent-kafka +``` + +**NOTE:** The pre-built Linux wheels do NOT contain SASL Kerberos/GSSAPI support. + If you need SASL Kerberos/GSSAPI support you must install librdkafka and + its dependencies using the repositories below and then build + confluent-kafka using the instructions in the + "Install from source" section below. + +>>>>>>> 7b378e7 (add accidentally removed md files) To use Schema Registry with the Avro serializer/deserializer: ```bash @@ -312,16 +384,78 @@ pip install "confluent-kafka[avro,schemaregistry,rules]" For source install, see the *Install from source* section in [INSTALL.md](INSTALL.md). +<<<<<<< HEAD ## Broker compatibility The Python client (as well as the underlying C library librdkafka) supports all broker versions >= 0.8. +======= + +## Broker Compatibility + +The Python client (as well as the underlying C library librdkafka) supports +all broker versions >= 0.8. +>>>>>>> 7b378e7 (add accidentally removed md files) But due to the nature of the Kafka protocol in broker versions 0.8 and 0.9 it is not safe for a client to assume what protocol version is actually supported by the broker, thus you will need to hint the Python client what protocol version it may use. This is done through two configuration settings: +<<<<<<< HEAD - `broker.version.fallback=YOUR_BROKER_VERSION` (default 0.9.0.1) - `api.version.request=true|false` (default true) When using a Kafka 0.10 broker or later you don't need to do anything +======= + * `broker.version.fallback=YOUR_BROKER_VERSION` (default 0.9.0.1) + * `api.version.request=true|false` (default true) + +When using a Kafka 0.10 broker or later you don't need to do anything +(`api.version.request=true` is the default). +If you use Kafka broker 0.9 or 0.8 you must set +`api.version.request=false` and set +`broker.version.fallback` to your broker version, +e.g `broker.version.fallback=0.9.0.1`. + +More info here: +https://github.com/edenhill/librdkafka/wiki/Broker-version-compatibility + + +## SSL certificates + +If you're connecting to a Kafka cluster through SSL you will need to configure +the client with `'security.protocol': 'SSL'` (or `'SASL_SSL'` if SASL +authentication is used). + +The client will use CA certificates to verify the broker's certificate. +The embedded OpenSSL library will look for CA certificates in `/usr/lib/ssl/certs/` +or `/usr/lib/ssl/cacert.pem`. CA certificates are typically provided by the +Linux distribution's `ca-certificates` package which needs to be installed +through `apt`, `yum`, et.al. + +If your system stores CA certificates in another location you will need to +configure the client with `'ssl.ca.location': '/path/to/cacert.pem'`. + +Alternatively, the CA certificates can be provided by the [certifi](https://pypi.org/project/certifi/) +Python package. To use certifi, add an `import certifi` line and configure the +client's CA location with `'ssl.ca.location': certifi.where()`. + + +## License + +[Apache License v2.0](http://www.apache.org/licenses/LICENSE-2.0) + +KAFKA is a registered trademark of The Apache Software Foundation and has been licensed for use +by confluent-kafka-python. confluent-kafka-python has no affiliation with and is not endorsed by +The Apache Software Foundation. + + +## Developer Notes + +Instructions on building and testing confluent-kafka-python can be found [here](DEVELOPER.md). + + +## Confluent Cloud + +For a step-by-step guide on using the Python client with Confluent Cloud see [Getting Started with Apache Kafka and Python](https://developer.confluent.io/get-started/python/) on [Confluent Developer](https://developer.confluent.io/). +>>>>>>> 7b378e7 (add accidentally removed md files) From 1abe216c69c2f56cdad7f4b57adcb859dcadd1cf Mon Sep 17 00:00:00 2001 From: Naxin Date: Wed, 15 Oct 2025 00:39:20 -0400 Subject: [PATCH 09/31] fix merge conflicts in md files, add types to admin and serialization entrypoint init files --- CHANGELOG.md | 58 ------- DEVELOPER.md | 38 ----- README.md | 134 --------------- src/confluent_kafka/admin/__init__.py | 157 ++++++------------ src/confluent_kafka/serialization/__init__.py | 6 +- 5 files changed, 54 insertions(+), 339 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73bf093c3..77f0dcf85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,3 @@ -<<<<<<< HEAD # Confluent Python Client for Apache Kafka - CHANGELOG ## v2.12.1 - 2025-10-21 @@ -67,11 +66,6 @@ for a complete list of changes, enhancements, fixes and upgrade considerations. ## v2.11.1 - 2025-08-18 -======= -# Confluent's Python client for Apache Kafka - -## v2.11.1 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.11.1 is a maintenance release with the following fixes: @@ -80,11 +74,7 @@ confluent-kafka-python v2.11.1 is based on librdkafka v2.11.1, see the for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.11.0 - 2025-07-03 -======= -## v2.11.0 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.11.0 is a feature release with the following enhancements: @@ -93,11 +83,7 @@ confluent-kafka-python v2.11.0 is based on librdkafka v2.11.0, see the for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.10.1 - 2025-06-11 -======= -## v2.10.1 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.10.1 is a maintenance release with the following fixes: @@ -113,11 +99,7 @@ confluent-kafka-python v2.10.1 is based on librdkafka v2.10.1, see the [librdkafka release notes](https://github.com/confluentinc/librdkafka/releases/tag/v2.10.1) for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.10.0 - 2025-04-18 -======= -## v2.10.0 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.10.0 is a feature release with the following fixes and enhancements: @@ -128,11 +110,7 @@ confluent-kafka-python v2.10.0 is based on librdkafka v2.10.0, see the [librdkafka release notes](https://github.com/confluentinc/librdkafka/releases/tag/v2.10.0) for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.9.0 - 2025-03-28 -======= -## v2.9.0 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.9.0 is a feature release with the following fixes and enhancements: @@ -145,11 +123,7 @@ confluent-kafka-python v2.9.0 is based on librdkafka v2.8.0, see the [librdkafka release notes](https://github.com/confluentinc/librdkafka/releases/tag/v2.8.0) for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.8.2 - 2025-02-28 -======= -## v2.8.2 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.8.2 is a maintenance release with the following fixes and enhancements: @@ -164,11 +138,7 @@ Note: Versioning is skipped due to breaking change in v2.8.1. Do not run software with v2.8.1 installed. -<<<<<<< HEAD ## v2.8.0 - 2025-01-07 -======= -## v2.8.0 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.8.0 is a feature release with the features, fixes and enhancements: @@ -177,11 +147,7 @@ confluent-kafka-python v2.8.0 is based on librdkafka v2.8.0, see the for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.7.0 - 2024-12-21 -======= -## v2.7.0 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.7.0 is a feature release with the features, fixes and enhancements present in v2.6.2 including the following fix: @@ -192,11 +158,7 @@ confluent-kafka-python v2.7.0 is based on librdkafka v2.6.1, see the for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.6.2 - 2024-12-18 -======= -## v2.6.2 ->>>>>>> 7b378e7 (add accidentally removed md files) > [!WARNING] > Due to an error in which we included dependency changes to a recent patch release, Confluent recommends users to **refrain from upgrading to 2.6.2** of Confluent Kafka. Confluent will release a new minor version, 2.7.0, where the dependency changes will be appropriately included. Users who have already upgraded to 2.6.2 and made the required dependency changes are free to remain on that version and are recommended to upgrade to 2.7.0 when that version is available. Upon the release of 2.7.0, the 2.6.2 version will be marked deprecated. @@ -239,11 +201,7 @@ confluent-kafka-python is based on librdkafka v2.6.1, see the for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.6.1 - 2024-11-18 -======= -## v2.6.1 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.6.1 is a maintenance release with the following fixes and enhancements: @@ -256,11 +214,7 @@ confluent-kafka-python is based on librdkafka v2.6.1, see the for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.6.0 - 2024-10-11 -======= -## v2.6.0 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.6.0 is a feature release with the following features, fixes and enhancements: @@ -274,11 +228,7 @@ confluent-kafka-python is based on librdkafka v2.6.0, see the for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.5.3 - 2024-09-02 -======= -## v2.5.3 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.5.3 is a maintenance release with the following fixes and enhancements: @@ -293,11 +243,7 @@ for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.5.0 - 2024-07-10 -======= -## v2.5.0 ->>>>>>> 7b378e7 (add accidentally removed md files) > [!WARNING] This version has introduced a regression in which an assert is triggered during **PushTelemetry** call. This happens when no metric is matched on the client side among those requested by broker subscription. @@ -330,11 +276,7 @@ confluent-kafka-python is based on librdkafka v2.5.0, see the for a complete list of changes, enhancements, fixes and upgrade considerations. -<<<<<<< HEAD ## v2.4.0 - 2024-05-07 -======= -## v2.4.0 ->>>>>>> 7b378e7 (add accidentally removed md files) v2.4.0 is a feature release with the following features, fixes and enhancements: diff --git a/DEVELOPER.md b/DEVELOPER.md index a6a0c5838..201816c73 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -10,18 +10,12 @@ This document provides information useful to developers working on confluent-kaf - Git - librdkafka (for Kafka functionality) -<<<<<<< HEAD ### Quick start (editable install) 1. **Fork and Clone** -======= -### Setup Steps - -1. **Fork and Clone** ->>>>>>> 7b378e7 (add accidentally removed md files) ```bash git clone https://github.com/your-username/confluent-kafka-python.git cd confluent-kafka-python @@ -36,12 +30,8 @@ This document provides information useful to developers working on confluent-kaf **Note**: On Windows the variables for Visual Studio are named INCLUDE and LIB 3. **Install librdkafka** (if not already installed) -<<<<<<< HEAD See the main README.md for platform-specific installation instructions. -======= -See the main README.md for platform-specific installation instructions ->>>>>>> 7b378e7 (add accidentally removed md files) If librdkafka is installed in a non-standard location provide the include and library directories with: @@ -49,36 +39,24 @@ If librdkafka is installed in a non-standard location provide the include and li C_INCLUDE_PATH=/path/to/include LIBRARY_PATH=/path/to/lib python -m build ``` -<<<<<<< HEAD 4. **Install confluent-kafka-python (editable) with dev/test/docs extras** -======= -4. **Install confluent-kafka-python with optional dependencies** ->>>>>>> 7b378e7 (add accidentally removed md files) ```bash pip3 install -e .[dev,tests,docs] ``` -<<<<<<< HEAD Alternatively you can build the bundle independently with: -======= - This will also build the wheel be default. Alternatively you can build the bundle independently with: ->>>>>>> 7b378e7 (add accidentally removed md files) ```bash python3 -m build ``` 5. **Verify Setup** -<<<<<<< HEAD -======= ->>>>>>> 7b378e7 (add accidentally removed md files) ```bash python3 -c "import confluent_kafka; print('Setup successful!')" ``` -<<<<<<< HEAD ## Project layout @@ -90,8 +68,6 @@ C_INCLUDE_PATH=/path/to/include LIBRARY_PATH=/path/to/lib python -m build - `examples/` — runnable samples (includes asyncio example) - `tools/unasync.py` — SR-only sync code generation from async sources -======= ->>>>>>> 7b378e7 (add accidentally removed md files) ## Generate Documentation Install docs dependencies: @@ -116,11 +92,7 @@ python3 setup.py build_sphinx Documentation will be generated in `build/sphinx/html`. -<<<<<<< HEAD ## Unasync — maintaining sync versions of async code (Schema Registry only) -======= -## Unasync -- maintaining sync versions of async code ->>>>>>> 7b378e7 (add accidentally removed md files) ```bash python3 tools/unasync.py @@ -129,7 +101,6 @@ python3 tools/unasync.py python3 tools/unasync.py --check ``` -<<<<<<< HEAD If you make any changes to the async code (in `src/confluent_kafka/schema_registry/_async` and `tests/integration/schema_registry/_async`), you **must** run this script to generate the sync counterparts (in `src/confluent_kafka/schema_registry/_sync` and `tests/integration/schema_registry/_sync`). Otherwise, this script will be run in CI with the `--check` flag and fail the build. Note: The AsyncIO Producer/Consumer under `src/confluent_kafka/experimental/aio/` are first-class asyncio implementations and are not generated using `unasync`. @@ -212,12 +183,3 @@ See “Generate Documentation” above; ensure examples and code blocks compile - Build errors related to librdkafka: ensure headers and libraries are discoverable; see “Install librdkafka” above for `C_INCLUDE_PATH` and `LIBRARY_PATH`. - Async tests hanging: check event loop usage and that `await producer.close()` is called to stop background tasks. -======= -If you make any changes to the async code (in `src/confluent_kafka/schema_registry/_async` and `tests/integration/schema_registry/_async`), you **must** run this script to generate the sync counter parts (in `src/confluent_kafka/schema_registry/_sync` and `tests/integration/schema_registry/_sync`). Otherwise, this script will be run in CI with the --check flag and fail the build. - - -## Tests - - -See [tests/README.md](tests/README.md) for instructions on how to run tests. ->>>>>>> 7b378e7 (add accidentally removed md files) diff --git a/README.md b/README.md index ce92155a9..126e9ac0b 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,10 @@ -<<<<<<< HEAD # Confluent Python Client for Apache Kafka [![Try Confluent Cloud - The Data Streaming Platform](https://images.ctfassets.net/8vofjvai1hpv/10bgcSfn5MzmvS4nNqr94J/af43dd2336e3f9e0c0ca4feef4398f6f/confluent-banner-v2.svg)](https://confluent.cloud/signup?utm_source=github&utm_medium=banner&utm_campaign=tm.plg.cflt-oss-repos&utm_term=confluent-kafka-python) -======= -> [!WARNING] -> Due to an error in which we included dependency changes to a recent patch release, Confluent recommends users to **refrain from upgrading to 2.6.2** of Confluent Kafka. Confluent will release a new minor version, 2.7.0, where the dependency changes will be appropriately included. Users who have already upgraded to 2.6.2 and made the required dependency changes are free to remain on that version and are recommended to upgrade to 2.7.0 when that version is available. Upon the release of 2.7.0, the 2.6.2 version will be marked deprecated. -We apologize for the inconvenience and appreciate the feedback that we have gotten from the community. ->>>>>>> 7b378e7 (add accidentally removed md files) Confluent's Python Client for Apache KafkaTM ======================================================= -<<<<<<< HEAD **confluent-kafka-python** provides a high-level `Producer`, `Consumer` and `AdminClient` compatible with all [Apache Kafka™](http://kafka.apache.org/) brokers >= v0.8, [Confluent Cloud](https://www.confluent.io/confluent-cloud/) and [Confluent Platform](https://www.confluent.io/product/compare/). **Recommended for Production:** While this client works with any Kafka deployment, it's optimized for and fully supported with [Confluent Cloud](https://www.confluent.io/confluent-cloud/) (fully managed) and [Confluent Platform](https://www.confluent.io/product/compare/) (self-managed), which provide enterprise-grade security, monitoring, and support. @@ -53,33 +46,11 @@ For a step-by-step guide on using the client, see [Getting Started with Apache K Additional examples can be found in the [examples](examples) directory or the [confluentinc/examples](https://github.com/confluentinc/examples/tree/master/clients/cloud/python) GitHub repo, which include demonstrations of: -======= -**confluent-kafka-python** provides a high-level Producer, Consumer and AdminClient compatible with all -[Apache KafkaTM](http://kafka.apache.org/) brokers >= v0.8, [Confluent Cloud](https://www.confluent.io/confluent-cloud/) -and [Confluent Platform](https://www.confluent.io/product/compare/). The client is: - -- **Reliable** - It's a wrapper around [librdkafka](https://github.com/edenhill/librdkafka) (provided automatically via binary wheels) which is widely deployed in a diverse set of production scenarios. It's tested using [the same set of system tests](https://github.com/confluentinc/confluent-kafka-python/tree/master/src/confluent_kafka/kafkatest) as the Java client [and more](https://github.com/confluentinc/confluent-kafka-python/tree/master/tests). It's supported by [Confluent](https://confluent.io). - -- **Performant** - Performance is a key design consideration. Maximum throughput is on par with the Java client for larger message sizes (where the overhead of the Python interpreter has less impact). Latency is on par with the Java client. - -- **Future proof** - Confluent, founded by the -creators of Kafka, is building a [streaming platform](https://www.confluent.io/product/compare/) -with Apache Kafka at its core. It's high priority for us that client features keep -pace with core Apache Kafka and components of the [Confluent Platform](https://www.confluent.io/product/compare/). - - -## Usage - -For a step-by-step guide on using the client see [Getting Started with Apache Kafka and Python](https://developer.confluent.io/get-started/python/). - -Aditional examples can be found in the [examples](examples) directory or the [confluentinc/examples](https://github.com/confluentinc/examples/tree/master/clients/cloud/python) github repo, which include demonstration of: ->>>>>>> 7b378e7 (add accidentally removed md files) - Exactly once data processing using the transactional API. - Integration with asyncio. - (De)serializing Protobuf, JSON, and Avro data with Confluent Schema Registry integration. - [Confluent Cloud](https://www.confluent.io/confluent-cloud/) configuration. -<<<<<<< HEAD Also see the [Python client docs](https://docs.confluent.io/kafka-clients/python/current/overview.html) and the [API reference](https://docs.confluent.io/kafka-clients/python/current/). Finally, the [tests](tests) are useful as a reference for example usage. @@ -127,13 +98,6 @@ For a more detailed example that includes both an async producer and consumer, s The AsyncIO producer and consumer integrate seamlessly with async Schema Registry serializers. See the [Schema Registry Integration](#schema-registry-integration) section below for full details. ### Basic Producer example -======= -Also refer to the [API documentation](http://docs.confluent.io/current/clients/confluent-kafka-python/index.html). - -Finally, the [tests](tests) are useful as a reference for example usage. - -### Basic Producer Example ->>>>>>> 7b378e7 (add accidentally removed md files) ```python from confluent_kafka import Producer @@ -165,7 +129,6 @@ p.flush() For a discussion on the poll based producer API, refer to the [Integrating Apache Kafka With Python Asyncio Web Applications](https://www.confluent.io/blog/kafka-python-asyncio-integration/) blog post. -<<<<<<< HEAD ### Schema Registry Integration This client provides full integration with Schema Registry for schema management and message serialization, and is compatible with both [Confluent Platform](https://docs.confluent.io/platform/current/schema-registry/index.html) and [Confluent Cloud](https://docs.confluent.io/cloud/current/sr/index.html). Both synchronous and asynchronous clients are available. @@ -253,11 +216,6 @@ from confluent_kafka.schema_registry._async.protobuf import AsyncProtobufSeriali - **401/403 Unauthorized when using Confluent Cloud:** Verify your `basic.auth.user.info` (SR API key/secret) is correct and that the Schema Registry URL is for your specific cluster. Ensure you are using an SR API key, not a Kafka API key. - **Schema not found:** Check that your `subject.name.strategy` configuration matches how your schemas are registered in Schema Registry, and that the topic and message field (key/value) pairing is correct. ### Basic Consumer example -======= - - -### Basic Consumer Example ->>>>>>> 7b378e7 (add accidentally removed md files) ```python from confluent_kafka import Consumer @@ -283,13 +241,7 @@ while True: c.close() ``` -<<<<<<< HEAD ### Basic AdminClient example -======= - - -### Basic AdminClient Example ->>>>>>> 7b378e7 (add accidentally removed md files) Create topics: @@ -313,7 +265,6 @@ for topic, f in fs.items(): except Exception as e: print("Failed to create topic {}: {}".format(topic, e)) ``` -<<<<<<< HEAD ## Thread safety The `Producer`, `Consumer`, and `AdminClient` are all thread safe. @@ -333,29 +284,6 @@ pip install "confluent-kafka[avro,schemaregistry,rules]" ``` **Note:** Pre-built Linux wheels do not include SASL Kerberos/GSSAPI support. For Kerberos, see the source installation instructions in [INSTALL.md](INSTALL.md). -======= - - -## Thread Safety - -The `Producer`, `Consumer` and `AdminClient` are all thread safe. - - -## Install - -**Install self-contained binary wheels** - -```bash -pip install confluent-kafka -``` - -**NOTE:** The pre-built Linux wheels do NOT contain SASL Kerberos/GSSAPI support. - If you need SASL Kerberos/GSSAPI support you must install librdkafka and - its dependencies using the repositories below and then build - confluent-kafka using the instructions in the - "Install from source" section below. - ->>>>>>> 7b378e7 (add accidentally removed md files) To use Schema Registry with the Avro serializer/deserializer: ```bash @@ -384,78 +312,16 @@ pip install "confluent-kafka[avro,schemaregistry,rules]" For source install, see the *Install from source* section in [INSTALL.md](INSTALL.md). -<<<<<<< HEAD ## Broker compatibility The Python client (as well as the underlying C library librdkafka) supports all broker versions >= 0.8. -======= - -## Broker Compatibility - -The Python client (as well as the underlying C library librdkafka) supports -all broker versions >= 0.8. ->>>>>>> 7b378e7 (add accidentally removed md files) But due to the nature of the Kafka protocol in broker versions 0.8 and 0.9 it is not safe for a client to assume what protocol version is actually supported by the broker, thus you will need to hint the Python client what protocol version it may use. This is done through two configuration settings: -<<<<<<< HEAD - `broker.version.fallback=YOUR_BROKER_VERSION` (default 0.9.0.1) - `api.version.request=true|false` (default true) When using a Kafka 0.10 broker or later you don't need to do anything -======= - * `broker.version.fallback=YOUR_BROKER_VERSION` (default 0.9.0.1) - * `api.version.request=true|false` (default true) - -When using a Kafka 0.10 broker or later you don't need to do anything -(`api.version.request=true` is the default). -If you use Kafka broker 0.9 or 0.8 you must set -`api.version.request=false` and set -`broker.version.fallback` to your broker version, -e.g `broker.version.fallback=0.9.0.1`. - -More info here: -https://github.com/edenhill/librdkafka/wiki/Broker-version-compatibility - - -## SSL certificates - -If you're connecting to a Kafka cluster through SSL you will need to configure -the client with `'security.protocol': 'SSL'` (or `'SASL_SSL'` if SASL -authentication is used). - -The client will use CA certificates to verify the broker's certificate. -The embedded OpenSSL library will look for CA certificates in `/usr/lib/ssl/certs/` -or `/usr/lib/ssl/cacert.pem`. CA certificates are typically provided by the -Linux distribution's `ca-certificates` package which needs to be installed -through `apt`, `yum`, et.al. - -If your system stores CA certificates in another location you will need to -configure the client with `'ssl.ca.location': '/path/to/cacert.pem'`. - -Alternatively, the CA certificates can be provided by the [certifi](https://pypi.org/project/certifi/) -Python package. To use certifi, add an `import certifi` line and configure the -client's CA location with `'ssl.ca.location': certifi.where()`. - - -## License - -[Apache License v2.0](http://www.apache.org/licenses/LICENSE-2.0) - -KAFKA is a registered trademark of The Apache Software Foundation and has been licensed for use -by confluent-kafka-python. confluent-kafka-python has no affiliation with and is not endorsed by -The Apache Software Foundation. - - -## Developer Notes - -Instructions on building and testing confluent-kafka-python can be found [here](DEVELOPER.md). - - -## Confluent Cloud - -For a step-by-step guide on using the Python client with Confluent Cloud see [Getting Started with Apache Kafka and Python](https://developer.confluent.io/get-started/python/) on [Confluent Developer](https://developer.confluent.io/). ->>>>>>> 7b378e7 (add accidentally removed md files) diff --git a/src/confluent_kafka/admin/__init__.py b/src/confluent_kafka/admin/__init__.py index 4a8394ce0..6a1c0e9f4 100644 --- a/src/confluent_kafka/admin/__init__.py +++ b/src/confluent_kafka/admin/__init__.py @@ -87,9 +87,11 @@ ConsumerGroupState as _ConsumerGroupState, \ IsolationLevel as _IsolationLevel +from .._types import ConfigDict + try: - string_type = basestring # type: ignore[name-defined] + string_type = basestring except NameError: string_type = str @@ -116,7 +118,7 @@ class AdminClient (_AdminClientImpl): Requires broker version v0.11.0.0 or later. """ - def __init__(self, conf: Dict[str, Union[str, int, float, bool]], **kwargs: Any) -> None: + def __init__(self, conf: ConfigDict, **kwargs: Any) -> None: """ Create a new AdminClient using the provided configuration dictionary. @@ -154,8 +156,7 @@ def _make_topics_result(f: concurrent.futures.Future, futmap: Dict[str, concurre fut.set_exception(e) @staticmethod - def _make_resource_result(f: concurrent.futures.Future, - futmap: Dict[ConfigResource, concurrent.futures.Future]) -> None: + def _make_resource_result(f: concurrent.futures.Future, futmap: Dict[ConfigResource, concurrent.futures.Future]) -> None: """ Map per-resource results to per-resource futures in futmap. The result value of each (successful) future is a ConfigResource. @@ -184,8 +185,7 @@ def _make_list_consumer_groups_result(f: concurrent.futures.Future, futmap: Any) pass @staticmethod - def _make_consumer_groups_result(f: concurrent.futures.Future, - futmap: Dict[str, concurrent.futures.Future]) -> None: + def _make_consumer_groups_result(f: concurrent.futures.Future, futmap: Dict[str, concurrent.futures.Future]) -> None: """ Map per-group results to per-group futures in futmap. """ @@ -210,8 +210,7 @@ def _make_consumer_groups_result(f: concurrent.futures.Future, fut.set_exception(e) @staticmethod - def _make_consumer_group_offsets_result(f: concurrent.futures.Future, - futmap: Dict[str, concurrent.futures.Future]) -> None: + def _make_consumer_group_offsets_result(f: concurrent.futures.Future, futmap: Dict[str, concurrent.futures.Future]) -> None: """ Map per-group results to per-group futures in futmap. The result value of each (successful) future is ConsumerGroupTopicPartitions. @@ -263,8 +262,7 @@ def _make_acls_result(f: concurrent.futures.Future, futmap: Dict[Any, concurrent fut.set_exception(e) @staticmethod - def _make_futmap_result_from_list(f: concurrent.futures.Future, - futmap: Dict[Any, concurrent.futures.Future]) -> None: + def _make_futmap_result_from_list(f: concurrent.futures.Future, futmap: Dict[Any, concurrent.futures.Future]) -> None: try: results = f.result() @@ -309,15 +307,13 @@ def _make_futmap_result(f: concurrent.futures.Future, futmap: Dict[str, concurre @staticmethod def _create_future() -> concurrent.futures.Future: - f: concurrent.futures.Future = concurrent.futures.Future() + f = concurrent.futures.Future() if not f.set_running_or_notify_cancel(): raise RuntimeError("Future was cancelled prematurely") return f @staticmethod - def _make_futures(futmap_keys: List[Any], class_check: Optional[type], - make_result_fn: Any) -> Tuple[concurrent.futures.Future, - Dict[Any, concurrent.futures.Future]]: + def _make_futures(futmap_keys: List[Any], class_check: Optional[type], make_result_fn: Any) -> Tuple[concurrent.futures.Future, Dict[Any, concurrent.futures.Future]]: """ Create futures and a futuremap for the keys in futmap_keys, and create a request-level future to be bassed to the C API. @@ -339,9 +335,7 @@ def _make_futures(futmap_keys: List[Any], class_check: Optional[type], return f, futmap @staticmethod - def _make_futures_v2(futmap_keys: Union[List[Any], Set[Any]], class_check: Optional[type], - make_result_fn: Any) -> Tuple[concurrent.futures.Future, - Dict[Any, concurrent.futures.Future]]: + def _make_futures_v2(futmap_keys: Union[List[Any], Set[Any]], class_check: Optional[type], make_result_fn: Any) -> Tuple[concurrent.futures.Future, Dict[Any, concurrent.futures.Future]]: """ Create futures and a futuremap for the keys in futmap_keys, and create a request-level future to be bassed to the C API. @@ -427,7 +421,7 @@ def _check_list_consumer_group_offsets_request(request: Optional[List[_ConsumerG raise ValueError("Element of 'topic_partitions' must not have 'offset' value") @staticmethod - def _check_alter_consumer_group_offsets_request(request: Optional[List[_ConsumerGroupTopicPartitions]]) -> None: + def _check_alter_consumer_group_offsets_request(request): if request is None: raise TypeError("request cannot be None") if not isinstance(request, list): @@ -466,7 +460,7 @@ def _check_alter_consumer_group_offsets_request(request: Optional[List[_Consumer "Element of 'topic_partitions' must not have negative value for 'offset' field") @staticmethod - def _check_describe_user_scram_credentials_request(users: Optional[List[str]]) -> None: + def _check_describe_user_scram_credentials_request(users): if users is None: return if not isinstance(users, list): @@ -480,7 +474,7 @@ def _check_describe_user_scram_credentials_request(users: Optional[List[str]]) - raise ValueError("'user' cannot be empty") @staticmethod - def _check_alter_user_scram_credentials_request(alterations: List[UserScramCredentialAlteration]) -> None: + def _check_alter_user_scram_credentials_request(alterations): if not isinstance(alterations, list): raise TypeError("Expected input to be list") if len(alterations) == 0: @@ -523,8 +517,7 @@ def _check_alter_user_scram_credentials_request(alterations: List[UserScramCrede "UserScramCredentialDeletion") @staticmethod - def _check_list_offsets_request(topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], - kwargs: Dict[str, Any]) -> None: + def _check_list_offsets_request(topic_partition_offsets, kwargs): if not isinstance(topic_partition_offsets, dict): raise TypeError("Expected topic_partition_offsets to be " + "dict of [TopicPartitions,OffsetSpec] for list offsets request") @@ -552,7 +545,7 @@ def _check_list_offsets_request(topic_partition_offsets: Dict[_TopicPartition, O raise TypeError("isolation_level argument should be an IsolationLevel") @staticmethod - def _check_delete_records(request: List[_TopicPartition]) -> None: + def _check_delete_records(request): if not isinstance(request, list): raise TypeError(f"Expected Request to be a list, got '{type(request).__name__}' ") for req in request: @@ -563,7 +556,7 @@ def _check_delete_records(request: List[_TopicPartition]) -> None: raise ValueError("'partition' cannot be negative") @staticmethod - def _check_elect_leaders(election_type: _ElectionType, partitions: Optional[List[_TopicPartition]]) -> None: + def _check_elect_leaders(election_type, partitions): if not isinstance(election_type, _ElectionType): raise TypeError("Expected 'election_type' to be of type 'ElectionType'") if partitions is not None: @@ -578,9 +571,7 @@ def _check_elect_leaders(election_type: _ElectionType, partitions: Optional[List raise ValueError("Elements of the 'partitions' list must not have negative value" + " for 'partition' field") - def create_topics( # type: ignore[override] - self, new_topics: List[NewTopic], **kwargs: Any - ) -> Dict[str, concurrent.futures.Future]: + def create_topics(self, new_topics: List[NewTopic], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Create one or more new topics. @@ -615,9 +606,7 @@ def create_topics( # type: ignore[override] return futmap - def delete_topics( # type: ignore[override] - self, topics: List[str], **kwargs: Any - ) -> Dict[str, concurrent.futures.Future]: + def delete_topics(self, topics: List[str], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Delete one or more topics. @@ -652,13 +641,11 @@ def list_topics(self, *args: Any, **kwargs: Any) -> ClusterMetadata: return super(AdminClient, self).list_topics(*args, **kwargs) - def list_groups(self, *args: Any, **kwargs: Any) -> List[GroupMetadata]: + def list_groups(self, *args: Any, **kwargs: Any) -> GroupMetadata: return super(AdminClient, self).list_groups(*args, **kwargs) - def create_partitions( # type: ignore[override] - self, new_partitions: List[NewPartitions], **kwargs: Any - ) -> Dict[str, concurrent.futures.Future]: + def create_partitions(self, new_partitions: List[NewPartitions], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Create additional partitions for the given topics. @@ -692,9 +679,7 @@ def create_partitions( # type: ignore[override] return futmap - def describe_configs( # type: ignore[override] - self, resources: List[ConfigResource], **kwargs: Any - ) -> Dict[ConfigResource, concurrent.futures.Future]: + def describe_configs(self, resources: List[ConfigResource], **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: """ Get the configuration of the specified resources. @@ -726,9 +711,7 @@ def describe_configs( # type: ignore[override] return futmap - def alter_configs( # type: ignore[override] - self, resources: List[ConfigResource], **kwargs: Any - ) -> Dict[ConfigResource, concurrent.futures.Future]: + def alter_configs(self, resources, **kwargs): """ .. deprecated:: 2.2.0 @@ -776,9 +759,7 @@ def alter_configs( # type: ignore[override] return futmap - def incremental_alter_configs( # type: ignore[override] - self, resources: List[ConfigResource], **kwargs: Any - ) -> Dict[ConfigResource, concurrent.futures.Future]: + def incremental_alter_configs(self, resources, **kwargs): """ Update configuration properties for the specified resources. Updates are incremental, i.e only the values mentioned are changed @@ -811,9 +792,7 @@ def incremental_alter_configs( # type: ignore[override] return futmap - def create_acls( # type: ignore[override] - self, acls: List[AclBinding], **kwargs: Any - ) -> Dict[AclBinding, concurrent.futures.Future]: + def create_acls(self, acls, **kwargs): """ Create one or more ACL bindings. @@ -842,9 +821,7 @@ def create_acls( # type: ignore[override] return futmap - def describe_acls( # type: ignore[override] - self, acl_binding_filter: AclBindingFilter, **kwargs: Any - ) -> concurrent.futures.Future: + def describe_acls(self, acl_binding_filter, **kwargs): """ Match ACL bindings by filter. @@ -879,9 +856,7 @@ def describe_acls( # type: ignore[override] return f - def delete_acls( # type: ignore[override] - self, acl_binding_filters: List[AclBindingFilter], **kwargs: Any - ) -> Dict[AclBindingFilter, concurrent.futures.Future]: + def delete_acls(self, acl_binding_filters, **kwargs): """ Delete ACL bindings matching one or more ACL binding filters. @@ -920,9 +895,7 @@ def delete_acls( # type: ignore[override] return futmap - def list_consumer_groups( # type: ignore[override] - self, **kwargs: Any - ) -> concurrent.futures.Future: + def list_consumer_groups(self, **kwargs): """ List consumer groups. @@ -969,9 +942,7 @@ def list_consumer_groups( # type: ignore[override] return f - def describe_consumer_groups( # type: ignore[override] - self, group_ids: List[str], **kwargs: Any - ) -> Dict[str, concurrent.futures.Future]: + def describe_consumer_groups(self, group_ids: List[str], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Describe consumer groups. @@ -1004,9 +975,7 @@ def describe_consumer_groups( # type: ignore[override] return futmap - def describe_topics( # type: ignore[override] - self, topics: _TopicCollection, **kwargs: Any - ) -> Dict[str, concurrent.futures.Future]: + def describe_topics(self, topics: _TopicCollection, **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Describe topics. @@ -1037,13 +1006,11 @@ def describe_topics( # type: ignore[override] f, futmap = AdminClient._make_futures_v2(topic_names, None, AdminClient._make_futmap_result_from_list) - super(AdminClient, self).describe_topics(topic_names, f, **kwargs) # type: ignore[arg-type] + super(AdminClient, self).describe_topics(topic_names, f, **kwargs) return futmap - def describe_cluster( # type: ignore[override] - self, **kwargs: Any - ) -> concurrent.futures.Future: + def describe_cluster(self, **kwargs: Any) -> concurrent.futures.Future: """ Describe cluster. @@ -1067,9 +1034,7 @@ def describe_cluster( # type: ignore[override] return f - def delete_consumer_groups( # type: ignore[override] - self, group_ids: List[str], **kwargs: Any - ) -> Dict[str, concurrent.futures.Future]: + def delete_consumer_groups(self, group_ids: List[str], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Delete the given consumer groups. @@ -1093,17 +1058,13 @@ def delete_consumer_groups( # type: ignore[override] if len(group_ids) == 0: raise ValueError("Expected at least one group to be deleted") - f, futmap = AdminClient._make_futures(group_ids, string_type, - AdminClient._make_consumer_groups_result) + f, futmap = AdminClient._make_futures(group_ids, string_type, AdminClient._make_consumer_groups_result) super(AdminClient, self).delete_consumer_groups(group_ids, f, **kwargs) return futmap - def list_consumer_group_offsets( # type: ignore[override] - self, list_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], - **kwargs: Any - ) -> Dict[str, concurrent.futures.Future]: + def list_consumer_group_offsets(self, list_consumer_group_offsets_request, **kwargs): """ List offset information for the consumer group and (optional) topic partition provided in the request. @@ -1131,19 +1092,15 @@ def list_consumer_group_offsets( # type: ignore[override] AdminClient._check_list_consumer_group_offsets_request(list_consumer_group_offsets_request) - f, futmap = AdminClient._make_futures( - [request.group_id for request in list_consumer_group_offsets_request], - string_type, - AdminClient._make_consumer_group_offsets_result) + f, futmap = AdminClient._make_futures([request.group_id for request in list_consumer_group_offsets_request], + string_type, + AdminClient._make_consumer_group_offsets_result) super(AdminClient, self).list_consumer_group_offsets(list_consumer_group_offsets_request, f, **kwargs) return futmap - def alter_consumer_group_offsets( # type: ignore[override] - self, alter_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], - **kwargs: Any - ) -> Dict[str, concurrent.futures.Future]: + def alter_consumer_group_offsets(self, alter_consumer_group_offsets_request, **kwargs): """ Alter offset for the consumer group and topic partition provided in the request. @@ -1195,9 +1152,7 @@ def set_sasl_credentials(self, username: str, password: str) -> None: """ super(AdminClient, self).set_sasl_credentials(username, password) - def describe_user_scram_credentials( # type: ignore[override] - self, users: Optional[List[str]] = None, **kwargs: Any - ) -> Union[concurrent.futures.Future, Dict[str, concurrent.futures.Future]]: + def describe_user_scram_credentials(self, users=None, **kwargs): """ Describe user SASL/SCRAM credentials. @@ -1228,14 +1183,12 @@ def describe_user_scram_credentials( # type: ignore[override] if users is None: internal_f, ret_fut = AdminClient._make_single_future_pair() else: - internal_f, ret_fut = AdminClient._make_futures_v2( # type: ignore[assignment] - users, None, AdminClient._make_futmap_result) + internal_f, ret_fut = AdminClient._make_futures_v2(users, None, + AdminClient._make_futmap_result) super(AdminClient, self).describe_user_scram_credentials(users, internal_f, **kwargs) return ret_fut - def alter_user_scram_credentials( # type: ignore[override] - self, alterations: List[UserScramCredentialAlteration], **kwargs: Any - ) -> Dict[str, concurrent.futures.Future]: + def alter_user_scram_credentials(self, alterations, **kwargs): """ Alter user SASL/SCRAM credentials. @@ -1257,16 +1210,13 @@ def alter_user_scram_credentials( # type: ignore[override] """ AdminClient._check_alter_user_scram_credentials_request(alterations) - f, futmap = AdminClient._make_futures_v2( - set([alteration.user for alteration in alterations]), None, - AdminClient._make_futmap_result) + f, futmap = AdminClient._make_futures_v2(set([alteration.user for alteration in alterations]), None, + AdminClient._make_futmap_result) super(AdminClient, self).alter_user_scram_credentials(alterations, f, **kwargs) return futmap - def list_offsets( # type: ignore[override] - self, topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], **kwargs: Any - ) -> Dict[_TopicPartition, concurrent.futures.Future]: + def list_offsets(self, topic_partition_offsets, **kwargs): """ Enables to find the beginning offset, end offset as well as the offset matching a timestamp @@ -1300,16 +1250,14 @@ def list_offsets( # type: ignore[override] int(offset_spec._value)) for topic_partition, offset_spec in topic_partition_offsets.items()] - f, futmap = AdminClient._make_futures_v2( - topic_partition_offsets_list, _TopicPartition, - AdminClient._make_futmap_result) + f, futmap = AdminClient._make_futures_v2(topic_partition_offsets_list, + _TopicPartition, + AdminClient._make_futmap_result) super(AdminClient, self).list_offsets(topic_partition_offsets_list, f, **kwargs) return futmap - def delete_records( # type: ignore[override] - self, topic_partition_offsets: List[_TopicPartition], **kwargs: Any - ) -> Dict[_TopicPartition, concurrent.futures.Future]: + def delete_records(self, topic_partition_offsets, **kwargs): """ Deletes all the records before the specified offsets (not including), in the specified topics and partitions. @@ -1346,10 +1294,7 @@ def delete_records( # type: ignore[override] super(AdminClient, self).delete_records(topic_partition_offsets, f, **kwargs) return futmap - def elect_leaders( # type: ignore[override] - self, election_type: _ElectionType, partitions: Optional[List[_TopicPartition]] = None, - **kwargs: Any - ) -> concurrent.futures.Future: + def elect_leaders(self, election_type, partitions=None, **kwargs): """ Perform Preferred or Unclean leader election for all the specified partitions or all partitions in the cluster. diff --git a/src/confluent_kafka/serialization/__init__.py b/src/confluent_kafka/serialization/__init__.py index 315ac0e99..ed59f3c1e 100644 --- a/src/confluent_kafka/serialization/__init__.py +++ b/src/confluent_kafka/serialization/__init__.py @@ -17,7 +17,7 @@ # import struct as _struct from enum import Enum -from typing import Any, List, Optional +from typing import Any, Optional from confluent_kafka.error import KafkaException from confluent_kafka._types import HeadersType @@ -114,7 +114,7 @@ class Serializer(object): - unicode(encoding) """ - __slots__: List[str] = [] + __slots__ = [] def __call__(self, obj: Any, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ @@ -171,7 +171,7 @@ class Deserializer(object): - unicode(encoding) """ - __slots__: List[str] = [] + __slots__ = [] def __call__(self, value: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Any: """ From 8f789a074e87177a401004d696d1df031c0c8c21 Mon Sep 17 00:00:00 2001 From: Naxin Date: Wed, 15 Oct 2025 13:42:02 -0400 Subject: [PATCH 10/31] finish admin init --- src/confluent_kafka/admin/__init__.py | 38 +++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/confluent_kafka/admin/__init__.py b/src/confluent_kafka/admin/__init__.py index 6a1c0e9f4..8c8f45ea9 100644 --- a/src/confluent_kafka/admin/__init__.py +++ b/src/confluent_kafka/admin/__init__.py @@ -421,7 +421,7 @@ def _check_list_consumer_group_offsets_request(request: Optional[List[_ConsumerG raise ValueError("Element of 'topic_partitions' must not have 'offset' value") @staticmethod - def _check_alter_consumer_group_offsets_request(request): + def _check_alter_consumer_group_offsets_request(request: Optional[List[_ConsumerGroupTopicPartitions]]) -> None: if request is None: raise TypeError("request cannot be None") if not isinstance(request, list): @@ -460,7 +460,7 @@ def _check_alter_consumer_group_offsets_request(request): "Element of 'topic_partitions' must not have negative value for 'offset' field") @staticmethod - def _check_describe_user_scram_credentials_request(users): + def _check_describe_user_scram_credentials_request(users: Optional[List[str]]) -> None: if users is None: return if not isinstance(users, list): @@ -474,7 +474,7 @@ def _check_describe_user_scram_credentials_request(users): raise ValueError("'user' cannot be empty") @staticmethod - def _check_alter_user_scram_credentials_request(alterations): + def _check_alter_user_scram_credentials_request(alterations: List[UserScramCredentialAlteration]) -> None: if not isinstance(alterations, list): raise TypeError("Expected input to be list") if len(alterations) == 0: @@ -517,7 +517,7 @@ def _check_alter_user_scram_credentials_request(alterations): "UserScramCredentialDeletion") @staticmethod - def _check_list_offsets_request(topic_partition_offsets, kwargs): + def _check_list_offsets_request(topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], kwargs: Dict[str, Any]) -> None: if not isinstance(topic_partition_offsets, dict): raise TypeError("Expected topic_partition_offsets to be " + "dict of [TopicPartitions,OffsetSpec] for list offsets request") @@ -545,7 +545,7 @@ def _check_list_offsets_request(topic_partition_offsets, kwargs): raise TypeError("isolation_level argument should be an IsolationLevel") @staticmethod - def _check_delete_records(request): + def _check_delete_records(request: List[_TopicPartition]) -> None: if not isinstance(request, list): raise TypeError(f"Expected Request to be a list, got '{type(request).__name__}' ") for req in request: @@ -556,7 +556,7 @@ def _check_delete_records(request): raise ValueError("'partition' cannot be negative") @staticmethod - def _check_elect_leaders(election_type, partitions): + def _check_elect_leaders(election_type: _ElectionType, partitions: Optional[List[_TopicPartition]]) -> None: if not isinstance(election_type, _ElectionType): raise TypeError("Expected 'election_type' to be of type 'ElectionType'") if partitions is not None: @@ -711,7 +711,7 @@ def describe_configs(self, resources: List[ConfigResource], **kwargs: Any) -> Di return futmap - def alter_configs(self, resources, **kwargs): + def alter_configs(self, resources: List[ConfigResource], **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: """ .. deprecated:: 2.2.0 @@ -759,7 +759,7 @@ def alter_configs(self, resources, **kwargs): return futmap - def incremental_alter_configs(self, resources, **kwargs): + def incremental_alter_configs(self, resources: List[ConfigResource], **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: """ Update configuration properties for the specified resources. Updates are incremental, i.e only the values mentioned are changed @@ -792,7 +792,7 @@ def incremental_alter_configs(self, resources, **kwargs): return futmap - def create_acls(self, acls, **kwargs): + def create_acls(self, acls: List[AclBinding], **kwargs: Any) -> Dict[AclBinding, concurrent.futures.Future]: """ Create one or more ACL bindings. @@ -821,7 +821,7 @@ def create_acls(self, acls, **kwargs): return futmap - def describe_acls(self, acl_binding_filter, **kwargs): + def describe_acls(self, acl_binding_filter: AclBindingFilter, **kwargs: Any) -> concurrent.futures.Future: """ Match ACL bindings by filter. @@ -856,7 +856,7 @@ def describe_acls(self, acl_binding_filter, **kwargs): return f - def delete_acls(self, acl_binding_filters, **kwargs): + def delete_acls(self, acl_binding_filters: List[AclBindingFilter], **kwargs: Any) -> Dict[AclBindingFilter, concurrent.futures.Future]: """ Delete ACL bindings matching one or more ACL binding filters. @@ -895,7 +895,7 @@ def delete_acls(self, acl_binding_filters, **kwargs): return futmap - def list_consumer_groups(self, **kwargs): + def list_consumer_groups(self, **kwargs: Any) -> concurrent.futures.Future: """ List consumer groups. @@ -1064,7 +1064,7 @@ def delete_consumer_groups(self, group_ids: List[str], **kwargs: Any) -> Dict[st return futmap - def list_consumer_group_offsets(self, list_consumer_group_offsets_request, **kwargs): + def list_consumer_group_offsets(self, list_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ List offset information for the consumer group and (optional) topic partition provided in the request. @@ -1100,7 +1100,7 @@ def list_consumer_group_offsets(self, list_consumer_group_offsets_request, **kwa return futmap - def alter_consumer_group_offsets(self, alter_consumer_group_offsets_request, **kwargs): + def alter_consumer_group_offsets(self, alter_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Alter offset for the consumer group and topic partition provided in the request. @@ -1152,7 +1152,7 @@ def set_sasl_credentials(self, username: str, password: str) -> None: """ super(AdminClient, self).set_sasl_credentials(username, password) - def describe_user_scram_credentials(self, users=None, **kwargs): + def describe_user_scram_credentials(self, users: Optional[List[str]] = None, **kwargs: Any) -> Union[concurrent.futures.Future, Dict[str, concurrent.futures.Future]]: """ Describe user SASL/SCRAM credentials. @@ -1188,7 +1188,7 @@ def describe_user_scram_credentials(self, users=None, **kwargs): super(AdminClient, self).describe_user_scram_credentials(users, internal_f, **kwargs) return ret_fut - def alter_user_scram_credentials(self, alterations, **kwargs): + def alter_user_scram_credentials(self, alterations: List[UserScramCredentialAlteration], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Alter user SASL/SCRAM credentials. @@ -1216,7 +1216,7 @@ def alter_user_scram_credentials(self, alterations, **kwargs): super(AdminClient, self).alter_user_scram_credentials(alterations, f, **kwargs) return futmap - def list_offsets(self, topic_partition_offsets, **kwargs): + def list_offsets(self, topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], **kwargs: Any) -> Dict[_TopicPartition, concurrent.futures.Future]: """ Enables to find the beginning offset, end offset as well as the offset matching a timestamp @@ -1257,7 +1257,7 @@ def list_offsets(self, topic_partition_offsets, **kwargs): super(AdminClient, self).list_offsets(topic_partition_offsets_list, f, **kwargs) return futmap - def delete_records(self, topic_partition_offsets, **kwargs): + def delete_records(self, topic_partition_offsets: List[_TopicPartition], **kwargs: Any) -> Dict[_TopicPartition, concurrent.futures.Future]: """ Deletes all the records before the specified offsets (not including), in the specified topics and partitions. @@ -1294,7 +1294,7 @@ def delete_records(self, topic_partition_offsets, **kwargs): super(AdminClient, self).delete_records(topic_partition_offsets, f, **kwargs) return futmap - def elect_leaders(self, election_type, partitions=None, **kwargs): + def elect_leaders(self, election_type: _ElectionType, partitions: Optional[List[_TopicPartition]] = None, **kwargs: Any) -> concurrent.futures.Future: """ Perform Preferred or Unclean leader election for all the specified partitions or all partitions in the cluster. From 9c99020984b995d556c169e9a5d6886d0d050038 Mon Sep 17 00:00:00 2001 From: Naxin Date: Wed, 15 Oct 2025 15:55:01 -0400 Subject: [PATCH 11/31] add types for AIO module --- .../experimental/aio/_AIOConsumer.py | 5 ++- .../experimental/aio/_common.py | 4 +- .../experimental/aio/producer/_AIOProducer.py | 7 ++-- .../aio/producer/_kafka_batch_executor.py | 41 ++++++------------- 4 files changed, 23 insertions(+), 34 deletions(-) diff --git a/src/confluent_kafka/experimental/aio/_AIOConsumer.py b/src/confluent_kafka/experimental/aio/_AIOConsumer.py index b169da80c..5ad8ae20b 100644 --- a/src/confluent_kafka/experimental/aio/_AIOConsumer.py +++ b/src/confluent_kafka/experimental/aio/_AIOConsumer.py @@ -14,17 +14,18 @@ import asyncio import concurrent.futures -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Optional, Tuple import confluent_kafka from . import _common as _common +from ..._types import ConfigDict class AIOConsumer: def __init__( self, - consumer_conf: Dict[str, Any], + consumer_conf: ConfigDict, max_workers: int = 2, executor: Optional[concurrent.futures.Executor] = None ) -> None: diff --git a/src/confluent_kafka/experimental/aio/_common.py b/src/confluent_kafka/experimental/aio/_common.py index 3bf274064..c9f8f6ea5 100644 --- a/src/confluent_kafka/experimental/aio/_common.py +++ b/src/confluent_kafka/experimental/aio/_common.py @@ -18,6 +18,8 @@ import concurrent.futures from typing import Any, Callable, Dict, Optional, Tuple, TypeVar +from ..._types import ConfigDict + T = TypeVar('T') @@ -32,7 +34,7 @@ def __init__( self.logger = logger def log(self, *args: Any, **kwargs: Any) -> None: - self.loop.call_soon_threadsafe(lambda: self.logger.log(*args, **kwargs)) + self.loop.call_soon_threadsafe(self.logger.log, *args, **kwargs) def wrap_callback( diff --git a/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py b/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py index 5e2fd8fb6..8546f794f 100644 --- a/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py +++ b/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py @@ -15,7 +15,7 @@ import asyncio import concurrent.futures import logging -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import confluent_kafka @@ -23,6 +23,7 @@ from ._producer_batch_processor import ProducerBatchManager from ._kafka_batch_executor import ProducerBatchExecutor from ._buffer_timeout_manager import BufferTimeoutManager +from ..._types import ConfigDict logger = logging.getLogger(__name__) @@ -36,7 +37,7 @@ class AIOProducer: def __init__( self, - producer_conf: Dict[str, Any], + producer_conf: ConfigDict, max_workers: int = 4, executor: Optional[concurrent.futures.Executor] = None, batch_size: int = 1000, @@ -224,7 +225,7 @@ async def flush(self, *args: Any, **kwargs: Any) -> Any: # Update buffer activity since we just flushed self._buffer_timeout_manager.mark_activity() - # Then flush the underlying producer and wait for delivery confirmation + # Then flush underlying producer and wait for delivery confirmation return await self._call(self._producer.flush, *args, **kwargs) async def purge(self, *args: Any, **kwargs: Any) -> Any: diff --git a/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py b/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py index 925849c58..dadecaa3a 100644 --- a/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py +++ b/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py @@ -19,8 +19,6 @@ import confluent_kafka -from .. import _common - logger = logging.getLogger(__name__) @@ -65,10 +63,10 @@ async def execute_batch( Args: topic: Target topic for the batch batch_messages: List of prepared messages with callbacks assigned - partition: Target partition for the batch (-1 = RD_KAFKA_PARTITION_UA) + partition: Target partition (-1 = RD_KAFKA_PARTITION_UA) Returns: - Result from producer.poll() indicating number of delivery reports processed + Result from producer.poll() indicating # of delivery reports processed Raises: Exception: Any exception from the batch operation is propagated @@ -77,9 +75,9 @@ def _produce_batch_and_poll() -> int: """Helper function to run in thread pool This function encapsulates all the blocking Kafka operations: - - Call produce_batch with specific partition and individual message callbacks + - Call produce_batch with specific partition & individual callbacks - Handle partial batch failures for messages that fail immediately - - Poll for delivery reports to trigger callbacks for successful messages + - Poll for delivery reports to trigger callbacks for successful msgs """ # Call produce_batch with specific partition and individual callbacks # Convert tuple to list since produce_batch expects a list @@ -90,7 +88,8 @@ def _produce_batch_and_poll() -> int: ) # Use the provided partition for the entire batch - # This enables proper partition control while working around librdkafka limitations + # This enables proper partition control while working around + # librdkafka limitations self._producer.produce_batch(topic, messages_list, partition=partition) # Handle partial batch failures: Check for messages that failed @@ -99,7 +98,7 @@ def _produce_batch_and_poll() -> int: # so we need to manually invoke their callbacks self._handle_partial_failures(messages_list) - # Immediately poll to process delivery callbacks for successful messages + # Immediately poll to process delivery callbacks for successful msgs poll_result = self._producer.poll(0) return poll_result @@ -108,24 +107,10 @@ def _produce_batch_and_poll() -> int: loop = asyncio.get_running_loop() return await loop.run_in_executor(self._executor, _produce_batch_and_poll) - async def flush_librdkafka_queue(self, timeout=-1): - """Flush the librdkafka queue and wait for all messages to be delivered - - This method awaits until all outstanding produce requests are completed - or the timeout is reached, unless the timeout is set to 0 (non-blocking). - - Args: - timeout: Maximum time to wait in seconds: - - -1 = wait indefinitely (default) - - 0 = non-blocking, return immediately - - >0 = wait up to timeout seconds - - Returns: - Number of messages still in queue after flush attempt - """ - return await _common.async_call(self._executor, self._producer.flush, timeout) - - def _handle_partial_failures(self, batch_messages: List[Dict[str, Any]]) -> None: + def _handle_partial_failures( + self, + batch_messages: List[Dict[str, Any]] + ) -> None: """Handle messages that failed during produce_batch When produce_batch encounters messages that fail immediately (e.g., @@ -135,7 +120,7 @@ def _handle_partial_failures(self, batch_messages: List[Dict[str, Any]]) -> None manually invoke the simple future-resolving callbacks. Args: - batch_messages: List of message dictionaries that were passed to produce_batch + batch_messages: List of message dicts passed to produce_batch """ for msg_dict in batch_messages: if '_error' in msg_dict: @@ -146,7 +131,7 @@ def _handle_partial_failures(self, batch_messages: List[Dict[str, Any]]) -> None # Extract the error from the message dict (set by Producer.c) error = msg_dict['_error'] # Manually invoke the callback with the error - # Note: msg is None since the message failed before being queued + # Note: msg is None since message failed before being queued try: callback(error, None) except Exception: From a387ea7ce72c10995a6415ca421ac2138112d76e Mon Sep 17 00:00:00 2001 From: Naxin Date: Wed, 15 Oct 2025 16:45:25 -0400 Subject: [PATCH 12/31] linter fix --- src/confluent_kafka/admin/__init__.py | 89 ++++++++++++------- src/confluent_kafka/admin/_metadata.py | 4 - src/confluent_kafka/deserializing_consumer.py | 2 +- .../experimental/aio/_common.py | 2 - src/confluent_kafka/serializing_producer.py | 2 +- 5 files changed, 61 insertions(+), 38 deletions(-) diff --git a/src/confluent_kafka/admin/__init__.py b/src/confluent_kafka/admin/__init__.py index 8c8f45ea9..a5c7eaa9a 100644 --- a/src/confluent_kafka/admin/__init__.py +++ b/src/confluent_kafka/admin/__init__.py @@ -156,7 +156,8 @@ def _make_topics_result(f: concurrent.futures.Future, futmap: Dict[str, concurre fut.set_exception(e) @staticmethod - def _make_resource_result(f: concurrent.futures.Future, futmap: Dict[ConfigResource, concurrent.futures.Future]) -> None: + def _make_resource_result(f: concurrent.futures.Future, + futmap: Dict[ConfigResource, concurrent.futures.Future]) -> None: """ Map per-resource results to per-resource futures in futmap. The result value of each (successful) future is a ConfigResource. @@ -185,7 +186,8 @@ def _make_list_consumer_groups_result(f: concurrent.futures.Future, futmap: Any) pass @staticmethod - def _make_consumer_groups_result(f: concurrent.futures.Future, futmap: Dict[str, concurrent.futures.Future]) -> None: + def _make_consumer_groups_result(f: concurrent.futures.Future, + futmap: Dict[str, concurrent.futures.Future]) -> None: """ Map per-group results to per-group futures in futmap. """ @@ -210,7 +212,8 @@ def _make_consumer_groups_result(f: concurrent.futures.Future, futmap: Dict[str, fut.set_exception(e) @staticmethod - def _make_consumer_group_offsets_result(f: concurrent.futures.Future, futmap: Dict[str, concurrent.futures.Future]) -> None: + def _make_consumer_group_offsets_result(f: concurrent.futures.Future, + futmap: Dict[str, concurrent.futures.Future]) -> None: """ Map per-group results to per-group futures in futmap. The result value of each (successful) future is ConsumerGroupTopicPartitions. @@ -262,7 +265,8 @@ def _make_acls_result(f: concurrent.futures.Future, futmap: Dict[Any, concurrent fut.set_exception(e) @staticmethod - def _make_futmap_result_from_list(f: concurrent.futures.Future, futmap: Dict[Any, concurrent.futures.Future]) -> None: + def _make_futmap_result_from_list(f: concurrent.futures.Future, + futmap: Dict[Any, concurrent.futures.Future]) -> None: try: results = f.result() @@ -313,7 +317,9 @@ def _create_future() -> concurrent.futures.Future: return f @staticmethod - def _make_futures(futmap_keys: List[Any], class_check: Optional[type], make_result_fn: Any) -> Tuple[concurrent.futures.Future, Dict[Any, concurrent.futures.Future]]: + def _make_futures(futmap_keys: List[Any], class_check: Optional[type], + make_result_fn: Any) -> Tuple[concurrent.futures.Future, + Dict[Any, concurrent.futures.Future]]: """ Create futures and a futuremap for the keys in futmap_keys, and create a request-level future to be bassed to the C API. @@ -335,7 +341,9 @@ def _make_futures(futmap_keys: List[Any], class_check: Optional[type], make_resu return f, futmap @staticmethod - def _make_futures_v2(futmap_keys: Union[List[Any], Set[Any]], class_check: Optional[type], make_result_fn: Any) -> Tuple[concurrent.futures.Future, Dict[Any, concurrent.futures.Future]]: + def _make_futures_v2(futmap_keys: Union[List[Any], Set[Any]], class_check: Optional[type], + make_result_fn: Any) -> Tuple[concurrent.futures.Future, + Dict[Any, concurrent.futures.Future]]: """ Create futures and a futuremap for the keys in futmap_keys, and create a request-level future to be bassed to the C API. @@ -517,7 +525,8 @@ def _check_alter_user_scram_credentials_request(alterations: List[UserScramCrede "UserScramCredentialDeletion") @staticmethod - def _check_list_offsets_request(topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], kwargs: Dict[str, Any]) -> None: + def _check_list_offsets_request(topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], + kwargs: Dict[str, Any]) -> None: if not isinstance(topic_partition_offsets, dict): raise TypeError("Expected topic_partition_offsets to be " + "dict of [TopicPartitions,OffsetSpec] for list offsets request") @@ -645,7 +654,8 @@ def list_groups(self, *args: Any, **kwargs: Any) -> GroupMetadata: return super(AdminClient, self).list_groups(*args, **kwargs) - def create_partitions(self, new_partitions: List[NewPartitions], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def create_partitions(self, new_partitions: List[NewPartitions], + **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Create additional partitions for the given topics. @@ -679,7 +689,8 @@ def create_partitions(self, new_partitions: List[NewPartitions], **kwargs: Any) return futmap - def describe_configs(self, resources: List[ConfigResource], **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: + def describe_configs(self, resources: List[ConfigResource], + **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: """ Get the configuration of the specified resources. @@ -711,7 +722,8 @@ def describe_configs(self, resources: List[ConfigResource], **kwargs: Any) -> Di return futmap - def alter_configs(self, resources: List[ConfigResource], **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: + def alter_configs(self, resources: List[ConfigResource], + **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: """ .. deprecated:: 2.2.0 @@ -759,7 +771,8 @@ def alter_configs(self, resources: List[ConfigResource], **kwargs: Any) -> Dict[ return futmap - def incremental_alter_configs(self, resources: List[ConfigResource], **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: + def incremental_alter_configs(self, resources: List[ConfigResource], + **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: """ Update configuration properties for the specified resources. Updates are incremental, i.e only the values mentioned are changed @@ -856,7 +869,8 @@ def describe_acls(self, acl_binding_filter: AclBindingFilter, **kwargs: Any) -> return f - def delete_acls(self, acl_binding_filters: List[AclBindingFilter], **kwargs: Any) -> Dict[AclBindingFilter, concurrent.futures.Future]: + def delete_acls(self, acl_binding_filters: List[AclBindingFilter], + **kwargs: Any) -> Dict[AclBindingFilter, concurrent.futures.Future]: """ Delete ACL bindings matching one or more ACL binding filters. @@ -1058,13 +1072,16 @@ def delete_consumer_groups(self, group_ids: List[str], **kwargs: Any) -> Dict[st if len(group_ids) == 0: raise ValueError("Expected at least one group to be deleted") - f, futmap = AdminClient._make_futures(group_ids, string_type, AdminClient._make_consumer_groups_result) + f, futmap = AdminClient._make_futures(group_ids, string_type, + AdminClient._make_consumer_groups_result) super(AdminClient, self).delete_consumer_groups(group_ids, f, **kwargs) return futmap - def list_consumer_group_offsets(self, list_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def list_consumer_group_offsets( + self, list_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], + **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ List offset information for the consumer group and (optional) topic partition provided in the request. @@ -1092,15 +1109,18 @@ def list_consumer_group_offsets(self, list_consumer_group_offsets_request: List[ AdminClient._check_list_consumer_group_offsets_request(list_consumer_group_offsets_request) - f, futmap = AdminClient._make_futures([request.group_id for request in list_consumer_group_offsets_request], - string_type, - AdminClient._make_consumer_group_offsets_result) + f, futmap = AdminClient._make_futures( + [request.group_id for request in list_consumer_group_offsets_request], + string_type, + AdminClient._make_consumer_group_offsets_result) super(AdminClient, self).list_consumer_group_offsets(list_consumer_group_offsets_request, f, **kwargs) return futmap - def alter_consumer_group_offsets(self, alter_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def alter_consumer_group_offsets( + self, alter_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], + **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Alter offset for the consumer group and topic partition provided in the request. @@ -1152,7 +1172,9 @@ def set_sasl_credentials(self, username: str, password: str) -> None: """ super(AdminClient, self).set_sasl_credentials(username, password) - def describe_user_scram_credentials(self, users: Optional[List[str]] = None, **kwargs: Any) -> Union[concurrent.futures.Future, Dict[str, concurrent.futures.Future]]: + def describe_user_scram_credentials( + self, users: Optional[List[str]] = None, + **kwargs: Any) -> Union[concurrent.futures.Future, Dict[str, concurrent.futures.Future]]: """ Describe user SASL/SCRAM credentials. @@ -1183,12 +1205,14 @@ def describe_user_scram_credentials(self, users: Optional[List[str]] = None, **k if users is None: internal_f, ret_fut = AdminClient._make_single_future_pair() else: - internal_f, ret_fut = AdminClient._make_futures_v2(users, None, - AdminClient._make_futmap_result) + internal_f, ret_fut = AdminClient._make_futures_v2( + users, None, AdminClient._make_futmap_result) super(AdminClient, self).describe_user_scram_credentials(users, internal_f, **kwargs) return ret_fut - def alter_user_scram_credentials(self, alterations: List[UserScramCredentialAlteration], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def alter_user_scram_credentials( + self, alterations: List[UserScramCredentialAlteration], + **kwargs: Any) -> Dict[str, concurrent.futures.Future]: """ Alter user SASL/SCRAM credentials. @@ -1210,13 +1234,15 @@ def alter_user_scram_credentials(self, alterations: List[UserScramCredentialAlte """ AdminClient._check_alter_user_scram_credentials_request(alterations) - f, futmap = AdminClient._make_futures_v2(set([alteration.user for alteration in alterations]), None, - AdminClient._make_futmap_result) + f, futmap = AdminClient._make_futures_v2( + set([alteration.user for alteration in alterations]), None, + AdminClient._make_futmap_result) super(AdminClient, self).alter_user_scram_credentials(alterations, f, **kwargs) return futmap - def list_offsets(self, topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], **kwargs: Any) -> Dict[_TopicPartition, concurrent.futures.Future]: + def list_offsets(self, topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], + **kwargs: Any) -> Dict[_TopicPartition, concurrent.futures.Future]: """ Enables to find the beginning offset, end offset as well as the offset matching a timestamp @@ -1250,14 +1276,15 @@ def list_offsets(self, topic_partition_offsets: Dict[_TopicPartition, OffsetSpec int(offset_spec._value)) for topic_partition, offset_spec in topic_partition_offsets.items()] - f, futmap = AdminClient._make_futures_v2(topic_partition_offsets_list, - _TopicPartition, - AdminClient._make_futmap_result) + f, futmap = AdminClient._make_futures_v2( + topic_partition_offsets_list, _TopicPartition, + AdminClient._make_futmap_result) super(AdminClient, self).list_offsets(topic_partition_offsets_list, f, **kwargs) return futmap - def delete_records(self, topic_partition_offsets: List[_TopicPartition], **kwargs: Any) -> Dict[_TopicPartition, concurrent.futures.Future]: + def delete_records(self, topic_partition_offsets: List[_TopicPartition], + **kwargs: Any) -> Dict[_TopicPartition, concurrent.futures.Future]: """ Deletes all the records before the specified offsets (not including), in the specified topics and partitions. @@ -1294,7 +1321,9 @@ def delete_records(self, topic_partition_offsets: List[_TopicPartition], **kwarg super(AdminClient, self).delete_records(topic_partition_offsets, f, **kwargs) return futmap - def elect_leaders(self, election_type: _ElectionType, partitions: Optional[List[_TopicPartition]] = None, **kwargs: Any) -> concurrent.futures.Future: + def elect_leaders(self, election_type: _ElectionType, + partitions: Optional[List[_TopicPartition]] = None, + **kwargs: Any) -> concurrent.futures.Future: """ Perform Preferred or Unclean leader election for all the specified partitions or all partitions in the cluster. diff --git a/src/confluent_kafka/admin/_metadata.py b/src/confluent_kafka/admin/_metadata.py index 51e9f7b37..d2c115003 100644 --- a/src/confluent_kafka/admin/_metadata.py +++ b/src/confluent_kafka/admin/_metadata.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional - -from confluent_kafka.cimpl import KafkaError - class ClusterMetadata(object): """ diff --git a/src/confluent_kafka/deserializing_consumer.py b/src/confluent_kafka/deserializing_consumer.py index 5abd8147b..417eac0b2 100644 --- a/src/confluent_kafka/deserializing_consumer.py +++ b/src/confluent_kafka/deserializing_consumer.py @@ -24,7 +24,7 @@ ValueDeserializationError) from .serialization import (SerializationContext, MessageField) -from ._types import ConfigDict, Deserializer +from ._types import ConfigDict class DeserializingConsumer(_ConsumerImpl): diff --git a/src/confluent_kafka/experimental/aio/_common.py b/src/confluent_kafka/experimental/aio/_common.py index c9f8f6ea5..24659ed9f 100644 --- a/src/confluent_kafka/experimental/aio/_common.py +++ b/src/confluent_kafka/experimental/aio/_common.py @@ -18,8 +18,6 @@ import concurrent.futures from typing import Any, Callable, Dict, Optional, Tuple, TypeVar -from ..._types import ConfigDict - T = TypeVar('T') diff --git a/src/confluent_kafka/serializing_producer.py b/src/confluent_kafka/serializing_producer.py index 88b2defd6..6d27f2586 100644 --- a/src/confluent_kafka/serializing_producer.py +++ b/src/confluent_kafka/serializing_producer.py @@ -23,7 +23,7 @@ SerializationContext) from .error import (KeySerializationError, ValueSerializationError) -from ._types import ConfigDict, HeadersType, DeliveryCallback, Serializer +from ._types import ConfigDict, HeadersType, DeliveryCallback class SerializingProducer(_ProducerImpl): From c1e2f91686d54d81b183489643476b8bec8e4a96 Mon Sep 17 00:00:00 2001 From: Naxin Date: Thu, 16 Oct 2025 14:48:22 -0400 Subject: [PATCH 13/31] address mypy complaints --- src/confluent_kafka/_types.py | 2 - src/confluent_kafka/admin/__init__.py | 118 +++++++++++------- src/confluent_kafka/admin/_acl.py | 8 +- src/confluent_kafka/admin/_config.py | 4 +- src/confluent_kafka/admin/_listoffsets.py | 9 +- src/confluent_kafka/admin/_metadata.py | 6 +- src/confluent_kafka/cimpl.pyi | 56 ++++----- src/confluent_kafka/deserializing_consumer.py | 10 +- .../experimental/aio/_AIOConsumer.py | 5 +- .../experimental/aio/_common.py | 2 +- .../experimental/aio/producer/_AIOProducer.py | 5 +- src/confluent_kafka/serializing_producer.py | 14 ++- 12 files changed, 129 insertions(+), 110 deletions(-) diff --git a/src/confluent_kafka/_types.py b/src/confluent_kafka/_types.py index b0e6d20b5..2cf8ed8ea 100644 --- a/src/confluent_kafka/_types.py +++ b/src/confluent_kafka/_types.py @@ -25,8 +25,6 @@ from typing import Any, Optional, Dict, Union, Callable, List, Tuple -# Configuration dictionary type -ConfigDict = Dict[str, Union[str, int, float, bool]] # Headers can be either dict format or list of tuples format HeadersType = Union[ Dict[str, Union[str, bytes, None]], diff --git a/src/confluent_kafka/admin/__init__.py b/src/confluent_kafka/admin/__init__.py index a5c7eaa9a..360d7b1cc 100644 --- a/src/confluent_kafka/admin/__init__.py +++ b/src/confluent_kafka/admin/__init__.py @@ -87,11 +87,9 @@ ConsumerGroupState as _ConsumerGroupState, \ IsolationLevel as _IsolationLevel -from .._types import ConfigDict - try: - string_type = basestring + string_type = basestring # type: ignore[name-defined] except NameError: string_type = str @@ -118,7 +116,7 @@ class AdminClient (_AdminClientImpl): Requires broker version v0.11.0.0 or later. """ - def __init__(self, conf: ConfigDict, **kwargs: Any) -> None: + def __init__(self, conf: Dict[str, Union[str, int, float, bool]], **kwargs: Any) -> None: """ Create a new AdminClient using the provided configuration dictionary. @@ -311,7 +309,7 @@ def _make_futmap_result(f: concurrent.futures.Future, futmap: Dict[str, concurre @staticmethod def _create_future() -> concurrent.futures.Future: - f = concurrent.futures.Future() + f: concurrent.futures.Future = concurrent.futures.Future() if not f.set_running_or_notify_cancel(): raise RuntimeError("Future was cancelled prematurely") return f @@ -580,7 +578,9 @@ def _check_elect_leaders(election_type: _ElectionType, partitions: Optional[List raise ValueError("Elements of the 'partitions' list must not have negative value" + " for 'partition' field") - def create_topics(self, new_topics: List[NewTopic], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def create_topics( # type: ignore[override] + self, new_topics: List[NewTopic], **kwargs: Any + ) -> Dict[str, concurrent.futures.Future]: """ Create one or more new topics. @@ -615,7 +615,9 @@ def create_topics(self, new_topics: List[NewTopic], **kwargs: Any) -> Dict[str, return futmap - def delete_topics(self, topics: List[str], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def delete_topics( # type: ignore[override] + self, topics: List[str], **kwargs: Any + ) -> Dict[str, concurrent.futures.Future]: """ Delete one or more topics. @@ -654,8 +656,9 @@ def list_groups(self, *args: Any, **kwargs: Any) -> GroupMetadata: return super(AdminClient, self).list_groups(*args, **kwargs) - def create_partitions(self, new_partitions: List[NewPartitions], - **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def create_partitions( # type: ignore[override] + self, new_partitions: List[NewPartitions], **kwargs: Any + ) -> Dict[str, concurrent.futures.Future]: """ Create additional partitions for the given topics. @@ -689,8 +692,9 @@ def create_partitions(self, new_partitions: List[NewPartitions], return futmap - def describe_configs(self, resources: List[ConfigResource], - **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: + def describe_configs( # type: ignore[override] + self, resources: List[ConfigResource], **kwargs: Any + ) -> Dict[ConfigResource, concurrent.futures.Future]: """ Get the configuration of the specified resources. @@ -722,8 +726,9 @@ def describe_configs(self, resources: List[ConfigResource], return futmap - def alter_configs(self, resources: List[ConfigResource], - **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: + def alter_configs( # type: ignore[override] + self, resources: List[ConfigResource], **kwargs: Any + ) -> Dict[ConfigResource, concurrent.futures.Future]: """ .. deprecated:: 2.2.0 @@ -771,8 +776,9 @@ def alter_configs(self, resources: List[ConfigResource], return futmap - def incremental_alter_configs(self, resources: List[ConfigResource], - **kwargs: Any) -> Dict[ConfigResource, concurrent.futures.Future]: + def incremental_alter_configs( # type: ignore[override] + self, resources: List[ConfigResource], **kwargs: Any + ) -> Dict[ConfigResource, concurrent.futures.Future]: """ Update configuration properties for the specified resources. Updates are incremental, i.e only the values mentioned are changed @@ -805,7 +811,9 @@ def incremental_alter_configs(self, resources: List[ConfigResource], return futmap - def create_acls(self, acls: List[AclBinding], **kwargs: Any) -> Dict[AclBinding, concurrent.futures.Future]: + def create_acls( # type: ignore[override] + self, acls: List[AclBinding], **kwargs: Any + ) -> Dict[AclBinding, concurrent.futures.Future]: """ Create one or more ACL bindings. @@ -834,7 +842,9 @@ def create_acls(self, acls: List[AclBinding], **kwargs: Any) -> Dict[AclBinding, return futmap - def describe_acls(self, acl_binding_filter: AclBindingFilter, **kwargs: Any) -> concurrent.futures.Future: + def describe_acls( # type: ignore[override] + self, acl_binding_filter: AclBindingFilter, **kwargs: Any + ) -> concurrent.futures.Future: """ Match ACL bindings by filter. @@ -869,8 +879,9 @@ def describe_acls(self, acl_binding_filter: AclBindingFilter, **kwargs: Any) -> return f - def delete_acls(self, acl_binding_filters: List[AclBindingFilter], - **kwargs: Any) -> Dict[AclBindingFilter, concurrent.futures.Future]: + def delete_acls( # type: ignore[override] + self, acl_binding_filters: List[AclBindingFilter], **kwargs: Any + ) -> Dict[AclBindingFilter, concurrent.futures.Future]: """ Delete ACL bindings matching one or more ACL binding filters. @@ -909,7 +920,9 @@ def delete_acls(self, acl_binding_filters: List[AclBindingFilter], return futmap - def list_consumer_groups(self, **kwargs: Any) -> concurrent.futures.Future: + def list_consumer_groups( # type: ignore[override] + self, **kwargs: Any + ) -> concurrent.futures.Future: """ List consumer groups. @@ -956,7 +969,9 @@ def list_consumer_groups(self, **kwargs: Any) -> concurrent.futures.Future: return f - def describe_consumer_groups(self, group_ids: List[str], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def describe_consumer_groups( # type: ignore[override] + self, group_ids: List[str], **kwargs: Any + ) -> Dict[str, concurrent.futures.Future]: """ Describe consumer groups. @@ -985,11 +1000,13 @@ def describe_consumer_groups(self, group_ids: List[str], **kwargs: Any) -> Dict[ f, futmap = AdminClient._make_futures(group_ids, None, AdminClient._make_consumer_groups_result) - super(AdminClient, self).describe_consumer_groups(group_ids, f, **kwargs) + super(AdminClient, self).describe_consumer_groups(group_ids, f, **kwargs) # type: ignore[arg-type] return futmap - def describe_topics(self, topics: _TopicCollection, **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def describe_topics( # type: ignore[override] + self, topics: _TopicCollection, **kwargs: Any + ) -> Dict[str, concurrent.futures.Future]: """ Describe topics. @@ -1020,11 +1037,13 @@ def describe_topics(self, topics: _TopicCollection, **kwargs: Any) -> Dict[str, f, futmap = AdminClient._make_futures_v2(topic_names, None, AdminClient._make_futmap_result_from_list) - super(AdminClient, self).describe_topics(topic_names, f, **kwargs) + super(AdminClient, self).describe_topics(topic_names, f, **kwargs) # type: ignore[arg-type] return futmap - def describe_cluster(self, **kwargs: Any) -> concurrent.futures.Future: + def describe_cluster( # type: ignore[override] + self, **kwargs: Any + ) -> concurrent.futures.Future: """ Describe cluster. @@ -1048,7 +1067,9 @@ def describe_cluster(self, **kwargs: Any) -> concurrent.futures.Future: return f - def delete_consumer_groups(self, group_ids: List[str], **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def delete_consumer_groups( # type: ignore[override] + self, group_ids: List[str], **kwargs: Any + ) -> Dict[str, concurrent.futures.Future]: """ Delete the given consumer groups. @@ -1079,9 +1100,10 @@ def delete_consumer_groups(self, group_ids: List[str], **kwargs: Any) -> Dict[st return futmap - def list_consumer_group_offsets( - self, list_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], - **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def list_consumer_group_offsets( # type: ignore[override] + self, list_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], + **kwargs: Any + ) -> Dict[str, concurrent.futures.Future]: """ List offset information for the consumer group and (optional) topic partition provided in the request. @@ -1118,9 +1140,10 @@ def list_consumer_group_offsets( return futmap - def alter_consumer_group_offsets( - self, alter_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], - **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def alter_consumer_group_offsets( # type: ignore[override] + self, alter_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], + **kwargs: Any + ) -> Dict[str, concurrent.futures.Future]: """ Alter offset for the consumer group and topic partition provided in the request. @@ -1172,9 +1195,9 @@ def set_sasl_credentials(self, username: str, password: str) -> None: """ super(AdminClient, self).set_sasl_credentials(username, password) - def describe_user_scram_credentials( - self, users: Optional[List[str]] = None, - **kwargs: Any) -> Union[concurrent.futures.Future, Dict[str, concurrent.futures.Future]]: + def describe_user_scram_credentials( # type: ignore[override] + self, users: Optional[List[str]] = None, **kwargs: Any + ) -> Union[concurrent.futures.Future, Dict[str, concurrent.futures.Future]]: """ Describe user SASL/SCRAM credentials. @@ -1205,14 +1228,14 @@ def describe_user_scram_credentials( if users is None: internal_f, ret_fut = AdminClient._make_single_future_pair() else: - internal_f, ret_fut = AdminClient._make_futures_v2( + internal_f, ret_fut = AdminClient._make_futures_v2( # type: ignore[assignment] users, None, AdminClient._make_futmap_result) super(AdminClient, self).describe_user_scram_credentials(users, internal_f, **kwargs) return ret_fut - def alter_user_scram_credentials( - self, alterations: List[UserScramCredentialAlteration], - **kwargs: Any) -> Dict[str, concurrent.futures.Future]: + def alter_user_scram_credentials( # type: ignore[override] + self, alterations: List[UserScramCredentialAlteration], **kwargs: Any + ) -> Dict[str, concurrent.futures.Future]: """ Alter user SASL/SCRAM credentials. @@ -1241,8 +1264,9 @@ def alter_user_scram_credentials( super(AdminClient, self).alter_user_scram_credentials(alterations, f, **kwargs) return futmap - def list_offsets(self, topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], - **kwargs: Any) -> Dict[_TopicPartition, concurrent.futures.Future]: + def list_offsets( # type: ignore[override] + self, topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], **kwargs: Any + ) -> Dict[_TopicPartition, concurrent.futures.Future]: """ Enables to find the beginning offset, end offset as well as the offset matching a timestamp @@ -1283,8 +1307,9 @@ def list_offsets(self, topic_partition_offsets: Dict[_TopicPartition, OffsetSpec super(AdminClient, self).list_offsets(topic_partition_offsets_list, f, **kwargs) return futmap - def delete_records(self, topic_partition_offsets: List[_TopicPartition], - **kwargs: Any) -> Dict[_TopicPartition, concurrent.futures.Future]: + def delete_records( # type: ignore[override] + self, topic_partition_offsets: List[_TopicPartition], **kwargs: Any + ) -> Dict[_TopicPartition, concurrent.futures.Future]: """ Deletes all the records before the specified offsets (not including), in the specified topics and partitions. @@ -1321,9 +1346,10 @@ def delete_records(self, topic_partition_offsets: List[_TopicPartition], super(AdminClient, self).delete_records(topic_partition_offsets, f, **kwargs) return futmap - def elect_leaders(self, election_type: _ElectionType, - partitions: Optional[List[_TopicPartition]] = None, - **kwargs: Any) -> concurrent.futures.Future: + def elect_leaders( # type: ignore[override] + self, election_type: _ElectionType, partitions: Optional[List[_TopicPartition]] = None, + **kwargs: Any + ) -> concurrent.futures.Future: """ Perform Preferred or Unclean leader election for all the specified partitions or all partitions in the cluster. diff --git a/src/confluent_kafka/admin/_acl.py b/src/confluent_kafka/admin/_acl.py index 940d25913..d318c97ec 100644 --- a/src/confluent_kafka/admin/_acl.py +++ b/src/confluent_kafka/admin/_acl.py @@ -110,7 +110,7 @@ def __init__(self, restype: Union[ResourceType, str, int], name: str, self.permission_type_int = int(self.permission_type.value) # type: ignore[union-attr] def _convert_enums(self) -> None: - self.restype = ConversionUtil.convert_to_enum(self.restype, ResourceType) + self.restype = ConversionUtil.convert_to_enum(self.restype, ResourceType) # type: ignore[assignment] self.resource_pattern_type = ConversionUtil.convert_to_enum( self.resource_pattern_type, ResourcePatternType) # type: ignore[assignment] self.operation = ConversionUtil.convert_to_enum( @@ -154,7 +154,7 @@ def __repr__(self) -> str: return "%s(%s,%s,%s,%s,%s,%s,%s)" % ((type_name,) + self._to_tuple()) def _to_tuple(self) -> Tuple[ResourceType, str, ResourcePatternType, str, str, AclOperation, AclPermissionType]: - return (self.restype, self.name, self.resource_pattern_type, + return (self.restype, self.name, self.resource_pattern_type, # type: ignore[return-value] self.principal, self.host, self.operation, self.permission_type) @@ -166,8 +166,8 @@ def __lt__(self, other: 'AclBinding') -> Any: return NotImplemented return self._to_tuple() < other._to_tuple() - def __eq__(self, other: 'AclBinding') -> Any: - if self.__class__ != other.__class__: + def __eq__(self, other: object) -> Any: + if not isinstance(other, AclBinding): return NotImplemented return self._to_tuple() == other._to_tuple() diff --git a/src/confluent_kafka/admin/_config.py b/src/confluent_kafka/admin/_config.py index abccffc66..c303f7bbf 100644 --- a/src/confluent_kafka/admin/_config.py +++ b/src/confluent_kafka/admin/_config.py @@ -190,7 +190,9 @@ def __lt__(self, other: 'ConfigResource') -> bool: return True return self.name.__lt__(other.name) - def __eq__(self, other: 'ConfigResource') -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, ConfigResource): + return NotImplemented return self.restype == other.restype and self.name == other.name def __len__(self) -> int: diff --git a/src/confluent_kafka/admin/_listoffsets.py b/src/confluent_kafka/admin/_listoffsets.py index 0d815266a..d8def288c 100644 --- a/src/confluent_kafka/admin/_listoffsets.py +++ b/src/confluent_kafka/admin/_listoffsets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Dict, Any, Optional from abc import ABC, abstractmethod from .. import cimpl @@ -24,6 +24,9 @@ class OffsetSpec(ABC): of the partition being queried. """ _values: Dict[int, 'OffsetSpec'] = {} + _max_timestamp: Optional['MaxTimestampSpec'] = None + _earliest: Optional['EarliestSpec'] = None + _latest: Optional['LatestSpec'] = None @property @abstractmethod @@ -156,6 +159,4 @@ class ListOffsetsResultInfo: def __init__(self, offset: int, timestamp: int, leader_epoch: int) -> None: self.offset = offset self.timestamp = timestamp - self.leader_epoch: Optional[int] = leader_epoch - if leader_epoch < 0: - self.leader_epoch = None + self.leader_epoch: Optional[int] = leader_epoch if leader_epoch >= 0 else None diff --git a/src/confluent_kafka/admin/_metadata.py b/src/confluent_kafka/admin/_metadata.py index d2c115003..8132e3bc1 100644 --- a/src/confluent_kafka/admin/_metadata.py +++ b/src/confluent_kafka/admin/_metadata.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List + class ClusterMetadata(object): """ @@ -89,7 +91,7 @@ def __repr__(self) -> str: return "TopicMetadata({}, {} partitions)".format(self.topic, len(self.partitions)) def __str__(self) -> str: - return self.topic + return str(self.topic) class PartitionMetadata(object): @@ -177,4 +179,4 @@ def __repr__(self) -> str: return "GroupMetadata({})".format(self.id) def __str__(self) -> str: - return self.id + return str(self.id) diff --git a/src/confluent_kafka/cimpl.pyi b/src/confluent_kafka/cimpl.pyi index ea8b3846d..9faa97f01 100644 --- a/src/confluent_kafka/cimpl.pyi +++ b/src/confluent_kafka/cimpl.pyi @@ -34,14 +34,11 @@ TODO: Consider migrating to Cython in the future to eliminate this dual maintenance burden and get type hints directly from the implementation. """ -from typing import Any, Optional, Callable, List, Tuple, Dict, Union, overload, TYPE_CHECKING +from typing import Any, Optional, Callable, List, Tuple, Dict, Union, overload from typing_extensions import Self, Literal import builtins -from ._types import ConfigDict, HeadersType - -if TYPE_CHECKING: - from confluent_kafka.admin._metadata import ClusterMetadata, GroupMetadata +from ._types import HeadersType # Callback types with proper class references (defined locally to avoid circular imports) DeliveryCallback = Callable[[Optional['KafkaError'], 'Message'], None] @@ -65,17 +62,16 @@ class KafkaError: def __str__(self) -> builtins.str: ... def __bool__(self) -> bool: ... def __hash__(self) -> int: ... - def __eq__(self, other: Union['KafkaError', int]) -> bool: ... - def __ne__(self, other: Union['KafkaError', int]) -> bool: ... + def __eq__(self, other: object) -> bool: ... + def __ne__(self, other: object) -> bool: ... def __lt__(self, other: Union['KafkaError', int]) -> bool: ... def __le__(self, other: Union['KafkaError', int]) -> bool: ... def __gt__(self, other: Union['KafkaError', int]) -> bool: ... def __ge__(self, other: Union['KafkaError', int]) -> bool: ... class KafkaException(Exception): - def __init__(self, kafka_error: KafkaError) -> None: ... - @property - def args(self) -> Tuple[KafkaError, ...]: ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + args: Tuple[Any, ...] class Message: def topic(self) -> str: ... @@ -116,7 +112,7 @@ class Uuid: def __eq__(self, other: object) -> bool: ... class Producer: - def __init__(self, config: ConfigDict) -> None: ... + def __init__(self, config: Dict[str, Union[str, int, float, bool]]) -> None: ... def produce( self, topic: str, @@ -143,7 +139,7 @@ class Producer: in_queue: bool = True, in_flight: bool = True, blocking: bool = True - ) -> int: ... + ) -> None: ... def abort_transaction(self, timeout: float = -1) -> None: ... def begin_transaction(self) -> None: ... def commit_transaction(self, timeout: float = -1) -> None: ... @@ -160,7 +156,7 @@ class Producer: def __bool__(self) -> bool: ... class Consumer: - def __init__(self, config: ConfigDict) -> None: ... + def __init__(self, config: Dict[str, Union[str, int, float, bool, None]]) -> None: ... def subscribe( self, topics: List[str], @@ -188,11 +184,6 @@ class Consumer: offsets: Optional[List[TopicPartition]] = None, asynchronous: Literal[False] = False ) -> List[TopicPartition]: ... - def committed( - self, - partitions: List[TopicPartition], - timeout: float = -1 - ) -> List[TopicPartition]: ... def get_watermark_offsets( self, partition: TopicPartition, @@ -223,7 +214,7 @@ class Consumer: def __bool__(self) -> bool: ... class _AdminClientImpl: - def __init__(self, config: ConfigDict) -> None: ... + def __init__(self, config: Dict[str, Union[str, int, float, bool]]) -> None: ... def create_topics( self, topics: List['NewTopic'], @@ -262,15 +253,14 @@ class _AdminClientImpl: ) -> None: ... def list_topics( self, - future: Any, - request_timeout: float = -1 - ) -> None: ... + topic: Optional[str] = None, + timeout: float = -1 + ) -> Any: ... def list_groups( self, - future: Any, - request_timeout: float = -1, - states: Optional[List[str]] = None - ) -> None: ... + group: Optional[str] = None, + timeout: float = -1 + ) -> Any: ... def describe_consumer_groups( self, future: Any, @@ -331,7 +321,7 @@ class _AdminClientImpl: ) -> None: ... def alter_configs( self, - resources: Dict[Any, Dict[str, str]], # Dict[ConfigResource, Dict[str, str]] + resources: List[Any], # List[ConfigResource] future: Any, validate_only: bool = False, request_timeout: float = -1, @@ -339,7 +329,7 @@ class _AdminClientImpl: ) -> None: ... def incremental_alter_configs( self, - resources: Dict[Any, Dict[str, Any]], # Dict[ConfigResource, Dict[str, ConfigEntry]] + resources: List[Any], # List[ConfigResource] future: Any, validate_only: bool = False, request_timeout: float = -1, @@ -379,7 +369,7 @@ class _AdminClientImpl: request_timeout: float = -1, operation_timeout: float = -1 ) -> None: ... - def poll(self, timeout: float = -1) -> Any: ... + def poll(self, timeout: float = -1) -> int: ... def set_sasl_credentials(self, username: str, password: str) -> None: ... class NewTopic: @@ -398,8 +388,8 @@ class NewTopic: config: Optional[Dict[str, str]] def __str__(self) -> str: ... def __hash__(self) -> int: ... - def __eq__(self, other: 'NewTopic') -> bool: ... - def __ne__(self, other: 'NewTopic') -> bool: ... + def __eq__(self, other: object) -> bool: ... + def __ne__(self, other: object) -> bool: ... def __lt__(self, other: 'NewTopic') -> bool: ... def __le__(self, other: 'NewTopic') -> bool: ... def __gt__(self, other: 'NewTopic') -> bool: ... @@ -417,8 +407,8 @@ class NewPartitions: replica_assignment: Optional[List[List[int]]] def __str__(self) -> str: ... def __hash__(self) -> int: ... - def __eq__(self, other: 'NewPartitions') -> bool: ... - def __ne__(self, other: 'NewPartitions') -> bool: ... + def __eq__(self, other: object) -> bool: ... + def __ne__(self, other: object) -> bool: ... def __lt__(self, other: 'NewPartitions') -> bool: ... def __le__(self, other: 'NewPartitions') -> bool: ... def __gt__(self, other: 'NewPartitions') -> bool: ... diff --git a/src/confluent_kafka/deserializing_consumer.py b/src/confluent_kafka/deserializing_consumer.py index 417eac0b2..793271b86 100644 --- a/src/confluent_kafka/deserializing_consumer.py +++ b/src/confluent_kafka/deserializing_consumer.py @@ -16,7 +16,7 @@ # limitations under the License. # -from typing import Optional, List +from typing import Any, Dict, List, Optional from confluent_kafka.cimpl import Consumer as _ConsumerImpl, Message from .error import (ConsumeError, @@ -24,7 +24,7 @@ ValueDeserializationError) from .serialization import (SerializationContext, MessageField) -from ._types import ConfigDict +from ._types import Deserializer class DeserializingConsumer(_ConsumerImpl): @@ -73,7 +73,7 @@ class DeserializingConsumer(_ConsumerImpl): ValueError: if configuration validation fails """ # noqa: E501 - def __init__(self, conf: ConfigDict) -> None: + def __init__(self, conf: Dict[str, Any]) -> None: conf_copy = conf.copy() self._key_deserializer = conf_copy.pop('key.deserializer', None) self._value_deserializer = conf_copy.pop('value.deserializer', None) @@ -123,8 +123,8 @@ def poll(self, timeout: float = -1) -> Optional[Message]: except Exception as se: raise KeyDeserializationError(exception=se, kafka_message=msg) - msg.set_key(key) - msg.set_value(value) + msg.set_key(key) # type: ignore[arg-type] + msg.set_value(value) # type: ignore[arg-type] return msg def consume(self, num_messages: int = 1, timeout: float = -1) -> List[Message]: diff --git a/src/confluent_kafka/experimental/aio/_AIOConsumer.py b/src/confluent_kafka/experimental/aio/_AIOConsumer.py index 5ad8ae20b..b169da80c 100644 --- a/src/confluent_kafka/experimental/aio/_AIOConsumer.py +++ b/src/confluent_kafka/experimental/aio/_AIOConsumer.py @@ -14,18 +14,17 @@ import asyncio import concurrent.futures -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import confluent_kafka from . import _common as _common -from ..._types import ConfigDict class AIOConsumer: def __init__( self, - consumer_conf: ConfigDict, + consumer_conf: Dict[str, Any], max_workers: int = 2, executor: Optional[concurrent.futures.Executor] = None ) -> None: diff --git a/src/confluent_kafka/experimental/aio/_common.py b/src/confluent_kafka/experimental/aio/_common.py index 24659ed9f..3bf274064 100644 --- a/src/confluent_kafka/experimental/aio/_common.py +++ b/src/confluent_kafka/experimental/aio/_common.py @@ -32,7 +32,7 @@ def __init__( self.logger = logger def log(self, *args: Any, **kwargs: Any) -> None: - self.loop.call_soon_threadsafe(self.logger.log, *args, **kwargs) + self.loop.call_soon_threadsafe(lambda: self.logger.log(*args, **kwargs)) def wrap_callback( diff --git a/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py b/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py index 8546f794f..b2c7b86df 100644 --- a/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py +++ b/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py @@ -15,7 +15,7 @@ import asyncio import concurrent.futures import logging -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional import confluent_kafka @@ -23,7 +23,6 @@ from ._producer_batch_processor import ProducerBatchManager from ._kafka_batch_executor import ProducerBatchExecutor from ._buffer_timeout_manager import BufferTimeoutManager -from ..._types import ConfigDict logger = logging.getLogger(__name__) @@ -37,7 +36,7 @@ class AIOProducer: def __init__( self, - producer_conf: ConfigDict, + producer_conf: Dict[str, Any], max_workers: int = 4, executor: Optional[concurrent.futures.Executor] = None, batch_size: int = 1000, diff --git a/src/confluent_kafka/serializing_producer.py b/src/confluent_kafka/serializing_producer.py index 6d27f2586..9234ce5e1 100644 --- a/src/confluent_kafka/serializing_producer.py +++ b/src/confluent_kafka/serializing_producer.py @@ -16,14 +16,14 @@ # limitations under the License. # -from typing import Any, Optional +from typing import Any, Dict, Optional from confluent_kafka.cimpl import Producer as _ProducerImpl from .serialization import (MessageField, SerializationContext) from .error import (KeySerializationError, ValueSerializationError) -from ._types import ConfigDict, HeadersType, DeliveryCallback +from ._types import HeadersType, DeliveryCallback, Serializer class SerializingProducer(_ProducerImpl): @@ -69,7 +69,7 @@ class SerializingProducer(_ProducerImpl): conf (producer): SerializingProducer configuration. """ # noqa E501 - def __init__(self, conf: ConfigDict) -> None: + def __init__(self, conf: Dict[str, Any]) -> None: conf_copy = conf.copy() self._key_serializer = conf_copy.pop('key.serializer', None) @@ -77,9 +77,11 @@ def __init__(self, conf: ConfigDict) -> None: super(SerializingProducer, self).__init__(conf_copy) - def produce(self, topic: str, key: Any = None, value: Any = None, partition: int = -1, - on_delivery: Optional[DeliveryCallback] = None, timestamp: int = 0, - headers: Optional[HeadersType] = None) -> None: + def produce( # type: ignore[override] + self, topic: str, key: Any = None, value: Any = None, partition: int = -1, + on_delivery: Optional[DeliveryCallback] = None, timestamp: int = 0, + headers: Optional[HeadersType] = None + ) -> None: """ Produce a message. From 8bdb0af830cecf5c7ac4be25ee5b9170598a036b Mon Sep 17 00:00:00 2001 From: Naxin Date: Thu, 16 Oct 2025 15:10:01 -0400 Subject: [PATCH 14/31] revert some accidental doc change --- .../aio/producer/_kafka_batch_executor.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py b/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py index dadecaa3a..3253bb083 100644 --- a/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py +++ b/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py @@ -63,10 +63,10 @@ async def execute_batch( Args: topic: Target topic for the batch batch_messages: List of prepared messages with callbacks assigned - partition: Target partition (-1 = RD_KAFKA_PARTITION_UA) + partition: Target partition for the batch (-1 = RD_KAFKA_PARTITION_UA) Returns: - Result from producer.poll() indicating # of delivery reports processed + Result from producer.poll() indicating number of delivery reports processed Raises: Exception: Any exception from the batch operation is propagated @@ -75,9 +75,9 @@ def _produce_batch_and_poll() -> int: """Helper function to run in thread pool This function encapsulates all the blocking Kafka operations: - - Call produce_batch with specific partition & individual callbacks + - Call produce_batch with specific partition and individual message callbacks - Handle partial batch failures for messages that fail immediately - - Poll for delivery reports to trigger callbacks for successful msgs + - Poll for delivery reports to trigger callbacks for successful messages """ # Call produce_batch with specific partition and individual callbacks # Convert tuple to list since produce_batch expects a list @@ -88,8 +88,7 @@ def _produce_batch_and_poll() -> int: ) # Use the provided partition for the entire batch - # This enables proper partition control while working around - # librdkafka limitations + # This enables proper partition control while working around librdkafka limitations self._producer.produce_batch(topic, messages_list, partition=partition) # Handle partial batch failures: Check for messages that failed @@ -98,7 +97,7 @@ def _produce_batch_and_poll() -> int: # so we need to manually invoke their callbacks self._handle_partial_failures(messages_list) - # Immediately poll to process delivery callbacks for successful msgs + # Immediately poll to process delivery callbacks for successful messages poll_result = self._producer.poll(0) return poll_result @@ -120,7 +119,7 @@ def _handle_partial_failures( manually invoke the simple future-resolving callbacks. Args: - batch_messages: List of message dicts passed to produce_batch + batch_messages: List of message dictionaries that were passed to produce_batch """ for msg_dict in batch_messages: if '_error' in msg_dict: @@ -131,7 +130,7 @@ def _handle_partial_failures( # Extract the error from the message dict (set by Producer.c) error = msg_dict['_error'] # Manually invoke the callback with the error - # Note: msg is None since message failed before being queued + # Note: msg is None since the message failed before being queued try: callback(error, None) except Exception: From 26694e6fc550c94e181943bbfed707456dd57608 Mon Sep 17 00:00:00 2001 From: Naxin Date: Tue, 21 Oct 2025 13:12:35 -0400 Subject: [PATCH 15/31] fix some suggestions by copilot --- src/confluent_kafka/admin/__init__.py | 2 +- src/confluent_kafka/admin/_listoffsets.py | 4 +++- src/confluent_kafka/cimpl.pyi | 6 +++--- src/confluent_kafka/deserializing_consumer.py | 4 ++-- src/confluent_kafka/experimental/aio/_common.py | 2 +- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/confluent_kafka/admin/__init__.py b/src/confluent_kafka/admin/__init__.py index 360d7b1cc..bb3b81543 100644 --- a/src/confluent_kafka/admin/__init__.py +++ b/src/confluent_kafka/admin/__init__.py @@ -1000,7 +1000,7 @@ def describe_consumer_groups( # type: ignore[override] f, futmap = AdminClient._make_futures(group_ids, None, AdminClient._make_consumer_groups_result) - super(AdminClient, self).describe_consumer_groups(group_ids, f, **kwargs) # type: ignore[arg-type] + super(AdminClient, self).describe_consumer_groups(group_ids, f, **kwargs) return futmap diff --git a/src/confluent_kafka/admin/_listoffsets.py b/src/confluent_kafka/admin/_listoffsets.py index d8def288c..205e852be 100644 --- a/src/confluent_kafka/admin/_listoffsets.py +++ b/src/confluent_kafka/admin/_listoffsets.py @@ -159,4 +159,6 @@ class ListOffsetsResultInfo: def __init__(self, offset: int, timestamp: int, leader_epoch: int) -> None: self.offset = offset self.timestamp = timestamp - self.leader_epoch: Optional[int] = leader_epoch if leader_epoch >= 0 else None + self.leader_epoch: Optional[int] = leader_epoch + if leader_epoch < 0: + self.leader_epoch = None diff --git a/src/confluent_kafka/cimpl.pyi b/src/confluent_kafka/cimpl.pyi index 9faa97f01..332d552bb 100644 --- a/src/confluent_kafka/cimpl.pyi +++ b/src/confluent_kafka/cimpl.pyi @@ -85,8 +85,8 @@ class Message: def latency(self) -> Optional[float]: ... def leader_epoch(self) -> Optional[int]: ... def set_headers(self, headers: HeadersType) -> None: ... - def set_key(self, key: bytes) -> None: ... - def set_value(self, value: bytes) -> None: ... + def set_key(self, key: Any) -> None: ... + def set_value(self, value: Any) -> None: ... def __len__(self) -> int: ... class TopicPartition: @@ -263,8 +263,8 @@ class _AdminClientImpl: ) -> Any: ... def describe_consumer_groups( self, - future: Any, group_ids: List[str], + future: Any, request_timeout: float = -1, include_authorized_operations: bool = False ) -> None: ... diff --git a/src/confluent_kafka/deserializing_consumer.py b/src/confluent_kafka/deserializing_consumer.py index 793271b86..c645a5b31 100644 --- a/src/confluent_kafka/deserializing_consumer.py +++ b/src/confluent_kafka/deserializing_consumer.py @@ -123,8 +123,8 @@ def poll(self, timeout: float = -1) -> Optional[Message]: except Exception as se: raise KeyDeserializationError(exception=se, kafka_message=msg) - msg.set_key(key) # type: ignore[arg-type] - msg.set_value(value) # type: ignore[arg-type] + msg.set_key(key) + msg.set_value(value) return msg def consume(self, num_messages: int = 1, timeout: float = -1) -> List[Message]: diff --git a/src/confluent_kafka/experimental/aio/_common.py b/src/confluent_kafka/experimental/aio/_common.py index 3bf274064..24659ed9f 100644 --- a/src/confluent_kafka/experimental/aio/_common.py +++ b/src/confluent_kafka/experimental/aio/_common.py @@ -32,7 +32,7 @@ def __init__( self.logger = logger def log(self, *args: Any, **kwargs: Any) -> None: - self.loop.call_soon_threadsafe(lambda: self.logger.log(*args, **kwargs)) + self.loop.call_soon_threadsafe(self.logger.log, *args, **kwargs) def wrap_callback( From c7865d8aea217be5637e2ee2d33376806f973d46 Mon Sep 17 00:00:00 2001 From: Naxin Date: Tue, 21 Oct 2025 13:43:18 -0400 Subject: [PATCH 16/31] linter --- src/confluent_kafka/admin/__init__.py | 12 ++++++------ src/confluent_kafka/deserializing_consumer.py | 1 - src/confluent_kafka/serializing_producer.py | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/confluent_kafka/admin/__init__.py b/src/confluent_kafka/admin/__init__.py index bb3b81543..b4705fa03 100644 --- a/src/confluent_kafka/admin/__init__.py +++ b/src/confluent_kafka/admin/__init__.py @@ -185,7 +185,7 @@ def _make_list_consumer_groups_result(f: concurrent.futures.Future, futmap: Any) @staticmethod def _make_consumer_groups_result(f: concurrent.futures.Future, - futmap: Dict[str, concurrent.futures.Future]) -> None: + futmap: Dict[str, concurrent.futures.Future]) -> None: """ Map per-group results to per-group futures in futmap. """ @@ -211,7 +211,7 @@ def _make_consumer_groups_result(f: concurrent.futures.Future, @staticmethod def _make_consumer_group_offsets_result(f: concurrent.futures.Future, - futmap: Dict[str, concurrent.futures.Future]) -> None: + futmap: Dict[str, concurrent.futures.Future]) -> None: """ Map per-group results to per-group futures in futmap. The result value of each (successful) future is ConsumerGroupTopicPartitions. @@ -264,7 +264,7 @@ def _make_acls_result(f: concurrent.futures.Future, futmap: Dict[Any, concurrent @staticmethod def _make_futmap_result_from_list(f: concurrent.futures.Future, - futmap: Dict[Any, concurrent.futures.Future]) -> None: + futmap: Dict[Any, concurrent.futures.Future]) -> None: try: results = f.result() @@ -317,7 +317,7 @@ def _create_future() -> concurrent.futures.Future: @staticmethod def _make_futures(futmap_keys: List[Any], class_check: Optional[type], make_result_fn: Any) -> Tuple[concurrent.futures.Future, - Dict[Any, concurrent.futures.Future]]: + Dict[Any, concurrent.futures.Future]]: """ Create futures and a futuremap for the keys in futmap_keys, and create a request-level future to be bassed to the C API. @@ -341,7 +341,7 @@ def _make_futures(futmap_keys: List[Any], class_check: Optional[type], @staticmethod def _make_futures_v2(futmap_keys: Union[List[Any], Set[Any]], class_check: Optional[type], make_result_fn: Any) -> Tuple[concurrent.futures.Future, - Dict[Any, concurrent.futures.Future]]: + Dict[Any, concurrent.futures.Future]]: """ Create futures and a futuremap for the keys in futmap_keys, and create a request-level future to be bassed to the C API. @@ -524,7 +524,7 @@ def _check_alter_user_scram_credentials_request(alterations: List[UserScramCrede @staticmethod def _check_list_offsets_request(topic_partition_offsets: Dict[_TopicPartition, OffsetSpec], - kwargs: Dict[str, Any]) -> None: + kwargs: Dict[str, Any]) -> None: if not isinstance(topic_partition_offsets, dict): raise TypeError("Expected topic_partition_offsets to be " + "dict of [TopicPartitions,OffsetSpec] for list offsets request") diff --git a/src/confluent_kafka/deserializing_consumer.py b/src/confluent_kafka/deserializing_consumer.py index c645a5b31..324239f25 100644 --- a/src/confluent_kafka/deserializing_consumer.py +++ b/src/confluent_kafka/deserializing_consumer.py @@ -24,7 +24,6 @@ ValueDeserializationError) from .serialization import (SerializationContext, MessageField) -from ._types import Deserializer class DeserializingConsumer(_ConsumerImpl): diff --git a/src/confluent_kafka/serializing_producer.py b/src/confluent_kafka/serializing_producer.py index 9234ce5e1..c4547d9d4 100644 --- a/src/confluent_kafka/serializing_producer.py +++ b/src/confluent_kafka/serializing_producer.py @@ -23,7 +23,7 @@ SerializationContext) from .error import (KeySerializationError, ValueSerializationError) -from ._types import HeadersType, DeliveryCallback, Serializer +from ._types import HeadersType, DeliveryCallback class SerializingProducer(_ProducerImpl): From 791c4adea6a972ff25bff30191e0ee6ac1b23506 Mon Sep 17 00:00:00 2001 From: Naxin Date: Fri, 17 Oct 2025 17:01:06 -0400 Subject: [PATCH 17/31] fix --- .../schema_registry/_async/avro.py | 30 +++++++++++------ .../schema_registry/_async/json_schema.py | 32 +++++++++++++------ .../schema_registry/_async/protobuf.py | 22 ++++++++----- .../schema_registry/_sync/avro.py | 30 +++++++++++------ .../schema_registry/_sync/json_schema.py | 32 +++++++++++++------ .../schema_registry/_sync/protobuf.py | 20 ++++++++---- 6 files changed, 113 insertions(+), 53 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index d29611853..3bf8e4e80 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -16,7 +16,7 @@ # limitations under the License. import io import json -from typing import Dict, Union, Optional, Callable +from typing import Any, Coroutine, Dict, Union, Optional, Callable, cast from fastavro import schemaless_reader, schemaless_writer from confluent_kafka.schema_registry.common import asyncinit @@ -206,7 +206,7 @@ async def __init_impl( self._registry = schema_registry_client self._schema_id = None self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._known_subjects = set() + self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() if to_dict is not None and not callable(to_dict): @@ -243,11 +243,17 @@ async def __init_impl( not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") @@ -286,7 +292,7 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__serialize(obj, ctx) async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -486,11 +492,17 @@ async def __init_impl( not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") @@ -518,11 +530,11 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Union[dict, object, None]]: return self.__deserialize(data, ctx) async def __deserialize( - self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: """ Deserialize Avro binary encoded data with Confluent Schema Registry framing to a dict, or object instance according to from_dict, if specified. diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py index d522838b0..120641eb0 100644 --- a/src/confluent_kafka/schema_registry/_async/json_schema.py +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -16,7 +16,7 @@ # limitations under the License. import io import orjson -from typing import Union, Optional, Tuple, Callable +from typing import Any, Coroutine, Union, Optional, Tuple, Callable, cast from cachetools import LRUCache from jsonschema import ValidationError @@ -226,7 +226,7 @@ async def __init_impl( rule_registry if rule_registry else RuleRegistry.get_global_instance() ) self._schema_id = None - self._known_subjects = set() + self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() self._validators = LRUCache(1000) @@ -264,11 +264,17 @@ async def __init_impl( not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") @@ -296,7 +302,7 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__serialize(obj, ctx) async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -525,12 +531,18 @@ async def __init_impl( not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') - if not callable(self._subject_name_func): + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) + if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") self._validate = conf_copy.pop('validate') @@ -558,10 +570,10 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__deserialize(data, ctx) - async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + async def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a JSON encoded record with Confluent Schema Registry framing to a dict, or object instance according to from_dict if from_dict is specified. diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index 41b8df970..eab735284 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -16,7 +16,7 @@ # limitations under the License. import io -from typing import Set, List, Union, Optional, Tuple +from typing import Any, Coroutine, Set, List, Union, Optional, Tuple, Callable, cast from google.protobuf import json_format, descriptor_pb2 from google.protobuf.descriptor_pool import DescriptorPool @@ -272,7 +272,7 @@ async def __init_impl( self._registry = schema_registry_client self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() self._schema_id = None - self._known_subjects = set() + self._known_subjects: set[str] = set() self._msg_class = msg_type self._parsed_schemas = ParsedSchemaCache() @@ -360,7 +360,7 @@ async def _resolve_dependencies( reference.version)) return schema_refs - def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__serialize(message, ctx) async def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -535,11 +535,17 @@ async def __init_impl( not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") @@ -558,10 +564,10 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__deserialize(data, ctx) - async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + async def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a serialized protobuf message with Confluent Schema Registry framing. @@ -601,7 +607,7 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = if subject is None: subject = self._subject_name_func(ctx, writer_desc.full_name) if subject is not None: - latest_schema = self._get_reader_schema(subject, fmt='serialized') + latest_schema = await self._get_reader_schema(subject, fmt='serialized') else: writer_schema_raw = None writer_schema = None diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index b4282624e..686369099 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -16,7 +16,7 @@ # limitations under the License. import io import json -from typing import Dict, Union, Optional, Callable +from typing import Any, Coroutine, Dict, Union, Optional, Callable, cast from fastavro import schemaless_reader, schemaless_writer @@ -206,7 +206,7 @@ def __init_impl( self._registry = schema_registry_client self._schema_id = None self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._known_subjects = set() + self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() if to_dict is not None and not callable(to_dict): @@ -243,11 +243,17 @@ def __init_impl( not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") @@ -286,7 +292,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__serialize(obj, ctx) def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -486,11 +492,17 @@ def __init_impl( not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") @@ -518,11 +530,11 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Union[dict, object, None]]: return self.__deserialize(data, ctx) def __deserialize( - self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: """ Deserialize Avro binary encoded data with Confluent Schema Registry framing to a dict, or object instance according to from_dict, if specified. diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 9d809386c..36877a6b0 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -16,7 +16,7 @@ # limitations under the License. import io import orjson -from typing import Union, Optional, Tuple, Callable +from typing import Any, Coroutine, Union, Optional, Tuple, Callable, cast from cachetools import LRUCache from jsonschema import ValidationError @@ -226,7 +226,7 @@ def __init_impl( rule_registry if rule_registry else RuleRegistry.get_global_instance() ) self._schema_id = None - self._known_subjects = set() + self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() self._validators = LRUCache(1000) @@ -264,11 +264,17 @@ def __init_impl( not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") @@ -296,7 +302,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__serialize(obj, ctx) def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -525,12 +531,18 @@ def __init_impl( not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') - if not callable(self._subject_name_func): + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) + if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") self._validate = conf_copy.pop('validate') @@ -558,10 +570,10 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__deserialize(data, ctx) - def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a JSON encoded record with Confluent Schema Registry framing to a dict, or object instance according to from_dict if from_dict is specified. diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 501823716..0c3addcb5 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -16,7 +16,7 @@ # limitations under the License. import io -from typing import Set, List, Union, Optional, Tuple +from typing import Any, Coroutine, Set, List, Union, Optional, Tuple, Callable, cast from google.protobuf import json_format, descriptor_pb2 from google.protobuf.descriptor_pool import DescriptorPool @@ -272,7 +272,7 @@ def __init_impl( self._registry = schema_registry_client self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() self._schema_id = None - self._known_subjects = set() + self._known_subjects: set[str] = set() self._msg_class = msg_type self._parsed_schemas = ParsedSchemaCache() @@ -360,7 +360,7 @@ def _resolve_dependencies( reference.version)) return schema_refs - def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__serialize(message, ctx) def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -535,11 +535,17 @@ def __init_impl( not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") @@ -558,10 +564,10 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__deserialize(data, ctx) - def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a serialized protobuf message with Confluent Schema Registry framing. From ffb118eaf14dff67b1ccd58b452151cfbf8a3b88 Mon Sep 17 00:00:00 2001 From: Naxin Date: Tue, 21 Oct 2025 14:46:09 -0400 Subject: [PATCH 18/31] resolve conflict --- pyproject.toml | 15 +++ requirements/requirements-tests.txt | 2 + .../schema_registry/__init__.py | 2 +- .../schema_registry/_async/avro.py | 90 +++++++------ .../schema_registry/_async/json_schema.py | 124 +++++++++-------- .../_async/mock_schema_registry_client.py | 33 +++-- .../schema_registry/_async/protobuf.py | 97 ++++++++------ .../_async/schema_registry_client.py | 109 +++++++-------- .../schema_registry/_async/serde.py | 53 ++++++-- .../schema_registry/_sync/avro.py | 92 +++++++------ .../schema_registry/_sync/json_schema.py | 126 ++++++++++-------- .../_sync/mock_schema_registry_client.py | 33 +++-- .../schema_registry/_sync/protobuf.py | 101 ++++++++------ .../_sync/schema_registry_client.py | 109 +++++++-------- .../schema_registry/_sync/serde.py | 53 ++++++-- .../schema_registry/common/avro.py | 18 +-- .../schema_registry/common/json_schema.py | 22 +-- .../schema_registry/common/protobuf.py | 16 +-- .../schema_registry/rules/cel/cel_executor.py | 12 +- .../schema_registry/rules/cel/constraints.py | 10 +- .../schema_registry/rules/cel/extra_func.py | 4 +- .../rules/cel/string_format.py | 30 ++--- .../rules/encryption/azurekms/azure_client.py | 7 +- .../dek_registry/dek_registry_client.py | 10 +- .../rules/encryption/encrypt_executor.py | 48 +++---- .../encryption/hcvault/hcvault_client.py | 19 +-- .../rules/encryption/localkms/local_client.py | 2 +- 27 files changed, 708 insertions(+), 529 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bf382ef2d..e33ffbf1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,21 @@ Homepage = "https://github.com/confluentinc/confluent-kafka-python" [tool.mypy] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = [ + "confluent_kafka.schema_registry.avro", + "confluent_kafka.schema_registry.json_schema", + "confluent_kafka.schema_registry.protobuf", +] +disable_error_code = ["assignment", "no-redef"] + +[[tool.mypy.overrides]] +module = [ + "confluent_kafka.schema_registry.confluent.meta_pb2", + "confluent_kafka.schema_registry.confluent.types.decimal_pb2", +] +ignore_errors = true + [tool.setuptools] include-package-data = false diff --git a/requirements/requirements-tests.txt b/requirements/requirements-tests.txt index e9cc5ca35..730bd2be5 100644 --- a/requirements/requirements-tests.txt +++ b/requirements/requirements-tests.txt @@ -1,6 +1,8 @@ # core test requirements urllib3<3 flake8 +mypy +types-cachetools orjson pytest pytest-timeout diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index 0cf16eef3..582cda859 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -205,7 +205,7 @@ def dual_schema_id_deserializer(payload: bytes, ctx: Optional[SerializationConte # Parse schema ID from determined source and return appropriate payload if header_value is not None: - schema_id.from_bytes(io.BytesIO(header_value)) + schema_id.from_bytes(io.BytesIO(header_value)) # type: ignore[arg-type] return io.BytesIO(payload) # Return full payload when schema ID is in header else: return schema_id.from_bytes(io.BytesIO(payload)) # Parse from payload, return remainder diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index 3bf8e4e80..d700d8dd6 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -57,12 +57,13 @@ async def _resolve_named_schema( named_schemas = {} if schema.references is not None: for ref in schema.references: - referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) - ref_named_schemas = await _resolve_named_schema(referenced_schema.schema, schema_registry_client) + # References in registered schemas are validated by server to be complete + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) # type: ignore[arg-type] + ref_named_schemas = await _resolve_named_schema(referenced_schema.schema, schema_registry_client) # type: ignore[arg-type] parsed_schema = parse_schema_with_repo( - referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) + referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) # type: ignore[union-attr,arg-type] named_schemas.update(ref_named_schemas) - named_schemas[ref.name] = parsed_schema + named_schemas[ref.name] = parsed_schema # type: ignore[index] return named_schemas @@ -204,7 +205,7 @@ async def __init_impl( schema = None self._registry = schema_registry_client - self._schema_id = None + self._schema_id: Optional[SchemaId] = None self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() @@ -219,26 +220,26 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") @@ -276,8 +277,11 @@ async def __init_impl( # i.e. {"type": "string"} has a name of string. # This function does not comply. # https://github.com/fastavro/fastavro/issues/415 - schema_dict = json.loads(schema.schema_str) - schema_name = parsed_schema.get("name", schema_dict.get("type")) + if schema.schema_str is not None: + schema_dict = json.loads(schema.schema_str) + schema_name = parsed_schema.get("name", schema_dict.get("type")) # type: ignore[union-attr] + else: + schema_name = None else: schema_name = None parsed_schema = None @@ -292,7 +296,7 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: # type: ignore[override] return self.__serialize(obj, ctx) async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -319,10 +323,10 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = await self._get_reader_schema(subject) + latest_schema = await self._get_reader_schema(subject) if subject else None # type: ignore[arg-type] if latest_schema is not None: self._schema_id = SchemaId(AVRO_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects: + elif subject is not None and subject not in self._known_subjects: # Check to ensure this schema has been registered under subject_name. if self._auto_register: # The schema name will always be the same. We can't however register @@ -339,26 +343,26 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N self._known_subjects.add(subject) if self._to_dict is not None: - value = self._to_dict(obj, ctx) + value = self._to_dict(obj, ctx) # type: ignore[arg-type] else: - value = obj + value = obj # type: ignore[assignment] if latest_schema is not None: - parsed_schema = await self._get_parsed_schema(latest_schema.schema) + parsed_schema = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, # type: ignore[arg-type] latest_schema.schema, value, get_inline_tags(parsed_schema), field_transformer) else: - parsed_schema = self._parsed_schema + parsed_schema = self._parsed_schema # type: ignore[assignment] with _ContextStringIO() as fo: # write the record to the rest of the buffer schemaless_writer(fo, parsed_schema, value) buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -371,7 +375,11 @@ async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema named_schemas = await _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) + if prepared_schema.schema_str is None: + raise ValueError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) @@ -483,11 +491,11 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") @@ -511,9 +519,9 @@ async def __init_impl( .format(", ".join(conf_copy.keys()))) if schema: - self._reader_schema = await self._get_parsed_schema(self._schema) + self._reader_schema = await self._get_parsed_schema(self._schema) # type: ignore[arg-type] else: - self._reader_schema = None + self._reader_schema = None # type: ignore[assignment] if from_dict is not None and not callable(from_dict): raise ValueError("from_dict must be callable with the signature " @@ -571,23 +579,24 @@ async def __deserialize( payload = self._schema_id_deserializer(data, ctx, schema_id) writer_schema_raw = await self._get_writer_schema(schema_id, subject) - writer_schema = await self._get_parsed_schema(writer_schema_raw) + writer_schema = await self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] if subject is None: - subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None + subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None # type: ignore[union-attr] if subject is not None: latest_schema = await self._get_reader_schema(subject) - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) - if latest_schema is not None: - migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) + if latest_schema is not None and subject is not None: + migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] reader_schema_raw = latest_schema.schema - reader_schema = await self._get_parsed_schema(latest_schema.schema) + reader_schema = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] elif self._schema is not None: migrations = None reader_schema_raw = self._schema @@ -597,7 +606,7 @@ async def __deserialize( reader_schema_raw = writer_schema_raw reader_schema = writer_schema - if migrations: + if migrations and ctx is not None and subject is not None: obj_dict = schemaless_reader(payload, writer_schema, None, @@ -611,12 +620,13 @@ async def __deserialize( def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, get_inline_tags(reader_schema), - field_transformer) + if ctx is not None and subject is not None: + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, # type: ignore[arg-type] + reader_schema_raw, obj_dict, get_inline_tags(reader_schema), + field_transformer) if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) + return self._from_dict(obj_dict, ctx) # type: ignore[arg-type] return obj_dict @@ -626,7 +636,11 @@ async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema named_schemas = await _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) + if prepared_schema.schema_str is None: + raise ValueError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py index 120641eb0..66378f9ee 100644 --- a/src/confluent_kafka/schema_registry/_async/json_schema.py +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -61,15 +61,15 @@ async def _resolve_named_schema( """ if ref_registry is None: # Retrieve external schemas for backward compatibility - ref_registry = Registry(retrieve=_retrieve_via_httpx) + ref_registry = Registry(retrieve=_retrieve_via_httpx) # type: ignore[call-arg] if schema.references is not None: for ref in schema.references: - referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) - ref_registry = await _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) - referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) # type: ignore[arg-type] + ref_registry = await _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) # type: ignore[arg-type] + referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) # type: ignore[union-attr,arg-type] resource = Resource.from_contents( referenced_schema_dict, default_specification=DEFAULT_SPEC) - ref_registry = ref_registry.with_resource(ref.name, resource) + ref_registry = ref_registry.with_resource(ref.name, resource) # type: ignore[arg-type] return ref_registry @@ -213,6 +213,7 @@ async def __init_impl( json_encode: Optional[Callable] = None, ): super().__init__() + self._schema: Optional[Schema] if isinstance(schema_str, str): self._schema = Schema(schema_str, schema_type="JSON") elif isinstance(schema_str, Schema): @@ -225,10 +226,10 @@ async def __init_impl( self._rule_registry = ( rule_registry if rule_registry else RuleRegistry.get_global_instance() ) - self._schema_id = None + self._schema_id: Optional[SchemaId] = None self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) + self._validators: LRUCache[Schema, Validator] = LRUCache(1000) if to_dict is not None and not callable(to_dict): raise ValueError("to_dict must be callable with the signature " @@ -240,26 +241,26 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") @@ -278,7 +279,7 @@ async def __init_impl( if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") - self._validate = conf_copy.pop('validate') + self._validate = cast(bool, conf_copy.pop('validate')) if not isinstance(self._validate, bool): raise ValueError("validate must be a boolean value") @@ -286,8 +287,8 @@ async def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - schema_dict, ref_registry = await self._get_parsed_schema(self._schema) - if schema_dict: + schema_dict, ref_registry = await self._get_parsed_schema(self._schema) # type: ignore[arg-type] + if schema_dict and isinstance(schema_dict, dict): schema_name = schema_dict.get('title', None) else: schema_name = None @@ -302,7 +303,7 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: # type: ignore[override] return self.__serialize(obj, ctx) async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -328,10 +329,10 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = await self._get_reader_schema(subject) + latest_schema = await self._get_reader_schema(subject) if subject else None # type: ignore[arg-type] if latest_schema is not None: self._schema_id = SchemaId(JSON_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects: + elif subject is not None and subject not in self._known_subjects: # Check to ensure this schema has been registered under subject_name. if self._auto_register: # The schema name will always be the same. We can't however register @@ -348,26 +349,27 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N self._known_subjects.add(subject) if self._to_dict is not None: - value = self._to_dict(obj, ctx) + value = self._to_dict(obj, ctx) # type: ignore[arg-type] else: - value = obj + value = obj # type: ignore[assignment] if latest_schema is not None: schema = latest_schema.schema - parsed_schema, ref_registry = await self._get_parsed_schema(latest_schema.schema) - root_resource = Resource.from_contents( - parsed_schema, default_specification=DEFAULT_SPEC) - ref_resolver = ref_registry.resolver_with_root(root_resource) - def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 - transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, value, None, - field_transformer) + parsed_schema, ref_registry = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + if ref_registry is not None: + root_resource = Resource.from_contents( + parsed_schema, default_specification=DEFAULT_SPEC) + ref_resolver = ref_registry.resolver_with_root(root_resource) + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 + transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, # type: ignore[arg-type] + latest_schema.schema, value, None, + field_transformer) else: schema = self._schema parsed_schema, ref_registry = self._parsed_schema, self._ref_registry - if self._validate: + if self._validate and schema is not None and parsed_schema is not None and ref_registry is not None: try: validator = self._get_validator(schema, parsed_schema, ref_registry) validator.validate(value) @@ -383,7 +385,7 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 fo.write(encoded_value) buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -399,6 +401,8 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema] return result ref_registry = await _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) @@ -494,6 +498,7 @@ async def __init_impl( json_decode: Optional[Callable] = None, ): super().__init__() + schema: Optional[Schema] if isinstance(schema_str, str): schema = Schema(schema_str, schema_type="JSON") elif isinstance(schema_str, Schema): @@ -510,11 +515,11 @@ async def __init_impl( else: raise TypeError('You must pass either str or Schema') - self._schema = schema + self._schema: Optional[Schema] = schema self._registry = schema_registry_client self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) + self._validators: LRUCache[Schema, Validator] = LRUCache(1000) self._json_decode = json_decode or orjson.loads self._use_schema_id = None @@ -522,11 +527,11 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") @@ -545,7 +550,7 @@ async def __init_impl( if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") - self._validate = conf_copy.pop('validate') + self._validate = cast(bool, conf_copy.pop('validate')) if not isinstance(self._validate, bool): raise ValueError("validate must be a boolean value") @@ -554,7 +559,7 @@ async def __init_impl( .format(", ".join(conf_copy.keys()))) if schema: - self._reader_schema, self._ref_registry = await self._get_parsed_schema(self._schema) + self._reader_schema, self._ref_registry = await self._get_parsed_schema(self._schema) # type: ignore[arg-type] else: self._reader_schema, self._ref_registry = None, None @@ -604,8 +609,8 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization if self._registry is not None: writer_schema_raw = await self._get_writer_schema(schema_id, subject) - writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw) - if subject is None: + writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] + if subject is None and isinstance(writer_schema, dict): subject = self._subject_name_func(ctx, writer_schema.get("title")) if subject is not None: latest_schema = await self._get_reader_schema(subject) @@ -613,19 +618,20 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization writer_schema_raw = None writer_schema, writer_ref_registry = None, None - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) # JSON documents are self-describing; no need to query schema obj_dict = self._json_decode(payload.read()) - if latest_schema is not None: - migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) + if latest_schema is not None and subject is not None: + migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] reader_schema_raw = latest_schema.schema - reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema) + reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] elif self._schema is not None: migrations = None reader_schema_raw = self._schema @@ -635,21 +641,23 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization reader_schema_raw = writer_schema_raw reader_schema, reader_ref_registry = writer_schema, writer_ref_registry - if migrations: + if migrations and ctx is not None and subject is not None: obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - reader_root_resource = Resource.from_contents( - reader_schema, default_specification=DEFAULT_SPEC) - reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) + if reader_ref_registry is not None: + reader_root_resource = Resource.from_contents( + reader_schema, default_specification=DEFAULT_SPEC) + reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) - def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 - transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, - "$", message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, None, - field_transformer) + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 + transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, + "$", message, field_transform)) + if ctx is not None and subject is not None: + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, None, + field_transformer) # type: ignore[arg-type] - if self._validate: + if self._validate and reader_schema_raw is not None and reader_schema is not None and reader_ref_registry is not None: try: validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) validator.validate(obj_dict) @@ -657,7 +665,7 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 raise SerializationError(ve.message) if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) + return self._from_dict(obj_dict, ctx) # type: ignore[arg-type,return-value] return obj_dict @@ -670,6 +678,8 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema] return result ref_registry = await _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) diff --git a/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py index 1a395ade4..ddfe68805 100644 --- a/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py @@ -18,7 +18,7 @@ import uuid from collections import defaultdict from threading import Lock -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union from .schema_registry_client import AsyncSchemaRegistryClient from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig @@ -157,7 +157,7 @@ async def register_schema( ) -> int: registered_schema = await self.register_schema_full_response( subject_name, schema, normalize_schemas=normalize_schemas) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] async def register_schema_full_response( self, subject_name: str, schema: 'Schema', @@ -168,7 +168,7 @@ async def register_schema_full_response( return registered_schema latest_schema = self._store.get_latest_version(subject_name) - latest_version = 1 if latest_schema is None else latest_schema.version + 1 + latest_version = 1 if latest_schema is None or latest_schema.version is None else latest_schema.version + 1 registered_schema = RegisteredSchema( schema_id=1, @@ -184,7 +184,7 @@ async def register_schema_full_response( async def get_schema( self, schema_id: int, subject_name: Optional[str] = None, - fmt: Optional[str] = None + fmt: Optional[str] = None, reference_format: Optional[str] = None ) -> 'Schema': schema = self._store.get_schema(schema_id) if schema is not None: @@ -212,7 +212,10 @@ async def lookup_schema( raise SchemaRegistryError(404, 40400, "Schema Not Found") - async def get_subjects(self) -> List[str]: + async def get_subjects( + self, subject_prefix: Optional[str] = None, deleted: bool = False, + deleted_only: bool = False, offset: int = 0, limit: int = -1 + ) -> List[str]: return self._store.get_subjects() async def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: @@ -236,30 +239,36 @@ async def get_latest_with_metadata( raise SchemaRegistryError(404, 40400, "Schema Not Found") async def get_version( - self, subject_name: str, version: int, + self, subject_name: str, version: Union[int, str] = "latest", deleted: bool = False, fmt: Optional[str] = None ) -> 'RegisteredSchema': - registered_schema = self._store.get_version(subject_name, version) + if version == "latest": + registered_schema = self._store.get_latest_version(subject_name) + else: + registered_schema = self._store.get_version(subject_name, version) # type: ignore[arg-type] if registered_schema is not None: return registered_schema raise SchemaRegistryError(404, 40400, "Schema Not Found") - async def get_versions(self, subject_name: str) -> List[int]: + async def get_versions( + self, subject_name: str, deleted: bool = False, deleted_only: bool = False, + offset: int = 0, limit: int = -1 + ) -> List[int]: return self._store.get_versions(subject_name) async def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: registered_schema = self._store.get_version(subject_name, version) if registered_schema is not None: self._store.remove_by_schema(registered_schema) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] raise SchemaRegistryError(404, 40400, "Schema Not Found") async def set_config( - self, subject_name: Optional[str] = None, config: 'ServerConfig' = None # noqa F821 + self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None # noqa F821 ) -> 'ServerConfig': # noqa F821 - return None + return None # type: ignore[return-value] async def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': # noqa F821 - return None + return None # type: ignore[return-value] diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index eab735284..91589822c 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -67,12 +67,13 @@ async def _resolve_named_schema( visited = set() if schema.references is not None: for ref in schema.references: - if _is_builtin(ref.name) or ref.name in visited: + # References in registered schemas are validated by server to be complete + if _is_builtin(ref.name) or ref.name in visited: # type: ignore[arg-type] continue - visited.add(ref.name) - referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') - await _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) - file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) + visited.add(ref.name) # type: ignore[arg-type] + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') # type: ignore[arg-type] + await _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) # type: ignore[arg-type] + file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) # type: ignore[arg-type,union-attr] pool.Add(file_descriptor_proto) @@ -218,50 +219,58 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._skip_known_types = conf_copy.pop('skip.known.types') + self._skip_known_types = cast(bool, conf_copy.pop('skip.known.types')) if not isinstance(self._skip_known_types, bool): raise ValueError("skip.known.types must be a boolean value") - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + self._use_deprecated_format = cast(bool, conf_copy.pop('use.deprecated.format')) if not isinstance(self._use_deprecated_format, bool): raise ValueError("use.deprecated.format must be a boolean value") if self._use_deprecated_format: raise ValueError("use.deprecated.format is no longer supported") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._ref_reference_subject_func = conf_copy.pop( - 'reference.subject.name.strategy') + self._ref_reference_subject_func = cast( + Callable[[Optional[SerializationContext], Any], Optional[str]], + conf_copy.pop('reference.subject.name.strategy') + ) if not callable(self._ref_reference_subject_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") @@ -271,7 +280,7 @@ async def __init_impl( self._registry = schema_registry_client self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._schema_id = None + self._schema_id: Optional[SchemaId] = None self._known_subjects: set[str] = set() self._msg_class = msg_type self._parsed_schemas = ParsedSchemaCache() @@ -360,7 +369,7 @@ async def _resolve_dependencies( reference.version)) return schema_refs - def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: # type: ignore[override] return self.__serialize(message, ctx) async def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -397,7 +406,7 @@ async def __serialize(self, message: Message, ctx: Optional[SerializationContext if latest_schema is not None: self._schema_id = SchemaId(PROTOBUF_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects and ctx is not None: + elif subject is not None and subject not in self._known_subjects and ctx is not None: references = await self._resolve_dependencies(ctx, message.DESCRIPTOR.file) self._schema = Schema( self._schema.schema_str, @@ -417,21 +426,23 @@ async def __serialize(self, message: Message, ctx: Optional[SerializationContext self._known_subjects.add(subject) if latest_schema is not None: - fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) + fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] fd = pool.FindFileByName(fd_proto.name) desc = fd.message_types_by_name[message.DESCRIPTOR.name] def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, desc, msg, field_transform)) - message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, message, None, - field_transformer) + if ctx is not None and subject is not None: + message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, message, None, + field_transformer) with _ContextStringIO() as fo: fo.write(message.SerializeToString()) - self._schema_id.message_indexes = self._index_array + if self._schema_id is not None: + self._schema_id.message_indexes = self._index_array buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -446,6 +457,8 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileD pool = DescriptorPool() _init_pool(pool) await _resolve_named_schema(schema, self._registry, pool) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") fd_proto = _str_to_proto("default", schema.schema_str) pool.Add(fd_proto) self._parsed_schemas.set(schema, (fd_proto, pool)) @@ -526,11 +539,11 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") @@ -549,7 +562,7 @@ async def __init_impl( if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + self._use_deprecated_format = cast(bool, conf_copy.pop('use.deprecated.format')) if not isinstance(self._use_deprecated_format, bool): raise ValueError("use.deprecated.format must be a boolean value") if self._use_deprecated_format: @@ -601,9 +614,9 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization if self._registry is not None: writer_schema_raw = await self._get_writer_schema(schema_id, subject, fmt='serialized') - fd_proto, pool = await self._get_parsed_schema(writer_schema_raw) + fd_proto, pool = await self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] writer_schema = pool.FindFileByName(fd_proto.name) - writer_desc = self._get_message_desc(pool, writer_schema, msg_index) + writer_desc = self._get_message_desc(pool, writer_schema, msg_index) # type: ignore[arg-type] if subject is None: subject = self._subject_name_func(ctx, writer_desc.full_name) if subject is not None: @@ -612,16 +625,17 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization writer_schema_raw = None writer_schema = None - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) - if latest_schema is not None: - migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) + if latest_schema is not None and subject is not None: + migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] reader_schema_raw = latest_schema.schema - fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) + fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] reader_schema = pool.FindFileByName(fd_proto.name) else: migrations = None @@ -634,7 +648,7 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization # Attempt to find a reader desc with the same name as the writer reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) - if migrations: + if migrations and ctx is not None and subject is not None: msg = GetMessageClass(writer_desc)() try: msg.ParseFromString(payload.read()) @@ -655,9 +669,10 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_desc, message, field_transform)) - msg = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, msg, None, - field_transformer) + if ctx is not None and subject is not None: + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, # type: ignore[arg-type] + reader_schema_raw, msg, None, + field_transformer) return msg async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: @@ -668,6 +683,8 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileD pool = DescriptorPool() _init_pool(pool) await _resolve_named_schema(schema, self._registry, pool) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") fd_proto = _str_to_proto("default", schema.schema_str) pool.Add(fd_proto) self._parsed_schemas.set(schema, (fd_proto, pool)) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index efc7d72ff..44200ae0e 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -65,10 +65,10 @@ # six: https://pypi.org/project/six/ # compat file : https://github.com/psf/requests/blob/master/requests/compat.py try: - string_type = basestring # noqa + string_type = basestring # type: ignore[name-defined] # noqa def _urlencode(value: str) -> str: - return urllib.quote(value, safe='') + return urllib.quote(value, safe='') # type: ignore[attr-defined] except NameError: string_type = str @@ -83,8 +83,8 @@ def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict) self.custom_function = custom_function self.custom_config = custom_config - async def get_bearer_fields(self) -> dict: - return await self.custom_function(self.custom_config) + async def get_bearer_fields(self) -> dict: # type: ignore[override] + return await self.custom_function(self.custom_config) # type: ignore[misc] class _AsyncOAuthClient(_BearerFieldProvider): @@ -100,7 +100,7 @@ def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoin self.retries_max_wait_ms = retries_max_wait_ms self.token_expiry_threshold = 0.8 - async def get_bearer_fields(self) -> dict: + async def get_bearer_fields(self) -> dict: # type: ignore[override] return { 'bearer.auth.token': await self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, @@ -108,15 +108,15 @@ async def get_bearer_fields(self) -> dict: } def token_expired(self) -> bool: - expiry_window = self.token['expires_in'] * self.token_expiry_threshold + expiry_window = self.token['expires_in'] * self.token_expiry_threshold # type: ignore[index] - return self.token['expires_at'] < time.time() + expiry_window + return self.token['expires_at'] < time.time() + expiry_window # type: ignore[index] async def get_access_token(self) -> str: if not self.token or self.token_expired(): await self.generate_access_token() - return self.token['access_token'] + return self.token['access_token'] # type: ignore[index] async def generate_access_token(self) -> None: for i in range(self.max_retries + 1): @@ -227,7 +227,7 @@ def __init__(self, conf: dict): if cache_capacity is not None: if not isinstance(cache_capacity, (int, float)): raise TypeError("cache.capacity must be a number, not " + str(type(cache_capacity))) - self.cache_capacity = cache_capacity + self.cache_capacity = int(cache_capacity) self.cache_latest_ttl_sec = None cache_latest_ttl_sec = conf_copy.pop('cache.latest.ttl.sec', None) @@ -241,7 +241,7 @@ def __init__(self, conf: dict): if max_retries is not None: if not isinstance(max_retries, (int, float)): raise TypeError("max.retries must be a number, not " + str(type(max_retries))) - self.max_retries = max_retries + self.max_retries = int(max_retries) self.retries_wait_ms = 1000 retries_wait_ms = conf_copy.pop('retries.wait.ms', None) @@ -249,7 +249,7 @@ def __init__(self, conf: dict): if not isinstance(retries_wait_ms, (int, float)): raise TypeError("retries.wait.ms must be a number, not " + str(type(retries_wait_ms))) - self.retries_wait_ms = retries_wait_ms + self.retries_wait_ms = int(retries_wait_ms) self.retries_max_wait_ms = 20000 retries_max_wait_ms = conf_copy.pop('retries.max.wait.ms', None) @@ -257,7 +257,7 @@ def __init__(self, conf: dict): if not isinstance(retries_max_wait_ms, (int, float)): raise TypeError("retries.max.wait.ms must be a number, not " + str(type(retries_max_wait_ms))) - self.retries_max_wait_ms = retries_max_wait_ms + self.retries_max_wait_ms = int(retries_max_wait_ms) self.bearer_field_provider = None logical_cluster = None @@ -308,14 +308,14 @@ def __init__(self, conf: dict): self.bearer_field_provider = _AsyncOAuthClient( self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, + self.token_endpoint, logical_cluster, identity_pool, # type: ignore[arg-type] self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms) elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': if 'bearer.auth.token' not in conf_copy: raise ValueError("Missing bearer.auth.token") static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) + self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) # type: ignore[assignment,arg-type] if not isinstance(static_token, string_type): raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) elif self.bearer_auth_credentials_source == 'CUSTOM': @@ -336,7 +336,7 @@ def __init__(self, conf: dict): raise TypeError("bearer.auth.custom.provider.config must be a dict, not " + str(type(custom_config))) - self.bearer_field_provider = _AsyncCustomOAuthClient(custom_function, custom_config) + self.bearer_field_provider = _AsyncCustomOAuthClient(custom_function, custom_config) # type: ignore[assignment] else: raise ValueError('Unrecognized bearer.auth.credentials.source') @@ -379,7 +379,7 @@ def __init__(self, conf: dict): ) async def handle_bearer_auth(self, headers: dict) -> None: - bearer_fields = await self.bearer_field_provider.get_bearer_fields() + bearer_fields = await self.bearer_field_provider.get_bearer_fields() # type: ignore[union-attr] required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] missing_fields = [] @@ -437,9 +437,10 @@ async def send_request( " application/vnd.schemaregistry+json," " application/json"} + body_str: Optional[str] = None if body is not None: - body = json.dumps(body) - headers = {'Content-Length': str(len(body)), + body_str = json.dumps(body) # type: ignore[assignment] + headers = {'Content-Length': str(len(body_str)), 'Content-Type': "application/vnd.schemaregistry.v1+json"} if self.bearer_auth_credentials_source: @@ -449,7 +450,7 @@ async def send_request( for i, base_url in enumerate(self.base_urls): try: response = await self.send_http_request( - base_url, url, method, headers, body, query) + base_url, url, method, headers, body_str, query) if is_success(response.status_code): return response.json() @@ -462,15 +463,15 @@ async def send_request( raise e try: - raise SchemaRegistryError(response.status_code, - response.json().get('error_code'), - response.json().get('message')) + raise SchemaRegistryError(response.status_code, # type: ignore[union-attr] + response.json().get('error_code'), # type: ignore[union-attr] + response.json().get('message')) # type: ignore[union-attr] # Schema Registry may return malformed output when it hits unexpected errors except (ValueError, KeyError, AttributeError): - raise SchemaRegistryError(response.status_code, + raise SchemaRegistryError(response.status_code, # type: ignore[union-attr] -1, "Unknown Schema Registry Error: " - + str(response.content)) + + str(response.content)) # type: ignore[union-attr] async def send_http_request( self, base_url: str, url: str, method: str, headers: Optional[dict], @@ -514,7 +515,7 @@ async def send_http_request( return response await asyncio.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) - return response + return response # type: ignore[return-value] class AsyncSchemaRegistryClient(object): @@ -598,11 +599,11 @@ def __init__(self, conf: dict): cache_capacity = self._rest_client.cache_capacity cache_ttl = self._rest_client.cache_latest_ttl_sec if cache_ttl is not None: - self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) - self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) + self._latest_version_cache: TTLCache[Any, Any] = TTLCache(cache_capacity, cache_ttl) + self._latest_with_metadata_cache: TTLCache[Any, Any] = TTLCache(cache_capacity, cache_ttl) else: - self._latest_version_cache = LRUCache(cache_capacity) - self._latest_with_metadata_cache = LRUCache(cache_capacity) + self._latest_version_cache = LRUCache[Any, Any](cache_capacity) # type: ignore[assignment] + self._latest_with_metadata_cache = LRUCache[Any, Any](cache_capacity) # type: ignore[assignment] async def __aenter__(self): return self @@ -639,7 +640,7 @@ async def register_schema( registered_schema = await self.register_schema_full_response( subject_name, schema, normalize_schemas=normalize_schemas) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] async def register_schema_full_response( self, subject_name: str, schema: 'Schema', @@ -674,7 +675,7 @@ async def register_schema_full_response( subject=subject_name, version=None, schema=result[1] - ) + ) # type: ignore[arg-type] request = schema.to_dict() @@ -682,20 +683,20 @@ async def register_schema_full_response( 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), body=request) - result = RegisteredSchema.from_dict(response) + result = RegisteredSchema.from_dict(response) # type: ignore[assignment] registered_schema = RegisteredSchema( - schema_id=result.schema_id, - guid=result.guid, - subject=result.subject or subject_name, - version=result.version, - schema=result.schema + schema_id=result.schema_id, # type: ignore[union-attr] + guid=result.guid, # type: ignore[union-attr] + subject=result.subject or subject_name, # type: ignore[union-attr] + version=result.version, # type: ignore[union-attr] + schema=result.schema, # type: ignore[union-attr] ) # The registered schema may not be fully populated - s = registered_schema.schema if registered_schema.schema.schema_str is not None else schema + s = registered_schema.schema if registered_schema.schema.schema_str is not None else schema # type: ignore[union-attr] self._cache.set_schema(subject_name, registered_schema.schema_id, - registered_schema.guid, s) + registered_schema.guid, s) # type: ignore[arg-type] return registered_schema @@ -724,7 +725,7 @@ async def get_schema( `GET Schema API Reference `_ """ # noqa: E501 - result = self._cache.get_schema_by_id(subject_name, schema_id) + result = self._cache.get_schema_by_id(subject_name, schema_id) # type: ignore[arg-type] if result is not None: return result[1] @@ -740,9 +741,9 @@ async def get_schema( registered_schema = RegisteredSchema.from_dict(response) self._cache.set_schema(subject_name, schema_id, - registered_schema.guid, registered_schema.schema) + registered_schema.guid, registered_schema.schema) # type: ignore[arg-type] - return registered_schema.schema + return registered_schema.schema # type: ignore[return-value] async def get_schema_by_guid( self, guid: str, fmt: Optional[str] = None @@ -778,9 +779,9 @@ async def get_schema_by_guid( registered_schema = RegisteredSchema.from_dict(response) self._cache.set_schema(None, registered_schema.schema_id, - registered_schema.guid, registered_schema.schema) + registered_schema.guid, registered_schema.schema) # type: ignore[arg-type] - return registered_schema.schema + return registered_schema.schema # type: ignore[return-value] async def get_schema_types(self) -> List[str]: """ @@ -820,7 +821,7 @@ async def get_subjects_by_schema_id( """ query = {'offset': offset, 'limit': limit} if subject_name is not None: - query['subject'] = subject_name + query['subject'] = subject_name # type: ignore[assignment] if deleted: query['deleted'] = deleted return await self._rest_client.get('schemas/ids/{}/subjects'.format(schema_id), query) @@ -853,7 +854,7 @@ async def get_schema_versions( query = {'offset': offset, 'limit': limit} if subject_name is not None: - query['subject'] = subject_name + query['subject'] = subject_name # type: ignore[assignment] if deleted: query['deleted'] = deleted response = await self._rest_client.get('schemas/ids/{}/versions'.format(schema_id), query) @@ -894,7 +895,7 @@ async def lookup_schema( 'deleted': deleted } if fmt is not None: - query_params['format'] = fmt + query_params['format'] = fmt # type: ignore[assignment] query_string = '&'.join(f"{key}={value}" for key, value in query_params.items()) @@ -944,7 +945,7 @@ async def get_subjects( query = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} if subject_prefix is not None: - query['subject'] = subject_prefix + query['subject'] = subject_prefix # type: ignore[assignment] return await self._rest_client.get('subjects', query) async def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: @@ -1041,11 +1042,11 @@ async def get_latest_with_metadata( query = {'deleted': deleted} if fmt is not None: - query['format'] = fmt + query['format'] = fmt # type: ignore[assignment] keys = metadata.keys() if keys: - query['key'] = [_urlencode(key) for key in keys] - query['value'] = [_urlencode(metadata[key]) for key in keys] + query['key'] = [_urlencode(key) for key in keys] # type: ignore[assignment] + query['value'] = [_urlencode(metadata[key]) for key in keys] # type: ignore[assignment] response = await self._rest_client.get( 'subjects/{}/metadata'.format(_urlencode(subject_name)), query @@ -1080,7 +1081,7 @@ async def get_version( `GET Subject Versions API Reference `_ """ # noqa: E501 - registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) + registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) # type: ignore[arg-type] if registered_schema is not None: return registered_schema @@ -1091,7 +1092,7 @@ async def get_version( registered_schema = RegisteredSchema.from_dict(response) - self._cache.set_registered_schema(registered_schema.schema, registered_schema) + self._cache.set_registered_schema(registered_schema.schema, registered_schema) # type: ignore[arg-type] return registered_schema @@ -1516,6 +1517,6 @@ def clear_caches(self): def new_client(conf: dict) -> 'AsyncSchemaRegistryClient': from .mock_schema_registry_client import AsyncMockSchemaRegistryClient url = conf.get("url") - if url.startswith("mock://"): + if url.startswith("mock://"): # type: ignore[union-attr] return AsyncMockSchemaRegistryClient(conf) return AsyncSchemaRegistryClient(conf) diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py index 40f72b721..ef6baab2d 100644 --- a/src/confluent_kafka/schema_registry/_async/serde.py +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -17,7 +17,7 @@ # import logging -from typing import List, Optional, Set, Dict, Any +from typing import List, Optional, Set, Dict, Any, Callable from confluent_kafka.schema_registry import RegisteredSchema from confluent_kafka.schema_registry.common.schema_registry_client import \ @@ -44,6 +44,14 @@ class AsyncBaseSerde(object): '_registry', '_rule_registry', '_subject_name_func', '_field_transformer'] + _use_schema_id: Optional[int] + _use_latest_version: bool + _use_latest_with_metadata: Optional[Dict[str, str]] + _registry: Any # AsyncSchemaRegistryClient + _rule_registry: Any # RuleRegistry + _subject_name_func: Callable[[Any, Optional[str]], Optional[str]] + _field_transformer: Optional[FieldTransformer] + async def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: if self._use_schema_id is not None: schema = await self._registry.get_schema(self._use_schema_id, subject, fmt) @@ -114,6 +122,11 @@ def _execute_rules_with_phase( ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule, index, rules, inline_tags, field_transformer) + if rule.type is None: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, + RuleError(f"Rule type is None for rule {rule.name}"), + 'ERROR') + return message rule_executor = self._rule_registry.get_executor(rule.type.upper()) if rule_executor is None: self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, @@ -140,18 +153,24 @@ def _execute_rules_with_phase( return message def _get_on_success(self, rule: Rule) -> Optional[str]: + if rule.type is None: + return rule.on_success override = self._rule_registry.get_override(rule.type) if override is not None and override.on_success is not None: return override.on_success return rule.on_success def _get_on_failure(self, rule: Rule) -> Optional[str]: + if rule.type is None: + return rule.on_failure override = self._rule_registry.get_override(rule.type) if override is not None and override.on_failure is not None: return override.on_failure return rule.on_failure def _is_disabled(self, rule: Rule) -> Optional[bool]: + if rule.type is None: + return rule.disabled override = self._rule_registry.get_override(rule.type) if override is not None and override.disabled is not None: return override.disabled @@ -200,10 +219,16 @@ def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleA class AsyncBaseSerializer(AsyncBaseSerde, Serializer): __slots__ = ['_auto_register', '_normalize_schemas', '_schema_id_serializer'] + _auto_register: bool + _normalize_schemas: bool + _schema_id_serializer: Callable[[bytes, Any, Any], bytes] + class AsyncBaseDeserializer(AsyncBaseSerde, Deserializer): __slots__ = ['_schema_id_deserializer'] + _schema_id_deserializer: Callable[[bytes, Any, Any], Any] + async def _get_writer_schema( self, schema_id: SchemaId, subject: Optional[str] = None, fmt: Optional[str] = None) -> Schema: @@ -241,7 +266,7 @@ async def _get_migrations( ) -> List[Migration]: source = await self._registry.lookup_schema( subject, source_info, normalize_schemas=False, deleted=True) - migrations = [] + migrations: List[Migration] = [] if source.version < target.version: migration_mode = RuleMode.UPGRADE first = source @@ -259,13 +284,14 @@ async def _get_migrations( if i == 0: previous = version continue - if version.schema.rule_set is not None and self._has_rules( + if version.schema is not None and version.schema.rule_set is not None and self._has_rules( version.schema.rule_set, RulePhase.MIGRATION, migration_mode): - if migration_mode == RuleMode.UPGRADE: - migration = Migration(migration_mode, previous, version) - else: - migration = Migration(migration_mode, version, previous) - migrations.append(migration) + if previous is not None: # previous is always set after first iteration + if migration_mode == RuleMode.UPGRADE: + migration = Migration(migration_mode, previous, version) + else: + migration = Migration(migration_mode, version, previous) + migrations.append(migration) previous = version if migration_mode == RuleMode.DOWNGRADE: migrations.reverse() @@ -275,6 +301,8 @@ async def _get_schemas_between( self, subject: str, first: RegisteredSchema, last: RegisteredSchema, fmt: Optional[str] = None ) -> List[RegisteredSchema]: + if first.version is None or last.version is None: + return [first, last] if last.version - first.version <= 1: return [first, last] version1 = first.version @@ -290,8 +318,9 @@ def _execute_migrations( migrations: List[Migration], message: Any ) -> Any: for migration in migrations: - message = self._execute_rules_with_phase( - ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode, - migration.source.schema, migration.target.schema, - message, None, None) + if migration.source is not None and migration.target is not None: + message = self._execute_rules_with_phase( + ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode, + migration.source.schema, migration.target.schema, + message, None, None) return message diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 686369099..1a06a5b05 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -57,12 +57,13 @@ def _resolve_named_schema( named_schemas = {} if schema.references is not None: for ref in schema.references: - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) - ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client) + # References in registered schemas are validated by server to be complete + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) # type: ignore[arg-type] + ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client) # type: ignore[arg-type] parsed_schema = parse_schema_with_repo( - referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) + referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) # type: ignore[union-attr,arg-type] named_schemas.update(ref_named_schemas) - named_schemas[ref.name] = parsed_schema + named_schemas[ref.name] = parsed_schema # type: ignore[index] return named_schemas @@ -204,7 +205,7 @@ def __init_impl( schema = None self._registry = schema_registry_client - self._schema_id = None + self._schema_id: Optional[SchemaId] = None self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() @@ -219,26 +220,26 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") @@ -276,8 +277,11 @@ def __init_impl( # i.e. {"type": "string"} has a name of string. # This function does not comply. # https://github.com/fastavro/fastavro/issues/415 - schema_dict = json.loads(schema.schema_str) - schema_name = parsed_schema.get("name", schema_dict.get("type")) + if schema.schema_str is not None: + schema_dict = json.loads(schema.schema_str) + schema_name = parsed_schema.get("name", schema_dict.get("type")) # type: ignore[union-attr] + else: + schema_name = None else: schema_name = None parsed_schema = None @@ -292,7 +296,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] return self.__serialize(obj, ctx) def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -319,10 +323,10 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = self._get_reader_schema(subject) + latest_schema = self._get_reader_schema(subject) if subject else None # type: ignore[arg-type] if latest_schema is not None: self._schema_id = SchemaId(AVRO_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects: + elif subject is not None and subject not in self._known_subjects: # Check to ensure this schema has been registered under subject_name. if self._auto_register: # The schema name will always be the same. We can't however register @@ -339,26 +343,26 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - self._known_subjects.add(subject) if self._to_dict is not None: - value = self._to_dict(obj, ctx) + value = self._to_dict(obj, ctx) # type: ignore[arg-type] else: - value = obj + value = obj # type: ignore[assignment] if latest_schema is not None: - parsed_schema = self._get_parsed_schema(latest_schema.schema) + parsed_schema = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, # type: ignore[arg-type] latest_schema.schema, value, get_inline_tags(parsed_schema), field_transformer) else: - parsed_schema = self._parsed_schema + parsed_schema = self._parsed_schema # type: ignore[assignment] with _ContextStringIO() as fo: # write the record to the rest of the buffer schemaless_writer(fo, parsed_schema, value) buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -371,7 +375,11 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema named_schemas = _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) + if prepared_schema.schema_str is None: + raise ValueError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) @@ -483,11 +491,11 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") @@ -511,9 +519,9 @@ def __init_impl( .format(", ".join(conf_copy.keys()))) if schema: - self._reader_schema = self._get_parsed_schema(self._schema) + self._reader_schema = self._get_parsed_schema(self._schema) # type: ignore[arg-type] else: - self._reader_schema = None + self._reader_schema = None # type: ignore[assignment] if from_dict is not None and not callable(from_dict): raise ValueError("from_dict must be callable with the signature " @@ -530,7 +538,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Union[dict, object, None]]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: return self.__deserialize(data, ctx) def __deserialize( @@ -571,23 +579,24 @@ def __deserialize( payload = self._schema_id_deserializer(data, ctx, schema_id) writer_schema_raw = self._get_writer_schema(schema_id, subject) - writer_schema = self._get_parsed_schema(writer_schema_raw) + writer_schema = self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] if subject is None: - subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None + subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None # type: ignore[union-attr] if subject is not None: latest_schema = self._get_reader_schema(subject) - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + if latest_schema is not None and subject is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] reader_schema_raw = latest_schema.schema - reader_schema = self._get_parsed_schema(latest_schema.schema) + reader_schema = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] elif self._schema is not None: migrations = None reader_schema_raw = self._schema @@ -597,7 +606,7 @@ def __deserialize( reader_schema_raw = writer_schema_raw reader_schema = writer_schema - if migrations: + if migrations and ctx is not None and subject is not None: obj_dict = schemaless_reader(payload, writer_schema, None, @@ -611,12 +620,13 @@ def __deserialize( def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, get_inline_tags(reader_schema), - field_transformer) + if ctx is not None and subject is not None: + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, # type: ignore[arg-type] + reader_schema_raw, obj_dict, get_inline_tags(reader_schema), + field_transformer) if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) + return self._from_dict(obj_dict, ctx) # type: ignore[arg-type] return obj_dict @@ -626,7 +636,11 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema named_schemas = _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) + if prepared_schema.schema_str is None: + raise ValueError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 36877a6b0..08f147876 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -61,15 +61,15 @@ def _resolve_named_schema( """ if ref_registry is None: # Retrieve external schemas for backward compatibility - ref_registry = Registry(retrieve=_retrieve_via_httpx) + ref_registry = Registry(retrieve=_retrieve_via_httpx) # type: ignore[call-arg] if schema.references is not None: for ref in schema.references: - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) - ref_registry = _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) - referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) # type: ignore[arg-type] + ref_registry = _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) # type: ignore[arg-type] + referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) # type: ignore[union-attr,arg-type] resource = Resource.from_contents( referenced_schema_dict, default_specification=DEFAULT_SPEC) - ref_registry = ref_registry.with_resource(ref.name, resource) + ref_registry = ref_registry.with_resource(ref.name, resource) # type: ignore[arg-type] return ref_registry @@ -213,6 +213,7 @@ def __init_impl( json_encode: Optional[Callable] = None, ): super().__init__() + self._schema: Optional[Schema] if isinstance(schema_str, str): self._schema = Schema(schema_str, schema_type="JSON") elif isinstance(schema_str, Schema): @@ -225,10 +226,10 @@ def __init_impl( self._rule_registry = ( rule_registry if rule_registry else RuleRegistry.get_global_instance() ) - self._schema_id = None + self._schema_id: Optional[SchemaId] = None self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) + self._validators: LRUCache[Schema, Validator] = LRUCache(1000) if to_dict is not None and not callable(to_dict): raise ValueError("to_dict must be callable with the signature " @@ -240,26 +241,26 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") @@ -278,7 +279,7 @@ def __init_impl( if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") - self._validate = conf_copy.pop('validate') + self._validate = cast(bool, conf_copy.pop('validate')) if not isinstance(self._validate, bool): raise ValueError("validate must be a boolean value") @@ -286,8 +287,8 @@ def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - schema_dict, ref_registry = self._get_parsed_schema(self._schema) - if schema_dict: + schema_dict, ref_registry = self._get_parsed_schema(self._schema) # type: ignore[arg-type] + if schema_dict and isinstance(schema_dict, dict): schema_name = schema_dict.get('title', None) else: schema_name = None @@ -302,7 +303,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] return self.__serialize(obj, ctx) def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -328,10 +329,10 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = self._get_reader_schema(subject) + latest_schema = self._get_reader_schema(subject) if subject else None # type: ignore[arg-type] if latest_schema is not None: self._schema_id = SchemaId(JSON_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects: + elif subject is not None and subject not in self._known_subjects: # Check to ensure this schema has been registered under subject_name. if self._auto_register: # The schema name will always be the same. We can't however register @@ -348,26 +349,27 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - self._known_subjects.add(subject) if self._to_dict is not None: - value = self._to_dict(obj, ctx) + value = self._to_dict(obj, ctx) # type: ignore[arg-type] else: - value = obj + value = obj # type: ignore[assignment] if latest_schema is not None: schema = latest_schema.schema - parsed_schema, ref_registry = self._get_parsed_schema(latest_schema.schema) - root_resource = Resource.from_contents( - parsed_schema, default_specification=DEFAULT_SPEC) - ref_resolver = ref_registry.resolver_with_root(root_resource) - def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 - transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, value, None, - field_transformer) + parsed_schema, ref_registry = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + if ref_registry is not None: + root_resource = Resource.from_contents( + parsed_schema, default_specification=DEFAULT_SPEC) + ref_resolver = ref_registry.resolver_with_root(root_resource) + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 + transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, # type: ignore[arg-type] + latest_schema.schema, value, None, + field_transformer) else: schema = self._schema parsed_schema, ref_registry = self._parsed_schema, self._ref_registry - if self._validate: + if self._validate and schema is not None and parsed_schema is not None and ref_registry is not None: try: validator = self._get_validator(schema, parsed_schema, ref_registry) validator.validate(value) @@ -383,7 +385,7 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 fo.write(encoded_value) buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -399,6 +401,8 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Opti return result ref_registry = _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) @@ -494,6 +498,7 @@ def __init_impl( json_decode: Optional[Callable] = None, ): super().__init__() + schema: Optional[Schema] if isinstance(schema_str, str): schema = Schema(schema_str, schema_type="JSON") elif isinstance(schema_str, Schema): @@ -510,11 +515,11 @@ def __init_impl( else: raise TypeError('You must pass either str or Schema') - self._schema = schema + self._schema: Optional[Schema] = schema self._registry = schema_registry_client self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) + self._validators: LRUCache[Schema, Validator] = LRUCache(1000) self._json_decode = json_decode or orjson.loads self._use_schema_id = None @@ -522,11 +527,11 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") @@ -545,7 +550,7 @@ def __init_impl( if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") - self._validate = conf_copy.pop('validate') + self._validate = cast(bool, conf_copy.pop('validate')) if not isinstance(self._validate, bool): raise ValueError("validate must be a boolean value") @@ -554,7 +559,7 @@ def __init_impl( .format(", ".join(conf_copy.keys()))) if schema: - self._reader_schema, self._ref_registry = self._get_parsed_schema(self._schema) + self._reader_schema, self._ref_registry = self._get_parsed_schema(self._schema) # type: ignore[arg-type] else: self._reader_schema, self._ref_registry = None, None @@ -570,7 +575,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__deserialize(data, ctx) def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -604,8 +609,8 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex if self._registry is not None: writer_schema_raw = self._get_writer_schema(schema_id, subject) - writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw) - if subject is None: + writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] + if subject is None and isinstance(writer_schema, dict): subject = self._subject_name_func(ctx, writer_schema.get("title")) if subject is not None: latest_schema = self._get_reader_schema(subject) @@ -613,19 +618,20 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex writer_schema_raw = None writer_schema, writer_ref_registry = None, None - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) # JSON documents are self-describing; no need to query schema obj_dict = self._json_decode(payload.read()) - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + if latest_schema is not None and subject is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] reader_schema_raw = latest_schema.schema - reader_schema, reader_ref_registry = self._get_parsed_schema(latest_schema.schema) + reader_schema, reader_ref_registry = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] elif self._schema is not None: migrations = None reader_schema_raw = self._schema @@ -635,21 +641,23 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex reader_schema_raw = writer_schema_raw reader_schema, reader_ref_registry = writer_schema, writer_ref_registry - if migrations: + if migrations and ctx is not None and subject is not None: obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - reader_root_resource = Resource.from_contents( - reader_schema, default_specification=DEFAULT_SPEC) - reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) + if reader_ref_registry is not None: + reader_root_resource = Resource.from_contents( + reader_schema, default_specification=DEFAULT_SPEC) + reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) - def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 - transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, - "$", message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, None, - field_transformer) + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 + transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, + "$", message, field_transform)) + if ctx is not None and subject is not None: + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, None, + field_transformer) # type: ignore[arg-type] - if self._validate: + if self._validate and reader_schema_raw is not None and reader_schema is not None and reader_ref_registry is not None: try: validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) validator.validate(obj_dict) @@ -657,7 +665,7 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 raise SerializationError(ve.message) if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) + return self._from_dict(obj_dict, ctx) # type: ignore[arg-type,return-value] return obj_dict @@ -670,6 +678,8 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Opti return result ref_registry = _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) diff --git a/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py index a7f39f20d..09340ac33 100644 --- a/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py @@ -18,7 +18,7 @@ import uuid from collections import defaultdict from threading import Lock -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union from .schema_registry_client import SchemaRegistryClient from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig @@ -157,7 +157,7 @@ def register_schema( ) -> int: registered_schema = self.register_schema_full_response( subject_name, schema, normalize_schemas=normalize_schemas) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] def register_schema_full_response( self, subject_name: str, schema: 'Schema', @@ -168,7 +168,7 @@ def register_schema_full_response( return registered_schema latest_schema = self._store.get_latest_version(subject_name) - latest_version = 1 if latest_schema is None else latest_schema.version + 1 + latest_version = 1 if latest_schema is None or latest_schema.version is None else latest_schema.version + 1 registered_schema = RegisteredSchema( schema_id=1, @@ -184,7 +184,7 @@ def register_schema_full_response( def get_schema( self, schema_id: int, subject_name: Optional[str] = None, - fmt: Optional[str] = None + fmt: Optional[str] = None, reference_format: Optional[str] = None ) -> 'Schema': schema = self._store.get_schema(schema_id) if schema is not None: @@ -212,7 +212,10 @@ def lookup_schema( raise SchemaRegistryError(404, 40400, "Schema Not Found") - def get_subjects(self) -> List[str]: + def get_subjects( + self, subject_prefix: Optional[str] = None, deleted: bool = False, + deleted_only: bool = False, offset: int = 0, limit: int = -1 + ) -> List[str]: return self._store.get_subjects() def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: @@ -236,30 +239,36 @@ def get_latest_with_metadata( raise SchemaRegistryError(404, 40400, "Schema Not Found") def get_version( - self, subject_name: str, version: int, + self, subject_name: str, version: Union[int, str] = "latest", deleted: bool = False, fmt: Optional[str] = None ) -> 'RegisteredSchema': - registered_schema = self._store.get_version(subject_name, version) + if version == "latest": + registered_schema = self._store.get_latest_version(subject_name) + else: + registered_schema = self._store.get_version(subject_name, version) # type: ignore[arg-type] if registered_schema is not None: return registered_schema raise SchemaRegistryError(404, 40400, "Schema Not Found") - def get_versions(self, subject_name: str) -> List[int]: + def get_versions( + self, subject_name: str, deleted: bool = False, deleted_only: bool = False, + offset: int = 0, limit: int = -1 + ) -> List[int]: return self._store.get_versions(subject_name) def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: registered_schema = self._store.get_version(subject_name, version) if registered_schema is not None: self._store.remove_by_schema(registered_schema) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] raise SchemaRegistryError(404, 40400, "Schema Not Found") def set_config( - self, subject_name: Optional[str] = None, config: 'ServerConfig' = None # noqa F821 + self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None # noqa F821 ) -> 'ServerConfig': # noqa F821 - return None + return None # type: ignore[return-value] def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': # noqa F821 - return None + return None # type: ignore[return-value] diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 0c3addcb5..f13ce7555 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -67,12 +67,13 @@ def _resolve_named_schema( visited = set() if schema.references is not None: for ref in schema.references: - if _is_builtin(ref.name) or ref.name in visited: + # References in registered schemas are validated by server to be complete + if _is_builtin(ref.name) or ref.name in visited: # type: ignore[arg-type] continue - visited.add(ref.name) - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') - _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) - file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) + visited.add(ref.name) # type: ignore[arg-type] + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') # type: ignore[arg-type] + _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) # type: ignore[arg-type] + file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) # type: ignore[arg-type,union-attr] pool.Add(file_descriptor_proto) @@ -218,50 +219,58 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._skip_known_types = conf_copy.pop('skip.known.types') + self._skip_known_types = cast(bool, conf_copy.pop('skip.known.types')) if not isinstance(self._skip_known_types, bool): raise ValueError("skip.known.types must be a boolean value") - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + self._use_deprecated_format = cast(bool, conf_copy.pop('use.deprecated.format')) if not isinstance(self._use_deprecated_format, bool): raise ValueError("use.deprecated.format must be a boolean value") if self._use_deprecated_format: raise ValueError("use.deprecated.format is no longer supported") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._ref_reference_subject_func = conf_copy.pop( - 'reference.subject.name.strategy') + self._ref_reference_subject_func = cast( + Callable[[Optional[SerializationContext], Any], Optional[str]], + conf_copy.pop('reference.subject.name.strategy') + ) if not callable(self._ref_reference_subject_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") @@ -271,7 +280,7 @@ def __init_impl( self._registry = schema_registry_client self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._schema_id = None + self._schema_id: Optional[SchemaId] = None self._known_subjects: set[str] = set() self._msg_class = msg_type self._parsed_schemas = ParsedSchemaCache() @@ -360,7 +369,7 @@ def _resolve_dependencies( reference.version)) return schema_refs - def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] return self.__serialize(message, ctx) def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -397,7 +406,7 @@ def __serialize(self, message: Message, ctx: Optional[SerializationContext] = No if latest_schema is not None: self._schema_id = SchemaId(PROTOBUF_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects and ctx is not None: + elif subject is not None and subject not in self._known_subjects and ctx is not None: references = self._resolve_dependencies(ctx, message.DESCRIPTOR.file) self._schema = Schema( self._schema.schema_str, @@ -417,21 +426,23 @@ def __serialize(self, message: Message, ctx: Optional[SerializationContext] = No self._known_subjects.add(subject) if latest_schema is not None: - fd_proto, pool = self._get_parsed_schema(latest_schema.schema) + fd_proto, pool = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] fd = pool.FindFileByName(fd_proto.name) desc = fd.message_types_by_name[message.DESCRIPTOR.name] def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, desc, msg, field_transform)) - message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, message, None, - field_transformer) + if ctx is not None and subject is not None: + message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, message, None, + field_transformer) with _ContextStringIO() as fo: fo.write(message.SerializeToString()) - self._schema_id.message_indexes = self._index_array + if self._schema_id is not None: + self._schema_id.message_indexes = self._index_array buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -446,6 +457,8 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip pool = DescriptorPool() _init_pool(pool) _resolve_named_schema(schema, self._registry, pool) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") fd_proto = _str_to_proto("default", schema.schema_str) pool.Add(fd_proto) self._parsed_schemas.set(schema, (fd_proto, pool)) @@ -526,11 +539,11 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") @@ -549,7 +562,7 @@ def __init_impl( if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + self._use_deprecated_format = cast(bool, conf_copy.pop('use.deprecated.format')) if not isinstance(self._use_deprecated_format, bool): raise ValueError("use.deprecated.format must be a boolean value") if self._use_deprecated_format: @@ -564,10 +577,10 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[object, None]: return self.__deserialize(data, ctx) - def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[object, None]: """ Deserialize a serialized protobuf message with Confluent Schema Registry framing. @@ -601,9 +614,9 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex if self._registry is not None: writer_schema_raw = self._get_writer_schema(schema_id, subject, fmt='serialized') - fd_proto, pool = self._get_parsed_schema(writer_schema_raw) + fd_proto, pool = self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] writer_schema = pool.FindFileByName(fd_proto.name) - writer_desc = self._get_message_desc(pool, writer_schema, msg_index) + writer_desc = self._get_message_desc(pool, writer_schema, msg_index) # type: ignore[arg-type] if subject is None: subject = self._subject_name_func(ctx, writer_desc.full_name) if subject is not None: @@ -612,16 +625,17 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex writer_schema_raw = None writer_schema = None - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + if latest_schema is not None and subject is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] reader_schema_raw = latest_schema.schema - fd_proto, pool = self._get_parsed_schema(latest_schema.schema) + fd_proto, pool = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] reader_schema = pool.FindFileByName(fd_proto.name) else: migrations = None @@ -634,7 +648,7 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex # Attempt to find a reader desc with the same name as the writer reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) - if migrations: + if migrations and ctx is not None and subject is not None: msg = GetMessageClass(writer_desc)() try: msg.ParseFromString(payload.read()) @@ -655,9 +669,10 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_desc, message, field_transform)) - msg = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, msg, None, - field_transformer) + if ctx is not None and subject is not None: + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, # type: ignore[arg-type] + reader_schema_raw, msg, None, + field_transformer) return msg def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: @@ -668,6 +683,8 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip pool = DescriptorPool() _init_pool(pool) _resolve_named_schema(schema, self._registry, pool) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") fd_proto = _str_to_proto("default", schema.schema_str) pool.Add(fd_proto) self._parsed_schemas.set(schema, (fd_proto, pool)) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 7c4eb7060..eeef46c52 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -65,10 +65,10 @@ # six: https://pypi.org/project/six/ # compat file : https://github.com/psf/requests/blob/master/requests/compat.py try: - string_type = basestring # noqa + string_type = basestring # type: ignore[name-defined] # noqa def _urlencode(value: str) -> str: - return urllib.quote(value, safe='') + return urllib.quote(value, safe='') # type: ignore[attr-defined] except NameError: string_type = str @@ -83,8 +83,8 @@ def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict) self.custom_function = custom_function self.custom_config = custom_config - def get_bearer_fields(self) -> dict: - return self.custom_function(self.custom_config) + def get_bearer_fields(self) -> dict: # type: ignore[override] + return self.custom_function(self.custom_config) # type: ignore[misc] class _OAuthClient(_BearerFieldProvider): @@ -100,7 +100,7 @@ def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoin self.retries_max_wait_ms = retries_max_wait_ms self.token_expiry_threshold = 0.8 - def get_bearer_fields(self) -> dict: + def get_bearer_fields(self) -> dict: # type: ignore[override] return { 'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, @@ -108,15 +108,15 @@ def get_bearer_fields(self) -> dict: } def token_expired(self) -> bool: - expiry_window = self.token['expires_in'] * self.token_expiry_threshold + expiry_window = self.token['expires_in'] * self.token_expiry_threshold # type: ignore[index] - return self.token['expires_at'] < time.time() + expiry_window + return self.token['expires_at'] < time.time() + expiry_window # type: ignore[index] def get_access_token(self) -> str: if not self.token or self.token_expired(): self.generate_access_token() - return self.token['access_token'] + return self.token['access_token'] # type: ignore[index] def generate_access_token(self) -> None: for i in range(self.max_retries + 1): @@ -227,7 +227,7 @@ def __init__(self, conf: dict): if cache_capacity is not None: if not isinstance(cache_capacity, (int, float)): raise TypeError("cache.capacity must be a number, not " + str(type(cache_capacity))) - self.cache_capacity = cache_capacity + self.cache_capacity = int(cache_capacity) self.cache_latest_ttl_sec = None cache_latest_ttl_sec = conf_copy.pop('cache.latest.ttl.sec', None) @@ -241,7 +241,7 @@ def __init__(self, conf: dict): if max_retries is not None: if not isinstance(max_retries, (int, float)): raise TypeError("max.retries must be a number, not " + str(type(max_retries))) - self.max_retries = max_retries + self.max_retries = int(max_retries) self.retries_wait_ms = 1000 retries_wait_ms = conf_copy.pop('retries.wait.ms', None) @@ -249,7 +249,7 @@ def __init__(self, conf: dict): if not isinstance(retries_wait_ms, (int, float)): raise TypeError("retries.wait.ms must be a number, not " + str(type(retries_wait_ms))) - self.retries_wait_ms = retries_wait_ms + self.retries_wait_ms = int(retries_wait_ms) self.retries_max_wait_ms = 20000 retries_max_wait_ms = conf_copy.pop('retries.max.wait.ms', None) @@ -257,7 +257,7 @@ def __init__(self, conf: dict): if not isinstance(retries_max_wait_ms, (int, float)): raise TypeError("retries.max.wait.ms must be a number, not " + str(type(retries_max_wait_ms))) - self.retries_max_wait_ms = retries_max_wait_ms + self.retries_max_wait_ms = int(retries_max_wait_ms) self.bearer_field_provider = None logical_cluster = None @@ -308,14 +308,14 @@ def __init__(self, conf: dict): self.bearer_field_provider = _OAuthClient( self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, + self.token_endpoint, logical_cluster, identity_pool, # type: ignore[arg-type] self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms) elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': if 'bearer.auth.token' not in conf_copy: raise ValueError("Missing bearer.auth.token") static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) + self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) # type: ignore[assignment,arg-type] if not isinstance(static_token, string_type): raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) elif self.bearer_auth_credentials_source == 'CUSTOM': @@ -336,7 +336,7 @@ def __init__(self, conf: dict): raise TypeError("bearer.auth.custom.provider.config must be a dict, not " + str(type(custom_config))) - self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config) + self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config) # type: ignore[assignment] else: raise ValueError('Unrecognized bearer.auth.credentials.source') @@ -379,7 +379,7 @@ def __init__(self, conf: dict): ) def handle_bearer_auth(self, headers: dict) -> None: - bearer_fields = self.bearer_field_provider.get_bearer_fields() + bearer_fields = self.bearer_field_provider.get_bearer_fields() # type: ignore[union-attr] required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] missing_fields = [] @@ -437,9 +437,10 @@ def send_request( " application/vnd.schemaregistry+json," " application/json"} + body_str: Optional[str] = None if body is not None: - body = json.dumps(body) - headers = {'Content-Length': str(len(body)), + body_str = json.dumps(body) # type: ignore[assignment] + headers = {'Content-Length': str(len(body_str)), 'Content-Type': "application/vnd.schemaregistry.v1+json"} if self.bearer_auth_credentials_source: @@ -449,7 +450,7 @@ def send_request( for i, base_url in enumerate(self.base_urls): try: response = self.send_http_request( - base_url, url, method, headers, body, query) + base_url, url, method, headers, body_str, query) if is_success(response.status_code): return response.json() @@ -462,15 +463,15 @@ def send_request( raise e try: - raise SchemaRegistryError(response.status_code, - response.json().get('error_code'), - response.json().get('message')) + raise SchemaRegistryError(response.status_code, # type: ignore[union-attr] + response.json().get('error_code'), # type: ignore[union-attr] + response.json().get('message')) # type: ignore[union-attr] # Schema Registry may return malformed output when it hits unexpected errors except (ValueError, KeyError, AttributeError): - raise SchemaRegistryError(response.status_code, + raise SchemaRegistryError(response.status_code, # type: ignore[union-attr] -1, "Unknown Schema Registry Error: " - + str(response.content)) + + str(response.content)) # type: ignore[union-attr] def send_http_request( self, base_url: str, url: str, method: str, headers: Optional[dict], @@ -514,7 +515,7 @@ def send_http_request( return response time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) - return response + return response # type: ignore[return-value] class SchemaRegistryClient(object): @@ -598,11 +599,11 @@ def __init__(self, conf: dict): cache_capacity = self._rest_client.cache_capacity cache_ttl = self._rest_client.cache_latest_ttl_sec if cache_ttl is not None: - self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) - self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) + self._latest_version_cache: TTLCache[Any, Any] = TTLCache(cache_capacity, cache_ttl) + self._latest_with_metadata_cache: TTLCache[Any, Any] = TTLCache(cache_capacity, cache_ttl) else: - self._latest_version_cache = LRUCache(cache_capacity) - self._latest_with_metadata_cache = LRUCache(cache_capacity) + self._latest_version_cache = LRUCache[Any, Any](cache_capacity) # type: ignore[assignment] + self._latest_with_metadata_cache = LRUCache[Any, Any](cache_capacity) # type: ignore[assignment] def __enter__(self): return self @@ -639,7 +640,7 @@ def register_schema( registered_schema = self.register_schema_full_response( subject_name, schema, normalize_schemas=normalize_schemas) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] def register_schema_full_response( self, subject_name: str, schema: 'Schema', @@ -674,7 +675,7 @@ def register_schema_full_response( subject=subject_name, version=None, schema=result[1] - ) + ) # type: ignore[arg-type] request = schema.to_dict() @@ -682,20 +683,20 @@ def register_schema_full_response( 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), body=request) - result = RegisteredSchema.from_dict(response) + result = RegisteredSchema.from_dict(response) # type: ignore[assignment] registered_schema = RegisteredSchema( - schema_id=result.schema_id, - guid=result.guid, - subject=result.subject or subject_name, - version=result.version, - schema=result.schema + schema_id=result.schema_id, # type: ignore[union-attr] + guid=result.guid, # type: ignore[union-attr] + subject=result.subject or subject_name, # type: ignore[union-attr] + version=result.version, # type: ignore[union-attr] + schema=result.schema, # type: ignore[union-attr] ) # The registered schema may not be fully populated - s = registered_schema.schema if registered_schema.schema.schema_str is not None else schema + s = registered_schema.schema if registered_schema.schema.schema_str is not None else schema # type: ignore[union-attr] self._cache.set_schema(subject_name, registered_schema.schema_id, - registered_schema.guid, s) + registered_schema.guid, s) # type: ignore[arg-type] return registered_schema @@ -724,7 +725,7 @@ def get_schema( `GET Schema API Reference `_ """ # noqa: E501 - result = self._cache.get_schema_by_id(subject_name, schema_id) + result = self._cache.get_schema_by_id(subject_name, schema_id) # type: ignore[arg-type] if result is not None: return result[1] @@ -740,9 +741,9 @@ def get_schema( registered_schema = RegisteredSchema.from_dict(response) self._cache.set_schema(subject_name, schema_id, - registered_schema.guid, registered_schema.schema) + registered_schema.guid, registered_schema.schema) # type: ignore[arg-type] - return registered_schema.schema + return registered_schema.schema # type: ignore[return-value] def get_schema_by_guid( self, guid: str, fmt: Optional[str] = None @@ -778,9 +779,9 @@ def get_schema_by_guid( registered_schema = RegisteredSchema.from_dict(response) self._cache.set_schema(None, registered_schema.schema_id, - registered_schema.guid, registered_schema.schema) + registered_schema.guid, registered_schema.schema) # type: ignore[arg-type] - return registered_schema.schema + return registered_schema.schema # type: ignore[return-value] def get_schema_types(self) -> List[str]: """ @@ -820,7 +821,7 @@ def get_subjects_by_schema_id( """ query = {'offset': offset, 'limit': limit} if subject_name is not None: - query['subject'] = subject_name + query['subject'] = subject_name # type: ignore[assignment] if deleted: query['deleted'] = deleted return self._rest_client.get('schemas/ids/{}/subjects'.format(schema_id), query) @@ -853,7 +854,7 @@ def get_schema_versions( query = {'offset': offset, 'limit': limit} if subject_name is not None: - query['subject'] = subject_name + query['subject'] = subject_name # type: ignore[assignment] if deleted: query['deleted'] = deleted response = self._rest_client.get('schemas/ids/{}/versions'.format(schema_id), query) @@ -894,7 +895,7 @@ def lookup_schema( 'deleted': deleted } if fmt is not None: - query_params['format'] = fmt + query_params['format'] = fmt # type: ignore[assignment] query_string = '&'.join(f"{key}={value}" for key, value in query_params.items()) @@ -944,7 +945,7 @@ def get_subjects( query = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} if subject_prefix is not None: - query['subject'] = subject_prefix + query['subject'] = subject_prefix # type: ignore[assignment] return self._rest_client.get('subjects', query) def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: @@ -1041,11 +1042,11 @@ def get_latest_with_metadata( query = {'deleted': deleted} if fmt is not None: - query['format'] = fmt + query['format'] = fmt # type: ignore[assignment] keys = metadata.keys() if keys: - query['key'] = [_urlencode(key) for key in keys] - query['value'] = [_urlencode(metadata[key]) for key in keys] + query['key'] = [_urlencode(key) for key in keys] # type: ignore[assignment] + query['value'] = [_urlencode(metadata[key]) for key in keys] # type: ignore[assignment] response = self._rest_client.get( 'subjects/{}/metadata'.format(_urlencode(subject_name)), query @@ -1080,7 +1081,7 @@ def get_version( `GET Subject Versions API Reference `_ """ # noqa: E501 - registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) + registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) # type: ignore[arg-type] if registered_schema is not None: return registered_schema @@ -1091,7 +1092,7 @@ def get_version( registered_schema = RegisteredSchema.from_dict(response) - self._cache.set_registered_schema(registered_schema.schema, registered_schema) + self._cache.set_registered_schema(registered_schema.schema, registered_schema) # type: ignore[arg-type] return registered_schema @@ -1516,6 +1517,6 @@ def clear_caches(self): def new_client(conf: dict) -> 'SchemaRegistryClient': from .mock_schema_registry_client import MockSchemaRegistryClient url = conf.get("url") - if url.startswith("mock://"): + if url.startswith("mock://"): # type: ignore[union-attr] return MockSchemaRegistryClient(conf) return SchemaRegistryClient(conf) diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py index 49957ce5c..dd58fce07 100644 --- a/src/confluent_kafka/schema_registry/_sync/serde.py +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -17,7 +17,7 @@ # import logging -from typing import List, Optional, Set, Dict, Any +from typing import List, Optional, Set, Dict, Any, Callable from confluent_kafka.schema_registry import RegisteredSchema from confluent_kafka.schema_registry.common.schema_registry_client import \ @@ -44,6 +44,14 @@ class BaseSerde(object): '_registry', '_rule_registry', '_subject_name_func', '_field_transformer'] + _use_schema_id: Optional[int] + _use_latest_version: bool + _use_latest_with_metadata: Optional[Dict[str, str]] + _registry: Any # SchemaRegistryClient + _rule_registry: Any # RuleRegistry + _subject_name_func: Callable[[Any, Optional[str]], Optional[str]] + _field_transformer: Optional[FieldTransformer] + def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: if self._use_schema_id is not None: schema = self._registry.get_schema(self._use_schema_id, subject, fmt) @@ -114,6 +122,11 @@ def _execute_rules_with_phase( ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule, index, rules, inline_tags, field_transformer) + if rule.type is None: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, + RuleError(f"Rule type is None for rule {rule.name}"), + 'ERROR') + return message rule_executor = self._rule_registry.get_executor(rule.type.upper()) if rule_executor is None: self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, @@ -140,18 +153,24 @@ def _execute_rules_with_phase( return message def _get_on_success(self, rule: Rule) -> Optional[str]: + if rule.type is None: + return rule.on_success override = self._rule_registry.get_override(rule.type) if override is not None and override.on_success is not None: return override.on_success return rule.on_success def _get_on_failure(self, rule: Rule) -> Optional[str]: + if rule.type is None: + return rule.on_failure override = self._rule_registry.get_override(rule.type) if override is not None and override.on_failure is not None: return override.on_failure return rule.on_failure def _is_disabled(self, rule: Rule) -> Optional[bool]: + if rule.type is None: + return rule.disabled override = self._rule_registry.get_override(rule.type) if override is not None and override.disabled is not None: return override.disabled @@ -200,10 +219,16 @@ def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleA class BaseSerializer(BaseSerde, Serializer): __slots__ = ['_auto_register', '_normalize_schemas', '_schema_id_serializer'] + _auto_register: bool + _normalize_schemas: bool + _schema_id_serializer: Callable[[bytes, Any, Any], bytes] + class BaseDeserializer(BaseSerde, Deserializer): __slots__ = ['_schema_id_deserializer'] + _schema_id_deserializer: Callable[[bytes, Any, Any], Any] + def _get_writer_schema( self, schema_id: SchemaId, subject: Optional[str] = None, fmt: Optional[str] = None) -> Schema: @@ -241,7 +266,7 @@ def _get_migrations( ) -> List[Migration]: source = self._registry.lookup_schema( subject, source_info, normalize_schemas=False, deleted=True) - migrations = [] + migrations: List[Migration] = [] if source.version < target.version: migration_mode = RuleMode.UPGRADE first = source @@ -259,13 +284,14 @@ def _get_migrations( if i == 0: previous = version continue - if version.schema.rule_set is not None and self._has_rules( + if version.schema is not None and version.schema.rule_set is not None and self._has_rules( version.schema.rule_set, RulePhase.MIGRATION, migration_mode): - if migration_mode == RuleMode.UPGRADE: - migration = Migration(migration_mode, previous, version) - else: - migration = Migration(migration_mode, version, previous) - migrations.append(migration) + if previous is not None: # previous is always set after first iteration + if migration_mode == RuleMode.UPGRADE: + migration = Migration(migration_mode, previous, version) + else: + migration = Migration(migration_mode, version, previous) + migrations.append(migration) previous = version if migration_mode == RuleMode.DOWNGRADE: migrations.reverse() @@ -275,6 +301,8 @@ def _get_schemas_between( self, subject: str, first: RegisteredSchema, last: RegisteredSchema, fmt: Optional[str] = None ) -> List[RegisteredSchema]: + if first.version is None or last.version is None: + return [first, last] if last.version - first.version <= 1: return [first, last] version1 = first.version @@ -290,8 +318,9 @@ def _execute_migrations( migrations: List[Migration], message: Any ) -> Any: for migration in migrations: - message = self._execute_rules_with_phase( - ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode, - migration.source.schema, migration.target.schema, - message, None, None) + if migration.source is not None and migration.target is not None: + message = self._execute_rules_with_phase( + ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode, + migration.source.schema, migration.target.schema, + message, None, None) return message diff --git a/src/confluent_kafka/schema_registry/common/avro.py b/src/confluent_kafka/schema_registry/common/avro.py index 70eff464c..d9e2345e6 100644 --- a/src/confluent_kafka/schema_registry/common/avro.py +++ b/src/confluent_kafka/schema_registry/common/avro.py @@ -114,10 +114,10 @@ def transform( schema_type = schema.get("type") if schema_type == 'array': return [transform(ctx, schema["items"], item, field_transform) - for item in message] + for item in message] # type: ignore[union-attr] elif schema_type == 'map': return {key: transform(ctx, schema["values"], value, field_transform) - for key, value in message.items()} + for key, value in message.items()} # type: ignore[union-attr] elif schema_type == 'record': fields = schema["fields"] for field in fields: @@ -137,7 +137,7 @@ def _transform_field( ): field_type = field["type"] name = field["name"] - full_name = schema["name"] + "." + name + full_name = schema["name"] + "." + name # type: ignore[call-overload,index] try: ctx.enter_field( message, @@ -146,13 +146,13 @@ def _transform_field( get_type(field_type), None ) - value = message[name] + value = message[name] # type: ignore[index] new_value = transform(ctx, field_type, value, field_transform) if ctx.rule.kind == RuleKind.CONDITION: if new_value is False: raise RuleConditionError(ctx.rule) else: - message[name] = new_value + message[name] = new_value # type: ignore[index] finally: ctx.exit_field() @@ -216,7 +216,7 @@ def _resolve_union(schema: AvroSchema, message: AvroMessage) -> Optional[AvroSch def get_inline_tags(schema: AvroSchema) -> Dict[str, Set[str]]: - inline_tags = defaultdict(set) + inline_tags: Dict[str, Set[str]] = defaultdict(set) _get_inline_tags_recursively('', '', schema, inline_tags) return inline_tags @@ -246,7 +246,7 @@ def _get_inline_tags_recursively( record_ns = _implied_namespace(name) if record_ns is None: record_ns = ns - if record_ns != '' and not record_name.startswith(record_ns): + if record_ns != '' and not record_name.startswith(record_ns): # type: ignore[union-attr] record_name = f"{record_ns}.{record_name}" fields = schema["fields"] for field in fields: @@ -254,9 +254,9 @@ def _get_inline_tags_recursively( field_name = field.get("name") field_type = field.get("type") if field_tags is not None and field_name is not None: - tags[record_name + '.' + field_name].update(field_tags) + tags[record_name + '.' + field_name].update(field_tags) # type: ignore[operator] if field_type is not None: - _get_inline_tags_recursively(record_ns, record_name, field_type, tags) + _get_inline_tags_recursively(record_ns, record_name, field_type, tags) # type: ignore[arg-type] def _implied_namespace(name: str) -> Optional[str]: diff --git a/src/confluent_kafka/schema_registry/common/json_schema.py b/src/confluent_kafka/schema_registry/common/json_schema.py index 0e3564688..4f3cf5920 100644 --- a/src/confluent_kafka/schema_registry/common/json_schema.py +++ b/src/confluent_kafka/schema_registry/common/json_schema.py @@ -42,7 +42,7 @@ JsonSchema = Union[bool, dict] -DEFAULT_SPEC = referencing.jsonschema.DRAFT7 +DEFAULT_SPEC = referencing.jsonschema.DRAFT7 # type: ignore[attr-defined] class _ContextStringIO(BytesIO): @@ -133,14 +133,14 @@ def _transform_field( get_type(prop_schema), get_inline_tags(prop_schema) ) - value = message.get(prop_name) + value = message.get(prop_name) # type: ignore[union-attr] if value is not None: new_value = transform(ctx, prop_schema, ref_registry, ref_resolver, full_name, value, field_transform) if ctx.rule.kind == RuleKind.CONDITION: if new_value is False: raise RuleConditionError(ctx.rule) else: - message[prop_name] = new_value + message[prop_name] = new_value # type: ignore[index,call-overload] finally: ctx.exit_field() @@ -148,11 +148,11 @@ def _transform_field( def _validate_subtypes( schema: JsonSchema, message: JsonMessage, registry: Registry ) -> Optional[JsonSchema]: - schema_type = schema.get("type") + schema_type = schema.get("type") # type: ignore[union-attr] if not isinstance(schema_type, list) or len(schema_type) == 0: return None for typ in schema_type: - schema["type"] = typ + schema["type"] = typ # type: ignore[index] try: validate(instance=message, schema=schema, registry=registry) return schema @@ -166,10 +166,10 @@ def _validate_subschemas( message: JsonMessage, registry: Registry, resolver: Resolver, -) -> Optional[JsonSchema]: +)-> Optional[JsonSchema]: for subschema in subschemas: try: - ref = subschema.get("$ref") + ref = subschema.get("$ref") # type: ignore[union-attr] if ref is not None: # resolve $ref before validating subschema = resolver.lookup(ref).contents @@ -189,10 +189,10 @@ def get_type(schema: JsonSchema) -> FieldType: # string schemas; this could be either a named schema or a primitive type schema_type = schema - if schema.get("const") is not None or schema.get("enum") is not None: + if schema.get("const") is not None or schema.get("enum") is not None: # type: ignore[union-attr] return FieldType.ENUM if schema_type == "object": - props = schema.get("properties") + props = schema.get("properties") # type: ignore[union-attr] if not props: return FieldType.MAP return FieldType.RECORD @@ -209,7 +209,7 @@ def get_type(schema: JsonSchema) -> FieldType: if schema_type == "null": return FieldType.NULL - props = schema.get("properties") + props = schema.get("properties") # type: ignore[union-attr] if props is not None: return FieldType.RECORD @@ -224,7 +224,7 @@ def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: def get_inline_tags(schema: JsonSchema) -> Set[str]: - tags = schema.get("confluent:tags") + tags = schema.get("confluent:tags") # type: ignore[union-attr] if tags is None: return set() else: diff --git a/src/confluent_kafka/schema_registry/common/protobuf.py b/src/confluent_kafka/schema_registry/common/protobuf.py index 3a55bc5de..889b58020 100644 --- a/src/confluent_kafka/schema_registry/common/protobuf.py +++ b/src/confluent_kafka/schema_registry/common/protobuf.py @@ -3,7 +3,7 @@ import base64 from collections import deque from decimal import Context, Decimal, MAX_PREC -from typing import Set, List, Any +from typing import Set, List, Any, Deque from google.protobuf import descriptor_pb2, any_pb2, api_pb2, empty_pb2, \ duration_pb2, field_mask_pb2, source_context_pb2, struct_pb2, timestamp_pb2, \ @@ -57,7 +57,7 @@ def _bytes(v: int) -> bytes: """ return bytes((v,)) else: - def _bytes(v: int) -> str: + def _bytes(v: int) -> str: # type: ignore[misc] """ Convert int to bytes @@ -97,7 +97,7 @@ def _create_index_array(msg_desc: Descriptor) -> List[int]: ValueError: If the message descriptor is malformed. """ - msg_idx = deque() + msg_idx: Deque[int] = deque() # Walk the nested MessageDescriptor tree up to the root. current = msg_desc @@ -310,7 +310,7 @@ def is_map_field(fd: FieldDescriptor): def get_inline_tags(fd: FieldDescriptor) -> Set[str]: - meta = fd.GetOptions().Extensions[meta_pb2.field_meta] + meta = fd.GetOptions().Extensions[meta_pb2.field_meta] # type: ignore[attr-defined] if meta is None: return set() else: @@ -330,7 +330,7 @@ def _is_builtin(name: str) -> bool: name.startswith('google/type/') -def decimal_to_protobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: +def decimal_to_protobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: # type: ignore[name-defined] """ Converts a Decimal to a Protobuf value. @@ -343,7 +343,7 @@ def decimal_to_protobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: """ sign, digits, exp = value.as_tuple() - delta = exp + scale + delta = exp + scale # type: ignore[operator] if delta < 0: raise ValueError( @@ -362,7 +362,7 @@ def decimal_to_protobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: bytes = unscaled_datum.to_bytes(bytes_req, byteorder="big", signed=True) - result = decimal_pb2.Decimal() + result = decimal_pb2.Decimal() # type: ignore[attr-defined] result.value = bytes result.precision = 0 result.scale = scale @@ -372,7 +372,7 @@ def decimal_to_protobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: decimal_context = Context() -def protobuf_to_decimal(value: decimal_pb2.Decimal) -> Decimal: +def protobuf_to_decimal(value: decimal_pb2.Decimal) -> Decimal: # type: ignore[name-defined] """ Converts a Protobuf value to Decimal. diff --git a/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py b/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py index 0a0542ccd..0b385dc4d 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py +++ b/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py @@ -55,24 +55,24 @@ def transform(self, ctx: RuleContext, msg: Any) -> Any: def execute(self, ctx: RuleContext, msg: Any, args: Any) -> Any: expr = ctx.rule.expr try: - index = expr.index(";") + index = expr.index(";") # type: ignore[union-attr] except ValueError: index = -1 if index >= 0: - guard = expr[:index] + guard = expr[:index] # type: ignore[index] if len(guard.strip()) > 0: guard_result = self.execute_rule(ctx, guard, args) if not guard_result: if ctx.rule.kind == RuleKind.CONDITION: return True return msg - expr = expr[index+1:] + expr = expr[index+1:] # type: ignore[index] - return self.execute_rule(ctx, expr, args) + return self.execute_rule(ctx, expr, args) # type: ignore[arg-type] def execute_rule(self, ctx: RuleContext, expr: str, args: Any) -> Any: schema = ctx.target - script_type = ctx.target.schema_type + script_type = ctx.target.schema_type # type: ignore[union-attr] prog = self._cache.get_program(expr, script_type, schema) if prog is None: ast = self._env.compile(expr) @@ -158,7 +158,7 @@ def _dict_to_cel(val: dict) -> Dict[str, celtypes.Value]: result = celtypes.MapType() for key, val in val.items(): result[key] = _value_to_cel(val) - return result + return result # type: ignore[return-value] def _array_to_cel(val: list) -> List[celtypes.Value]: diff --git a/src/confluent_kafka/schema_registry/rules/cel/constraints.py b/src/confluent_kafka/schema_registry/rules/cel/constraints.py index 98b739cd3..a50dc0210 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/constraints.py +++ b/src/confluent_kafka/schema_registry/rules/cel/constraints.py @@ -27,7 +27,7 @@ class CompilationError(Exception): def make_key_path(field_name: str, key: celtypes.Value) -> str: - return f"{field_name}[{string_format.format_value(key)}]" + return f"{field_name}[{string_format.format_value(key)}]" # type: ignore[str-bytes-safe] def make_duration(msg: message.Message) -> celtypes.DurationType: @@ -38,7 +38,7 @@ def make_duration(msg: message.Message) -> celtypes.DurationType: def make_timestamp(msg: message.Message) -> celtypes.TimestampType: - return make_duration(msg) + celtypes.TimestampType(1970, 1, 1) + return make_duration(msg) + celtypes.TimestampType(1970, 1, 1) # type: ignore[return-value] def unwrap(msg: message.Message) -> celtypes.Value: @@ -87,8 +87,8 @@ def __getitem__(self, name): def _msg_to_cel(msg: message.Message) -> typing.Dict[str, celtypes.Value]: ctor = _MSG_TYPE_URL_TO_CTOR.get(msg.DESCRIPTOR.full_name) if ctor is not None: - return ctor(msg) - return MessageType(msg) + return ctor(msg) # type: ignore[return-value] + return MessageType(msg) # type: ignore[return-value] _TYPE_TO_CTOR = { @@ -132,7 +132,7 @@ def _scalar_field_value_to_cel(val: typing.Any, field: descriptor.FieldDescripto if ctor is None: msg = "unknown field type" raise CompilationError(msg) - return ctor(val) + return ctor(val) # type: ignore[operator] def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value: diff --git a/src/confluent_kafka/schema_registry/rules/cel/extra_func.py b/src/confluent_kafka/schema_registry/rules/cel/extra_func.py index 497e9632a..194ce826f 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/extra_func.py +++ b/src/confluent_kafka/schema_registry/rules/cel/extra_func.py @@ -143,14 +143,14 @@ def is_email(string: celtypes.Value) -> celpy.Result: def is_uri(string: celtypes.Value) -> celpy.Result: - url = urlparse.urlparse(string) + url = urlparse.urlparse(string) # type: ignore[arg-type] if not all([url.scheme, url.netloc, url.path]): return celtypes.BoolType(False) return celtypes.BoolType(True) def is_uri_ref(string: celtypes.Value) -> celpy.Result: - url = urlparse.urlparse(string) + url = urlparse.urlparse(string) # type: ignore[arg-type] if not all([url.scheme, url.path]) and url.fragment: return celtypes.BoolType(False) return celtypes.BoolType(True) diff --git a/src/confluent_kafka/schema_registry/rules/cel/string_format.py b/src/confluent_kafka/schema_registry/rules/cel/string_format.py index 3e295849c..39db450b1 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/string_format.py +++ b/src/confluent_kafka/schema_registry/rules/cel/string_format.py @@ -43,9 +43,9 @@ def __init__(self, locale: str): def format(self, fmt: celtypes.Value, args: celtypes.Value) -> celpy.Result: if not isinstance(fmt, celtypes.StringType): - return celpy.native_to_cel(celpy.new_error("format() requires a string as the first argument")) + return celpy.native_to_cel(celpy.new_error("format() requires a string as the first argument")) # type: ignore[attr-defined] if not isinstance(args, celtypes.ListType): - return celpy.native_to_cel(celpy.new_error("format() requires a list as the second argument")) + return celpy.native_to_cel(celpy.new_error("format() requires a list as the second argument")) # type: ignore[attr-defined] # printf style formatting i = 0 j = 0 @@ -77,21 +77,21 @@ def format(self, fmt: celtypes.Value, args: celtypes.Value) -> celpy.Result: if i >= len(fmt): return celpy.CELEvalError("format() incomplete format specifier") if fmt[i] == "f": - result += self.format_float(arg, precision) + result += self.format_float(arg, precision) # type: ignore[operator,assignment] if fmt[i] == "e": - result += self.format_exponential(arg, precision) + result += self.format_exponential(arg, precision) # type: ignore[operator,assignment] elif fmt[i] == "d": - result += self.format_int(arg) + result += self.format_int(arg) # type: ignore[operator,assignment] elif fmt[i] == "s": - result += self.format_string(arg) + result += self.format_string(arg) # type: ignore[operator,assignment] elif fmt[i] == "x": - result += self.format_hex(arg) + result += self.format_hex(arg) # type: ignore[operator,assignment] elif fmt[i] == "X": - result += self.format_hex(arg).upper() + result += self.format_hex(arg).upper() # type: ignore[operator,assignment,union-attr,call-arg] elif fmt[i] == "o": - result += self.format_oct(arg) + result += self.format_oct(arg) # type: ignore[operator,assignment] elif fmt[i] == "b": - result += self.format_bin(arg) + result += self.format_bin(arg) # type: ignore[operator,assignment] else: return celpy.CELEvalError("format() unknown format specifier: " + fmt[i]) i += 1 @@ -111,9 +111,9 @@ def format_exponential(self, arg: celtypes.Value, precision: int) -> celpy.Resul def format_int(self, arg: celtypes.Value) -> celpy.Result: if isinstance(arg, celtypes.IntType): - return celtypes.StringType(arg) + return celtypes.StringType(arg) # type: ignore[arg-type] if isinstance(arg, celtypes.UintType): - return celtypes.StringType(arg) + return celtypes.StringType(arg) # type: ignore[arg-type] return celpy.CELEvalError("format_int() requires an integer argument") def format_hex(self, arg: celtypes.Value) -> celpy.Result: @@ -150,13 +150,13 @@ def format_string(self, arg: celtypes.Value) -> celpy.Result: return celtypes.StringType(arg.hex()) if isinstance(arg, celtypes.ListType): return self.format_list(arg) - return celtypes.StringType(arg) + return celtypes.StringType(arg) # type: ignore[arg-type] def format_value(self, arg: celtypes.Value) -> celpy.Result: if isinstance(arg, (celtypes.StringType, str)): return celtypes.StringType(quote(arg)) if isinstance(arg, celtypes.UintType): - return celtypes.StringType(arg) + return celtypes.StringType(arg) # type: ignore[arg-type] return self.format_string(arg) def format_list(self, arg: celtypes.ListType) -> celpy.Result: @@ -164,7 +164,7 @@ def format_list(self, arg: celtypes.ListType) -> celpy.Result: for i in range(len(arg)): if i > 0: result += ", " - result += self.format_value(arg[i]) + result += self.format_value(arg[i]) # type: ignore[operator,assignment] result += "]" return celtypes.StringType(result) diff --git a/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py b/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py index cf33ca2d6..d99b65ba4 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py @@ -47,12 +47,13 @@ def __init__( if not key_uri: self._key_uri = None + self._client = None # type: ignore[assignment] elif key_uri.startswith(AZURE_KEYURI_PREFIX): self._key_uri = key_uri + key_id = key_uri[len(AZURE_KEYURI_PREFIX):] + self._client = CryptographyClient(key_id, credentials) else: raise tink.TinkError('Invalid key_uri.') - key_id = key_uri[len(AZURE_KEYURI_PREFIX):] - self._client = CryptographyClient(key_id, credentials) def does_support(self, key_uri: str) -> bool: """Returns true iff this client supports KMS key specified in 'key_uri'. @@ -83,4 +84,4 @@ def get_aead(self, key_uri: str) -> aead.Aead: ) if not key_uri.startswith(AZURE_KEYURI_PREFIX): raise tink.TinkError('Invalid key_uri.') - return AzureKmsAead(self._client, EncryptionAlgorithm.rsa_oaep_256) + return AzureKmsAead(self._client, EncryptionAlgorithm.rsa_oaep_256) # type: ignore[arg-type] diff --git a/src/confluent_kafka/schema_registry/rules/encryption/dek_registry/dek_registry_client.py b/src/confluent_kafka/schema_registry/rules/encryption/dek_registry/dek_registry_client.py index 55195407c..1c130a8fa 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/dek_registry/dek_registry_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/dek_registry/dek_registry_client.py @@ -46,7 +46,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: d = src_dict.copy() kek_kms_props = cls() - kek_kms_props.properties = d + kek_kms_props.properties = d # type: ignore[attr-defined] return kek_kms_props @@ -124,7 +124,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: deleted = d.pop("deleted", None) - kek = cls( + kek = cls( # type: ignore[call-arg] name=name, kms_type=kms_type, kms_key_id=kms_key_id, @@ -198,7 +198,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: shared = d.pop("shared", None) - create_kek_request = cls( + create_kek_request = cls( # type: ignore[call-arg] name=name, kms_type=kms_type, kms_key_id=kms_key_id, @@ -321,7 +321,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: deleted = d.pop("deleted", None) - dek = cls( + dek = cls( # type: ignore[call-arg] kek_name=kek_name, subject=subject, version=version, @@ -381,7 +381,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: encrypted_key_material = d.pop("encryptedKeyMaterial", None) - create_dek_request = cls( + create_dek_request = cls( # type: ignore[call-arg] subject=subject, version=version, algorithm=algorithm, diff --git a/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py b/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py index 45a4ab99a..5a6b76c5f 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py @@ -58,8 +58,8 @@ def now(self) -> int: class EncryptionExecutor(RuleExecutor): def __init__(self, clock: Clock = Clock()): - self.client = None - self.config = None + self.client: Optional[DekRegistryClient] = None + self.config: Optional[dict] = None self.clock = clock def configure(self, client_conf: dict, rule_conf: dict): @@ -203,7 +203,7 @@ def __init__(self, executor: EncryptionExecutor, cryptor: Cryptor, kek_name: str self._executor = executor self._cryptor = cryptor self._kek_name = kek_name - self._kek = None + self._kek: Optional[Kek] = None self._dek_expiry_days = dek_expiry_days def _is_dek_rotated(self): @@ -212,7 +212,7 @@ def _is_dek_rotated(self): def _get_kek(self, ctx: RuleContext) -> Kek: if self._kek is None: self._kek = self._get_or_create_kek(ctx) - return self._kek + return self._kek # type: ignore[return-value] def _get_or_create_kek(self, ctx: RuleContext) -> Kek: is_read = ctx.rule_mode == RuleMode.READ @@ -243,7 +243,7 @@ def _get_or_create_kek(self, ctx: RuleContext) -> Kek: def _retrieve_kek_from_registry(self, kek_id: KekId) -> Optional[Kek]: try: - return self._executor.client.get_kek(kek_id.name, kek_id.deleted) + return self._executor.client.get_kek(kek_id.name, kek_id.deleted) # type: ignore[union-attr] except Exception as e: if isinstance(e, SchemaRegistryError) and e.http_status_code == 404: return None @@ -254,7 +254,7 @@ def _store_kek_to_registry( kms_key_id: str, shared: bool ) -> Optional[Kek]: try: - return self._executor.client.register_kek(kek_id.name, kms_type, kms_key_id, shared) + return self._executor.client.register_kek(kek_id.name, kms_type, kms_key_id, shared) # type: ignore[union-attr] except Exception as e: if isinstance(e, SchemaRegistryError) and e.http_status_code == 409: return None @@ -266,7 +266,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: if version is None or version == 0: version = 1 dek_id = DekId( - kek.name, + kek.name, # type: ignore[arg-type] ctx.subject, version, self._cryptor.dek_format, @@ -280,10 +280,10 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: raise RuleError(f"no dek found for {dek_id.kek_name} during consume") encrypted_dek = None if not kek.shared: - primitive = AeadWrapper(self._executor.config, self._kek) + primitive = AeadWrapper(self._executor.config, self._kek) # type: ignore[arg-type] raw_dek = self._cryptor.generate_key() encrypted_dek = primitive.encrypt(raw_dek, self._cryptor.EMPTY_AAD) - new_version = dek.version + 1 if is_expired else 1 + new_version = dek.version + 1 if is_expired else 1 # type: ignore[union-attr,operator] try: dek = self._create_dek(dek_id, new_version, encrypted_dek) except RuleError as e: @@ -294,9 +294,9 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: key_bytes = dek.get_key_material_bytes() if key_bytes is None: if primitive is None: - primitive = AeadWrapper(self._executor.config, self._kek) + primitive = AeadWrapper(self._executor.config, self._kek) # type: ignore[arg-type] encrypted_dek = dek.get_encrypted_key_material_bytes() - raw_dek = primitive.decrypt(encrypted_dek, self._cryptor.EMPTY_AAD) + raw_dek = primitive.decrypt(encrypted_dek, self._cryptor.EMPTY_AAD) # type: ignore[arg-type] dek.set_key_material(raw_dek) return dek @@ -304,7 +304,7 @@ def _create_dek(self, dek_id: DekId, new_version: Optional[int], encrypted_dek: new_dek_id = DekId( dek_id.kek_name, dek_id.subject, - new_version, + new_version, # type: ignore[arg-type] dek_id.algorithm, dek_id.deleted, ) @@ -321,7 +321,7 @@ def _retrieve_dek_from_registry(self, key: DekId) -> Optional[Dek]: version = key.version if not version: version = 1 - dek = self._executor.client.get_dek( + dek = self._executor.client.get_dek( # type: ignore[union-attr] key.kek_name, key.subject, key.algorithm, version, key.deleted) return dek if dek and dek.encrypted_key_material else None except Exception as e: @@ -332,8 +332,8 @@ def _retrieve_dek_from_registry(self, key: DekId) -> Optional[Dek]: def _store_dek_to_registry(self, key: DekId, encrypted_dek: Optional[bytes]) -> Optional[Dek]: try: encrypted_dek_str = base64.b64encode(encrypted_dek).decode("utf-8") if encrypted_dek else None - dek = self._executor.client.register_dek( - key.kek_name, key.subject, encrypted_dek_str, key.algorithm, key.version) + dek = self._executor.client.register_dek( # type: ignore[union-attr] + key.kek_name, key.subject, encrypted_dek_str, key.algorithm, key.version) # type: ignore[arg-type] return dek except Exception as e: if isinstance(e, SchemaRegistryError) and e.http_status_code == 409: @@ -345,7 +345,7 @@ def _is_expired(self, ctx: RuleContext, dek: Optional[Dek]) -> bool: return (ctx.rule_mode != RuleMode.READ and self._dek_expiry_days > 0 and dek is not None - and (now - dek.ts) / MILLIS_IN_DAY > self._dek_expiry_days) + and (now - dek.ts) / MILLIS_IN_DAY > self._dek_expiry_days) # type: ignore[operator] def transform(self, ctx: RuleContext, field_type: FieldType, field_value: Any) -> Any: if field_value is None: @@ -359,9 +359,9 @@ def transform(self, ctx: RuleContext, field_type: FieldType, field_value: Any) - version = -1 dek = self._get_or_create_dek(ctx, version) key_material_bytes = dek.get_key_material_bytes() - ciphertext = self._cryptor.encrypt(key_material_bytes, plaintext, Cryptor.EMPTY_AAD) + ciphertext = self._cryptor.encrypt(key_material_bytes, plaintext, Cryptor.EMPTY_AAD) # type: ignore[arg-type] if self._is_dek_rotated(): - ciphertext = self._prefix_version(dek.version, ciphertext) + ciphertext = self._prefix_version(dek.version, ciphertext) # type: ignore[arg-type] if field_type == FieldType.STRING: return base64.b64encode(ciphertext).decode("utf-8") else: @@ -370,7 +370,7 @@ def transform(self, ctx: RuleContext, field_type: FieldType, field_value: Any) - if field_type == FieldType.STRING: ciphertext = base64.b64decode(field_value) else: - ciphertext = self._to_bytes(field_type, field_value) + ciphertext = self._to_bytes(field_type, field_value) # type: ignore[assignment] if ciphertext is None: return field_value @@ -381,7 +381,7 @@ def transform(self, ctx: RuleContext, field_type: FieldType, field_value: Any) - raise RuleError("no version found in ciphertext") dek = self._get_or_create_dek(ctx, version) key_material_bytes = dek.get_key_material_bytes() - plaintext = self._cryptor.decrypt(key_material_bytes, ciphertext, Cryptor.EMPTY_AAD) + plaintext = self._cryptor.decrypt(key_material_bytes, ciphertext, Cryptor.EMPTY_AAD) # type: ignore[arg-type] return self._to_object(field_type, plaintext) else: raise RuleError(f"unsupported rule mode {ctx.rule_mode}") @@ -421,7 +421,7 @@ def __init__(self, config: dict, kek: Kek): def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes: for index, kms_key_id in enumerate(self._kms_key_ids): try: - aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) + aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) # type: ignore[arg-type] return aead.encrypt(plaintext, associated_data) except Exception as e: log.warning("failed to encrypt with kek %s and kms key id %s", @@ -433,7 +433,7 @@ def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes: def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes: for index, kms_key_id in enumerate(self._kms_key_ids): try: - aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) + aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) # type: ignore[arg-type] return aead.decrypt(ciphertext, associated_data) except Exception as e: log.warning("failed to decrypt with kek %s and kms key id %s", @@ -443,7 +443,7 @@ def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes: raise RuleError("No KEK found for decryption") def _get_kms_key_ids(self) -> List[str]: - kms_key_ids = [self._kek.kms_key_id] + kms_key_ids = [self._kek.kms_key_id] # type: ignore[list-item] alternate_kms_key_ids = None if self._kek.kms_props is not None: alternate_kms_key_ids = self._kek.kms_props.properties.get(ENCRYPT_ALTERNATE_KMS_KEY_IDS) @@ -452,7 +452,7 @@ def _get_kms_key_ids(self) -> List[str]: if alternate_kms_key_ids is not None: # Split the comma-separated list of alternate KMS key IDs and append to kms_key_ids kms_key_ids.extend([id.strip() for id in alternate_kms_key_ids.split(',') if id.strip()]) - return kms_key_ids + return kms_key_ids # type: ignore[return-value] def _get_aead(self, config: dict, kms_type: str, kms_key_id: str) -> aead.Aead: kek_url = kms_type + "://" + kms_key_id diff --git a/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py b/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py index f8523d73f..d85a11b64 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py @@ -48,19 +48,20 @@ def __init__( if not key_uri: self._key_uri = None + self._client = None # type: ignore[assignment] elif key_uri.startswith(VAULT_KEYURI_PREFIX): self._key_uri = key_uri + parsed = urlparse(key_uri[len(VAULT_KEYURI_PREFIX):]) + vault_url = parsed.scheme + '://' + parsed.netloc + self._client = hvac.Client( + url=vault_url, + token=token, + namespace=ns, + verify=False + ) else: raise tink.TinkError('Invalid key_uri.') - parsed = urlparse(key_uri[len(VAULT_KEYURI_PREFIX):]) - vault_url = parsed.scheme + '://' + parsed.netloc - self._client = hvac.Client( - url=vault_url, - token=token, - namespace=ns, - verify=False - ) - if role_id and secret_id: + if role_id and secret_id and self._client is not None: self._client.auth.approle.login(role_id=role_id, secret_id=secret_id) def does_support(self, key_uri: str) -> bool: diff --git a/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py b/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py index ba92e0478..f9b0bb452 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py @@ -23,7 +23,7 @@ class LocalKmsClient(KmsClient): def __init__(self, secret: Optional[str] = None): - self._aead = self._get_primitive(secret) + self._aead = self._get_primitive(secret) # type: ignore[arg-type] def _get_primitive(self, secret: str) -> aead.Aead: key = self._get_key(secret) From d536a71185ef65406fc93e435907e57732473b39 Mon Sep 17 00:00:00 2001 From: Naxin Date: Mon, 20 Oct 2025 18:46:44 -0400 Subject: [PATCH 19/31] encryption clients --- .../rules/encryption/azurekms/azure_client.py | 15 +++++------ .../encryption/hcvault/hcvault_client.py | 27 +++++++++---------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py b/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py index d99b65ba4..a98ab5116 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py @@ -30,30 +30,27 @@ class AzureKmsClient(tink.KmsClient): """Basic Azure client for AEAD.""" def __init__( - self, key_uri: Optional[str], credentials: TokenCredential + self, key_uri: str, credentials: TokenCredential ) -> None: """Creates a new AzureKmsClient that is bound to the key specified in 'key_uri'. Uses the specified credentials when communicating with the KMS. Args: - key_uri: The URI of the key the client should be bound to. If it is None - or empty, then the client is not bound to any particular key. + key_uri: The URI of the key the client should be bound to. credentials: The token credentials. Raises: TinkError: If the key uri is not valid. """ - if not key_uri: - self._key_uri = None - self._client = None # type: ignore[assignment] - elif key_uri.startswith(AZURE_KEYURI_PREFIX): + if key_uri.startswith(AZURE_KEYURI_PREFIX): self._key_uri = key_uri - key_id = key_uri[len(AZURE_KEYURI_PREFIX):] - self._client = CryptographyClient(key_id, credentials) else: raise tink.TinkError('Invalid key_uri.') + + key_id = key_uri[len(AZURE_KEYURI_PREFIX):] + self._client = CryptographyClient(key_id, credentials) def does_support(self, key_uri: str) -> bool: """Returns true iff this client supports KMS key specified in 'key_uri'. diff --git a/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py b/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py index d85a11b64..adee38241 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py @@ -29,7 +29,7 @@ class HcVaultKmsClient(tink.KmsClient): """Basic HashiCorp Vault client for AEAD.""" def __init__( - self, key_uri: Optional[str], token: Optional[str], ns: Optional[str] = None, + self, key_uri: str, token: Optional[str], ns: Optional[str] = None, role_id: Optional[str] = None, secret_id: Optional[str] = None ) -> None: """Creates a new HcVaultKmsClient that is bound to the key specified in 'key_uri'. @@ -37,8 +37,7 @@ def __init__( Uses the specified credentials when communicating with the KMS. Args: - key_uri: The URI of the key the client should be bound to. If it is None - or empty, then the client is not bound to any particular key. + key_uri: The URI of the key the client should be bound to. token: The Vault token. ns: The Vault namespace. @@ -46,21 +45,19 @@ def __init__( TinkError: If the key uri is not valid. """ - if not key_uri: - self._key_uri = None - self._client = None # type: ignore[assignment] - elif key_uri.startswith(VAULT_KEYURI_PREFIX): + if key_uri.startswith(VAULT_KEYURI_PREFIX): self._key_uri = key_uri - parsed = urlparse(key_uri[len(VAULT_KEYURI_PREFIX):]) - vault_url = parsed.scheme + '://' + parsed.netloc - self._client = hvac.Client( - url=vault_url, - token=token, - namespace=ns, - verify=False - ) else: raise tink.TinkError('Invalid key_uri.') + + parsed = urlparse(key_uri[len(VAULT_KEYURI_PREFIX):]) + vault_url = parsed.scheme + '://' + parsed.netloc + self._client = hvac.Client( + url=vault_url, + token=token, + namespace=ns, + verify=False + ) if role_id and secret_id and self._client is not None: self._client.auth.approle.login(role_id=role_id, secret_id=secret_id) From cdbd20300006c082f5585456cec36a020ebe33bd Mon Sep 17 00:00:00 2001 From: Naxin Date: Thu, 23 Oct 2025 00:15:24 -0400 Subject: [PATCH 20/31] fix --- src/confluent_kafka/admin/__init__.py | 3 +-- src/confluent_kafka/admin/_metadata.py | 4 +++- src/confluent_kafka/cimpl.pyi | 11 +++++++++-- src/confluent_kafka/experimental/aio/_common.py | 2 +- src/confluent_kafka/serialization/__init__.py | 4 ++-- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/confluent_kafka/admin/__init__.py b/src/confluent_kafka/admin/__init__.py index b4705fa03..552c9a417 100644 --- a/src/confluent_kafka/admin/__init__.py +++ b/src/confluent_kafka/admin/__init__.py @@ -652,8 +652,7 @@ def list_topics(self, *args: Any, **kwargs: Any) -> ClusterMetadata: return super(AdminClient, self).list_topics(*args, **kwargs) - def list_groups(self, *args: Any, **kwargs: Any) -> GroupMetadata: - + def list_groups(self, *args: Any, **kwargs: Any) -> List[GroupMetadata]: return super(AdminClient, self).list_groups(*args, **kwargs) def create_partitions( # type: ignore[override] diff --git a/src/confluent_kafka/admin/_metadata.py b/src/confluent_kafka/admin/_metadata.py index 8132e3bc1..f5d58b01b 100644 --- a/src/confluent_kafka/admin/_metadata.py +++ b/src/confluent_kafka/admin/_metadata.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Optional + +from confluent_kafka.cimpl import KafkaError class ClusterMetadata(object): diff --git a/src/confluent_kafka/cimpl.pyi b/src/confluent_kafka/cimpl.pyi index 332d552bb..68d04b936 100644 --- a/src/confluent_kafka/cimpl.pyi +++ b/src/confluent_kafka/cimpl.pyi @@ -38,6 +38,8 @@ from typing import Any, Optional, Callable, List, Tuple, Dict, Union, overload from typing_extensions import Self, Literal import builtins +from confluent_kafka.admin._metadata import ClusterMetadata, GroupMetadata + from ._types import HeadersType # Callback types with proper class references (defined locally to avoid circular imports) @@ -199,6 +201,11 @@ class Consumer: message: Optional['Message'] = None, offsets: Optional[List[TopicPartition]] = None ) -> None: ... + def committed( + self, + partitions: List[TopicPartition], + timeout: float = -1 + ) -> List[TopicPartition]: ... def close(self) -> None: ... def list_topics(self, topic: Optional[str] = None, timeout: float = -1) -> Any: ... def offsets_for_times( @@ -255,12 +262,12 @@ class _AdminClientImpl: self, topic: Optional[str] = None, timeout: float = -1 - ) -> Any: ... + ) -> ClusterMetadata: ... def list_groups( self, group: Optional[str] = None, timeout: float = -1 - ) -> Any: ... + ) -> List[GroupMetadata]: ... def describe_consumer_groups( self, group_ids: List[str], diff --git a/src/confluent_kafka/experimental/aio/_common.py b/src/confluent_kafka/experimental/aio/_common.py index 24659ed9f..3bf274064 100644 --- a/src/confluent_kafka/experimental/aio/_common.py +++ b/src/confluent_kafka/experimental/aio/_common.py @@ -32,7 +32,7 @@ def __init__( self.logger = logger def log(self, *args: Any, **kwargs: Any) -> None: - self.loop.call_soon_threadsafe(self.logger.log, *args, **kwargs) + self.loop.call_soon_threadsafe(lambda: self.logger.log(*args, **kwargs)) def wrap_callback( diff --git a/src/confluent_kafka/serialization/__init__.py b/src/confluent_kafka/serialization/__init__.py index ed59f3c1e..d90817190 100644 --- a/src/confluent_kafka/serialization/__init__.py +++ b/src/confluent_kafka/serialization/__init__.py @@ -17,7 +17,7 @@ # import struct as _struct from enum import Enum -from typing import Any, Optional +from typing import Any, List, Optional from confluent_kafka.error import KafkaException from confluent_kafka._types import HeadersType @@ -114,7 +114,7 @@ class Serializer(object): - unicode(encoding) """ - __slots__ = [] + __slots__: List[str] = [] def __call__(self, obj: Any, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ From 6fa87301e48b16a0b9f1769e841b87e88f99d673 Mon Sep 17 00:00:00 2001 From: Naxin Date: Thu, 23 Oct 2025 00:26:54 -0400 Subject: [PATCH 21/31] revert incorrect merge conflict changes --- src/confluent_kafka/_model/__init__.py | 16 ++++++++-------- src/confluent_kafka/admin/_acl.py | 12 ++++++------ src/confluent_kafka/admin/_config.py | 4 +++- src/confluent_kafka/admin/_listoffsets.py | 2 +- src/confluent_kafka/admin/_resource.py | 8 ++++---- src/confluent_kafka/admin/_scram.py | 4 ++-- .../experimental/aio/producer/_AIOProducer.py | 2 +- .../aio/producer/_kafka_batch_executor.py | 16 ++++++++++++++++ src/confluent_kafka/serialization/__init__.py | 2 +- 9 files changed, 42 insertions(+), 24 deletions(-) diff --git a/src/confluent_kafka/_model/__init__.py b/src/confluent_kafka/_model/__init__.py index 8c775dbd9..0f072b5a4 100644 --- a/src/confluent_kafka/_model/__init__.py +++ b/src/confluent_kafka/_model/__init__.py @@ -91,8 +91,8 @@ class ConsumerGroupState(Enum): #: Consumer Group is empty. EMPTY = cimpl.CONSUMER_GROUP_STATE_EMPTY - def __lt__(self, other) -> Any: - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, ConsumerGroupState): return NotImplemented return self.value < other.value @@ -111,8 +111,8 @@ class ConsumerGroupType(Enum): #: Classic Type CLASSIC = cimpl.CONSUMER_GROUP_TYPE_CLASSIC - def __lt__(self, other) -> Any: - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, ConsumerGroupType): return NotImplemented return self.value < other.value @@ -167,8 +167,8 @@ class IsolationLevel(Enum): READ_UNCOMMITTED = cimpl.ISOLATION_LEVEL_READ_UNCOMMITTED #: Receive all the offsets. READ_COMMITTED = cimpl.ISOLATION_LEVEL_READ_COMMITTED #: Skip offsets belonging to an aborted transaction. - def __lt__(self, other) -> Any: - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, IsolationLevel): return NotImplemented return self.value < other.value @@ -186,7 +186,7 @@ class ElectionType(Enum): #: Unclean election UNCLEAN = cimpl.ELECTION_TYPE_UNCLEAN - def __lt__(self, other) -> Any: - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, ElectionType): return NotImplemented return self.value < other.value diff --git a/src/confluent_kafka/admin/_acl.py b/src/confluent_kafka/admin/_acl.py index d318c97ec..75adc0c8f 100644 --- a/src/confluent_kafka/admin/_acl.py +++ b/src/confluent_kafka/admin/_acl.py @@ -44,8 +44,8 @@ class AclOperation(Enum): ALTER_CONFIGS = _cimpl.ACL_OPERATION_ALTER_CONFIGS #: ALTER_CONFIGS operation IDEMPOTENT_WRITE = _cimpl.ACL_OPERATION_IDEMPOTENT_WRITE #: IDEMPOTENT_WRITE operation - def __lt__(self, other: 'AclOperation') -> Any: - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, AclOperation): return NotImplemented return self.value < other.value @@ -59,8 +59,8 @@ class AclPermissionType(Enum): DENY = _cimpl.ACL_PERMISSION_TYPE_DENY #: Disallows access ALLOW = _cimpl.ACL_PERMISSION_TYPE_ALLOW #: Grants access - def __lt__(self, other: 'AclPermissionType') -> Any: - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, AclPermissionType): return NotImplemented return self.value < other.value @@ -161,8 +161,8 @@ def _to_tuple(self) -> Tuple[ResourceType, str, ResourcePatternType, str, str, A def __hash__(self) -> int: return hash(self._to_tuple()) - def __lt__(self, other: 'AclBinding') -> Any: - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, AclBinding): return NotImplemented return self._to_tuple() < other._to_tuple() diff --git a/src/confluent_kafka/admin/_config.py b/src/confluent_kafka/admin/_config.py index c303f7bbf..70b19c6aa 100644 --- a/src/confluent_kafka/admin/_config.py +++ b/src/confluent_kafka/admin/_config.py @@ -185,7 +185,9 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash((self.restype, self.name)) - def __lt__(self, other: 'ConfigResource') -> bool: + def __lt__(self, other: object) -> bool: + if not isinstance(other, ConfigResource): + return NotImplemented if self.restype < other.restype: return True return self.name.__lt__(other.name) diff --git a/src/confluent_kafka/admin/_listoffsets.py b/src/confluent_kafka/admin/_listoffsets.py index 205e852be..e23d75257 100644 --- a/src/confluent_kafka/admin/_listoffsets.py +++ b/src/confluent_kafka/admin/_listoffsets.py @@ -68,7 +68,7 @@ def __new__(cls, index: int): else: return cls.for_timestamp(index) - def __lt__(self, other) -> Any: + def __lt__(self, other: object) -> bool: if not isinstance(other, OffsetSpec): return NotImplemented return self._value < other._value diff --git a/src/confluent_kafka/admin/_resource.py b/src/confluent_kafka/admin/_resource.py index 8fa6dd19b..131c56407 100644 --- a/src/confluent_kafka/admin/_resource.py +++ b/src/confluent_kafka/admin/_resource.py @@ -28,8 +28,8 @@ class ResourceType(Enum): BROKER = _cimpl.RESOURCE_BROKER #: Broker resource. Resource name is broker id. TRANSACTIONAL_ID = _cimpl.RESOURCE_TRANSACTIONAL_ID #: Transactional ID resource. - def __lt__(self, other: 'ResourceType') -> Any: - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, ResourceType): return NotImplemented return self.value < other.value @@ -44,7 +44,7 @@ class ResourcePatternType(Enum): LITERAL = _cimpl.RESOURCE_PATTERN_LITERAL #: Literal: A literal resource name PREFIXED = _cimpl.RESOURCE_PATTERN_PREFIXED #: Prefixed: A prefixed resource name - def __lt__(self, other: 'ResourcePatternType') -> Any: - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, ResourcePatternType): return NotImplemented return self.value < other.value diff --git a/src/confluent_kafka/admin/_scram.py b/src/confluent_kafka/admin/_scram.py index 76c999dbc..e0ba07249 100644 --- a/src/confluent_kafka/admin/_scram.py +++ b/src/confluent_kafka/admin/_scram.py @@ -26,8 +26,8 @@ class ScramMechanism(Enum): SCRAM_SHA_256 = cimpl.SCRAM_MECHANISM_SHA_256 #: SCRAM-SHA-256 mechanism SCRAM_SHA_512 = cimpl.SCRAM_MECHANISM_SHA_512 #: SCRAM-SHA-512 mechanism - def __lt__(self, other) -> Any: - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, ScramMechanism): return NotImplemented return self.value < other.value diff --git a/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py b/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py index b2c7b86df..5e2fd8fb6 100644 --- a/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py +++ b/src/confluent_kafka/experimental/aio/producer/_AIOProducer.py @@ -224,7 +224,7 @@ async def flush(self, *args: Any, **kwargs: Any) -> Any: # Update buffer activity since we just flushed self._buffer_timeout_manager.mark_activity() - # Then flush underlying producer and wait for delivery confirmation + # Then flush the underlying producer and wait for delivery confirmation return await self._call(self._producer.flush, *args, **kwargs) async def purge(self, *args: Any, **kwargs: Any) -> Any: diff --git a/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py b/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py index 3253bb083..8af98f3c3 100644 --- a/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py +++ b/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py @@ -19,6 +19,8 @@ import confluent_kafka +from .. import _common + logger = logging.getLogger(__name__) @@ -106,6 +108,20 @@ def _produce_batch_and_poll() -> int: loop = asyncio.get_running_loop() return await loop.run_in_executor(self._executor, _produce_batch_and_poll) + async def flush_librdkafka_queue(self, timeout=-1): + """Flush the librdkafka queue and wait for all messages to be delivered + This method awaits until all outstanding produce requests are completed + or the timeout is reached, unless the timeout is set to 0 (non-blocking). + Args: + timeout: Maximum time to wait in seconds: + - -1 = wait indefinitely (default) + - 0 = non-blocking, return immediately + - >0 = wait up to timeout seconds + Returns: + Number of messages still in queue after flush attempt + """ + return await _common.async_call(self._executor, self._producer.flush, timeout) + def _handle_partial_failures( self, batch_messages: List[Dict[str, Any]] diff --git a/src/confluent_kafka/serialization/__init__.py b/src/confluent_kafka/serialization/__init__.py index d90817190..315ac0e99 100644 --- a/src/confluent_kafka/serialization/__init__.py +++ b/src/confluent_kafka/serialization/__init__.py @@ -171,7 +171,7 @@ class Deserializer(object): - unicode(encoding) """ - __slots__ = [] + __slots__: List[str] = [] def __call__(self, value: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Any: """ From 485532f55a5afe3f024673a8169726367e80f30f Mon Sep 17 00:00:00 2001 From: Naxin Date: Thu, 23 Oct 2025 16:15:31 -0400 Subject: [PATCH 22/31] fix many things --- .../schema_registry/__init__.py | 40 ++++++++-- .../schema_registry/_async/avro.py | 43 ++++++----- .../schema_registry/_async/json_schema.py | 46 ++++++++---- .../_async/mock_schema_registry_client.py | 8 +- .../schema_registry/_async/protobuf.py | 32 +++++--- .../_async/schema_registry_client.py | 75 ++++++++++--------- .../schema_registry/_async/serde.py | 2 +- .../schema_registry/_sync/avro.py | 45 ++++++----- .../schema_registry/_sync/json_schema.py | 48 +++++++----- .../_sync/mock_schema_registry_client.py | 8 +- .../schema_registry/_sync/protobuf.py | 36 +++++---- .../_sync/schema_registry_client.py | 75 ++++++++++--------- .../schema_registry/_sync/serde.py | 2 +- .../common/schema_registry_client.py | 4 +- tools/unasync.py | 34 ++++++++- 15 files changed, 309 insertions(+), 189 deletions(-) diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index 582cda859..2055d2c05 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -72,58 +72,82 @@ ] -def topic_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: +def topic_subject_name_strategy(ctx: Optional[SerializationContext], record_name: Optional[str]) -> Optional[str]: """ Constructs a subject name in the form of {topic}-key|value. Args: ctx (SerializationContext): Metadata pertaining to the serialization - operation. + operation. **Required** - will raise ValueError if None. record_name (Optional[str]): Record name. + Raises: + ValueError: If ctx is None. + """ + if ctx is None: + raise ValueError( + "SerializationContext is required for topic_subject_name_strategy. " + "Either provide a SerializationContext or use record_subject_name_strategy." + ) return ctx.topic + "-" + ctx.field -def topic_record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: +def topic_record_subject_name_strategy(ctx: Optional[SerializationContext], record_name: Optional[str]) -> Optional[str]: """ Constructs a subject name in the form of {topic}-{record_name}. Args: ctx (SerializationContext): Metadata pertaining to the serialization - operation. + operation. **Required** - will raise ValueError if None. record_name (Optional[str]): Record name. + Raises: + ValueError: If ctx is None. + """ + if ctx is None: + raise ValueError( + "SerializationContext is required for topic_record_subject_name_strategy. " + "Either provide a SerializationContext or use record_subject_name_strategy." + ) return ctx.topic + "-" + record_name if record_name is not None else None -def record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: +def record_subject_name_strategy(ctx: Optional[SerializationContext], record_name: Optional[str]) -> Optional[str]: """ Constructs a subject name in the form of {record_name}. Args: ctx (SerializationContext): Metadata pertaining to the serialization - operation. + operation. **Not used** by this strategy. record_name (Optional[str]): Record name. + Note: + This strategy does not require SerializationContext and can be used + when ctx is None. + """ return record_name if record_name is not None else None -def reference_subject_name_strategy(ctx, schema_ref: SchemaReference) -> Optional[str]: +def reference_subject_name_strategy(ctx: Optional[SerializationContext], schema_ref: SchemaReference) -> Optional[str]: """ Constructs a subject reference name in the form of {reference name}. Args: ctx (SerializationContext): Metadata pertaining to the serialization - operation. + operation. **Not used** by this strategy. schema_ref (SchemaReference): SchemaReference instance. + Note: + This strategy does not require SerializationContext and can be used + when ctx is None. + """ return schema_ref.name if schema_ref is not None else None diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index d700d8dd6..1f45f9e7d 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -57,13 +57,19 @@ async def _resolve_named_schema( named_schemas = {} if schema.references is not None: for ref in schema.references: - # References in registered schemas are validated by server to be complete - referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) # type: ignore[arg-type] - ref_named_schemas = await _resolve_named_schema(referenced_schema.schema, schema_registry_client) # type: ignore[arg-type] + if ref.subject is None or ref.version is None: + raise ValueError("Subject or version cannot be None") + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) + ref_named_schemas = await _resolve_named_schema(referenced_schema.schema, schema_registry_client) + if referenced_schema.schema.schema_str is None: + raise ValueError("Schema string cannot be None") + parsed_schema = parse_schema_with_repo( - referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) # type: ignore[union-attr,arg-type] + referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) named_schemas.update(ref_named_schemas) - named_schemas[ref.name] = parsed_schema # type: ignore[index] + if ref.name is None: + raise ValueError("Name cannot be None") + named_schemas[ref.name] = parsed_schema return named_schemas @@ -323,7 +329,7 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = await self._get_reader_schema(subject) if subject else None # type: ignore[arg-type] + latest_schema = await self._get_reader_schema(subject) if subject else None if latest_schema is not None: self._schema_id = SchemaId(AVRO_TYPE, latest_schema.schema_id, latest_schema.guid) elif subject is not None and subject not in self._known_subjects: @@ -343,15 +349,17 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N self._known_subjects.add(subject) if self._to_dict is not None: - value = self._to_dict(obj, ctx) # type: ignore[arg-type] + if ctx is None: + raise ValueError("SerializationContext cannot be None") + value = self._to_dict(obj, ctx) else: value = obj # type: ignore[assignment] - if latest_schema is not None: - parsed_schema = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + if latest_schema is not None and ctx is not None and subject is not None: + parsed_schema = await self._get_parsed_schema(latest_schema.schema) def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, # type: ignore[arg-type] + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, latest_schema.schema, value, get_inline_tags(parsed_schema), field_transformer) else: @@ -521,7 +529,7 @@ async def __init_impl( if schema: self._reader_schema = await self._get_parsed_schema(self._schema) # type: ignore[arg-type] else: - self._reader_schema = None # type: ignore[assignment] + self._reader_schema = None # type: ignore[assignment] if from_dict is not None and not callable(from_dict): raise ValueError("from_dict must be callable with the signature " @@ -579,8 +587,7 @@ async def __deserialize( payload = self._schema_id_deserializer(data, ctx, schema_id) writer_schema_raw = await self._get_writer_schema(schema_id, subject) - writer_schema = await self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] - + writer_schema = await self._get_parsed_schema(writer_schema_raw) if subject is None: subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None # type: ignore[union-attr] if subject is not None: @@ -594,9 +601,9 @@ async def __deserialize( payload = io.BytesIO(payload) if latest_schema is not None and subject is not None: - migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] + migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema - reader_schema = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + reader_schema = await self._get_parsed_schema(latest_schema.schema) elif self._schema is not None: migrations = None reader_schema_raw = self._schema @@ -621,12 +628,14 @@ async def __deserialize( def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) if ctx is not None and subject is not None: - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, # type: ignore[arg-type] + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, obj_dict, get_inline_tags(reader_schema), field_transformer) if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) # type: ignore[arg-type] + if ctx is None: + raise ValueError("SerializationContext cannot be None") + return self._from_dict(obj_dict, ctx) return obj_dict diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py index 66378f9ee..448add6b7 100644 --- a/src/confluent_kafka/schema_registry/_async/json_schema.py +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -64,12 +64,19 @@ async def _resolve_named_schema( ref_registry = Registry(retrieve=_retrieve_via_httpx) # type: ignore[call-arg] if schema.references is not None: for ref in schema.references: - referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) # type: ignore[arg-type] - ref_registry = await _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) # type: ignore[arg-type] - referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) # type: ignore[union-attr,arg-type] + if ref.subject is None or ref.version is None: + raise ValueError("Subject or version cannot be None") + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) + ref_registry = await _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) + if referenced_schema.schema.schema_str is None: + raise ValueError("Schema string cannot be None") + + referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) resource = Resource.from_contents( referenced_schema_dict, default_specification=DEFAULT_SPEC) - ref_registry = ref_registry.with_resource(ref.name, resource) # type: ignore[arg-type] + if ref.name is None: + raise ValueError("Name cannot be None") + ref_registry = ref_registry.with_resource(ref.name, resource) return ref_registry @@ -329,7 +336,7 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = await self._get_reader_schema(subject) if subject else None # type: ignore[arg-type] + latest_schema = await self._get_reader_schema(subject) if subject else None if latest_schema is not None: self._schema_id = SchemaId(JSON_TYPE, latest_schema.schema_id, latest_schema.guid) elif subject is not None and subject not in self._known_subjects: @@ -349,22 +356,26 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N self._known_subjects.add(subject) if self._to_dict is not None: - value = self._to_dict(obj, ctx) # type: ignore[arg-type] + if ctx is None: + raise ValueError("SerializationContext cannot be None") + value = self._to_dict(obj, ctx) else: value = obj # type: ignore[assignment] + schema: Optional[Schema] = None if latest_schema is not None: schema = latest_schema.schema - parsed_schema, ref_registry = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + parsed_schema, ref_registry = await self._get_parsed_schema(latest_schema.schema) if ref_registry is not None: root_resource = Resource.from_contents( parsed_schema, default_specification=DEFAULT_SPEC) ref_resolver = ref_registry.resolver_with_root(root_resource) def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, # type: ignore[arg-type] - latest_schema.schema, value, None, - field_transformer) + if ctx is not None and subject is not None: + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, None, + field_transformer) else: schema = self._schema parsed_schema, ref_registry = self._parsed_schema, self._ref_registry @@ -609,7 +620,7 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization if self._registry is not None: writer_schema_raw = await self._get_writer_schema(schema_id, subject) - writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] + writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw) if subject is None and isinstance(writer_schema, dict): subject = self._subject_name_func(ctx, writer_schema.get("title")) if subject is not None: @@ -628,10 +639,11 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization # JSON documents are self-describing; no need to query schema obj_dict = self._json_decode(payload.read()) - if latest_schema is not None and subject is not None: - migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] + reader_schema_raw: Optional[Schema] = None + if latest_schema is not None and subject is not None and writer_schema_raw is not None: + migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema - reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema) elif self._schema is not None: migrations = None reader_schema_raw = self._schema @@ -655,7 +667,7 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 if ctx is not None and subject is not None: obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, obj_dict, None, - field_transformer) # type: ignore[arg-type] + field_transformer) if self._validate and reader_schema_raw is not None and reader_schema is not None and reader_ref_registry is not None: try: @@ -665,7 +677,9 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 raise SerializationError(ve.message) if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) # type: ignore[arg-type,return-value] + if ctx is None: + raise ValueError("SerializationContext cannot be None") + return self._from_dict(obj_dict, ctx) # type: ignore[return-value] return obj_dict diff --git a/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py index ddfe68805..5c8bffa1f 100644 --- a/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py @@ -18,7 +18,7 @@ import uuid from collections import defaultdict from threading import Lock -from typing import List, Dict, Optional, Union +from typing import List, Dict, Optional, Union, Literal from .schema_registry_client import AsyncSchemaRegistryClient from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig @@ -73,7 +73,7 @@ def get_registered_schema_by_schema( return rs return None - def get_version(self, subject_name: str, version: int) -> Optional[RegisteredSchema]: + def get_version(self, subject_name: str, version: Union[int, str]) -> Optional[RegisteredSchema]: with self.lock: if subject_name in self.subject_schemas: for rs in self.subject_schemas[subject_name]: @@ -239,13 +239,13 @@ async def get_latest_with_metadata( raise SchemaRegistryError(404, 40400, "Schema Not Found") async def get_version( - self, subject_name: str, version: Union[int, str] = "latest", + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", deleted: bool = False, fmt: Optional[str] = None ) -> 'RegisteredSchema': if version == "latest": registered_schema = self._store.get_latest_version(subject_name) else: - registered_schema = self._store.get_version(subject_name, version) # type: ignore[arg-type] + registered_schema = self._store.get_version(subject_name, version) if registered_schema is not None: return registered_schema diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index 91589822c..73146f167 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -67,13 +67,20 @@ async def _resolve_named_schema( visited = set() if schema.references is not None: for ref in schema.references: - # References in registered schemas are validated by server to be complete - if _is_builtin(ref.name) or ref.name in visited: # type: ignore[arg-type] + if ref.name is None: + raise ValueError("Name cannot be None") + + if _is_builtin(ref.name) or ref.name in visited: continue - visited.add(ref.name) # type: ignore[arg-type] - referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') # type: ignore[arg-type] - await _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) # type: ignore[arg-type] - file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) # type: ignore[arg-type,union-attr] + visited.add(ref.name) + + if ref.subject is None or ref.version is None: + raise ValueError("Subject or version cannot be None") + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') + if referenced_schema.schema.schema_str is None: + raise ValueError("Schema string cannot be None") + await _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) + file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) pool.Add(file_descriptor_proto) @@ -426,7 +433,7 @@ async def __serialize(self, message: Message, ctx: Optional[SerializationContext self._known_subjects.add(subject) if latest_schema is not None: - fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) fd = pool.FindFileByName(fd_proto.name) desc = fd.message_types_by_name[message.DESCRIPTOR.name] def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 @@ -614,7 +621,7 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization if self._registry is not None: writer_schema_raw = await self._get_writer_schema(schema_id, subject, fmt='serialized') - fd_proto, pool = await self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] + fd_proto, pool = await self._get_parsed_schema(writer_schema_raw) writer_schema = pool.FindFileByName(fd_proto.name) writer_desc = self._get_message_desc(pool, writer_schema, msg_index) # type: ignore[arg-type] if subject is None: @@ -632,10 +639,11 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization if isinstance(payload, bytes): payload = io.BytesIO(payload) - if latest_schema is not None and subject is not None: - migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] + reader_schema_raw: Optional[Schema] = None + if latest_schema is not None and subject is not None and writer_schema_raw is not None: + migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema - fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) reader_schema = pool.FindFileByName(fd_proto.name) else: migrations = None @@ -670,7 +678,7 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_desc, message, field_transform)) if ctx is not None and subject is not None: - msg = self._execute_rules(ctx, subject, RuleMode.READ, None, # type: ignore[arg-type] + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, msg, None, field_transformer) return msg diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index 44200ae0e..3d7a08b70 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -26,7 +26,7 @@ from urllib.parse import unquote, urlparse import httpx -from typing import List, Dict, Optional, Union, Any, Callable +from typing import List, Dict, Optional, Union, Any, Callable, Literal from cachetools import TTLCache, LRUCache from httpx import Response @@ -675,7 +675,7 @@ async def register_schema_full_response( subject=subject_name, version=None, schema=result[1] - ) # type: ignore[arg-type] + ) request = schema.to_dict() @@ -683,20 +683,20 @@ async def register_schema_full_response( 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), body=request) - result = RegisteredSchema.from_dict(response) # type: ignore[assignment] + response_schema = RegisteredSchema.from_dict(response) registered_schema = RegisteredSchema( - schema_id=result.schema_id, # type: ignore[union-attr] - guid=result.guid, # type: ignore[union-attr] - subject=result.subject or subject_name, # type: ignore[union-attr] - version=result.version, # type: ignore[union-attr] - schema=result.schema, # type: ignore[union-attr] + schema_id=response_schema.schema_id, + guid=response_schema.guid, + subject=response_schema.subject or subject_name, + version=response_schema.version, + schema=response_schema.schema, ) # The registered schema may not be fully populated - s = registered_schema.schema if registered_schema.schema.schema_str is not None else schema # type: ignore[union-attr] + s = registered_schema.schema if registered_schema.schema.schema_str is not None else schema self._cache.set_schema(subject_name, registered_schema.schema_id, - registered_schema.guid, s) # type: ignore[arg-type] + registered_schema.guid, s) return registered_schema @@ -725,7 +725,8 @@ async def get_schema( `GET Schema API Reference `_ """ # noqa: E501 - result = self._cache.get_schema_by_id(subject_name, schema_id) # type: ignore[arg-type] + if subject_name is not None: + result = self._cache.get_schema_by_id(subject_name, schema_id) if result is not None: return result[1] @@ -741,9 +742,9 @@ async def get_schema( registered_schema = RegisteredSchema.from_dict(response) self._cache.set_schema(subject_name, schema_id, - registered_schema.guid, registered_schema.schema) # type: ignore[arg-type] + registered_schema.guid, registered_schema.schema) - return registered_schema.schema # type: ignore[return-value] + return registered_schema.schema async def get_schema_by_guid( self, guid: str, fmt: Optional[str] = None @@ -779,9 +780,9 @@ async def get_schema_by_guid( registered_schema = RegisteredSchema.from_dict(response) self._cache.set_schema(None, registered_schema.schema_id, - registered_schema.guid, registered_schema.schema) # type: ignore[arg-type] + registered_schema.guid, registered_schema.schema) - return registered_schema.schema # type: ignore[return-value] + return registered_schema.schema async def get_schema_types(self) -> List[str]: """ @@ -819,9 +820,9 @@ async def get_subjects_by_schema_id( Raises: SchemaRegistryError: if subjects can't be found """ - query = {'offset': offset, 'limit': limit} + query: dict[str, Any] = {'offset': offset, 'limit': limit} if subject_name is not None: - query['subject'] = subject_name # type: ignore[assignment] + query['subject'] = subject_name if deleted: query['deleted'] = deleted return await self._rest_client.get('schemas/ids/{}/subjects'.format(schema_id), query) @@ -852,9 +853,8 @@ async def get_schema_versions( `GET Schema Versions API Reference `_ """ # noqa: E501 - query = {'offset': offset, 'limit': limit} - if subject_name is not None: - query['subject'] = subject_name # type: ignore[assignment] + query: dict[str, Any] = {'offset': offset, 'limit': limit} + if subject_name is not None: query['subject'] = subject_name if deleted: query['deleted'] = deleted response = await self._rest_client.get('schemas/ids/{}/versions'.format(schema_id), query) @@ -943,9 +943,9 @@ async def get_subjects( `GET subjects API Reference `_ """ # noqa: E501 - query = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} + query: dict[str, Any] = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} if subject_prefix is not None: - query['subject'] = subject_prefix # type: ignore[assignment] + query['subject'] = subject_prefix return await self._rest_client.get('subjects', query) async def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: @@ -1040,13 +1040,13 @@ async def get_latest_with_metadata( if registered_schema is not None: return registered_schema - query = {'deleted': deleted} + query: dict[str, Any] = {'deleted': deleted} if fmt is not None: - query['format'] = fmt # type: ignore[assignment] + query['format'] = fmt keys = metadata.keys() if keys: - query['key'] = [_urlencode(key) for key in keys] # type: ignore[assignment] - query['value'] = [_urlencode(metadata[key]) for key in keys] # type: ignore[assignment] + query['key'] = [_urlencode(key) for key in keys] + query['value'] = [_urlencode(metadata[key]) for key in keys] response = await self._rest_client.get( 'subjects/{}/metadata'.format(_urlencode(subject_name)), query @@ -1059,7 +1059,7 @@ async def get_latest_with_metadata( return registered_schema async def get_version( - self, subject_name: str, version: Union[int, str] = "latest", + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", deleted: bool = False, fmt: Optional[str] = None ) -> 'RegisteredSchema': """ @@ -1067,7 +1067,7 @@ async def get_version( Args: subject_name (str): Subject name. - version (Union[int, str]): Version of the schema or string "latest". Defaults to latest version. + version (Union[int, Literal["latest"]]): Version of the schema or string "latest". Defaults to latest version. deleted (bool): Whether to include deleted schemas. fmt (str): Format of the schema. @@ -1081,23 +1081,24 @@ async def get_version( `GET Subject Versions API Reference `_ """ # noqa: E501 - registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) # type: ignore[arg-type] - if registered_schema is not None: - return registered_schema + if version != "latest": + registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) + if registered_schema is not None: + return registered_schema - query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + query: dict[str, Any] = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} response = await self._rest_client.get( 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query ) registered_schema = RegisteredSchema.from_dict(response) - self._cache.set_registered_schema(registered_schema.schema, registered_schema) # type: ignore[arg-type] + self._cache.set_registered_schema(registered_schema.schema, registered_schema) return registered_schema async def get_referenced_by( - self, subject_name: str, version: Union[int, str] = "latest", + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", offset: int = 0, limit: int = -1 ) -> List[int]: """ @@ -1105,7 +1106,7 @@ async def get_referenced_by( Args: subject_name (str): Subject name - version (int or str): Version number or "latest" + version (Union[int, Literal["latest"]]): Version number or "latest" offset (int): Pagination offset for results. limit (int): Pagination size for results. Ignored if negative. @@ -1119,7 +1120,7 @@ async def get_referenced_by( `GET Subject Versions (ReferenceBy) API Reference `_ """ # noqa: E501 - query = {'offset': offset, 'limit': limit} + query: dict[str, Any] = {'offset': offset, 'limit': limit} return await self._rest_client.get('subjects/{}/versions/{}/referencedby'.format( _urlencode(subject_name), version), query) @@ -1147,7 +1148,7 @@ async def get_versions( `GET Subject All Versions API Reference `_ """ # noqa: E501 - query = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} + query: dict[str, Any] = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} return await self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name)), query) async def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py index ef6baab2d..bb1be66e1 100644 --- a/src/confluent_kafka/schema_registry/_async/serde.py +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -49,7 +49,7 @@ class AsyncBaseSerde(object): _use_latest_with_metadata: Optional[Dict[str, str]] _registry: Any # AsyncSchemaRegistryClient _rule_registry: Any # RuleRegistry - _subject_name_func: Callable[[Any, Optional[str]], Optional[str]] + _subject_name_func: Callable[[Optional['SerializationContext'], Optional[str]], Optional[str]] _field_transformer: Optional[FieldTransformer] async def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 1a06a5b05..a54847a6f 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -57,13 +57,19 @@ def _resolve_named_schema( named_schemas = {} if schema.references is not None: for ref in schema.references: - # References in registered schemas are validated by server to be complete - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) # type: ignore[arg-type] - ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client) # type: ignore[arg-type] + if ref.subject is None or ref.version is None: + raise ValueError("Subject or version cannot be None") + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) + ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client) + if referenced_schema.schema.schema_str is None: + raise ValueError("Schema string cannot be None") + parsed_schema = parse_schema_with_repo( - referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) # type: ignore[union-attr,arg-type] + referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) named_schemas.update(ref_named_schemas) - named_schemas[ref.name] = parsed_schema # type: ignore[index] + if ref.name is None: + raise ValueError("Name cannot be None") + named_schemas[ref.name] = parsed_schema return named_schemas @@ -296,7 +302,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__serialize(obj, ctx) def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -323,7 +329,7 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = self._get_reader_schema(subject) if subject else None # type: ignore[arg-type] + latest_schema = self._get_reader_schema(subject) if subject else None if latest_schema is not None: self._schema_id = SchemaId(AVRO_TYPE, latest_schema.schema_id, latest_schema.guid) elif subject is not None and subject not in self._known_subjects: @@ -343,15 +349,17 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - self._known_subjects.add(subject) if self._to_dict is not None: - value = self._to_dict(obj, ctx) # type: ignore[arg-type] + if ctx is None: + raise ValueError("SerializationContext cannot be None") + value = self._to_dict(obj, ctx) else: value = obj # type: ignore[assignment] - if latest_schema is not None: - parsed_schema = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + if latest_schema is not None and ctx is not None and subject is not None: + parsed_schema = self._get_parsed_schema(latest_schema.schema) def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, # type: ignore[arg-type] + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, latest_schema.schema, value, get_inline_tags(parsed_schema), field_transformer) else: @@ -521,7 +529,7 @@ def __init_impl( if schema: self._reader_schema = self._get_parsed_schema(self._schema) # type: ignore[arg-type] else: - self._reader_schema = None # type: ignore[assignment] + self._reader_schema = None # type: ignore[assignment] if from_dict is not None and not callable(from_dict): raise ValueError("from_dict must be callable with the signature " @@ -579,8 +587,7 @@ def __deserialize( payload = self._schema_id_deserializer(data, ctx, schema_id) writer_schema_raw = self._get_writer_schema(schema_id, subject) - writer_schema = self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] - + writer_schema = self._get_parsed_schema(writer_schema_raw) if subject is None: subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None # type: ignore[union-attr] if subject is not None: @@ -594,9 +601,9 @@ def __deserialize( payload = io.BytesIO(payload) if latest_schema is not None and subject is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema - reader_schema = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + reader_schema = self._get_parsed_schema(latest_schema.schema) elif self._schema is not None: migrations = None reader_schema_raw = self._schema @@ -621,12 +628,14 @@ def __deserialize( def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) if ctx is not None and subject is not None: - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, # type: ignore[arg-type] + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, obj_dict, get_inline_tags(reader_schema), field_transformer) if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) # type: ignore[arg-type] + if ctx is None: + raise ValueError("SerializationContext cannot be None") + return self._from_dict(obj_dict, ctx) return obj_dict diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 08f147876..a47c5bcd2 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -64,12 +64,19 @@ def _resolve_named_schema( ref_registry = Registry(retrieve=_retrieve_via_httpx) # type: ignore[call-arg] if schema.references is not None: for ref in schema.references: - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) # type: ignore[arg-type] - ref_registry = _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) # type: ignore[arg-type] - referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) # type: ignore[union-attr,arg-type] + if ref.subject is None or ref.version is None: + raise ValueError("Subject or version cannot be None") + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) + ref_registry = _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) + if referenced_schema.schema.schema_str is None: + raise ValueError("Schema string cannot be None") + + referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) resource = Resource.from_contents( referenced_schema_dict, default_specification=DEFAULT_SPEC) - ref_registry = ref_registry.with_resource(ref.name, resource) # type: ignore[arg-type] + if ref.name is None: + raise ValueError("Name cannot be None") + ref_registry = ref_registry.with_resource(ref.name, resource) return ref_registry @@ -303,7 +310,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__serialize(obj, ctx) def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -329,7 +336,7 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = self._get_reader_schema(subject) if subject else None # type: ignore[arg-type] + latest_schema = self._get_reader_schema(subject) if subject else None if latest_schema is not None: self._schema_id = SchemaId(JSON_TYPE, latest_schema.schema_id, latest_schema.guid) elif subject is not None and subject not in self._known_subjects: @@ -349,22 +356,26 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - self._known_subjects.add(subject) if self._to_dict is not None: - value = self._to_dict(obj, ctx) # type: ignore[arg-type] + if ctx is None: + raise ValueError("SerializationContext cannot be None") + value = self._to_dict(obj, ctx) else: value = obj # type: ignore[assignment] + schema: Optional[Schema] = None if latest_schema is not None: schema = latest_schema.schema - parsed_schema, ref_registry = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + parsed_schema, ref_registry = self._get_parsed_schema(latest_schema.schema) if ref_registry is not None: root_resource = Resource.from_contents( parsed_schema, default_specification=DEFAULT_SPEC) ref_resolver = ref_registry.resolver_with_root(root_resource) def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, # type: ignore[arg-type] - latest_schema.schema, value, None, - field_transformer) + if ctx is not None and subject is not None: + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, None, + field_transformer) else: schema = self._schema parsed_schema, ref_registry = self._parsed_schema, self._ref_registry @@ -609,7 +620,7 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex if self._registry is not None: writer_schema_raw = self._get_writer_schema(schema_id, subject) - writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] + writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw) if subject is None and isinstance(writer_schema, dict): subject = self._subject_name_func(ctx, writer_schema.get("title")) if subject is not None: @@ -628,10 +639,11 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex # JSON documents are self-describing; no need to query schema obj_dict = self._json_decode(payload.read()) - if latest_schema is not None and subject is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] + reader_schema_raw: Optional[Schema] = None + if latest_schema is not None and subject is not None and writer_schema_raw is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema - reader_schema, reader_ref_registry = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + reader_schema, reader_ref_registry = self._get_parsed_schema(latest_schema.schema) elif self._schema is not None: migrations = None reader_schema_raw = self._schema @@ -655,7 +667,7 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 if ctx is not None and subject is not None: obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, obj_dict, None, - field_transformer) # type: ignore[arg-type] + field_transformer) if self._validate and reader_schema_raw is not None and reader_schema is not None and reader_ref_registry is not None: try: @@ -665,7 +677,9 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 raise SerializationError(ve.message) if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) # type: ignore[arg-type,return-value] + if ctx is None: + raise ValueError("SerializationContext cannot be None") + return self._from_dict(obj_dict, ctx) # type: ignore[return-value] return obj_dict diff --git a/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py index 09340ac33..3ca77fbf3 100644 --- a/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py @@ -18,7 +18,7 @@ import uuid from collections import defaultdict from threading import Lock -from typing import List, Dict, Optional, Union +from typing import List, Dict, Optional, Union, Literal from .schema_registry_client import SchemaRegistryClient from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig @@ -73,7 +73,7 @@ def get_registered_schema_by_schema( return rs return None - def get_version(self, subject_name: str, version: int) -> Optional[RegisteredSchema]: + def get_version(self, subject_name: str, version: Union[int, str]) -> Optional[RegisteredSchema]: with self.lock: if subject_name in self.subject_schemas: for rs in self.subject_schemas[subject_name]: @@ -239,13 +239,13 @@ def get_latest_with_metadata( raise SchemaRegistryError(404, 40400, "Schema Not Found") def get_version( - self, subject_name: str, version: Union[int, str] = "latest", + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", deleted: bool = False, fmt: Optional[str] = None ) -> 'RegisteredSchema': if version == "latest": registered_schema = self._store.get_latest_version(subject_name) else: - registered_schema = self._store.get_version(subject_name, version) # type: ignore[arg-type] + registered_schema = self._store.get_version(subject_name, version) if registered_schema is not None: return registered_schema diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index f13ce7555..05651b739 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -67,13 +67,20 @@ def _resolve_named_schema( visited = set() if schema.references is not None: for ref in schema.references: - # References in registered schemas are validated by server to be complete - if _is_builtin(ref.name) or ref.name in visited: # type: ignore[arg-type] + if ref.name is None: + raise ValueError("Name cannot be None") + + if _is_builtin(ref.name) or ref.name in visited: continue - visited.add(ref.name) # type: ignore[arg-type] - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') # type: ignore[arg-type] - _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) # type: ignore[arg-type] - file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) # type: ignore[arg-type,union-attr] + visited.add(ref.name) + + if ref.subject is None or ref.version is None: + raise ValueError("Subject or version cannot be None") + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') + if referenced_schema.schema.schema_str is None: + raise ValueError("Schema string cannot be None") + _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) + file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) pool.Add(file_descriptor_proto) @@ -369,7 +376,7 @@ def _resolve_dependencies( reference.version)) return schema_refs - def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__serialize(message, ctx) def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -426,7 +433,7 @@ def __serialize(self, message: Message, ctx: Optional[SerializationContext] = No self._known_subjects.add(subject) if latest_schema is not None: - fd_proto, pool = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + fd_proto, pool = self._get_parsed_schema(latest_schema.schema) fd = pool.FindFileByName(fd_proto.name) desc = fd.message_types_by_name[message.DESCRIPTOR.name] def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 @@ -580,7 +587,7 @@ def __init_impl( def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[object, None]: return self.__deserialize(data, ctx) - def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[object, None]: + def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a serialized protobuf message with Confluent Schema Registry framing. @@ -614,7 +621,7 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex if self._registry is not None: writer_schema_raw = self._get_writer_schema(schema_id, subject, fmt='serialized') - fd_proto, pool = self._get_parsed_schema(writer_schema_raw) # type: ignore[arg-type] + fd_proto, pool = self._get_parsed_schema(writer_schema_raw) writer_schema = pool.FindFileByName(fd_proto.name) writer_desc = self._get_message_desc(pool, writer_schema, msg_index) # type: ignore[arg-type] if subject is None: @@ -632,10 +639,11 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex if isinstance(payload, bytes): payload = io.BytesIO(payload) - if latest_schema is not None and subject is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) # type: ignore[arg-type] + reader_schema_raw: Optional[Schema] = None + if latest_schema is not None and subject is not None and writer_schema_raw is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema - fd_proto, pool = self._get_parsed_schema(latest_schema.schema) # type: ignore[arg-type] + fd_proto, pool = self._get_parsed_schema(latest_schema.schema) reader_schema = pool.FindFileByName(fd_proto.name) else: migrations = None @@ -670,7 +678,7 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_desc, message, field_transform)) if ctx is not None and subject is not None: - msg = self._execute_rules(ctx, subject, RuleMode.READ, None, # type: ignore[arg-type] + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, msg, None, field_transformer) return msg diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index eeef46c52..a8db636a3 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -26,7 +26,7 @@ from urllib.parse import unquote, urlparse import httpx -from typing import List, Dict, Optional, Union, Any, Callable +from typing import List, Dict, Optional, Union, Any, Callable, Literal from cachetools import TTLCache, LRUCache from httpx import Response @@ -675,7 +675,7 @@ def register_schema_full_response( subject=subject_name, version=None, schema=result[1] - ) # type: ignore[arg-type] + ) request = schema.to_dict() @@ -683,20 +683,20 @@ def register_schema_full_response( 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), body=request) - result = RegisteredSchema.from_dict(response) # type: ignore[assignment] + response_schema = RegisteredSchema.from_dict(response) registered_schema = RegisteredSchema( - schema_id=result.schema_id, # type: ignore[union-attr] - guid=result.guid, # type: ignore[union-attr] - subject=result.subject or subject_name, # type: ignore[union-attr] - version=result.version, # type: ignore[union-attr] - schema=result.schema, # type: ignore[union-attr] + schema_id=response_schema.schema_id, + guid=response_schema.guid, + subject=response_schema.subject or subject_name, + version=response_schema.version, + schema=response_schema.schema, ) # The registered schema may not be fully populated - s = registered_schema.schema if registered_schema.schema.schema_str is not None else schema # type: ignore[union-attr] + s = registered_schema.schema if registered_schema.schema.schema_str is not None else schema self._cache.set_schema(subject_name, registered_schema.schema_id, - registered_schema.guid, s) # type: ignore[arg-type] + registered_schema.guid, s) return registered_schema @@ -725,7 +725,8 @@ def get_schema( `GET Schema API Reference `_ """ # noqa: E501 - result = self._cache.get_schema_by_id(subject_name, schema_id) # type: ignore[arg-type] + if subject_name is not None: + result = self._cache.get_schema_by_id(subject_name, schema_id) if result is not None: return result[1] @@ -741,9 +742,9 @@ def get_schema( registered_schema = RegisteredSchema.from_dict(response) self._cache.set_schema(subject_name, schema_id, - registered_schema.guid, registered_schema.schema) # type: ignore[arg-type] + registered_schema.guid, registered_schema.schema) - return registered_schema.schema # type: ignore[return-value] + return registered_schema.schema def get_schema_by_guid( self, guid: str, fmt: Optional[str] = None @@ -779,9 +780,9 @@ def get_schema_by_guid( registered_schema = RegisteredSchema.from_dict(response) self._cache.set_schema(None, registered_schema.schema_id, - registered_schema.guid, registered_schema.schema) # type: ignore[arg-type] + registered_schema.guid, registered_schema.schema) - return registered_schema.schema # type: ignore[return-value] + return registered_schema.schema def get_schema_types(self) -> List[str]: """ @@ -819,9 +820,9 @@ def get_subjects_by_schema_id( Raises: SchemaRegistryError: if subjects can't be found """ - query = {'offset': offset, 'limit': limit} + query: dict[str, Any] = {'offset': offset, 'limit': limit} if subject_name is not None: - query['subject'] = subject_name # type: ignore[assignment] + query['subject'] = subject_name if deleted: query['deleted'] = deleted return self._rest_client.get('schemas/ids/{}/subjects'.format(schema_id), query) @@ -852,9 +853,8 @@ def get_schema_versions( `GET Schema Versions API Reference `_ """ # noqa: E501 - query = {'offset': offset, 'limit': limit} - if subject_name is not None: - query['subject'] = subject_name # type: ignore[assignment] + query: dict[str, Any] = {'offset': offset, 'limit': limit} + if subject_name is not None: query['subject'] = subject_name if deleted: query['deleted'] = deleted response = self._rest_client.get('schemas/ids/{}/versions'.format(schema_id), query) @@ -943,9 +943,9 @@ def get_subjects( `GET subjects API Reference `_ """ # noqa: E501 - query = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} + query: dict[str, Any] = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} if subject_prefix is not None: - query['subject'] = subject_prefix # type: ignore[assignment] + query['subject'] = subject_prefix return self._rest_client.get('subjects', query) def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: @@ -1040,13 +1040,13 @@ def get_latest_with_metadata( if registered_schema is not None: return registered_schema - query = {'deleted': deleted} + query: dict[str, Any] = {'deleted': deleted} if fmt is not None: - query['format'] = fmt # type: ignore[assignment] + query['format'] = fmt keys = metadata.keys() if keys: - query['key'] = [_urlencode(key) for key in keys] # type: ignore[assignment] - query['value'] = [_urlencode(metadata[key]) for key in keys] # type: ignore[assignment] + query['key'] = [_urlencode(key) for key in keys] + query['value'] = [_urlencode(metadata[key]) for key in keys] response = self._rest_client.get( 'subjects/{}/metadata'.format(_urlencode(subject_name)), query @@ -1059,7 +1059,7 @@ def get_latest_with_metadata( return registered_schema def get_version( - self, subject_name: str, version: Union[int, str] = "latest", + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", deleted: bool = False, fmt: Optional[str] = None ) -> 'RegisteredSchema': """ @@ -1067,7 +1067,7 @@ def get_version( Args: subject_name (str): Subject name. - version (Union[int, str]): Version of the schema or string "latest". Defaults to latest version. + version (Union[int, Literal["latest"]]): Version of the schema or string "latest". Defaults to latest version. deleted (bool): Whether to include deleted schemas. fmt (str): Format of the schema. @@ -1081,23 +1081,24 @@ def get_version( `GET Subject Versions API Reference `_ """ # noqa: E501 - registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) # type: ignore[arg-type] - if registered_schema is not None: - return registered_schema + if version != "latest": + registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) + if registered_schema is not None: + return registered_schema - query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + query: dict[str, Any] = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} response = self._rest_client.get( 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query ) registered_schema = RegisteredSchema.from_dict(response) - self._cache.set_registered_schema(registered_schema.schema, registered_schema) # type: ignore[arg-type] + self._cache.set_registered_schema(registered_schema.schema, registered_schema) return registered_schema def get_referenced_by( - self, subject_name: str, version: Union[int, str] = "latest", + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", offset: int = 0, limit: int = -1 ) -> List[int]: """ @@ -1105,7 +1106,7 @@ def get_referenced_by( Args: subject_name (str): Subject name - version (int or str): Version number or "latest" + version (Union[int, Literal["latest"]]): Version number or "latest" offset (int): Pagination offset for results. limit (int): Pagination size for results. Ignored if negative. @@ -1119,7 +1120,7 @@ def get_referenced_by( `GET Subject Versions (ReferenceBy) API Reference `_ """ # noqa: E501 - query = {'offset': offset, 'limit': limit} + query: dict[str, Any] = {'offset': offset, 'limit': limit} return self._rest_client.get('subjects/{}/versions/{}/referencedby'.format( _urlencode(subject_name), version), query) @@ -1147,7 +1148,7 @@ def get_versions( `GET Subject All Versions API Reference `_ """ # noqa: E501 - query = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} + query: dict[str, Any] = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} return self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name)), query) def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py index dd58fce07..bdf4f8b02 100644 --- a/src/confluent_kafka/schema_registry/_sync/serde.py +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -49,7 +49,7 @@ class BaseSerde(object): _use_latest_with_metadata: Optional[Dict[str, str]] _registry: Any # SchemaRegistryClient _rule_registry: Any # RuleRegistry - _subject_name_func: Callable[[Any, Optional[str]], Optional[str]] + _subject_name_func: Callable[[Optional['SerializationContext'], Optional[str]], Optional[str]] _field_transformer: Optional[FieldTransformer] def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index e72237cdd..bdfe3e51d 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -22,7 +22,7 @@ from collections import defaultdict from enum import Enum from threading import Lock -from typing import List, Dict, Type, TypeVar, \ +from typing import List, Dict, Type, TypeVar, Union, \ cast, Optional, Any, Tuple __all__ = [ @@ -943,7 +943,7 @@ class RegisteredSchema: version: Optional[int] schema_id: Optional[int] guid: Optional[str] - schema: Optional[Schema] + schema: Schema def to_dict(self) -> Dict[str, Any]: schema = self.schema diff --git a/tools/unasync.py b/tools/unasync.py index feb616c59..162591551 100644 --- a/tools/unasync.py +++ b/tools/unasync.py @@ -189,5 +189,37 @@ def unasync(dir_pairs=None, check=False): '--check', action='store_true', help='Exit with non-zero status if sync directory has any differences') + parser.add_argument( + '--file', + type=str, + help='Convert a single file instead of all directories') args = parser.parse_args() - unasync(check=args.check) + + if args.file: + # Single file mode + async_file = args.file + if not os.path.exists(async_file): + print(f"Error: File {async_file} does not exist") + sys.exit(1) + + # Determine the sync file path + sync_file = None + for async_dir, sync_dir in ASYNC_TO_SYNC: + if async_file.startswith(async_dir): + sync_file = async_file.replace(async_dir, sync_dir, 1) + break + + if not sync_file: + print(f"Error: File {async_file} is not in a known async directory") + print(f"Known async directories: {[d[0] for d in ASYNC_TO_SYNC]}") + sys.exit(1) + + # Create the output directory if needed + os.makedirs(os.path.dirname(sync_file), exist_ok=True) + + print(f"Converting: {async_file} -> {sync_file}") + unasync_file(async_file, sync_file) + print("✅ Done!") + else: + # Directory mode (original behavior) + unasync(check=args.check) From 6262a73572acd0c21427bc556178d8940f7ff640 Mon Sep 17 00:00:00 2001 From: Naxin Date: Thu, 23 Oct 2025 16:25:07 -0400 Subject: [PATCH 23/31] more fixes in non sr modules --- .../avro/cached_schema_registry_client.py | 2 +- src/confluent_kafka/avro/load.py | 10 +++++----- .../schema_registry/_async/schema_registry_client.py | 4 ++-- .../schema_registry/_sync/schema_registry_client.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/confluent_kafka/avro/cached_schema_registry_client.py b/src/confluent_kafka/avro/cached_schema_registry_client.py index b0ea6c388..3478ba9fb 100644 --- a/src/confluent_kafka/avro/cached_schema_registry_client.py +++ b/src/confluent_kafka/avro/cached_schema_registry_client.py @@ -32,7 +32,7 @@ # Python 2 considers int an instance of str try: - string_type = basestring # noqa + string_type = basestring # type: ignore[name-defined] # noqa except NameError: string_type = str diff --git a/src/confluent_kafka/avro/load.py b/src/confluent_kafka/avro/load.py index 9db8660e1..d774a0513 100644 --- a/src/confluent_kafka/avro/load.py +++ b/src/confluent_kafka/avro/load.py @@ -47,11 +47,11 @@ def _hash_func(self): from avro.errors import SchemaParseException except ImportError: # avro < 1.11.0 - from avro.schema import SchemaParseException + from avro.schema import SchemaParseException # type: ignore[attr-defined,no-redef] - schema.RecordSchema.__hash__ = _hash_func - schema.PrimitiveSchema.__hash__ = _hash_func - schema.UnionSchema.__hash__ = _hash_func + schema.RecordSchema.__hash__ = _hash_func # type: ignore[method-assign] + schema.PrimitiveSchema.__hash__ = _hash_func # type: ignore[method-assign] + schema.UnionSchema.__hash__ = _hash_func # type: ignore[method-assign] except ImportError: - schema = None + schema = None # type: ignore[assignment] diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index 3d7a08b70..6ac29dfe4 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -890,12 +890,12 @@ async def lookup_schema( request = schema.to_dict() - query_params = { + query_params: dict[str, Any] = { 'normalize': normalize_schemas, 'deleted': deleted } if fmt is not None: - query_params['format'] = fmt # type: ignore[assignment] + query_params['format'] = fmt query_string = '&'.join(f"{key}={value}" for key, value in query_params.items()) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index a8db636a3..ab34e2fbe 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -890,12 +890,12 @@ def lookup_schema( request = schema.to_dict() - query_params = { + query_params: dict[str, Any] = { 'normalize': normalize_schemas, 'deleted': deleted } if fmt is not None: - query_params['format'] = fmt # type: ignore[assignment] + query_params['format'] = fmt query_string = '&'.join(f"{key}={value}" for key, value in query_params.items()) From 5ffe30124a54a65d3b9b1b8cae7a77c5ebb5c314 Mon Sep 17 00:00:00 2001 From: Naxin Date: Thu, 23 Oct 2025 19:27:16 -0400 Subject: [PATCH 24/31] type encrypt_executor.py --- .../_async/schema_registry_client.py | 145 ++++++++++-------- .../_sync/schema_registry_client.py | 136 ++++++++-------- .../common/schema_registry_client.py | 44 ++++-- .../schema_registry/rule_registry.py | 1 + .../rules/encryption/encrypt_executor.py | 62 ++++++-- 5 files changed, 232 insertions(+), 156 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index 6ac29dfe4..1320cdec9 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -28,7 +28,7 @@ import httpx from typing import List, Dict, Optional, Union, Any, Callable, Literal -from cachetools import TTLCache, LRUCache +from cachetools import Cache, TTLCache, LRUCache from httpx import Response from authlib.integrations.httpx_client import AsyncOAuth2Client @@ -40,11 +40,11 @@ ServerConfig, is_success, is_retriable, - _BearerFieldProvider, + _AsyncBearerFieldProvider, full_jitter, _SchemaCache, Schema, - _StaticFieldProvider, + _AsyncStaticFieldProvider, ) __all__ = [ @@ -78,16 +78,16 @@ def _urlencode(value: str) -> str: log = logging.getLogger(__name__) -class _AsyncCustomOAuthClient(_BearerFieldProvider): +class _AsyncCustomOAuthClient(_AsyncBearerFieldProvider): def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict): self.custom_function = custom_function self.custom_config = custom_config - async def get_bearer_fields(self) -> dict: # type: ignore[override] + async def get_bearer_fields(self) -> dict: return await self.custom_function(self.custom_config) # type: ignore[misc] -class _AsyncOAuthClient(_BearerFieldProvider): +class _AsyncOAuthClient(_AsyncBearerFieldProvider): def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): self.token = None @@ -100,7 +100,7 @@ def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoin self.retries_max_wait_ms = retries_max_wait_ms self.token_expiry_threshold = 0.8 - async def get_bearer_fields(self) -> dict: # type: ignore[override] + async def get_bearer_fields(self) -> dict: return { 'bearer.auth.token': await self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, @@ -108,15 +108,19 @@ async def get_bearer_fields(self) -> dict: # type: ignore[override] } def token_expired(self) -> bool: - expiry_window = self.token['expires_in'] * self.token_expiry_threshold # type: ignore[index] + if self.token is None: + raise ValueError("Token is not set") - return self.token['expires_at'] < time.time() + expiry_window # type: ignore[index] + expiry_window = self.token['expires_in'] * self.token_expiry_threshold + + return self.token['expires_at'] < time.time() + expiry_window async def get_access_token(self) -> str: - if not self.token or self.token_expired(): + if self.token is None or self.token_expired(): await self.generate_access_token() - - return self.token['access_token'] # type: ignore[index] + if self.token is None: + raise ValueError("Token is not set") + return self.token['access_token'] async def generate_access_token(self) -> None: for i in range(self.max_retries + 1): @@ -259,9 +263,9 @@ def __init__(self, conf: dict): + str(type(retries_max_wait_ms))) self.retries_max_wait_ms = int(retries_max_wait_ms) - self.bearer_field_provider = None logical_cluster = None identity_pool = None + self.bearer_field_provider: Optional[_AsyncBearerFieldProvider] = None self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) if self.bearer_auth_credentials_source is not None: self.auth = None @@ -281,43 +285,43 @@ def __init__(self, conf: dict): if not isinstance(identity_pool, str): raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) - if self.bearer_auth_credentials_source == 'OAUTHBEARER': - properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', - 'bearer.auth.issuer.endpoint.url'] - missing_properties = [prop for prop in properties_list if prop not in conf_copy] - if missing_properties: - raise ValueError("Missing required OAuth configuration properties: {}". - format(", ".join(missing_properties))) - - self.client_id = conf_copy.pop('bearer.auth.client.id') - if not isinstance(self.client_id, string_type): - raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) - - self.client_secret = conf_copy.pop('bearer.auth.client.secret') - if not isinstance(self.client_secret, string_type): - raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) - - self.scope = conf_copy.pop('bearer.auth.scope') - if not isinstance(self.scope, string_type): - raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) - - self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') - if not isinstance(self.token_endpoint, string_type): - raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " - + str(type(self.token_endpoint))) - - self.bearer_field_provider = _AsyncOAuthClient( - self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, # type: ignore[arg-type] - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) - elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': - if 'bearer.auth.token' not in conf_copy: - raise ValueError("Missing bearer.auth.token") - static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) # type: ignore[assignment,arg-type] - if not isinstance(static_token, string_type): - raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) + if self.bearer_auth_credentials_source == 'OAUTHBEARER': + properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', + 'bearer.auth.issuer.endpoint.url'] + missing_properties = [prop for prop in properties_list if prop not in conf_copy] + if missing_properties: + raise ValueError("Missing required OAuth configuration properties: {}". + format(", ".join(missing_properties))) + + self.client_id = conf_copy.pop('bearer.auth.client.id') + if not isinstance(self.client_id, string_type): + raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) + + self.client_secret = conf_copy.pop('bearer.auth.client.secret') + if not isinstance(self.client_secret, string_type): + raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) + + self.scope = conf_copy.pop('bearer.auth.scope') + if not isinstance(self.scope, string_type): + raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) + + self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') + if not isinstance(self.token_endpoint, string_type): + raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + + str(type(self.token_endpoint))) + + self.bearer_field_provider = _AsyncOAuthClient( + self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) + else: # STATIC_TOKEN + if 'bearer.auth.token' not in conf_copy: + raise ValueError("Missing bearer.auth.token") + static_token = conf_copy.pop('bearer.auth.token') + self.bearer_field_provider = _AsyncStaticFieldProvider(static_token, logical_cluster, identity_pool) + if not isinstance(static_token, string_type): + raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) elif self.bearer_auth_credentials_source == 'CUSTOM': custom_bearer_properties = ['bearer.auth.custom.provider.function', 'bearer.auth.custom.provider.config'] @@ -336,7 +340,7 @@ def __init__(self, conf: dict): raise TypeError("bearer.auth.custom.provider.config must be a dict, not " + str(type(custom_config))) - self.bearer_field_provider = _AsyncCustomOAuthClient(custom_function, custom_config) # type: ignore[assignment] + self.bearer_field_provider = _AsyncCustomOAuthClient(custom_function, custom_config) else: raise ValueError('Unrecognized bearer.auth.credentials.source') @@ -379,7 +383,9 @@ def __init__(self, conf: dict): ) async def handle_bearer_auth(self, headers: dict) -> None: - bearer_fields = await self.bearer_field_provider.get_bearer_fields() # type: ignore[union-attr] + if self.bearer_field_provider is None: + raise ValueError("Bearer field provider is not set") + bearer_fields = await self.bearer_field_provider.get_bearer_fields() required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] missing_fields = [] @@ -439,7 +445,7 @@ async def send_request( body_str: Optional[str] = None if body is not None: - body_str = json.dumps(body) # type: ignore[assignment] + body_str = json.dumps(body) headers = {'Content-Length': str(len(body_str)), 'Content-Type': "application/vnd.schemaregistry.v1+json"} @@ -462,16 +468,19 @@ async def send_request( # Raise the exception since we have no more urls to try raise e - try: - raise SchemaRegistryError(response.status_code, # type: ignore[union-attr] - response.json().get('error_code'), # type: ignore[union-attr] - response.json().get('message')) # type: ignore[union-attr] - # Schema Registry may return malformed output when it hits unexpected errors - except (ValueError, KeyError, AttributeError): - raise SchemaRegistryError(response.status_code, # type: ignore[union-attr] - -1, - "Unknown Schema Registry Error: " - + str(response.content)) # type: ignore[union-attr] + if isinstance(response, Response): + try: + raise SchemaRegistryError(response.status_code, + response.json().get('error_code'), + response.json().get('message')) + # Schema Registry may return malformed output when it hits unexpected errors + except (ValueError, KeyError, AttributeError): + raise SchemaRegistryError(response.status_code, + -1, + "Unknown Schema Registry Error: " + + str(response.content)) + else: + raise TypeError("Unexpected response of unsupported type: " + str(type(response))) async def send_http_request( self, base_url: str, url: str, method: str, headers: Optional[dict], @@ -598,12 +607,14 @@ def __init__(self, conf: dict): self._cache = _SchemaCache() cache_capacity = self._rest_client.cache_capacity cache_ttl = self._rest_client.cache_latest_ttl_sec + self._latest_version_cache: Cache[Any, Any] + self._latest_with_metadata_cache: Cache[Any, Any] if cache_ttl is not None: - self._latest_version_cache: TTLCache[Any, Any] = TTLCache(cache_capacity, cache_ttl) - self._latest_with_metadata_cache: TTLCache[Any, Any] = TTLCache(cache_capacity, cache_ttl) + self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) + self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) else: - self._latest_version_cache = LRUCache[Any, Any](cache_capacity) # type: ignore[assignment] - self._latest_with_metadata_cache = LRUCache[Any, Any](cache_capacity) # type: ignore[assignment] + self._latest_version_cache = LRUCache(cache_capacity) + self._latest_with_metadata_cache = LRUCache(cache_capacity) async def __aenter__(self): return self diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index ab34e2fbe..0738d9e46 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -28,7 +28,7 @@ import httpx from typing import List, Dict, Optional, Union, Any, Callable, Literal -from cachetools import TTLCache, LRUCache +from cachetools import Cache, TTLCache, LRUCache from httpx import Response from authlib.integrations.httpx_client import OAuth2Client @@ -83,8 +83,8 @@ def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict) self.custom_function = custom_function self.custom_config = custom_config - def get_bearer_fields(self) -> dict: # type: ignore[override] - return self.custom_function(self.custom_config) # type: ignore[misc] + def get_bearer_fields(self) -> dict: + return self.custom_function(self.custom_config) class _OAuthClient(_BearerFieldProvider): @@ -100,7 +100,7 @@ def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoin self.retries_max_wait_ms = retries_max_wait_ms self.token_expiry_threshold = 0.8 - def get_bearer_fields(self) -> dict: # type: ignore[override] + def get_bearer_fields(self) -> dict: return { 'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, @@ -108,15 +108,19 @@ def get_bearer_fields(self) -> dict: # type: ignore[override] } def token_expired(self) -> bool: - expiry_window = self.token['expires_in'] * self.token_expiry_threshold # type: ignore[index] + if self.token is None: + raise ValueError("Token is not set") - return self.token['expires_at'] < time.time() + expiry_window # type: ignore[index] + expiry_window = self.token['expires_in'] * self.token_expiry_threshold + + return self.token['expires_at'] < time.time() + expiry_window def get_access_token(self) -> str: if not self.token or self.token_expired(): self.generate_access_token() - - return self.token['access_token'] # type: ignore[index] + if self.token is None: + raise ValueError("Token is not set") + return self.token['access_token'] def generate_access_token(self) -> None: for i in range(self.max_retries + 1): @@ -259,9 +263,9 @@ def __init__(self, conf: dict): + str(type(retries_max_wait_ms))) self.retries_max_wait_ms = int(retries_max_wait_ms) - self.bearer_field_provider = None logical_cluster = None identity_pool = None + self.bearer_field_provider: Optional[_BearerFieldProvider] = None self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) if self.bearer_auth_credentials_source is not None: self.auth = None @@ -281,43 +285,43 @@ def __init__(self, conf: dict): if not isinstance(identity_pool, str): raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) - if self.bearer_auth_credentials_source == 'OAUTHBEARER': - properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', - 'bearer.auth.issuer.endpoint.url'] - missing_properties = [prop for prop in properties_list if prop not in conf_copy] - if missing_properties: - raise ValueError("Missing required OAuth configuration properties: {}". - format(", ".join(missing_properties))) - - self.client_id = conf_copy.pop('bearer.auth.client.id') - if not isinstance(self.client_id, string_type): - raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) - - self.client_secret = conf_copy.pop('bearer.auth.client.secret') - if not isinstance(self.client_secret, string_type): - raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) - - self.scope = conf_copy.pop('bearer.auth.scope') - if not isinstance(self.scope, string_type): - raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) - - self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') - if not isinstance(self.token_endpoint, string_type): - raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " - + str(type(self.token_endpoint))) - - self.bearer_field_provider = _OAuthClient( - self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, # type: ignore[arg-type] - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) - elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': - if 'bearer.auth.token' not in conf_copy: - raise ValueError("Missing bearer.auth.token") - static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) # type: ignore[assignment,arg-type] - if not isinstance(static_token, string_type): - raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) + if self.bearer_auth_credentials_source == 'OAUTHBEARER': + properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', + 'bearer.auth.issuer.endpoint.url'] + missing_properties = [prop for prop in properties_list if prop not in conf_copy] + if missing_properties: + raise ValueError("Missing required OAuth configuration properties: {}". + format(", ".join(missing_properties))) + + self.client_id = conf_copy.pop('bearer.auth.client.id') + if not isinstance(self.client_id, string_type): + raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) + + self.client_secret = conf_copy.pop('bearer.auth.client.secret') + if not isinstance(self.client_secret, string_type): + raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) + + self.scope = conf_copy.pop('bearer.auth.scope') + if not isinstance(self.scope, string_type): + raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) + + self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') + if not isinstance(self.token_endpoint, string_type): + raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + + str(type(self.token_endpoint))) + + self.bearer_field_provider = _OAuthClient( + self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) + else: # STATIC_TOKEN + if 'bearer.auth.token' not in conf_copy: + raise ValueError("Missing bearer.auth.token") + static_token = conf_copy.pop('bearer.auth.token') + self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) + if not isinstance(static_token, string_type): + raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) elif self.bearer_auth_credentials_source == 'CUSTOM': custom_bearer_properties = ['bearer.auth.custom.provider.function', 'bearer.auth.custom.provider.config'] @@ -336,7 +340,7 @@ def __init__(self, conf: dict): raise TypeError("bearer.auth.custom.provider.config must be a dict, not " + str(type(custom_config))) - self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config) # type: ignore[assignment] + self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config) else: raise ValueError('Unrecognized bearer.auth.credentials.source') @@ -379,7 +383,9 @@ def __init__(self, conf: dict): ) def handle_bearer_auth(self, headers: dict) -> None: - bearer_fields = self.bearer_field_provider.get_bearer_fields() # type: ignore[union-attr] + if self.bearer_field_provider is None: + raise ValueError("Bearer field provider is not set") + bearer_fields = self.bearer_field_provider.get_bearer_fields() required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] missing_fields = [] @@ -439,7 +445,7 @@ def send_request( body_str: Optional[str] = None if body is not None: - body_str = json.dumps(body) # type: ignore[assignment] + body_str = json.dumps(body) headers = {'Content-Length': str(len(body_str)), 'Content-Type': "application/vnd.schemaregistry.v1+json"} @@ -462,16 +468,18 @@ def send_request( # Raise the exception since we have no more urls to try raise e - try: - raise SchemaRegistryError(response.status_code, # type: ignore[union-attr] - response.json().get('error_code'), # type: ignore[union-attr] - response.json().get('message')) # type: ignore[union-attr] - # Schema Registry may return malformed output when it hits unexpected errors - except (ValueError, KeyError, AttributeError): - raise SchemaRegistryError(response.status_code, # type: ignore[union-attr] - -1, - "Unknown Schema Registry Error: " - + str(response.content)) # type: ignore[union-attr] + if isinstance(response, Response): + try: + raise SchemaRegistryError(response.status_code, + response.json().get('error_code'), + response.json().get('message')) + except (ValueError, KeyError, AttributeError): + raise SchemaRegistryError(response.status_code, + -1, + "Unknown Schema Registry Error: " + + str(response.content)) + else: + raise TypeError("Unexpected response of unsupported type: " + str(type(response))) def send_http_request( self, base_url: str, url: str, method: str, headers: Optional[dict], @@ -598,12 +606,14 @@ def __init__(self, conf: dict): self._cache = _SchemaCache() cache_capacity = self._rest_client.cache_capacity cache_ttl = self._rest_client.cache_latest_ttl_sec + self._latest_version_cache: Cache[Any, Any] + self._latest_with_metadata_cache: Cache[Any, Any] if cache_ttl is not None: - self._latest_version_cache: TTLCache[Any, Any] = TTLCache(cache_capacity, cache_ttl) - self._latest_with_metadata_cache: TTLCache[Any, Any] = TTLCache(cache_capacity, cache_ttl) + self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) + self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) else: - self._latest_version_cache = LRUCache[Any, Any](cache_capacity) # type: ignore[assignment] - self._latest_with_metadata_cache = LRUCache[Any, Any](cache_capacity) # type: ignore[assignment] + self._latest_version_cache = LRUCache(cache_capacity) + self._latest_with_metadata_cache = LRUCache(cache_capacity) def __enter__(self): return self diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index bdfe3e51d..70511d842 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -28,10 +28,12 @@ __all__ = [ 'VALID_AUTH_PROVIDERS', '_BearerFieldProvider', + '_AsyncBearerFieldProvider', 'is_success', 'is_retriable', 'full_jitter', '_StaticFieldProvider', + '_AsyncStaticFieldProvider', '_SchemaCache', 'RuleKind', 'RuleMode', @@ -52,35 +54,55 @@ class _BearerFieldProvider(metaclass=abc.ABCMeta): + """Base class for synchronous bearer field providers.""" @abc.abstractmethod def get_bearer_fields(self) -> dict: raise NotImplementedError -def is_success(status_code: int) -> bool: - return 200 <= status_code <= 299 - - -def is_retriable(status_code: int) -> bool: - return status_code in (408, 429, 500, 502, 503, 504) +class _AsyncBearerFieldProvider(metaclass=abc.ABCMeta): + """Base class for asynchronous bearer field providers.""" + @abc.abstractmethod + async def get_bearer_fields(self) -> dict: + raise NotImplementedError +class _StaticFieldProvider(_BearerFieldProvider): + """Synchronous static token bearer field provider.""" + def __init__(self, token: str, logical_cluster: str, identity_pool: str): + self.token = token + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool -def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) -> float: - no_jitter_delay = base_delay_ms * (2.0 ** retries_attempted) - return random.random() * min(no_jitter_delay, max_delay_ms) + def get_bearer_fields(self) -> dict: + return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool} -class _StaticFieldProvider(_BearerFieldProvider): +class _AsyncStaticFieldProvider(_AsyncBearerFieldProvider): + """Asynchronous static token bearer field provider.""" def __init__(self, token: str, logical_cluster: str, identity_pool: str): self.token = token self.logical_cluster = logical_cluster self.identity_pool = identity_pool - def get_bearer_fields(self) -> dict: + async def get_bearer_fields(self) -> dict: return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, 'bearer.auth.identity.pool.id': self.identity_pool} +def is_success(status_code: int) -> bool: + return 200 <= status_code <= 299 + + +def is_retriable(status_code: int) -> bool: + return status_code in (408, 429, 500, 502, 503, 504) + + +def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) -> float: + no_jitter_delay = base_delay_ms * (2.0 ** retries_attempted) + return random.random() * min(no_jitter_delay, max_delay_ms) + + class _SchemaCache(object): """ Thread-safe cache for use with the Schema Registry Client. diff --git a/src/confluent_kafka/schema_registry/rule_registry.py b/src/confluent_kafka/schema_registry/rule_registry.py index f1d90756f..93a41235b 100644 --- a/src/confluent_kafka/schema_registry/rule_registry.py +++ b/src/confluent_kafka/schema_registry/rule_registry.py @@ -1,3 +1,4 @@ + #!/usr/bin/env python # -*- coding: utf-8 -*- # diff --git a/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py b/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py index 5a6b76c5f..4c9642c4d 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py @@ -212,7 +212,7 @@ def _is_dek_rotated(self): def _get_kek(self, ctx: RuleContext) -> Kek: if self._kek is None: self._kek = self._get_or_create_kek(ctx) - return self._kek # type: ignore[return-value] + return self._kek def _get_or_create_kek(self, ctx: RuleContext) -> Kek: is_read = ctx.rule_mode == RuleMode.READ @@ -242,8 +242,10 @@ def _get_or_create_kek(self, ctx: RuleContext) -> Kek: return kek def _retrieve_kek_from_registry(self, kek_id: KekId) -> Optional[Kek]: + if self._executor.client is None: + raise RuleError("client not configured") try: - return self._executor.client.get_kek(kek_id.name, kek_id.deleted) # type: ignore[union-attr] + return self._executor.client.get_kek(kek_id.name, kek_id.deleted) except Exception as e: if isinstance(e, SchemaRegistryError) and e.http_status_code == 404: return None @@ -253,8 +255,10 @@ def _store_kek_to_registry( self, kek_id: KekId, kms_type: str, kms_key_id: str, shared: bool ) -> Optional[Kek]: + if self._executor.client is None: + raise RuleError("client not configured") try: - return self._executor.client.register_kek(kek_id.name, kms_type, kms_key_id, shared) # type: ignore[union-attr] + return self._executor.client.register_kek(kek_id.name, kms_type, kms_key_id, shared) except Exception as e: if isinstance(e, SchemaRegistryError) and e.http_status_code == 409: return None @@ -265,6 +269,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: is_read = ctx.rule_mode == RuleMode.READ if version is None or version == 0: version = 1 + # TODO: fallback value for name? dek_id = DekId( kek.name, # type: ignore[arg-type] ctx.subject, @@ -278,12 +283,19 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: if dek is None or is_expired: if is_read: raise RuleError(f"no dek found for {dek_id.kek_name} during consume") + if self._kek is None: + raise RuleError("no kek found") encrypted_dek = None if not kek.shared: - primitive = AeadWrapper(self._executor.config, self._kek) # type: ignore[arg-type] + if self._executor.config is None: + raise RuleError("config not found in executor") + primitive = AeadWrapper(self._executor.config, self._kek) raw_dek = self._cryptor.generate_key() encrypted_dek = primitive.encrypt(raw_dek, self._cryptor.EMPTY_AAD) - new_version = dek.version + 1 if is_expired else 1 # type: ignore[union-attr,operator] + if dek is None or dek.version is None: + new_version = 1 + else: + new_version = dek.version + 1 if is_expired else 1 try: dek = self._create_dek(dek_id, new_version, encrypted_dek) except RuleError as e: @@ -301,6 +313,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: return dek def _create_dek(self, dek_id: DekId, new_version: Optional[int], encrypted_dek: Optional[bytes]) -> Dek: + # TODO: fallback value for version? new_dek_id = DekId( dek_id.kek_name, dek_id.subject, @@ -321,7 +334,9 @@ def _retrieve_dek_from_registry(self, key: DekId) -> Optional[Dek]: version = key.version if not version: version = 1 - dek = self._executor.client.get_dek( # type: ignore[union-attr] + if self._executor.client is None: + raise RuleError("client not configured") + dek = self._executor.client.get_dek( key.kek_name, key.subject, key.algorithm, version, key.deleted) return dek if dek and dek.encrypted_key_material else None except Exception as e: @@ -332,8 +347,14 @@ def _retrieve_dek_from_registry(self, key: DekId) -> Optional[Dek]: def _store_dek_to_registry(self, key: DekId, encrypted_dek: Optional[bytes]) -> Optional[Dek]: try: encrypted_dek_str = base64.b64encode(encrypted_dek).decode("utf-8") if encrypted_dek else None - dek = self._executor.client.register_dek( # type: ignore[union-attr] - key.kek_name, key.subject, encrypted_dek_str, key.algorithm, key.version) # type: ignore[arg-type] + if self._executor.client is None: + raise RuleError("client not configured") + dek = self._executor.client.register_dek( + key.kek_name, + key.subject, + encrypted_dek_str, # type: ignore[arg-type] + key.algorithm, + key.version) return dek except Exception as e: if isinstance(e, SchemaRegistryError) and e.http_status_code == 409: @@ -359,18 +380,23 @@ def transform(self, ctx: RuleContext, field_type: FieldType, field_value: Any) - version = -1 dek = self._get_or_create_dek(ctx, version) key_material_bytes = dek.get_key_material_bytes() - ciphertext = self._cryptor.encrypt(key_material_bytes, plaintext, Cryptor.EMPTY_AAD) # type: ignore[arg-type] + if key_material_bytes is None: + raise RuleError("no key material bytes found for dek") + ciphertext = self._cryptor.encrypt(key_material_bytes, plaintext, Cryptor.EMPTY_AAD) if self._is_dek_rotated(): - ciphertext = self._prefix_version(dek.version, ciphertext) # type: ignore[arg-type] + if dek.version is None: + raise RuleError("no version found for dek") + ciphertext = self._prefix_version(dek.version, ciphertext) if field_type == FieldType.STRING: return base64.b64encode(ciphertext).decode("utf-8") else: return self._to_object(field_type, ciphertext) elif ctx.rule_mode == RuleMode.READ: + ciphertext = None if field_type == FieldType.STRING: ciphertext = base64.b64decode(field_value) else: - ciphertext = self._to_bytes(field_type, field_value) # type: ignore[assignment] + ciphertext = self._to_bytes(field_type, field_value) if ciphertext is None: return field_value @@ -381,7 +407,9 @@ def transform(self, ctx: RuleContext, field_type: FieldType, field_value: Any) - raise RuleError("no version found in ciphertext") dek = self._get_or_create_dek(ctx, version) key_material_bytes = dek.get_key_material_bytes() - plaintext = self._cryptor.decrypt(key_material_bytes, ciphertext, Cryptor.EMPTY_AAD) # type: ignore[arg-type] + if key_material_bytes is None: + raise RuleError("no key material bytes found for dek") + plaintext = self._cryptor.decrypt(key_material_bytes, ciphertext, Cryptor.EMPTY_AAD) return self._to_object(field_type, plaintext) else: raise RuleError(f"unsupported rule mode {ctx.rule_mode}") @@ -421,7 +449,9 @@ def __init__(self, config: dict, kek: Kek): def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes: for index, kms_key_id in enumerate(self._kms_key_ids): try: - aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) # type: ignore[arg-type] + if self._kek.kms_type is None: + raise RuleError("no kms type found for kek") + aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) return aead.encrypt(plaintext, associated_data) except Exception as e: log.warning("failed to encrypt with kek %s and kms key id %s", @@ -433,7 +463,9 @@ def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes: def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes: for index, kms_key_id in enumerate(self._kms_key_ids): try: - aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) # type: ignore[arg-type] + if self._kek.kms_type is None: + raise RuleError("no kms type found for kek") + aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) return aead.decrypt(ciphertext, associated_data) except Exception as e: log.warning("failed to decrypt with kek %s and kms key id %s", @@ -443,7 +475,7 @@ def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes: raise RuleError("No KEK found for decryption") def _get_kms_key_ids(self) -> List[str]: - kms_key_ids = [self._kek.kms_key_id] # type: ignore[list-item] + kms_key_ids = [self._kek.kms_key_id] alternate_kms_key_ids = None if self._kek.kms_props is not None: alternate_kms_key_ids = self._kek.kms_props.properties.get(ENCRYPT_ALTERNATE_KMS_KEY_IDS) From 0f512477ec84d37948566589ec382a57e6f8eb8e Mon Sep 17 00:00:00 2001 From: Naxin Date: Fri, 24 Oct 2025 00:31:54 -0400 Subject: [PATCH 25/31] more typeignore removals --- .../schema_registry/_async/avro.py | 43 +++++++++++-------- .../schema_registry/_async/json_schema.py | 32 ++++++++------ .../schema_registry/_async/protobuf.py | 2 +- .../schema_registry/_sync/avro.py | 21 ++++++--- .../schema_registry/_sync/json_schema.py | 31 +++++++------ .../schema_registry/_sync/protobuf.py | 2 +- .../rules/cel/cel_field_presence.py | 2 +- .../schema_registry/rules/cel/constraints.py | 12 +++--- .../schema_registry/rules/cel/extra_func.py | 14 ++++-- .../rules/cel/string_format.py | 4 +- .../rules/encryption/azurekms/azure_client.py | 4 +- .../rules/encryption/localkms/local_client.py | 4 +- 12 files changed, 101 insertions(+), 70 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index 1f45f9e7d..1f82d3721 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -58,17 +58,17 @@ async def _resolve_named_schema( if schema.references is not None: for ref in schema.references: if ref.subject is None or ref.version is None: - raise ValueError("Subject or version cannot be None") + raise TypeError("Subject or version cannot be None") referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) ref_named_schemas = await _resolve_named_schema(referenced_schema.schema, schema_registry_client) if referenced_schema.schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") parsed_schema = parse_schema_with_repo( referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) named_schemas.update(ref_named_schemas) if ref.name is None: - raise ValueError("Name cannot be None") + raise TypeError("Name cannot be None") named_schemas[ref.name] = parsed_schema return named_schemas @@ -278,16 +278,18 @@ async def __init_impl( # and union types should use topic_subject_name_strategy, which # just discards the schema name anyway schema_name = None - else: + elif isinstance(parsed_schema, dict): # The Avro spec states primitives have a name equal to their type # i.e. {"type": "string"} has a name of string. # This function does not comply. # https://github.com/fastavro/fastavro/issues/415 if schema.schema_str is not None: schema_dict = json.loads(schema.schema_str) - schema_name = parsed_schema.get("name", schema_dict.get("type")) # type: ignore[union-attr] + schema_name = parsed_schema.get("name", schema_dict.get("type")) else: schema_name = None + else: + schema_name = None else: schema_name = None parsed_schema = None @@ -348,12 +350,14 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N self._known_subjects.add(subject) + value: Any + parsed_schema: Any if self._to_dict is not None: if ctx is None: - raise ValueError("SerializationContext cannot be None") + raise TypeError("SerializationContext cannot be None") value = self._to_dict(obj, ctx) else: - value = obj # type: ignore[assignment] + value = obj if latest_schema is not None and ctx is not None and subject is not None: parsed_schema = await self._get_parsed_schema(latest_schema.schema) @@ -363,7 +367,7 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 latest_schema.schema, value, get_inline_tags(parsed_schema), field_transformer) else: - parsed_schema = self._parsed_schema # type: ignore[assignment] + parsed_schema = self._parsed_schema with _ContextStringIO() as fo: # write the record to the rest of the buffer @@ -384,10 +388,10 @@ async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: named_schemas = await _resolve_named_schema(schema, self._registry) if schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) if prepared_schema.schema_str is None: - raise ValueError("Prepared schema string cannot be None") + raise TypeError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) @@ -526,10 +530,11 @@ async def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - if schema: - self._reader_schema = await self._get_parsed_schema(self._schema) # type: ignore[arg-type] + self._reader_schema: Optional[AvroSchema] + if schema and self._schema is not None: + self._reader_schema = await self._get_parsed_schema(self._schema) else: - self._reader_schema = None # type: ignore[assignment] + self._reader_schema = None if from_dict is not None and not callable(from_dict): raise ValueError("from_dict must be callable with the signature " @@ -600,6 +605,7 @@ async def __deserialize( if isinstance(payload, bytes): payload = io.BytesIO(payload) + reader_schema: Optional[AvroSchema] if latest_schema is not None and subject is not None: migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema @@ -628,13 +634,14 @@ async def __deserialize( def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) if ctx is not None and subject is not None: + inline_tags = get_inline_tags(reader_schema) if reader_schema is not None else None obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, get_inline_tags(reader_schema), - field_transformer) + reader_schema_raw, obj_dict, + inline_tags,field_transformer) if self._from_dict is not None: if ctx is None: - raise ValueError("SerializationContext cannot be None") + raise TypeError("SerializationContext cannot be None") return self._from_dict(obj_dict, ctx) return obj_dict @@ -646,10 +653,10 @@ async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: named_schemas = await _resolve_named_schema(schema, self._registry) if schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) if prepared_schema.schema_str is None: - raise ValueError("Prepared schema string cannot be None") + raise TypeError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py index 448add6b7..c87c35e5d 100644 --- a/src/confluent_kafka/schema_registry/_async/json_schema.py +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -65,17 +65,17 @@ async def _resolve_named_schema( if schema.references is not None: for ref in schema.references: if ref.subject is None or ref.version is None: - raise ValueError("Subject or version cannot be None") + raise TypeError("Subject or version cannot be None") referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) ref_registry = await _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) if referenced_schema.schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) resource = Resource.from_contents( referenced_schema_dict, default_specification=DEFAULT_SPEC) if ref.name is None: - raise ValueError("Name cannot be None") + raise TypeError("Name cannot be None") ref_registry = ref_registry.with_resource(ref.name, resource) return ref_registry @@ -293,11 +293,14 @@ async def __init_impl( if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - - schema_dict, ref_registry = await self._get_parsed_schema(self._schema) # type: ignore[arg-type] - if schema_dict and isinstance(schema_dict, dict): - schema_name = schema_dict.get('title', None) + if self._schema: + schema_dict, ref_registry = await self._get_parsed_schema(self._schema) + if schema_dict and isinstance(schema_dict, dict): + schema_name = schema_dict.get('title', None) + else: + schema_name = None else: + schema_dict = None schema_name = None self._schema_name = schema_name @@ -355,12 +358,13 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N self._known_subjects.add(subject) + value: Any if self._to_dict is not None: if ctx is None: - raise ValueError("SerializationContext cannot be None") + raise TypeError("SerializationContext cannot be None") value = self._to_dict(obj, ctx) else: - value = obj # type: ignore[assignment] + value = obj schema: Optional[Schema] = None if latest_schema is not None: @@ -413,7 +417,7 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema] ref_registry = await _resolve_named_schema(schema, self._registry) if schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) @@ -569,8 +573,8 @@ async def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - if schema: - self._reader_schema, self._ref_registry = await self._get_parsed_schema(self._schema) # type: ignore[arg-type] + if schema and self._schema is not None: + self._reader_schema, self._ref_registry = await self._get_parsed_schema(self._schema) else: self._reader_schema, self._ref_registry = None, None @@ -678,7 +682,7 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 if self._from_dict is not None: if ctx is None: - raise ValueError("SerializationContext cannot be None") + raise TypeError("SerializationContext cannot be None") return self._from_dict(obj_dict, ctx) # type: ignore[return-value] return obj_dict @@ -693,7 +697,7 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema] ref_registry = await _resolve_named_schema(schema, self._registry) if schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index 73146f167..1d2a1a7bb 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -623,7 +623,7 @@ async def __deserialize(self, data: Optional[bytes], ctx: Optional[Serialization writer_schema_raw = await self._get_writer_schema(schema_id, subject, fmt='serialized') fd_proto, pool = await self._get_parsed_schema(writer_schema_raw) writer_schema = pool.FindFileByName(fd_proto.name) - writer_desc = self._get_message_desc(pool, writer_schema, msg_index) # type: ignore[arg-type] + writer_desc = self._get_message_desc(pool, writer_schema, msg_index if msg_index is not None else []) if subject is None: subject = self._subject_name_func(ctx, writer_desc.full_name) if subject is not None: diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index a54847a6f..4e7928573 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -278,16 +278,18 @@ def __init_impl( # and union types should use topic_subject_name_strategy, which # just discards the schema name anyway schema_name = None - else: + elif isinstance(parsed_schema, dict): # The Avro spec states primitives have a name equal to their type # i.e. {"type": "string"} has a name of string. # This function does not comply. # https://github.com/fastavro/fastavro/issues/415 if schema.schema_str is not None: schema_dict = json.loads(schema.schema_str) - schema_name = parsed_schema.get("name", schema_dict.get("type")) # type: ignore[union-attr] + schema_name = parsed_schema.get("name", schema_dict.get("type")) else: schema_name = None + else: + schema_name = None else: schema_name = None parsed_schema = None @@ -348,6 +350,8 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - self._known_subjects.add(subject) + value: Any + parsed_schema: Any if self._to_dict is not None: if ctx is None: raise ValueError("SerializationContext cannot be None") @@ -526,10 +530,11 @@ def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - if schema: - self._reader_schema = self._get_parsed_schema(self._schema) # type: ignore[arg-type] + self._reader_schema: Optional[AvroSchema] + if schema and self._schema is not None: + self._reader_schema = self._get_parsed_schema(self._schema) else: - self._reader_schema = None # type: ignore[assignment] + self._reader_schema = None if from_dict is not None and not callable(from_dict): raise ValueError("from_dict must be callable with the signature " @@ -600,6 +605,7 @@ def __deserialize( if isinstance(payload, bytes): payload = io.BytesIO(payload) + reader_schema: Optional[AvroSchema] if latest_schema is not None and subject is not None: migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema @@ -628,9 +634,10 @@ def __deserialize( def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) if ctx is not None and subject is not None: + inline_tags = get_inline_tags(reader_schema) if reader_schema is not None else None obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, get_inline_tags(reader_schema), - field_transformer) + reader_schema_raw, obj_dict, + inline_tags, field_transformer) if self._from_dict is not None: if ctx is None: diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index a47c5bcd2..f8137887b 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -65,17 +65,17 @@ def _resolve_named_schema( if schema.references is not None: for ref in schema.references: if ref.subject is None or ref.version is None: - raise ValueError("Subject or version cannot be None") + raise TypeError("Subject or version cannot be None") referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) ref_registry = _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) if referenced_schema.schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) resource = Resource.from_contents( referenced_schema_dict, default_specification=DEFAULT_SPEC) if ref.name is None: - raise ValueError("Name cannot be None") + raise TypeError("Name cannot be None") ref_registry = ref_registry.with_resource(ref.name, resource) return ref_registry @@ -294,10 +294,14 @@ def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - schema_dict, ref_registry = self._get_parsed_schema(self._schema) # type: ignore[arg-type] - if schema_dict and isinstance(schema_dict, dict): - schema_name = schema_dict.get('title', None) + if self._schema: + schema_dict, ref_registry = self._get_parsed_schema(self._schema) + if schema_dict and isinstance(schema_dict, dict): + schema_name = schema_dict.get('title', None) + else: + schema_name = None else: + schema_dict = None schema_name = None self._schema_name = schema_name @@ -355,12 +359,13 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - self._known_subjects.add(subject) + value: Any if self._to_dict is not None: if ctx is None: - raise ValueError("SerializationContext cannot be None") + raise TypeError("SerializationContext cannot be None") value = self._to_dict(obj, ctx) else: - value = obj # type: ignore[assignment] + value = obj schema: Optional[Schema] = None if latest_schema is not None: @@ -413,7 +418,7 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Opti ref_registry = _resolve_named_schema(schema, self._registry) if schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) @@ -569,8 +574,8 @@ def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - if schema: - self._reader_schema, self._ref_registry = self._get_parsed_schema(self._schema) # type: ignore[arg-type] + if schema and self._schema is not None: + self._reader_schema, self._ref_registry = self._get_parsed_schema(self._schema) else: self._reader_schema, self._ref_registry = None, None @@ -678,7 +683,7 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 if self._from_dict is not None: if ctx is None: - raise ValueError("SerializationContext cannot be None") + raise TypeError("SerializationContext cannot be None") return self._from_dict(obj_dict, ctx) # type: ignore[return-value] return obj_dict @@ -693,7 +698,7 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Opti ref_registry = _resolve_named_schema(schema, self._registry) if schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 05651b739..f48068c70 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -623,7 +623,7 @@ def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContex writer_schema_raw = self._get_writer_schema(schema_id, subject, fmt='serialized') fd_proto, pool = self._get_parsed_schema(writer_schema_raw) writer_schema = pool.FindFileByName(fd_proto.name) - writer_desc = self._get_message_desc(pool, writer_schema, msg_index) # type: ignore[arg-type] + writer_desc = self._get_message_desc(pool, writer_schema, msg_index if msg_index is not None else []) if subject is None: subject = self._subject_name_func(ctx, writer_desc.full_name) if subject is not None: diff --git a/src/confluent_kafka/schema_registry/rules/cel/cel_field_presence.py b/src/confluent_kafka/schema_registry/rules/cel/cel_field_presence.py index 329077a3e..597f1b5fb 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/cel_field_presence.py +++ b/src/confluent_kafka/schema_registry/rules/cel/cel_field_presence.py @@ -16,7 +16,7 @@ import threading from typing import Any -import celpy # type: ignore +import celpy _has_state = threading.local() diff --git a/src/confluent_kafka/schema_registry/rules/cel/constraints.py b/src/confluent_kafka/schema_registry/rules/cel/constraints.py index a50dc0210..32e87fdcf 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/constraints.py +++ b/src/confluent_kafka/schema_registry/rules/cel/constraints.py @@ -15,7 +15,7 @@ import typing -from celpy import celtypes # type: ignore +from celpy import celtypes from google.protobuf import descriptor, message, message_factory from confluent_kafka.schema_registry.rules.cel import string_format @@ -32,8 +32,8 @@ def make_key_path(field_name: str, key: celtypes.Value) -> str: def make_duration(msg: message.Message) -> celtypes.DurationType: return celtypes.DurationType( - seconds=msg.seconds, # type: ignore - nanos=msg.nanos, # type: ignore + seconds=msg.seconds, + nanos=msg.nanos, ) @@ -115,14 +115,14 @@ def _msg_to_cel(msg: message.Message) -> typing.Dict[str, celtypes.Value]: def _proto_message_has_field(msg: message.Message, field: descriptor.FieldDescriptor) -> typing.Any: if field.is_extension: - return msg.HasExtension(field) # type: ignore + return msg.HasExtension(field) else: return msg.HasField(field.name) def _proto_message_get_field(msg: message.Message, field: descriptor.FieldDescriptor) -> typing.Any: if field.is_extension: - return msg.Extensions[field] # type: ignore + return msg.Extensions[field] else: return getattr(msg, field.name) @@ -144,7 +144,7 @@ def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> c def _is_empty_field(msg: message.Message, field: descriptor.FieldDescriptor) -> bool: - if field.has_presence: # type: ignore[attr-defined] + if field.has_presence: return not _proto_message_has_field(msg, field) if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: return len(_proto_message_get_field(msg, field)) == 0 diff --git a/src/confluent_kafka/schema_registry/rules/cel/extra_func.py b/src/confluent_kafka/schema_registry/rules/cel/extra_func.py index 194ce826f..95c016fc2 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/extra_func.py +++ b/src/confluent_kafka/schema_registry/rules/cel/extra_func.py @@ -19,8 +19,8 @@ from ipaddress import IPv4Address, IPv6Address, ip_address from urllib import parse as urlparse -import celpy # type: ignore -from celpy import celtypes # type: ignore +import celpy +from celpy import celtypes from confluent_kafka.schema_registry.rules.cel import string_format @@ -143,14 +143,20 @@ def is_email(string: celtypes.Value) -> celpy.Result: def is_uri(string: celtypes.Value) -> celpy.Result: - url = urlparse.urlparse(string) # type: ignore[arg-type] + if not isinstance(string, celtypes.StringType): + msg = "invalid argument, expected string" + raise celpy.CELEvalError(msg) + url = urlparse.urlparse(string) if not all([url.scheme, url.netloc, url.path]): return celtypes.BoolType(False) return celtypes.BoolType(True) def is_uri_ref(string: celtypes.Value) -> celpy.Result: - url = urlparse.urlparse(string) # type: ignore[arg-type] + if not isinstance(string, celtypes.StringType): + msg = "invalid argument, expected string" + raise celpy.CELEvalError(msg) + url = urlparse.urlparse(string) if not all([url.scheme, url.path]) and url.fragment: return celtypes.BoolType(False) return celtypes.BoolType(True) diff --git a/src/confluent_kafka/schema_registry/rules/cel/string_format.py b/src/confluent_kafka/schema_registry/rules/cel/string_format.py index 39db450b1..9aaaf5eaf 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/string_format.py +++ b/src/confluent_kafka/schema_registry/rules/cel/string_format.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import celpy # type: ignore -from celpy import celtypes # type: ignore +import celpy +from celpy import celtypes QUOTE_TRANS = str.maketrans( { diff --git a/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py b/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py index a98ab5116..629cb78fd 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py @@ -48,7 +48,7 @@ def __init__( self._key_uri = key_uri else: raise tink.TinkError('Invalid key_uri.') - + key_id = key_uri[len(AZURE_KEYURI_PREFIX):] self._client = CryptographyClient(key_id, credentials) @@ -81,4 +81,4 @@ def get_aead(self, key_uri: str) -> aead.Aead: ) if not key_uri.startswith(AZURE_KEYURI_PREFIX): raise tink.TinkError('Invalid key_uri.') - return AzureKmsAead(self._client, EncryptionAlgorithm.rsa_oaep_256) # type: ignore[arg-type] + return AzureKmsAead(self._client, EncryptionAlgorithm.rsa_oaep_256) diff --git a/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py b/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py index f9b0bb452..51f5508cb 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py @@ -23,7 +23,9 @@ class LocalKmsClient(KmsClient): def __init__(self, secret: Optional[str] = None): - self._aead = self._get_primitive(secret) # type: ignore[arg-type] + if secret is None: + raise TypeError("secret cannot be None") + self._aead = self._get_primitive(secret) def _get_primitive(self, secret: str) -> aead.Aead: key = self._get_key(secret) From b4bf42cd29daf74c20428d623b58dd0d172777ca Mon Sep 17 00:00:00 2001 From: Naxin Date: Fri, 24 Oct 2025 11:30:32 -0400 Subject: [PATCH 26/31] update --- .../schema_registry/rules/cel/cel_executor.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py b/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py index 0b385dc4d..adf260250 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py +++ b/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +import logging import uuid import celpy @@ -33,6 +34,9 @@ from google.protobuf import message + +log = logging.getLogger(__name__) + # A date logical type annotates an Avro int, where the int stores the number # of days from the unix epoch, 1 January 1970 (ISO calendar). DAYS_SHIFT = datetime.date(1970, 1, 1).toordinal() @@ -54,25 +58,30 @@ def transform(self, ctx: RuleContext, msg: Any) -> Any: def execute(self, ctx: RuleContext, msg: Any, args: Any) -> Any: expr = ctx.rule.expr + if expr is None: + log.warning("Expression from rule %s is None", ctx.rule.name) + return msg try: - index = expr.index(";") # type: ignore[union-attr] + index = expr.index(";") except ValueError: index = -1 if index >= 0: - guard = expr[:index] # type: ignore[index] + guard = expr[:index] if len(guard.strip()) > 0: guard_result = self.execute_rule(ctx, guard, args) if not guard_result: if ctx.rule.kind == RuleKind.CONDITION: return True return msg - expr = expr[index+1:] # type: ignore[index] + expr = expr[index+1:] - return self.execute_rule(ctx, expr, args) # type: ignore[arg-type] + return self.execute_rule(ctx, expr, args) def execute_rule(self, ctx: RuleContext, expr: str, args: Any) -> Any: schema = ctx.target - script_type = ctx.target.schema_type # type: ignore[union-attr] + if schema is None: + raise ValueError("Target schema is None") # TODO: check whether we should raise or return fallback + script_type = ctx.target.schema_type prog = self._cache.get_program(expr, script_type, schema) if prog is None: ast = self._env.compile(expr) From 1451647c9614bd42da56b9087a08d49290f2ff99 Mon Sep 17 00:00:00 2001 From: Naxin Date: Fri, 24 Oct 2025 15:12:02 -0400 Subject: [PATCH 27/31] handle union types in schemas --- .../_async/schema_registry_client.py | 4 +- .../_sync/schema_registry_client.py | 2 +- .../schema_registry/common/avro.py | 26 ++++-- .../schema_registry/common/json_schema.py | 82 ++++++++++++------- .../schema_registry/rules/cel/cel_executor.py | 4 +- .../rules/cel/string_format.py | 44 +++++++--- 6 files changed, 110 insertions(+), 52 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index 1320cdec9..d67501a06 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -116,10 +116,10 @@ def token_expired(self) -> bool: return self.token['expires_at'] < time.time() + expiry_window async def get_access_token(self) -> str: - if self.token is None or self.token_expired(): + if not self.token or self.token_expired(): await self.generate_access_token() if self.token is None: - raise ValueError("Token is not set") + raise ValueError("Token is not set after the at") return self.token['access_token'] async def generate_access_token(self) -> None: diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 0738d9e46..e21d79340 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -119,7 +119,7 @@ def get_access_token(self) -> str: if not self.token or self.token_expired(): self.generate_access_token() if self.token is None: - raise ValueError("Token is not set") + raise ValueError("Token is not set after the attempt to generate it") return self.token['access_token'] def generate_access_token(self) -> None: diff --git a/src/confluent_kafka/schema_registry/common/avro.py b/src/confluent_kafka/schema_registry/common/avro.py index d9e2345e6..97cdcc518 100644 --- a/src/confluent_kafka/schema_registry/common/avro.py +++ b/src/confluent_kafka/schema_registry/common/avro.py @@ -1,4 +1,5 @@ import decimal +import logging import re from collections import defaultdict from copy import deepcopy @@ -44,6 +45,8 @@ ] AvroSchema = Union[str, list, dict] +log = logging.getLogger(__name__) + class _ContextStringIO(BytesIO): """ @@ -113,12 +116,21 @@ def transform( elif isinstance(schema, dict): schema_type = schema.get("type") if schema_type == 'array': + if not isinstance(message, list): + log.warning("Incompatible message type for array schema") + return message return [transform(ctx, schema["items"], item, field_transform) - for item in message] # type: ignore[union-attr] + for item in message] elif schema_type == 'map': + if not isinstance(message, dict): + log.warning("Incompatible message type for map schema") + return message return {key: transform(ctx, schema["values"], value, field_transform) - for key, value in message.items()} # type: ignore[union-attr] + for key, value in message.items()} elif schema_type == 'record': + if not isinstance(message, dict): + log.warning("Incompatible message type for record schema") + return message fields = schema["fields"] for field in fields: _transform_field(ctx, schema, field, message, field_transform) @@ -132,12 +144,12 @@ def transform( def _transform_field( - ctx: RuleContext, schema: AvroSchema, field: dict, - message: AvroMessage, field_transform: FieldTransform + ctx: RuleContext, schema: dict, field: dict, + message: dict, field_transform: FieldTransform ): field_type = field["type"] name = field["name"] - full_name = schema["name"] + "." + name # type: ignore[call-overload,index] + full_name = schema["name"] + "." + name try: ctx.enter_field( message, @@ -146,13 +158,13 @@ def _transform_field( get_type(field_type), None ) - value = message[name] # type: ignore[index] + value = message[name] new_value = transform(ctx, field_type, value, field_transform) if ctx.rule.kind == RuleKind.CONDITION: if new_value is False: raise RuleConditionError(ctx.rule) else: - message[name] = new_value # type: ignore[index] + message[name] = new_value finally: ctx.exit_field() diff --git a/src/confluent_kafka/schema_registry/common/json_schema.py b/src/confluent_kafka/schema_registry/common/json_schema.py index 4f3cf5920..ceeacf7bc 100644 --- a/src/confluent_kafka/schema_registry/common/json_schema.py +++ b/src/confluent_kafka/schema_registry/common/json_schema.py @@ -2,6 +2,7 @@ import decimal from io import BytesIO +import logging from typing import Union, Optional, List, Set import httpx @@ -44,6 +45,7 @@ DEFAULT_SPEC = referencing.jsonschema.DRAFT7 # type: ignore[attr-defined] +log = logging.getLogger(__name__) class _ContextStringIO(BytesIO): """ @@ -68,8 +70,10 @@ def transform( ctx: RuleContext, schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, path: str, message: JsonMessage, field_transform: FieldTransform ) -> Optional[JsonMessage]: + # Only proceed to transform the message if schema is of dict type if message is None or schema is None or isinstance(schema, bool): return message + field_ctx = ctx.current_field() if field_ctx is not None: field_ctx.field_type = get_type(schema) @@ -104,13 +108,18 @@ def transform( if ref is not None: ref_schema = ref_resolver.lookup(ref) return transform(ctx, ref_schema.contents, ref_registry, ref_resolver, path, message, field_transform) + schema_type = get_type(schema) if schema_type == FieldType.RECORD: props = schema.get("properties") + if not isinstance(message, dict): + log.warning("Incompatible message type for record schema") + return message if props is not None: for prop_name, prop_schema in props.items(): - _transform_field(ctx, path, prop_name, message, - prop_schema, ref_registry, ref_resolver, field_transform) + if isinstance(prop_schema, dict): + _transform_field(ctx, path, prop_name, message, + prop_schema, ref_registry, ref_resolver, field_transform) return message if schema_type in (FieldType.ENUM, FieldType.STRING, FieldType.INT, FieldType.DOUBLE, FieldType.BOOLEAN): if field_ctx is not None: @@ -121,8 +130,8 @@ def transform( def _transform_field( - ctx: RuleContext, path: str, prop_name: str, message: JsonMessage, - prop_schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, field_transform: FieldTransform + ctx: RuleContext, path: str, prop_name: str, message: dict, + prop_schema: dict, ref_registry: Registry, ref_resolver: Resolver, field_transform: FieldTransform ): full_name = path + "." + prop_name try: @@ -133,26 +142,35 @@ def _transform_field( get_type(prop_schema), get_inline_tags(prop_schema) ) - value = message.get(prop_name) # type: ignore[union-attr] + value = message.get(prop_name) if value is not None: new_value = transform(ctx, prop_schema, ref_registry, ref_resolver, full_name, value, field_transform) if ctx.rule.kind == RuleKind.CONDITION: if new_value is False: raise RuleConditionError(ctx.rule) else: - message[prop_name] = new_value # type: ignore[index,call-overload] + message[prop_name] = new_value finally: ctx.exit_field() def _validate_subtypes( - schema: JsonSchema, message: JsonMessage, registry: Registry + schema: dict, message: JsonMessage, registry: Registry ) -> Optional[JsonSchema]: - schema_type = schema.get("type") # type: ignore[union-attr] + """ + Validate the message against the subtypes. + Args: + schema: The schema to validate the message against. + message: The message to validate. + registry: The registry to use for the validation. + Returns: + The validated schema if the message is valid against the subtypes, otherwise None. + """ + schema_type = schema.get("type") if not isinstance(schema_type, list) or len(schema_type) == 0: return None for typ in schema_type: - schema["type"] = typ # type: ignore[index] + schema["type"] = typ try: validate(instance=message, schema=schema, registry=registry) return schema @@ -167,32 +185,38 @@ def _validate_subschemas( registry: Registry, resolver: Resolver, )-> Optional[JsonSchema]: + """ + Validate the message against the subschemas. + Args: + subschemas: The list of subschemas to validate the message against. + message: The message to validate. + registry: The registry to use for the validation. + resolver: The resolver to use for the validation. + Returns: + The validated schema if the message is valid against the subschemas, otherwise None. + """ for subschema in subschemas: - try: - ref = subschema.get("$ref") # type: ignore[union-attr] - if ref is not None: - # resolve $ref before validating - subschema = resolver.lookup(ref).contents - validate(instance=message, schema=subschema, registry=registry, resolver=resolver) - return subschema - except ValidationError: - pass + if isinstance(subschema, dict): + try: + ref = subschema.get("$ref") + if ref is not None: + subschema = resolver.lookup(ref).contents + validate(instance=message, schema=subschema, registry=registry, resolver=resolver) + return subschema + except ValidationError: + pass return None def get_type(schema: JsonSchema) -> FieldType: - if isinstance(schema, list): + if isinstance(schema, bool): return FieldType.COMBINED - elif isinstance(schema, dict): - schema_type = schema.get("type") - else: - # string schemas; this could be either a named schema or a primitive type - schema_type = schema - if schema.get("const") is not None or schema.get("enum") is not None: # type: ignore[union-attr] + schema_type = schema.get("type") + if schema.get("const") is not None or schema.get("enum") is not None: return FieldType.ENUM if schema_type == "object": - props = schema.get("properties") # type: ignore[union-attr] + props = schema.get("properties") if not props: return FieldType.MAP return FieldType.RECORD @@ -209,7 +233,7 @@ def get_type(schema: JsonSchema) -> FieldType: if schema_type == "null": return FieldType.NULL - props = schema.get("properties") # type: ignore[union-attr] + props = schema.get("properties") if props is not None: return FieldType.RECORD @@ -223,8 +247,8 @@ def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: return True -def get_inline_tags(schema: JsonSchema) -> Set[str]: - tags = schema.get("confluent:tags") # type: ignore[union-attr] +def get_inline_tags(schema: dict) -> Set[str]: + tags = schema.get("confluent:tags") if tags is None: return set() else: diff --git a/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py b/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py index adf260250..c9cdaf92f 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py +++ b/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py @@ -81,7 +81,9 @@ def execute_rule(self, ctx: RuleContext, expr: str, args: Any) -> Any: schema = ctx.target if schema is None: raise ValueError("Target schema is None") # TODO: check whether we should raise or return fallback - script_type = ctx.target.schema_type + script_type = schema.schema_type + if script_type is None: + raise ValueError("Target schema type is None") # TODO: check whether we should raise or return fallback prog = self._cache.get_program(expr, script_type, schema) if prog is None: ast = self._env.compile(expr) diff --git a/src/confluent_kafka/schema_registry/rules/cel/string_format.py b/src/confluent_kafka/schema_registry/rules/cel/string_format.py index 9aaaf5eaf..a596c623e 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/string_format.py +++ b/src/confluent_kafka/schema_registry/rules/cel/string_format.py @@ -49,7 +49,7 @@ def format(self, fmt: celtypes.Value, args: celtypes.Value) -> celpy.Result: # printf style formatting i = 0 j = 0 - result = "" + result: str = "" while i < len(fmt): if fmt[i] != "%": result += fmt[i] @@ -76,25 +76,41 @@ def format(self, fmt: celtypes.Value, args: celtypes.Value) -> celpy.Result: i += 1 if i >= len(fmt): return celpy.CELEvalError("format() incomplete format specifier") + + # Format the argument and handle errors + formatted: celpy.Result if fmt[i] == "f": - result += self.format_float(arg, precision) # type: ignore[operator,assignment] - if fmt[i] == "e": - result += self.format_exponential(arg, precision) # type: ignore[operator,assignment] + formatted = self.format_float(arg, precision) + elif fmt[i] == "e": + formatted = self.format_exponential(arg, precision) elif fmt[i] == "d": - result += self.format_int(arg) # type: ignore[operator,assignment] + formatted = self.format_int(arg) elif fmt[i] == "s": - result += self.format_string(arg) # type: ignore[operator,assignment] + formatted = self.format_string(arg) elif fmt[i] == "x": - result += self.format_hex(arg) # type: ignore[operator,assignment] + formatted = self.format_hex(arg) elif fmt[i] == "X": - result += self.format_hex(arg).upper() # type: ignore[operator,assignment,union-attr,call-arg] + formatted = self.format_hex(arg) + if isinstance(formatted, celpy.CELEvalError): + return formatted + result += str(formatted).upper() + i += 1 + continue elif fmt[i] == "o": - result += self.format_oct(arg) # type: ignore[operator,assignment] + formatted = self.format_oct(arg) elif fmt[i] == "b": - result += self.format_bin(arg) # type: ignore[operator,assignment] + formatted = self.format_bin(arg) else: return celpy.CELEvalError("format() unknown format specifier: " + fmt[i]) + + # Check if formatting returned an error + if isinstance(formatted, celpy.CELEvalError): + return formatted + + # Append the formatted string + result += str(formatted) i += 1 + if j < len(args): return celpy.CELEvalError("format() too many arguments for format string") return celtypes.StringType(result) @@ -109,6 +125,7 @@ def format_exponential(self, arg: celtypes.Value, precision: int) -> celpy.Resul return celtypes.StringType(f"{arg:.{precision}e}") return self.format_int(arg) + # TODO: check if celtypes.StringType() supports int conversion def format_int(self, arg: celtypes.Value) -> celpy.Result: if isinstance(arg, celtypes.IntType): return celtypes.StringType(arg) # type: ignore[arg-type] @@ -160,11 +177,14 @@ def format_value(self, arg: celtypes.Value) -> celpy.Result: return self.format_string(arg) def format_list(self, arg: celtypes.ListType) -> celpy.Result: - result = "[" + result: str = "[" for i in range(len(arg)): if i > 0: result += ", " - result += self.format_value(arg[i]) # type: ignore[operator,assignment] + formatted = self.format_value(arg[i]) + if isinstance(formatted, celpy.CELEvalError): + return formatted + result += str(formatted) result += "]" return celtypes.StringType(result) From 5e718d62bd78b923913de3eaaefcd34f0c4a931f Mon Sep 17 00:00:00 2001 From: Naxin Date: Fri, 24 Oct 2025 15:28:54 -0400 Subject: [PATCH 28/31] a bit more --- src/confluent_kafka/schema_registry/common/avro.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/confluent_kafka/schema_registry/common/avro.py b/src/confluent_kafka/schema_registry/common/avro.py index 97cdcc518..ab5982f4e 100644 --- a/src/confluent_kafka/schema_registry/common/avro.py +++ b/src/confluent_kafka/schema_registry/common/avro.py @@ -258,17 +258,19 @@ def _get_inline_tags_recursively( record_ns = _implied_namespace(name) if record_ns is None: record_ns = ns - if record_ns != '' and not record_name.startswith(record_ns): # type: ignore[union-attr] + # Ensure record_name is not None and doesn't already have namespace prefix + if record_name is not None and record_ns != '' and not record_name.startswith(record_ns): record_name = f"{record_ns}.{record_name}" fields = schema["fields"] for field in fields: field_tags = field.get("confluent:tags") field_name = field.get("name") field_type = field.get("type") - if field_tags is not None and field_name is not None: - tags[record_name + '.' + field_name].update(field_tags) # type: ignore[operator] - if field_type is not None: - _get_inline_tags_recursively(record_ns, record_name, field_type, tags) # type: ignore[arg-type] + # Ensure all required fields are present before building tag key + if field_tags is not None and field_name is not None and record_name is not None: + tags[record_name + '.' + field_name].update(field_tags) + if field_type is not None and record_name is not None: + _get_inline_tags_recursively(record_ns, record_name, field_type, tags) def _implied_namespace(name: str) -> Optional[str]: From 0ee5103bfe90558d32cb56dd8f68ef03a1c153e8 Mon Sep 17 00:00:00 2001 From: Naxin Date: Fri, 24 Oct 2025 15:53:18 -0400 Subject: [PATCH 29/31] revert some bad changes during merge, address copilot comments --- src/confluent_kafka/_model/__init__.py | 2 +- src/confluent_kafka/admin/__init__.py | 1 + src/confluent_kafka/admin/_group.py | 2 +- src/confluent_kafka/admin/_listoffsets.py | 2 +- src/confluent_kafka/admin/_metadata.py | 2 +- src/confluent_kafka/admin/_scram.py | 2 +- src/confluent_kafka/schema_registry/_async/avro.py | 2 +- .../schema_registry/_async/schema_registry_client.py | 5 +++-- .../schema_registry/_sync/schema_registry_client.py | 3 ++- 9 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/confluent_kafka/_model/__init__.py b/src/confluent_kafka/_model/__init__.py index 0f072b5a4..f3bce031f 100644 --- a/src/confluent_kafka/_model/__init__.py +++ b/src/confluent_kafka/_model/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Any +from typing import List, Optional from enum import Enum from .. import cimpl from ..cimpl import TopicPartition diff --git a/src/confluent_kafka/admin/__init__.py b/src/confluent_kafka/admin/__init__.py index 552c9a417..2e3c80184 100644 --- a/src/confluent_kafka/admin/__init__.py +++ b/src/confluent_kafka/admin/__init__.py @@ -653,6 +653,7 @@ def list_topics(self, *args: Any, **kwargs: Any) -> ClusterMetadata: return super(AdminClient, self).list_topics(*args, **kwargs) def list_groups(self, *args: Any, **kwargs: Any) -> List[GroupMetadata]: + return super(AdminClient, self).list_groups(*args, **kwargs) def create_partitions( # type: ignore[override] diff --git a/src/confluent_kafka/admin/_group.py b/src/confluent_kafka/admin/_group.py index 3ff3fa8be..af4db4ab0 100644 --- a/src/confluent_kafka/admin/_group.py +++ b/src/confluent_kafka/admin/_group.py @@ -80,7 +80,7 @@ class MemberAssignment: The topic partitions assigned to a group member. """ - def __init__(self, topic_partitions: List[TopicPartition] = []) -> None: + def __init__(self, topic_partitions: Optional[List[TopicPartition]]) -> None: self.topic_partitions = topic_partitions or [] diff --git a/src/confluent_kafka/admin/_listoffsets.py b/src/confluent_kafka/admin/_listoffsets.py index e23d75257..6b567088e 100644 --- a/src/confluent_kafka/admin/_listoffsets.py +++ b/src/confluent_kafka/admin/_listoffsets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any, Optional +from typing import Dict, Optional from abc import ABC, abstractmethod from .. import cimpl diff --git a/src/confluent_kafka/admin/_metadata.py b/src/confluent_kafka/admin/_metadata.py index f5d58b01b..90dc061c0 100644 --- a/src/confluent_kafka/admin/_metadata.py +++ b/src/confluent_kafka/admin/_metadata.py @@ -79,7 +79,7 @@ class TopicMetadata(object): # on other classes which raises a warning/error. def __init__(self) -> None: - self.topic = None + self.topic: Optional[str] = None """Topic name""" self.partitions: Dict[int, 'PartitionMetadata'] = {} """Map of partitions indexed by partition id. Value is a PartitionMetadata object.""" diff --git a/src/confluent_kafka/admin/_scram.py b/src/confluent_kafka/admin/_scram.py index e0ba07249..2bb19a414 100644 --- a/src/confluent_kafka/admin/_scram.py +++ b/src/confluent_kafka/admin/_scram.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Any +from typing import List, Optional from enum import Enum from .. import cimpl diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index 1f82d3721..2126bbc8a 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -637,7 +637,7 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 inline_tags = get_inline_tags(reader_schema) if reader_schema is not None else None obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, obj_dict, - inline_tags,field_transformer) + inline_tags, field_transformer) if self._from_dict is not None: if ctx is None: diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index d67501a06..f378d7bdb 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -119,7 +119,7 @@ async def get_access_token(self) -> str: if not self.token or self.token_expired(): await self.generate_access_token() if self.token is None: - raise ValueError("Token is not set after the at") + raise ValueError("Token is not set after the attempt to generate it") return self.token['access_token'] async def generate_access_token(self) -> None: @@ -865,7 +865,8 @@ async def get_schema_versions( """ # noqa: E501 query: dict[str, Any] = {'offset': offset, 'limit': limit} - if subject_name is not None: query['subject'] = subject_name + if subject_name is not None: + query['subject'] = subject_name if deleted: query['deleted'] = deleted response = await self._rest_client.get('schemas/ids/{}/versions'.format(schema_id), query) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index e21d79340..d156b78a5 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -864,7 +864,8 @@ def get_schema_versions( """ # noqa: E501 query: dict[str, Any] = {'offset': offset, 'limit': limit} - if subject_name is not None: query['subject'] = subject_name + if subject_name is not None: + query['subject'] = subject_name if deleted: query['deleted'] = deleted response = self._rest_client.get('schemas/ids/{}/versions'.format(schema_id), query) From 402688929d1d3260861088c52896102e96a36fed Mon Sep 17 00:00:00 2001 From: Naxin Date: Thu, 30 Oct 2025 12:15:50 -0500 Subject: [PATCH 30/31] minor --- .../schema_registry/_async/schema_registry_client.py | 2 +- .../schema_registry/_sync/schema_registry_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index f378d7bdb..1d8fb4b87 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -1530,6 +1530,6 @@ def clear_caches(self): def new_client(conf: dict) -> 'AsyncSchemaRegistryClient': from .mock_schema_registry_client import AsyncMockSchemaRegistryClient url = conf.get("url") - if url.startswith("mock://"): # type: ignore[union-attr] + if url and isinstance(url, str) and url.startswith("mock://"): return AsyncMockSchemaRegistryClient(conf) return AsyncSchemaRegistryClient(conf) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index d156b78a5..8725c3d00 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -1529,6 +1529,6 @@ def clear_caches(self): def new_client(conf: dict) -> 'SchemaRegistryClient': from .mock_schema_registry_client import MockSchemaRegistryClient url = conf.get("url") - if url.startswith("mock://"): # type: ignore[union-attr] + if url and isinstance(url, str) and url.startswith("mock://"): return MockSchemaRegistryClient(conf) return SchemaRegistryClient(conf) From 4f0a609298e4f412aec0d542ef7c4d5ed793100c Mon Sep 17 00:00:00 2001 From: Naxin Date: Fri, 31 Oct 2025 16:59:14 -0400 Subject: [PATCH 31/31] support type hint substitution for unasync --- .../_async/schema_registry_client.py | 13 +++++----- .../schema_registry/_sync/avro.py | 24 +++++++++---------- .../schema_registry/_sync/json_schema.py | 3 +-- .../schema_registry/_sync/protobuf.py | 4 ++-- tools/unasync.py | 24 +++++++++++++++++++ 5 files changed, 45 insertions(+), 23 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index 1d8fb4b87..7e495ad11 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -84,7 +84,7 @@ def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict) self.custom_config = custom_config async def get_bearer_fields(self) -> dict: - return await self.custom_function(self.custom_config) # type: ignore[misc] + return await self.custom_function(self.custom_config) class _AsyncOAuthClient(_AsyncBearerFieldProvider): @@ -471,14 +471,13 @@ async def send_request( if isinstance(response, Response): try: raise SchemaRegistryError(response.status_code, - response.json().get('error_code'), - response.json().get('message')) - # Schema Registry may return malformed output when it hits unexpected errors + response.json().get('error_code'), + response.json().get('message')) except (ValueError, KeyError, AttributeError): raise SchemaRegistryError(response.status_code, - -1, - "Unknown Schema Registry Error: " - + str(response.content)) + -1, + "Unknown Schema Registry Error: " + + str(response.content)) else: raise TypeError("Unexpected response of unsupported type: " + str(type(response))) diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 4e7928573..dbb7d43b2 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -58,17 +58,17 @@ def _resolve_named_schema( if schema.references is not None: for ref in schema.references: if ref.subject is None or ref.version is None: - raise ValueError("Subject or version cannot be None") + raise TypeError("Subject or version cannot be None") referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client) if referenced_schema.schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") parsed_schema = parse_schema_with_repo( referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) named_schemas.update(ref_named_schemas) if ref.name is None: - raise ValueError("Name cannot be None") + raise TypeError("Name cannot be None") named_schemas[ref.name] = parsed_schema return named_schemas @@ -304,7 +304,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] return self.__serialize(obj, ctx) def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -354,10 +354,10 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - parsed_schema: Any if self._to_dict is not None: if ctx is None: - raise ValueError("SerializationContext cannot be None") + raise TypeError("SerializationContext cannot be None") value = self._to_dict(obj, ctx) else: - value = obj # type: ignore[assignment] + value = obj if latest_schema is not None and ctx is not None and subject is not None: parsed_schema = self._get_parsed_schema(latest_schema.schema) @@ -367,7 +367,7 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 latest_schema.schema, value, get_inline_tags(parsed_schema), field_transformer) else: - parsed_schema = self._parsed_schema # type: ignore[assignment] + parsed_schema = self._parsed_schema with _ContextStringIO() as fo: # write the record to the rest of the buffer @@ -388,10 +388,10 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: named_schemas = _resolve_named_schema(schema, self._registry) if schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) if prepared_schema.schema_str is None: - raise ValueError("Prepared schema string cannot be None") + raise TypeError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) @@ -641,7 +641,7 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 if self._from_dict is not None: if ctx is None: - raise ValueError("SerializationContext cannot be None") + raise TypeError("SerializationContext cannot be None") return self._from_dict(obj_dict, ctx) return obj_dict @@ -653,10 +653,10 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: named_schemas = _resolve_named_schema(schema, self._registry) if schema.schema_str is None: - raise ValueError("Schema string cannot be None") + raise TypeError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) if prepared_schema.schema_str is None: - raise ValueError("Prepared schema string cannot be None") + raise TypeError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index f8137887b..1b8ad67d1 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -293,7 +293,6 @@ def __init_impl( if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - if self._schema: schema_dict, ref_registry = self._get_parsed_schema(self._schema) if schema_dict and isinstance(schema_dict, dict): @@ -314,7 +313,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] return self.__serialize(obj, ctx) def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index f48068c70..72c936b97 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -376,7 +376,7 @@ def _resolve_dependencies( reference.version)) return schema_refs - def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] return self.__serialize(message, ctx) def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -584,7 +584,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[object, None]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__deserialize(data, ctx) def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: diff --git a/tools/unasync.py b/tools/unasync.py index 162591551..68691d889 100644 --- a/tools/unasync.py +++ b/tools/unasync.py @@ -16,6 +16,18 @@ ("tests/schema_registry/_async", "tests/schema_registry/_sync"), ] +# Type hint patterns that should NOT have word boundaries (they contain brackets) +TYPE_HINT_SUBS = [ + (r'Coroutine\[Any, Any, ([^\]]+)\]', r'\1'), + (r'Coroutine\[None, None, ([^\]]+)\]', r'\1'), + (r'Awaitable\[([^\]]+)\]', r'\1'), + (r'AsyncIterator\[([^\]]+)\]', r'Iterator[\1]'), + (r'AsyncIterable\[([^\]]+)\]', r'Iterable[\1]'), + (r'AsyncGenerator\[([^,]+), ([^\]]+)\]', r'Generator[\1, \2, None]'), + (r'AsyncContextManager\[([^\]]+)\]', r'ContextManager[\1]'), +] + +# Regular substitutions that need word boundaries SUBS = [ ('from confluent_kafka.schema_registry.common import asyncinit', ''), ('@asyncinit', ''), @@ -36,6 +48,13 @@ (r'asyncio.run\((.*)\)', r'\2'), ] +# Compile type hint patterns without word boundaries +COMPILED_TYPE_HINT_SUBS = [ + (re.compile(regex), repl) + for regex, repl in TYPE_HINT_SUBS +] + +# Compile regular patterns with word boundaries COMPILED_SUBS = [ (re.compile(r'(^|\b)' + regex + r'($|\b)'), repl) for regex, repl in SUBS @@ -45,6 +64,11 @@ def unasync_line(line): + # First apply type hint transformations (without word boundaries) + for regex, repl in COMPILED_TYPE_HINT_SUBS: + line = re.sub(regex, repl, line) + + # Then apply regular transformations (with word boundaries) for index, (regex, repl) in enumerate(COMPILED_SUBS): old_line = line line = re.sub(regex, repl, line)