diff --git a/pyproject.toml b/pyproject.toml index bf382ef2d..e33ffbf1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,21 @@ Homepage = "https://github.com/confluentinc/confluent-kafka-python" [tool.mypy] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = [ + "confluent_kafka.schema_registry.avro", + "confluent_kafka.schema_registry.json_schema", + "confluent_kafka.schema_registry.protobuf", +] +disable_error_code = ["assignment", "no-redef"] + +[[tool.mypy.overrides]] +module = [ + "confluent_kafka.schema_registry.confluent.meta_pb2", + "confluent_kafka.schema_registry.confluent.types.decimal_pb2", +] +ignore_errors = true + [tool.setuptools] include-package-data = false diff --git a/requirements/requirements-tests.txt b/requirements/requirements-tests.txt index e9cc5ca35..730bd2be5 100644 --- a/requirements/requirements-tests.txt +++ b/requirements/requirements-tests.txt @@ -1,6 +1,8 @@ # core test requirements urllib3<3 flake8 +mypy +types-cachetools orjson pytest pytest-timeout diff --git a/src/confluent_kafka/admin/__init__.py b/src/confluent_kafka/admin/__init__.py index 4a8394ce0..2e3c80184 100644 --- a/src/confluent_kafka/admin/__init__.py +++ b/src/confluent_kafka/admin/__init__.py @@ -653,7 +653,7 @@ def list_topics(self, *args: Any, **kwargs: Any) -> ClusterMetadata: return super(AdminClient, self).list_topics(*args, **kwargs) def list_groups(self, *args: Any, **kwargs: Any) -> List[GroupMetadata]: - + return super(AdminClient, self).list_groups(*args, **kwargs) def create_partitions( # type: ignore[override] diff --git a/src/confluent_kafka/admin/_group.py b/src/confluent_kafka/admin/_group.py index 1db7923bb..af4db4ab0 100644 --- a/src/confluent_kafka/admin/_group.py +++ b/src/confluent_kafka/admin/_group.py @@ -80,7 +80,7 @@ class MemberAssignment: The topic partitions assigned to a group member. """ - def __init__(self, topic_partitions: Optional[List[TopicPartition]] = None) -> None: + def __init__(self, topic_partitions: Optional[List[TopicPartition]]) -> None: self.topic_partitions = topic_partitions or [] diff --git a/src/confluent_kafka/admin/_resource.py b/src/confluent_kafka/admin/_resource.py index 55c6d783d..131c56407 100644 --- a/src/confluent_kafka/admin/_resource.py +++ b/src/confluent_kafka/admin/_resource.py @@ -28,7 +28,7 @@ class ResourceType(Enum): BROKER = _cimpl.RESOURCE_BROKER #: Broker resource. Resource name is broker id. TRANSACTIONAL_ID = _cimpl.RESOURCE_TRANSACTIONAL_ID #: Transactional ID resource. - def __lt__(self, other: object) -> Any: + def __lt__(self, other: object) -> bool: if not isinstance(other, ResourceType): return NotImplemented return self.value < other.value @@ -44,7 +44,7 @@ class ResourcePatternType(Enum): LITERAL = _cimpl.RESOURCE_PATTERN_LITERAL #: Literal: A literal resource name PREFIXED = _cimpl.RESOURCE_PATTERN_PREFIXED #: Prefixed: A prefixed resource name - def __lt__(self, other: object) -> Any: + def __lt__(self, other: object) -> bool: if not isinstance(other, ResourcePatternType): return NotImplemented return self.value < other.value diff --git a/src/confluent_kafka/avro/cached_schema_registry_client.py b/src/confluent_kafka/avro/cached_schema_registry_client.py index b0ea6c388..3478ba9fb 100644 --- a/src/confluent_kafka/avro/cached_schema_registry_client.py +++ b/src/confluent_kafka/avro/cached_schema_registry_client.py @@ -32,7 +32,7 @@ # Python 2 considers int an instance of str try: - string_type = basestring # noqa + string_type = basestring # type: ignore[name-defined] # noqa except NameError: string_type = str diff --git a/src/confluent_kafka/avro/load.py b/src/confluent_kafka/avro/load.py index 9db8660e1..d774a0513 100644 --- a/src/confluent_kafka/avro/load.py +++ b/src/confluent_kafka/avro/load.py @@ -47,11 +47,11 @@ def _hash_func(self): from avro.errors import SchemaParseException except ImportError: # avro < 1.11.0 - from avro.schema import SchemaParseException + from avro.schema import SchemaParseException # type: ignore[attr-defined,no-redef] - schema.RecordSchema.__hash__ = _hash_func - schema.PrimitiveSchema.__hash__ = _hash_func - schema.UnionSchema.__hash__ = _hash_func + schema.RecordSchema.__hash__ = _hash_func # type: ignore[method-assign] + schema.PrimitiveSchema.__hash__ = _hash_func # type: ignore[method-assign] + schema.UnionSchema.__hash__ = _hash_func # type: ignore[method-assign] except ImportError: - schema = None + schema = None # type: ignore[assignment] diff --git a/src/confluent_kafka/cimpl.pyi b/src/confluent_kafka/cimpl.pyi index 99a888acc..68d04b936 100644 --- a/src/confluent_kafka/cimpl.pyi +++ b/src/confluent_kafka/cimpl.pyi @@ -34,14 +34,13 @@ TODO: Consider migrating to Cython in the future to eliminate this dual maintenance burden and get type hints directly from the implementation. """ -from typing import Any, Optional, Callable, List, Tuple, Dict, Union, overload, TYPE_CHECKING +from typing import Any, Optional, Callable, List, Tuple, Dict, Union, overload from typing_extensions import Self, Literal import builtins -from ._types import HeadersType +from confluent_kafka.admin._metadata import ClusterMetadata, GroupMetadata -if TYPE_CHECKING: - from confluent_kafka.admin._metadata import ClusterMetadata, GroupMetadata +from ._types import HeadersType # Callback types with proper class references (defined locally to avoid circular imports) DeliveryCallback = Callable[[Optional['KafkaError'], 'Message'], None] diff --git a/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py b/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py index 925849c58..8af98f3c3 100644 --- a/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py +++ b/src/confluent_kafka/experimental/aio/producer/_kafka_batch_executor.py @@ -110,22 +110,22 @@ def _produce_batch_and_poll() -> int: async def flush_librdkafka_queue(self, timeout=-1): """Flush the librdkafka queue and wait for all messages to be delivered - This method awaits until all outstanding produce requests are completed or the timeout is reached, unless the timeout is set to 0 (non-blocking). - Args: timeout: Maximum time to wait in seconds: - -1 = wait indefinitely (default) - 0 = non-blocking, return immediately - >0 = wait up to timeout seconds - Returns: Number of messages still in queue after flush attempt """ return await _common.async_call(self._executor, self._producer.flush, timeout) - def _handle_partial_failures(self, batch_messages: List[Dict[str, Any]]) -> None: + def _handle_partial_failures( + self, + batch_messages: List[Dict[str, Any]] + ) -> None: """Handle messages that failed during produce_batch When produce_batch encounters messages that fail immediately (e.g., diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index 0cf16eef3..2055d2c05 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -72,58 +72,82 @@ ] -def topic_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: +def topic_subject_name_strategy(ctx: Optional[SerializationContext], record_name: Optional[str]) -> Optional[str]: """ Constructs a subject name in the form of {topic}-key|value. Args: ctx (SerializationContext): Metadata pertaining to the serialization - operation. + operation. **Required** - will raise ValueError if None. record_name (Optional[str]): Record name. + Raises: + ValueError: If ctx is None. + """ + if ctx is None: + raise ValueError( + "SerializationContext is required for topic_subject_name_strategy. " + "Either provide a SerializationContext or use record_subject_name_strategy." + ) return ctx.topic + "-" + ctx.field -def topic_record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: +def topic_record_subject_name_strategy(ctx: Optional[SerializationContext], record_name: Optional[str]) -> Optional[str]: """ Constructs a subject name in the form of {topic}-{record_name}. Args: ctx (SerializationContext): Metadata pertaining to the serialization - operation. + operation. **Required** - will raise ValueError if None. record_name (Optional[str]): Record name. + Raises: + ValueError: If ctx is None. + """ + if ctx is None: + raise ValueError( + "SerializationContext is required for topic_record_subject_name_strategy. " + "Either provide a SerializationContext or use record_subject_name_strategy." + ) return ctx.topic + "-" + record_name if record_name is not None else None -def record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: +def record_subject_name_strategy(ctx: Optional[SerializationContext], record_name: Optional[str]) -> Optional[str]: """ Constructs a subject name in the form of {record_name}. Args: ctx (SerializationContext): Metadata pertaining to the serialization - operation. + operation. **Not used** by this strategy. record_name (Optional[str]): Record name. + Note: + This strategy does not require SerializationContext and can be used + when ctx is None. + """ return record_name if record_name is not None else None -def reference_subject_name_strategy(ctx, schema_ref: SchemaReference) -> Optional[str]: +def reference_subject_name_strategy(ctx: Optional[SerializationContext], schema_ref: SchemaReference) -> Optional[str]: """ Constructs a subject reference name in the form of {reference name}. Args: ctx (SerializationContext): Metadata pertaining to the serialization - operation. + operation. **Not used** by this strategy. schema_ref (SchemaReference): SchemaReference instance. + Note: + This strategy does not require SerializationContext and can be used + when ctx is None. + """ return schema_ref.name if schema_ref is not None else None @@ -205,7 +229,7 @@ def dual_schema_id_deserializer(payload: bytes, ctx: Optional[SerializationConte # Parse schema ID from determined source and return appropriate payload if header_value is not None: - schema_id.from_bytes(io.BytesIO(header_value)) + schema_id.from_bytes(io.BytesIO(header_value)) # type: ignore[arg-type] return io.BytesIO(payload) # Return full payload when schema ID is in header else: return schema_id.from_bytes(io.BytesIO(payload)) # Parse from payload, return remainder diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index d29611853..2126bbc8a 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -16,7 +16,7 @@ # limitations under the License. import io import json -from typing import Dict, Union, Optional, Callable +from typing import Any, Coroutine, Dict, Union, Optional, Callable, cast from fastavro import schemaless_reader, schemaless_writer from confluent_kafka.schema_registry.common import asyncinit @@ -57,11 +57,18 @@ async def _resolve_named_schema( named_schemas = {} if schema.references is not None: for ref in schema.references: + if ref.subject is None or ref.version is None: + raise TypeError("Subject or version cannot be None") referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) ref_named_schemas = await _resolve_named_schema(referenced_schema.schema, schema_registry_client) + if referenced_schema.schema.schema_str is None: + raise TypeError("Schema string cannot be None") + parsed_schema = parse_schema_with_repo( referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) named_schemas.update(ref_named_schemas) + if ref.name is None: + raise TypeError("Name cannot be None") named_schemas[ref.name] = parsed_schema return named_schemas @@ -204,9 +211,9 @@ async def __init_impl( schema = None self._registry = schema_registry_client - self._schema_id = None + self._schema_id: Optional[SchemaId] = None self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._known_subjects = set() + self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() if to_dict is not None and not callable(to_dict): @@ -219,35 +226,41 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") @@ -265,13 +278,18 @@ async def __init_impl( # and union types should use topic_subject_name_strategy, which # just discards the schema name anyway schema_name = None - else: + elif isinstance(parsed_schema, dict): # The Avro spec states primitives have a name equal to their type # i.e. {"type": "string"} has a name of string. # This function does not comply. # https://github.com/fastavro/fastavro/issues/415 - schema_dict = json.loads(schema.schema_str) - schema_name = parsed_schema.get("name", schema_dict.get("type")) + if schema.schema_str is not None: + schema_dict = json.loads(schema.schema_str) + schema_name = parsed_schema.get("name", schema_dict.get("type")) + else: + schema_name = None + else: + schema_name = None else: schema_name = None parsed_schema = None @@ -286,7 +304,7 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: # type: ignore[override] return self.__serialize(obj, ctx) async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -313,10 +331,10 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = await self._get_reader_schema(subject) + latest_schema = await self._get_reader_schema(subject) if subject else None if latest_schema is not None: self._schema_id = SchemaId(AVRO_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects: + elif subject is not None and subject not in self._known_subjects: # Check to ensure this schema has been registered under subject_name. if self._auto_register: # The schema name will always be the same. We can't however register @@ -332,12 +350,16 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N self._known_subjects.add(subject) + value: Any + parsed_schema: Any if self._to_dict is not None: + if ctx is None: + raise TypeError("SerializationContext cannot be None") value = self._to_dict(obj, ctx) else: value = obj - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: parsed_schema = await self._get_parsed_schema(latest_schema.schema) def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, msg, field_transform)) @@ -352,7 +374,7 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 schemaless_writer(fo, parsed_schema, value) buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -365,7 +387,11 @@ async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema named_schemas = await _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise TypeError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) + if prepared_schema.schema_str is None: + raise TypeError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) @@ -477,20 +503,26 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") @@ -498,7 +530,8 @@ async def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - if schema: + self._reader_schema: Optional[AvroSchema] + if schema and self._schema is not None: self._reader_schema = await self._get_parsed_schema(self._schema) else: self._reader_schema = None @@ -518,11 +551,11 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Union[dict, object, None]]: return self.__deserialize(data, ctx) async def __deserialize( - self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: """ Deserialize Avro binary encoded data with Confluent Schema Registry framing to a dict, or object instance according to from_dict, if specified. @@ -560,19 +593,20 @@ async def __deserialize( writer_schema_raw = await self._get_writer_schema(schema_id, subject) writer_schema = await self._get_parsed_schema(writer_schema_raw) - if subject is None: - subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None + subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None # type: ignore[union-attr] if subject is not None: latest_schema = await self._get_reader_schema(subject) - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) - if latest_schema is not None: + reader_schema: Optional[AvroSchema] + if latest_schema is not None and subject is not None: migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema reader_schema = await self._get_parsed_schema(latest_schema.schema) @@ -585,7 +619,7 @@ async def __deserialize( reader_schema_raw = writer_schema_raw reader_schema = writer_schema - if migrations: + if migrations and ctx is not None and subject is not None: obj_dict = schemaless_reader(payload, writer_schema, None, @@ -599,11 +633,15 @@ async def __deserialize( def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, get_inline_tags(reader_schema), - field_transformer) + if ctx is not None and subject is not None: + inline_tags = get_inline_tags(reader_schema) if reader_schema is not None else None + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, + inline_tags, field_transformer) if self._from_dict is not None: + if ctx is None: + raise TypeError("SerializationContext cannot be None") return self._from_dict(obj_dict, ctx) return obj_dict @@ -614,7 +652,11 @@ async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema named_schemas = await _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise TypeError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) + if prepared_schema.schema_str is None: + raise TypeError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py index d522838b0..c87c35e5d 100644 --- a/src/confluent_kafka/schema_registry/_async/json_schema.py +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -16,7 +16,7 @@ # limitations under the License. import io import orjson -from typing import Union, Optional, Tuple, Callable +from typing import Any, Coroutine, Union, Optional, Tuple, Callable, cast from cachetools import LRUCache from jsonschema import ValidationError @@ -61,14 +61,21 @@ async def _resolve_named_schema( """ if ref_registry is None: # Retrieve external schemas for backward compatibility - ref_registry = Registry(retrieve=_retrieve_via_httpx) + ref_registry = Registry(retrieve=_retrieve_via_httpx) # type: ignore[call-arg] if schema.references is not None: for ref in schema.references: + if ref.subject is None or ref.version is None: + raise TypeError("Subject or version cannot be None") referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) ref_registry = await _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) + if referenced_schema.schema.schema_str is None: + raise TypeError("Schema string cannot be None") + referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) resource = Resource.from_contents( referenced_schema_dict, default_specification=DEFAULT_SPEC) + if ref.name is None: + raise TypeError("Name cannot be None") ref_registry = ref_registry.with_resource(ref.name, resource) return ref_registry @@ -213,6 +220,7 @@ async def __init_impl( json_encode: Optional[Callable] = None, ): super().__init__() + self._schema: Optional[Schema] if isinstance(schema_str, str): self._schema = Schema(schema_str, schema_type="JSON") elif isinstance(schema_str, Schema): @@ -225,10 +233,10 @@ async def __init_impl( self._rule_registry = ( rule_registry if rule_registry else RuleRegistry.get_global_instance() ) - self._schema_id = None - self._known_subjects = set() + self._schema_id: Optional[SchemaId] = None + self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) + self._validators: LRUCache[Schema, Validator] = LRUCache(1000) if to_dict is not None and not callable(to_dict): raise ValueError("to_dict must be callable with the signature " @@ -240,50 +248,59 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") - self._validate = conf_copy.pop('validate') + self._validate = cast(bool, conf_copy.pop('validate')) if not isinstance(self._validate, bool): raise ValueError("validate must be a boolean value") if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - - schema_dict, ref_registry = await self._get_parsed_schema(self._schema) - if schema_dict: - schema_name = schema_dict.get('title', None) + if self._schema: + schema_dict, ref_registry = await self._get_parsed_schema(self._schema) + if schema_dict and isinstance(schema_dict, dict): + schema_name = schema_dict.get('title', None) + else: + schema_name = None else: + schema_dict = None schema_name = None self._schema_name = schema_name @@ -296,7 +313,7 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: # type: ignore[override] return self.__serialize(obj, ctx) async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -322,10 +339,10 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = await self._get_reader_schema(subject) + latest_schema = await self._get_reader_schema(subject) if subject else None if latest_schema is not None: self._schema_id = SchemaId(JSON_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects: + elif subject is not None and subject not in self._known_subjects: # Check to ensure this schema has been registered under subject_name. if self._auto_register: # The schema name will always be the same. We can't however register @@ -341,27 +358,33 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N self._known_subjects.add(subject) + value: Any if self._to_dict is not None: + if ctx is None: + raise TypeError("SerializationContext cannot be None") value = self._to_dict(obj, ctx) else: value = obj + schema: Optional[Schema] = None if latest_schema is not None: schema = latest_schema.schema parsed_schema, ref_registry = await self._get_parsed_schema(latest_schema.schema) - root_resource = Resource.from_contents( - parsed_schema, default_specification=DEFAULT_SPEC) - ref_resolver = ref_registry.resolver_with_root(root_resource) - def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 - transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, value, None, - field_transformer) + if ref_registry is not None: + root_resource = Resource.from_contents( + parsed_schema, default_specification=DEFAULT_SPEC) + ref_resolver = ref_registry.resolver_with_root(root_resource) + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 + transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) + if ctx is not None and subject is not None: + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, None, + field_transformer) else: schema = self._schema parsed_schema, ref_registry = self._parsed_schema, self._ref_registry - if self._validate: + if self._validate and schema is not None and parsed_schema is not None and ref_registry is not None: try: validator = self._get_validator(schema, parsed_schema, ref_registry) validator.validate(value) @@ -377,7 +400,7 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 fo.write(encoded_value) buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -393,6 +416,8 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema] return result ref_registry = await _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise TypeError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) @@ -488,6 +513,7 @@ async def __init_impl( json_decode: Optional[Callable] = None, ): super().__init__() + schema: Optional[Schema] if isinstance(schema_str, str): schema = Schema(schema_str, schema_type="JSON") elif isinstance(schema_str, Schema): @@ -504,11 +530,11 @@ async def __init_impl( else: raise TypeError('You must pass either str or Schema') - self._schema = schema + self._schema: Optional[Schema] = schema self._registry = schema_registry_client self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) + self._validators: LRUCache[Schema, Validator] = LRUCache(1000) self._json_decode = json_decode or orjson.loads self._use_schema_id = None @@ -516,24 +542,30 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') - if not callable(self._subject_name_func): + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) + if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") - self._validate = conf_copy.pop('validate') + self._validate = cast(bool, conf_copy.pop('validate')) if not isinstance(self._validate, bool): raise ValueError("validate must be a boolean value") @@ -541,7 +573,7 @@ async def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - if schema: + if schema and self._schema is not None: self._reader_schema, self._ref_registry = await self._get_parsed_schema(self._schema) else: self._reader_schema, self._ref_registry = None, None @@ -558,10 +590,10 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__deserialize(data, ctx) - async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + async def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a JSON encoded record with Confluent Schema Registry framing to a dict, or object instance according to from_dict if from_dict is specified. @@ -593,7 +625,7 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = if self._registry is not None: writer_schema_raw = await self._get_writer_schema(schema_id, subject) writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw) - if subject is None: + if subject is None and isinstance(writer_schema, dict): subject = self._subject_name_func(ctx, writer_schema.get("title")) if subject is not None: latest_schema = await self._get_reader_schema(subject) @@ -601,16 +633,18 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = writer_schema_raw = None writer_schema, writer_ref_registry = None, None - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) # JSON documents are self-describing; no need to query schema obj_dict = self._json_decode(payload.read()) - if latest_schema is not None: + reader_schema_raw: Optional[Schema] = None + if latest_schema is not None and subject is not None and writer_schema_raw is not None: migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema) @@ -623,21 +657,23 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = reader_schema_raw = writer_schema_raw reader_schema, reader_ref_registry = writer_schema, writer_ref_registry - if migrations: + if migrations and ctx is not None and subject is not None: obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - reader_root_resource = Resource.from_contents( - reader_schema, default_specification=DEFAULT_SPEC) - reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) + if reader_ref_registry is not None: + reader_root_resource = Resource.from_contents( + reader_schema, default_specification=DEFAULT_SPEC) + reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) - def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 - transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, - "$", message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, None, - field_transformer) + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 + transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, + "$", message, field_transform)) + if ctx is not None and subject is not None: + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, None, + field_transformer) - if self._validate: + if self._validate and reader_schema_raw is not None and reader_schema is not None and reader_ref_registry is not None: try: validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) validator.validate(obj_dict) @@ -645,7 +681,9 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 raise SerializationError(ve.message) if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) + if ctx is None: + raise TypeError("SerializationContext cannot be None") + return self._from_dict(obj_dict, ctx) # type: ignore[return-value] return obj_dict @@ -658,6 +696,8 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema] return result ref_registry = await _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise TypeError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) diff --git a/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py index 1a395ade4..5c8bffa1f 100644 --- a/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/mock_schema_registry_client.py @@ -18,7 +18,7 @@ import uuid from collections import defaultdict from threading import Lock -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union, Literal from .schema_registry_client import AsyncSchemaRegistryClient from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig @@ -73,7 +73,7 @@ def get_registered_schema_by_schema( return rs return None - def get_version(self, subject_name: str, version: int) -> Optional[RegisteredSchema]: + def get_version(self, subject_name: str, version: Union[int, str]) -> Optional[RegisteredSchema]: with self.lock: if subject_name in self.subject_schemas: for rs in self.subject_schemas[subject_name]: @@ -157,7 +157,7 @@ async def register_schema( ) -> int: registered_schema = await self.register_schema_full_response( subject_name, schema, normalize_schemas=normalize_schemas) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] async def register_schema_full_response( self, subject_name: str, schema: 'Schema', @@ -168,7 +168,7 @@ async def register_schema_full_response( return registered_schema latest_schema = self._store.get_latest_version(subject_name) - latest_version = 1 if latest_schema is None else latest_schema.version + 1 + latest_version = 1 if latest_schema is None or latest_schema.version is None else latest_schema.version + 1 registered_schema = RegisteredSchema( schema_id=1, @@ -184,7 +184,7 @@ async def register_schema_full_response( async def get_schema( self, schema_id: int, subject_name: Optional[str] = None, - fmt: Optional[str] = None + fmt: Optional[str] = None, reference_format: Optional[str] = None ) -> 'Schema': schema = self._store.get_schema(schema_id) if schema is not None: @@ -212,7 +212,10 @@ async def lookup_schema( raise SchemaRegistryError(404, 40400, "Schema Not Found") - async def get_subjects(self) -> List[str]: + async def get_subjects( + self, subject_prefix: Optional[str] = None, deleted: bool = False, + deleted_only: bool = False, offset: int = 0, limit: int = -1 + ) -> List[str]: return self._store.get_subjects() async def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: @@ -236,30 +239,36 @@ async def get_latest_with_metadata( raise SchemaRegistryError(404, 40400, "Schema Not Found") async def get_version( - self, subject_name: str, version: int, + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", deleted: bool = False, fmt: Optional[str] = None ) -> 'RegisteredSchema': - registered_schema = self._store.get_version(subject_name, version) + if version == "latest": + registered_schema = self._store.get_latest_version(subject_name) + else: + registered_schema = self._store.get_version(subject_name, version) if registered_schema is not None: return registered_schema raise SchemaRegistryError(404, 40400, "Schema Not Found") - async def get_versions(self, subject_name: str) -> List[int]: + async def get_versions( + self, subject_name: str, deleted: bool = False, deleted_only: bool = False, + offset: int = 0, limit: int = -1 + ) -> List[int]: return self._store.get_versions(subject_name) async def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: registered_schema = self._store.get_version(subject_name, version) if registered_schema is not None: self._store.remove_by_schema(registered_schema) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] raise SchemaRegistryError(404, 40400, "Schema Not Found") async def set_config( - self, subject_name: Optional[str] = None, config: 'ServerConfig' = None # noqa F821 + self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None # noqa F821 ) -> 'ServerConfig': # noqa F821 - return None + return None # type: ignore[return-value] async def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': # noqa F821 - return None + return None # type: ignore[return-value] diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index 41b8df970..1d2a1a7bb 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -16,7 +16,7 @@ # limitations under the License. import io -from typing import Set, List, Union, Optional, Tuple +from typing import Any, Coroutine, Set, List, Union, Optional, Tuple, Callable, cast from google.protobuf import json_format, descriptor_pb2 from google.protobuf.descriptor_pool import DescriptorPool @@ -67,10 +67,18 @@ async def _resolve_named_schema( visited = set() if schema.references is not None: for ref in schema.references: + if ref.name is None: + raise ValueError("Name cannot be None") + if _is_builtin(ref.name) or ref.name in visited: continue visited.add(ref.name) + + if ref.subject is None or ref.version is None: + raise ValueError("Subject or version cannot be None") referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') + if referenced_schema.schema.schema_str is None: + raise ValueError("Schema string cannot be None") await _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) pool.Add(file_descriptor_proto) @@ -218,50 +226,58 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._skip_known_types = conf_copy.pop('skip.known.types') + self._skip_known_types = cast(bool, conf_copy.pop('skip.known.types')) if not isinstance(self._skip_known_types, bool): raise ValueError("skip.known.types must be a boolean value") - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + self._use_deprecated_format = cast(bool, conf_copy.pop('use.deprecated.format')) if not isinstance(self._use_deprecated_format, bool): raise ValueError("use.deprecated.format must be a boolean value") if self._use_deprecated_format: raise ValueError("use.deprecated.format is no longer supported") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._ref_reference_subject_func = conf_copy.pop( - 'reference.subject.name.strategy') + self._ref_reference_subject_func = cast( + Callable[[Optional[SerializationContext], Any], Optional[str]], + conf_copy.pop('reference.subject.name.strategy') + ) if not callable(self._ref_reference_subject_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") @@ -271,8 +287,8 @@ async def __init_impl( self._registry = schema_registry_client self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._schema_id = None - self._known_subjects = set() + self._schema_id: Optional[SchemaId] = None + self._known_subjects: set[str] = set() self._msg_class = msg_type self._parsed_schemas = ParsedSchemaCache() @@ -360,7 +376,7 @@ async def _resolve_dependencies( reference.version)) return schema_refs - def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: # type: ignore[override] return self.__serialize(message, ctx) async def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -397,7 +413,7 @@ async def __serialize(self, message: Message, ctx: Optional[SerializationContext if latest_schema is not None: self._schema_id = SchemaId(PROTOBUF_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects and ctx is not None: + elif subject is not None and subject not in self._known_subjects and ctx is not None: references = await self._resolve_dependencies(ctx, message.DESCRIPTOR.file) self._schema = Schema( self._schema.schema_str, @@ -422,16 +438,18 @@ async def __serialize(self, message: Message, ctx: Optional[SerializationContext desc = fd.message_types_by_name[message.DESCRIPTOR.name] def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, desc, msg, field_transform)) - message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, message, None, - field_transformer) + if ctx is not None and subject is not None: + message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, message, None, + field_transformer) with _ContextStringIO() as fo: fo.write(message.SerializeToString()) - self._schema_id.message_indexes = self._index_array + if self._schema_id is not None: + self._schema_id.message_indexes = self._index_array buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -446,6 +464,8 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileD pool = DescriptorPool() _init_pool(pool) await _resolve_named_schema(schema, self._registry, pool) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") fd_proto = _str_to_proto("default", schema.schema_str) pool.Add(fd_proto) self._parsed_schemas.set(schema, (fd_proto, pool)) @@ -526,24 +546,30 @@ async def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + self._use_deprecated_format = cast(bool, conf_copy.pop('use.deprecated.format')) if not isinstance(self._use_deprecated_format, bool): raise ValueError("use.deprecated.format must be a boolean value") if self._use_deprecated_format: @@ -558,10 +584,10 @@ async def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]: return self.__deserialize(data, ctx) - async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + async def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a serialized protobuf message with Confluent Schema Registry framing. @@ -597,22 +623,24 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = writer_schema_raw = await self._get_writer_schema(schema_id, subject, fmt='serialized') fd_proto, pool = await self._get_parsed_schema(writer_schema_raw) writer_schema = pool.FindFileByName(fd_proto.name) - writer_desc = self._get_message_desc(pool, writer_schema, msg_index) + writer_desc = self._get_message_desc(pool, writer_schema, msg_index if msg_index is not None else []) if subject is None: subject = self._subject_name_func(ctx, writer_desc.full_name) if subject is not None: - latest_schema = self._get_reader_schema(subject, fmt='serialized') + latest_schema = await self._get_reader_schema(subject, fmt='serialized') else: writer_schema_raw = None writer_schema = None - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) - if latest_schema is not None: + reader_schema_raw: Optional[Schema] = None + if latest_schema is not None and subject is not None and writer_schema_raw is not None: migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) @@ -628,7 +656,7 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = # Attempt to find a reader desc with the same name as the writer reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) - if migrations: + if migrations and ctx is not None and subject is not None: msg = GetMessageClass(writer_desc)() try: msg.ParseFromString(payload.read()) @@ -649,9 +677,10 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_desc, message, field_transform)) - msg = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, msg, None, - field_transformer) + if ctx is not None and subject is not None: + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, msg, None, + field_transformer) return msg async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: @@ -662,6 +691,8 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileD pool = DescriptorPool() _init_pool(pool) await _resolve_named_schema(schema, self._registry, pool) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") fd_proto = _str_to_proto("default", schema.schema_str) pool.Add(fd_proto) self._parsed_schemas.set(schema, (fd_proto, pool)) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index efc7d72ff..7e495ad11 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -26,9 +26,9 @@ from urllib.parse import unquote, urlparse import httpx -from typing import List, Dict, Optional, Union, Any, Callable +from typing import List, Dict, Optional, Union, Any, Callable, Literal -from cachetools import TTLCache, LRUCache +from cachetools import Cache, TTLCache, LRUCache from httpx import Response from authlib.integrations.httpx_client import AsyncOAuth2Client @@ -40,11 +40,11 @@ ServerConfig, is_success, is_retriable, - _BearerFieldProvider, + _AsyncBearerFieldProvider, full_jitter, _SchemaCache, Schema, - _StaticFieldProvider, + _AsyncStaticFieldProvider, ) __all__ = [ @@ -65,10 +65,10 @@ # six: https://pypi.org/project/six/ # compat file : https://github.com/psf/requests/blob/master/requests/compat.py try: - string_type = basestring # noqa + string_type = basestring # type: ignore[name-defined] # noqa def _urlencode(value: str) -> str: - return urllib.quote(value, safe='') + return urllib.quote(value, safe='') # type: ignore[attr-defined] except NameError: string_type = str @@ -78,7 +78,7 @@ def _urlencode(value: str) -> str: log = logging.getLogger(__name__) -class _AsyncCustomOAuthClient(_BearerFieldProvider): +class _AsyncCustomOAuthClient(_AsyncBearerFieldProvider): def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict): self.custom_function = custom_function self.custom_config = custom_config @@ -87,7 +87,7 @@ async def get_bearer_fields(self) -> dict: return await self.custom_function(self.custom_config) -class _AsyncOAuthClient(_BearerFieldProvider): +class _AsyncOAuthClient(_AsyncBearerFieldProvider): def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): self.token = None @@ -108,6 +108,9 @@ async def get_bearer_fields(self) -> dict: } def token_expired(self) -> bool: + if self.token is None: + raise ValueError("Token is not set") + expiry_window = self.token['expires_in'] * self.token_expiry_threshold return self.token['expires_at'] < time.time() + expiry_window @@ -115,7 +118,8 @@ def token_expired(self) -> bool: async def get_access_token(self) -> str: if not self.token or self.token_expired(): await self.generate_access_token() - + if self.token is None: + raise ValueError("Token is not set after the attempt to generate it") return self.token['access_token'] async def generate_access_token(self) -> None: @@ -227,7 +231,7 @@ def __init__(self, conf: dict): if cache_capacity is not None: if not isinstance(cache_capacity, (int, float)): raise TypeError("cache.capacity must be a number, not " + str(type(cache_capacity))) - self.cache_capacity = cache_capacity + self.cache_capacity = int(cache_capacity) self.cache_latest_ttl_sec = None cache_latest_ttl_sec = conf_copy.pop('cache.latest.ttl.sec', None) @@ -241,7 +245,7 @@ def __init__(self, conf: dict): if max_retries is not None: if not isinstance(max_retries, (int, float)): raise TypeError("max.retries must be a number, not " + str(type(max_retries))) - self.max_retries = max_retries + self.max_retries = int(max_retries) self.retries_wait_ms = 1000 retries_wait_ms = conf_copy.pop('retries.wait.ms', None) @@ -249,7 +253,7 @@ def __init__(self, conf: dict): if not isinstance(retries_wait_ms, (int, float)): raise TypeError("retries.wait.ms must be a number, not " + str(type(retries_wait_ms))) - self.retries_wait_ms = retries_wait_ms + self.retries_wait_ms = int(retries_wait_ms) self.retries_max_wait_ms = 20000 retries_max_wait_ms = conf_copy.pop('retries.max.wait.ms', None) @@ -257,11 +261,11 @@ def __init__(self, conf: dict): if not isinstance(retries_max_wait_ms, (int, float)): raise TypeError("retries.max.wait.ms must be a number, not " + str(type(retries_max_wait_ms))) - self.retries_max_wait_ms = retries_max_wait_ms + self.retries_max_wait_ms = int(retries_max_wait_ms) - self.bearer_field_provider = None logical_cluster = None identity_pool = None + self.bearer_field_provider: Optional[_AsyncBearerFieldProvider] = None self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) if self.bearer_auth_credentials_source is not None: self.auth = None @@ -281,43 +285,43 @@ def __init__(self, conf: dict): if not isinstance(identity_pool, str): raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) - if self.bearer_auth_credentials_source == 'OAUTHBEARER': - properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', - 'bearer.auth.issuer.endpoint.url'] - missing_properties = [prop for prop in properties_list if prop not in conf_copy] - if missing_properties: - raise ValueError("Missing required OAuth configuration properties: {}". - format(", ".join(missing_properties))) - - self.client_id = conf_copy.pop('bearer.auth.client.id') - if not isinstance(self.client_id, string_type): - raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) - - self.client_secret = conf_copy.pop('bearer.auth.client.secret') - if not isinstance(self.client_secret, string_type): - raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) - - self.scope = conf_copy.pop('bearer.auth.scope') - if not isinstance(self.scope, string_type): - raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) - - self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') - if not isinstance(self.token_endpoint, string_type): - raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " - + str(type(self.token_endpoint))) - - self.bearer_field_provider = _AsyncOAuthClient( - self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) - elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': - if 'bearer.auth.token' not in conf_copy: - raise ValueError("Missing bearer.auth.token") - static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) - if not isinstance(static_token, string_type): - raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) + if self.bearer_auth_credentials_source == 'OAUTHBEARER': + properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', + 'bearer.auth.issuer.endpoint.url'] + missing_properties = [prop for prop in properties_list if prop not in conf_copy] + if missing_properties: + raise ValueError("Missing required OAuth configuration properties: {}". + format(", ".join(missing_properties))) + + self.client_id = conf_copy.pop('bearer.auth.client.id') + if not isinstance(self.client_id, string_type): + raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) + + self.client_secret = conf_copy.pop('bearer.auth.client.secret') + if not isinstance(self.client_secret, string_type): + raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) + + self.scope = conf_copy.pop('bearer.auth.scope') + if not isinstance(self.scope, string_type): + raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) + + self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') + if not isinstance(self.token_endpoint, string_type): + raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + + str(type(self.token_endpoint))) + + self.bearer_field_provider = _AsyncOAuthClient( + self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) + else: # STATIC_TOKEN + if 'bearer.auth.token' not in conf_copy: + raise ValueError("Missing bearer.auth.token") + static_token = conf_copy.pop('bearer.auth.token') + self.bearer_field_provider = _AsyncStaticFieldProvider(static_token, logical_cluster, identity_pool) + if not isinstance(static_token, string_type): + raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) elif self.bearer_auth_credentials_source == 'CUSTOM': custom_bearer_properties = ['bearer.auth.custom.provider.function', 'bearer.auth.custom.provider.config'] @@ -379,6 +383,8 @@ def __init__(self, conf: dict): ) async def handle_bearer_auth(self, headers: dict) -> None: + if self.bearer_field_provider is None: + raise ValueError("Bearer field provider is not set") bearer_fields = await self.bearer_field_provider.get_bearer_fields() required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] @@ -437,9 +443,10 @@ async def send_request( " application/vnd.schemaregistry+json," " application/json"} + body_str: Optional[str] = None if body is not None: - body = json.dumps(body) - headers = {'Content-Length': str(len(body)), + body_str = json.dumps(body) + headers = {'Content-Length': str(len(body_str)), 'Content-Type': "application/vnd.schemaregistry.v1+json"} if self.bearer_auth_credentials_source: @@ -449,7 +456,7 @@ async def send_request( for i, base_url in enumerate(self.base_urls): try: response = await self.send_http_request( - base_url, url, method, headers, body, query) + base_url, url, method, headers, body_str, query) if is_success(response.status_code): return response.json() @@ -461,16 +468,18 @@ async def send_request( # Raise the exception since we have no more urls to try raise e - try: - raise SchemaRegistryError(response.status_code, - response.json().get('error_code'), - response.json().get('message')) - # Schema Registry may return malformed output when it hits unexpected errors - except (ValueError, KeyError, AttributeError): - raise SchemaRegistryError(response.status_code, - -1, - "Unknown Schema Registry Error: " - + str(response.content)) + if isinstance(response, Response): + try: + raise SchemaRegistryError(response.status_code, + response.json().get('error_code'), + response.json().get('message')) + except (ValueError, KeyError, AttributeError): + raise SchemaRegistryError(response.status_code, + -1, + "Unknown Schema Registry Error: " + + str(response.content)) + else: + raise TypeError("Unexpected response of unsupported type: " + str(type(response))) async def send_http_request( self, base_url: str, url: str, method: str, headers: Optional[dict], @@ -514,7 +523,7 @@ async def send_http_request( return response await asyncio.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) - return response + return response # type: ignore[return-value] class AsyncSchemaRegistryClient(object): @@ -597,6 +606,8 @@ def __init__(self, conf: dict): self._cache = _SchemaCache() cache_capacity = self._rest_client.cache_capacity cache_ttl = self._rest_client.cache_latest_ttl_sec + self._latest_version_cache: Cache[Any, Any] + self._latest_with_metadata_cache: Cache[Any, Any] if cache_ttl is not None: self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) @@ -639,7 +650,7 @@ async def register_schema( registered_schema = await self.register_schema_full_response( subject_name, schema, normalize_schemas=normalize_schemas) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] async def register_schema_full_response( self, subject_name: str, schema: 'Schema', @@ -682,14 +693,14 @@ async def register_schema_full_response( 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), body=request) - result = RegisteredSchema.from_dict(response) + response_schema = RegisteredSchema.from_dict(response) registered_schema = RegisteredSchema( - schema_id=result.schema_id, - guid=result.guid, - subject=result.subject or subject_name, - version=result.version, - schema=result.schema + schema_id=response_schema.schema_id, + guid=response_schema.guid, + subject=response_schema.subject or subject_name, + version=response_schema.version, + schema=response_schema.schema, ) # The registered schema may not be fully populated @@ -724,7 +735,8 @@ async def get_schema( `GET Schema API Reference `_ """ # noqa: E501 - result = self._cache.get_schema_by_id(subject_name, schema_id) + if subject_name is not None: + result = self._cache.get_schema_by_id(subject_name, schema_id) if result is not None: return result[1] @@ -818,7 +830,7 @@ async def get_subjects_by_schema_id( Raises: SchemaRegistryError: if subjects can't be found """ - query = {'offset': offset, 'limit': limit} + query: dict[str, Any] = {'offset': offset, 'limit': limit} if subject_name is not None: query['subject'] = subject_name if deleted: @@ -851,7 +863,7 @@ async def get_schema_versions( `GET Schema Versions API Reference `_ """ # noqa: E501 - query = {'offset': offset, 'limit': limit} + query: dict[str, Any] = {'offset': offset, 'limit': limit} if subject_name is not None: query['subject'] = subject_name if deleted: @@ -889,7 +901,7 @@ async def lookup_schema( request = schema.to_dict() - query_params = { + query_params: dict[str, Any] = { 'normalize': normalize_schemas, 'deleted': deleted } @@ -942,7 +954,7 @@ async def get_subjects( `GET subjects API Reference `_ """ # noqa: E501 - query = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} + query: dict[str, Any] = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} if subject_prefix is not None: query['subject'] = subject_prefix return await self._rest_client.get('subjects', query) @@ -1039,7 +1051,7 @@ async def get_latest_with_metadata( if registered_schema is not None: return registered_schema - query = {'deleted': deleted} + query: dict[str, Any] = {'deleted': deleted} if fmt is not None: query['format'] = fmt keys = metadata.keys() @@ -1058,7 +1070,7 @@ async def get_latest_with_metadata( return registered_schema async def get_version( - self, subject_name: str, version: Union[int, str] = "latest", + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", deleted: bool = False, fmt: Optional[str] = None ) -> 'RegisteredSchema': """ @@ -1066,7 +1078,7 @@ async def get_version( Args: subject_name (str): Subject name. - version (Union[int, str]): Version of the schema or string "latest". Defaults to latest version. + version (Union[int, Literal["latest"]]): Version of the schema or string "latest". Defaults to latest version. deleted (bool): Whether to include deleted schemas. fmt (str): Format of the schema. @@ -1080,11 +1092,12 @@ async def get_version( `GET Subject Versions API Reference `_ """ # noqa: E501 - registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) - if registered_schema is not None: - return registered_schema + if version != "latest": + registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) + if registered_schema is not None: + return registered_schema - query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + query: dict[str, Any] = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} response = await self._rest_client.get( 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query ) @@ -1096,7 +1109,7 @@ async def get_version( return registered_schema async def get_referenced_by( - self, subject_name: str, version: Union[int, str] = "latest", + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", offset: int = 0, limit: int = -1 ) -> List[int]: """ @@ -1104,7 +1117,7 @@ async def get_referenced_by( Args: subject_name (str): Subject name - version (int or str): Version number or "latest" + version (Union[int, Literal["latest"]]): Version number or "latest" offset (int): Pagination offset for results. limit (int): Pagination size for results. Ignored if negative. @@ -1118,7 +1131,7 @@ async def get_referenced_by( `GET Subject Versions (ReferenceBy) API Reference `_ """ # noqa: E501 - query = {'offset': offset, 'limit': limit} + query: dict[str, Any] = {'offset': offset, 'limit': limit} return await self._rest_client.get('subjects/{}/versions/{}/referencedby'.format( _urlencode(subject_name), version), query) @@ -1146,7 +1159,7 @@ async def get_versions( `GET Subject All Versions API Reference `_ """ # noqa: E501 - query = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} + query: dict[str, Any] = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} return await self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name)), query) async def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: @@ -1516,6 +1529,6 @@ def clear_caches(self): def new_client(conf: dict) -> 'AsyncSchemaRegistryClient': from .mock_schema_registry_client import AsyncMockSchemaRegistryClient url = conf.get("url") - if url.startswith("mock://"): + if url and isinstance(url, str) and url.startswith("mock://"): return AsyncMockSchemaRegistryClient(conf) return AsyncSchemaRegistryClient(conf) diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py index 40f72b721..bb1be66e1 100644 --- a/src/confluent_kafka/schema_registry/_async/serde.py +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -17,7 +17,7 @@ # import logging -from typing import List, Optional, Set, Dict, Any +from typing import List, Optional, Set, Dict, Any, Callable from confluent_kafka.schema_registry import RegisteredSchema from confluent_kafka.schema_registry.common.schema_registry_client import \ @@ -44,6 +44,14 @@ class AsyncBaseSerde(object): '_registry', '_rule_registry', '_subject_name_func', '_field_transformer'] + _use_schema_id: Optional[int] + _use_latest_version: bool + _use_latest_with_metadata: Optional[Dict[str, str]] + _registry: Any # AsyncSchemaRegistryClient + _rule_registry: Any # RuleRegistry + _subject_name_func: Callable[[Optional['SerializationContext'], Optional[str]], Optional[str]] + _field_transformer: Optional[FieldTransformer] + async def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: if self._use_schema_id is not None: schema = await self._registry.get_schema(self._use_schema_id, subject, fmt) @@ -114,6 +122,11 @@ def _execute_rules_with_phase( ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule, index, rules, inline_tags, field_transformer) + if rule.type is None: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, + RuleError(f"Rule type is None for rule {rule.name}"), + 'ERROR') + return message rule_executor = self._rule_registry.get_executor(rule.type.upper()) if rule_executor is None: self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, @@ -140,18 +153,24 @@ def _execute_rules_with_phase( return message def _get_on_success(self, rule: Rule) -> Optional[str]: + if rule.type is None: + return rule.on_success override = self._rule_registry.get_override(rule.type) if override is not None and override.on_success is not None: return override.on_success return rule.on_success def _get_on_failure(self, rule: Rule) -> Optional[str]: + if rule.type is None: + return rule.on_failure override = self._rule_registry.get_override(rule.type) if override is not None and override.on_failure is not None: return override.on_failure return rule.on_failure def _is_disabled(self, rule: Rule) -> Optional[bool]: + if rule.type is None: + return rule.disabled override = self._rule_registry.get_override(rule.type) if override is not None and override.disabled is not None: return override.disabled @@ -200,10 +219,16 @@ def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleA class AsyncBaseSerializer(AsyncBaseSerde, Serializer): __slots__ = ['_auto_register', '_normalize_schemas', '_schema_id_serializer'] + _auto_register: bool + _normalize_schemas: bool + _schema_id_serializer: Callable[[bytes, Any, Any], bytes] + class AsyncBaseDeserializer(AsyncBaseSerde, Deserializer): __slots__ = ['_schema_id_deserializer'] + _schema_id_deserializer: Callable[[bytes, Any, Any], Any] + async def _get_writer_schema( self, schema_id: SchemaId, subject: Optional[str] = None, fmt: Optional[str] = None) -> Schema: @@ -241,7 +266,7 @@ async def _get_migrations( ) -> List[Migration]: source = await self._registry.lookup_schema( subject, source_info, normalize_schemas=False, deleted=True) - migrations = [] + migrations: List[Migration] = [] if source.version < target.version: migration_mode = RuleMode.UPGRADE first = source @@ -259,13 +284,14 @@ async def _get_migrations( if i == 0: previous = version continue - if version.schema.rule_set is not None and self._has_rules( + if version.schema is not None and version.schema.rule_set is not None and self._has_rules( version.schema.rule_set, RulePhase.MIGRATION, migration_mode): - if migration_mode == RuleMode.UPGRADE: - migration = Migration(migration_mode, previous, version) - else: - migration = Migration(migration_mode, version, previous) - migrations.append(migration) + if previous is not None: # previous is always set after first iteration + if migration_mode == RuleMode.UPGRADE: + migration = Migration(migration_mode, previous, version) + else: + migration = Migration(migration_mode, version, previous) + migrations.append(migration) previous = version if migration_mode == RuleMode.DOWNGRADE: migrations.reverse() @@ -275,6 +301,8 @@ async def _get_schemas_between( self, subject: str, first: RegisteredSchema, last: RegisteredSchema, fmt: Optional[str] = None ) -> List[RegisteredSchema]: + if first.version is None or last.version is None: + return [first, last] if last.version - first.version <= 1: return [first, last] version1 = first.version @@ -290,8 +318,9 @@ def _execute_migrations( migrations: List[Migration], message: Any ) -> Any: for migration in migrations: - message = self._execute_rules_with_phase( - ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode, - migration.source.schema, migration.target.schema, - message, None, None) + if migration.source is not None and migration.target is not None: + message = self._execute_rules_with_phase( + ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode, + migration.source.schema, migration.target.schema, + message, None, None) return message diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index b4282624e..dbb7d43b2 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -16,7 +16,7 @@ # limitations under the License. import io import json -from typing import Dict, Union, Optional, Callable +from typing import Any, Coroutine, Dict, Union, Optional, Callable, cast from fastavro import schemaless_reader, schemaless_writer @@ -57,11 +57,18 @@ def _resolve_named_schema( named_schemas = {} if schema.references is not None: for ref in schema.references: + if ref.subject is None or ref.version is None: + raise TypeError("Subject or version cannot be None") referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client) + if referenced_schema.schema.schema_str is None: + raise TypeError("Schema string cannot be None") + parsed_schema = parse_schema_with_repo( referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) named_schemas.update(ref_named_schemas) + if ref.name is None: + raise TypeError("Name cannot be None") named_schemas[ref.name] = parsed_schema return named_schemas @@ -204,9 +211,9 @@ def __init_impl( schema = None self._registry = schema_registry_client - self._schema_id = None + self._schema_id: Optional[SchemaId] = None self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._known_subjects = set() + self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() if to_dict is not None and not callable(to_dict): @@ -219,35 +226,41 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") @@ -265,13 +278,18 @@ def __init_impl( # and union types should use topic_subject_name_strategy, which # just discards the schema name anyway schema_name = None - else: + elif isinstance(parsed_schema, dict): # The Avro spec states primitives have a name equal to their type # i.e. {"type": "string"} has a name of string. # This function does not comply. # https://github.com/fastavro/fastavro/issues/415 - schema_dict = json.loads(schema.schema_str) - schema_name = parsed_schema.get("name", schema_dict.get("type")) + if schema.schema_str is not None: + schema_dict = json.loads(schema.schema_str) + schema_name = parsed_schema.get("name", schema_dict.get("type")) + else: + schema_name = None + else: + schema_name = None else: schema_name = None parsed_schema = None @@ -286,7 +304,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] return self.__serialize(obj, ctx) def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -313,10 +331,10 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = self._get_reader_schema(subject) + latest_schema = self._get_reader_schema(subject) if subject else None if latest_schema is not None: self._schema_id = SchemaId(AVRO_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects: + elif subject is not None and subject not in self._known_subjects: # Check to ensure this schema has been registered under subject_name. if self._auto_register: # The schema name will always be the same. We can't however register @@ -332,12 +350,16 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - self._known_subjects.add(subject) + value: Any + parsed_schema: Any if self._to_dict is not None: + if ctx is None: + raise TypeError("SerializationContext cannot be None") value = self._to_dict(obj, ctx) else: value = obj - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: parsed_schema = self._get_parsed_schema(latest_schema.schema) def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, msg, field_transform)) @@ -352,7 +374,7 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 schemaless_writer(fo, parsed_schema, value) buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -365,7 +387,11 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema named_schemas = _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise TypeError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) + if prepared_schema.schema_str is None: + raise TypeError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) @@ -477,20 +503,26 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") @@ -498,7 +530,8 @@ def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - if schema: + self._reader_schema: Optional[AvroSchema] + if schema and self._schema is not None: self._reader_schema = self._get_parsed_schema(self._schema) else: self._reader_schema = None @@ -518,11 +551,11 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: return self.__deserialize(data, ctx) def __deserialize( - self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: """ Deserialize Avro binary encoded data with Confluent Schema Registry framing to a dict, or object instance according to from_dict, if specified. @@ -560,19 +593,20 @@ def __deserialize( writer_schema_raw = self._get_writer_schema(schema_id, subject) writer_schema = self._get_parsed_schema(writer_schema_raw) - if subject is None: - subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None + subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None # type: ignore[union-attr] if subject is not None: latest_schema = self._get_reader_schema(subject) - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) - if latest_schema is not None: + reader_schema: Optional[AvroSchema] + if latest_schema is not None and subject is not None: migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema reader_schema = self._get_parsed_schema(latest_schema.schema) @@ -585,7 +619,7 @@ def __deserialize( reader_schema_raw = writer_schema_raw reader_schema = writer_schema - if migrations: + if migrations and ctx is not None and subject is not None: obj_dict = schemaless_reader(payload, writer_schema, None, @@ -599,11 +633,15 @@ def __deserialize( def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, get_inline_tags(reader_schema), - field_transformer) + if ctx is not None and subject is not None: + inline_tags = get_inline_tags(reader_schema) if reader_schema is not None else None + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, + inline_tags, field_transformer) if self._from_dict is not None: + if ctx is None: + raise TypeError("SerializationContext cannot be None") return self._from_dict(obj_dict, ctx) return obj_dict @@ -614,7 +652,11 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema named_schemas = _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise TypeError("Schema string cannot be None") prepared_schema = _schema_loads(schema.schema_str) + if prepared_schema.schema_str is None: + raise TypeError("Prepared schema string cannot be None") parsed_schema = parse_schema_with_repo( prepared_schema.schema_str, named_schemas=named_schemas) diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 9d809386c..1b8ad67d1 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -16,7 +16,7 @@ # limitations under the License. import io import orjson -from typing import Union, Optional, Tuple, Callable +from typing import Any, Coroutine, Union, Optional, Tuple, Callable, cast from cachetools import LRUCache from jsonschema import ValidationError @@ -61,14 +61,21 @@ def _resolve_named_schema( """ if ref_registry is None: # Retrieve external schemas for backward compatibility - ref_registry = Registry(retrieve=_retrieve_via_httpx) + ref_registry = Registry(retrieve=_retrieve_via_httpx) # type: ignore[call-arg] if schema.references is not None: for ref in schema.references: + if ref.subject is None or ref.version is None: + raise TypeError("Subject or version cannot be None") referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) ref_registry = _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) + if referenced_schema.schema.schema_str is None: + raise TypeError("Schema string cannot be None") + referenced_schema_dict = orjson.loads(referenced_schema.schema.schema_str) resource = Resource.from_contents( referenced_schema_dict, default_specification=DEFAULT_SPEC) + if ref.name is None: + raise TypeError("Name cannot be None") ref_registry = ref_registry.with_resource(ref.name, resource) return ref_registry @@ -213,6 +220,7 @@ def __init_impl( json_encode: Optional[Callable] = None, ): super().__init__() + self._schema: Optional[Schema] if isinstance(schema_str, str): self._schema = Schema(schema_str, schema_type="JSON") elif isinstance(schema_str, Schema): @@ -225,10 +233,10 @@ def __init_impl( self._rule_registry = ( rule_registry if rule_registry else RuleRegistry.get_global_instance() ) - self._schema_id = None - self._known_subjects = set() + self._schema_id: Optional[SchemaId] = None + self._known_subjects: set[str] = set() self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) + self._validators: LRUCache[Schema, Validator] = LRUCache(1000) if to_dict is not None and not callable(to_dict): raise ValueError("to_dict must be callable with the signature " @@ -240,50 +248,59 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") - self._validate = conf_copy.pop('validate') + self._validate = cast(bool, conf_copy.pop('validate')) if not isinstance(self._validate, bool): raise ValueError("validate must be a boolean value") if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - - schema_dict, ref_registry = self._get_parsed_schema(self._schema) - if schema_dict: - schema_name = schema_dict.get('title', None) + if self._schema: + schema_dict, ref_registry = self._get_parsed_schema(self._schema) + if schema_dict and isinstance(schema_dict, dict): + schema_name = schema_dict.get('title', None) + else: + schema_name = None else: + schema_dict = None schema_name = None self._schema_name = schema_name @@ -296,7 +313,7 @@ def __init_impl( __init__ = __init_impl - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] return self.__serialize(obj, ctx) def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -322,10 +339,10 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - return None subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = self._get_reader_schema(subject) + latest_schema = self._get_reader_schema(subject) if subject else None if latest_schema is not None: self._schema_id = SchemaId(JSON_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects: + elif subject is not None and subject not in self._known_subjects: # Check to ensure this schema has been registered under subject_name. if self._auto_register: # The schema name will always be the same. We can't however register @@ -341,27 +358,33 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - self._known_subjects.add(subject) + value: Any if self._to_dict is not None: + if ctx is None: + raise TypeError("SerializationContext cannot be None") value = self._to_dict(obj, ctx) else: value = obj + schema: Optional[Schema] = None if latest_schema is not None: schema = latest_schema.schema parsed_schema, ref_registry = self._get_parsed_schema(latest_schema.schema) - root_resource = Resource.from_contents( - parsed_schema, default_specification=DEFAULT_SPEC) - ref_resolver = ref_registry.resolver_with_root(root_resource) - def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 - transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, value, None, - field_transformer) + if ref_registry is not None: + root_resource = Resource.from_contents( + parsed_schema, default_specification=DEFAULT_SPEC) + ref_resolver = ref_registry.resolver_with_root(root_resource) + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 + transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) + if ctx is not None and subject is not None: + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, None, + field_transformer) else: schema = self._schema parsed_schema, ref_registry = self._parsed_schema, self._ref_registry - if self._validate: + if self._validate and schema is not None and parsed_schema is not None and ref_registry is not None: try: validator = self._get_validator(schema, parsed_schema, ref_registry) validator.validate(value) @@ -377,7 +400,7 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 fo.write(encoded_value) buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -393,6 +416,8 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Opti return result ref_registry = _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise TypeError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) @@ -488,6 +513,7 @@ def __init_impl( json_decode: Optional[Callable] = None, ): super().__init__() + schema: Optional[Schema] if isinstance(schema_str, str): schema = Schema(schema_str, schema_type="JSON") elif isinstance(schema_str, Schema): @@ -504,11 +530,11 @@ def __init_impl( else: raise TypeError('You must pass either str or Schema') - self._schema = schema + self._schema: Optional[Schema] = schema self._registry = schema_registry_client self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) + self._validators: LRUCache[Schema, Validator] = LRUCache(1000) self._json_decode = json_decode or orjson.loads self._use_schema_id = None @@ -516,24 +542,30 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') - if not callable(self._subject_name_func): + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) + if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") - self._validate = conf_copy.pop('validate') + self._validate = cast(bool, conf_copy.pop('validate')) if not isinstance(self._validate, bool): raise ValueError("validate must be a boolean value") @@ -541,7 +573,7 @@ def __init_impl( raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - if schema: + if schema and self._schema is not None: self._reader_schema, self._ref_registry = self._get_parsed_schema(self._schema) else: self._reader_schema, self._ref_registry = None, None @@ -558,10 +590,10 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__deserialize(data, ctx) - def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a JSON encoded record with Confluent Schema Registry framing to a dict, or object instance according to from_dict if from_dict is specified. @@ -593,7 +625,7 @@ def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) if self._registry is not None: writer_schema_raw = self._get_writer_schema(schema_id, subject) writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw) - if subject is None: + if subject is None and isinstance(writer_schema, dict): subject = self._subject_name_func(ctx, writer_schema.get("title")) if subject is not None: latest_schema = self._get_reader_schema(subject) @@ -601,16 +633,18 @@ def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) writer_schema_raw = None writer_schema, writer_ref_registry = None, None - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) # JSON documents are self-describing; no need to query schema obj_dict = self._json_decode(payload.read()) - if latest_schema is not None: + reader_schema_raw: Optional[Schema] = None + if latest_schema is not None and subject is not None and writer_schema_raw is not None: migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema reader_schema, reader_ref_registry = self._get_parsed_schema(latest_schema.schema) @@ -623,21 +657,23 @@ def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) reader_schema_raw = writer_schema_raw reader_schema, reader_ref_registry = writer_schema, writer_ref_registry - if migrations: + if migrations and ctx is not None and subject is not None: obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - reader_root_resource = Resource.from_contents( - reader_schema, default_specification=DEFAULT_SPEC) - reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) + if reader_ref_registry is not None: + reader_root_resource = Resource.from_contents( + reader_schema, default_specification=DEFAULT_SPEC) + reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) - def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 - transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, - "$", message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, None, - field_transformer) + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 + transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, + "$", message, field_transform)) + if ctx is not None and subject is not None: + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, None, + field_transformer) - if self._validate: + if self._validate and reader_schema_raw is not None and reader_schema is not None and reader_ref_registry is not None: try: validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) validator.validate(obj_dict) @@ -645,7 +681,9 @@ def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E73 raise SerializationError(ve.message) if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) + if ctx is None: + raise TypeError("SerializationContext cannot be None") + return self._from_dict(obj_dict, ctx) # type: ignore[return-value] return obj_dict @@ -658,6 +696,8 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Opti return result ref_registry = _resolve_named_schema(schema, self._registry) + if schema.schema_str is None: + raise TypeError("Schema string cannot be None") parsed_schema = orjson.loads(schema.schema_str) self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) diff --git a/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py index a7f39f20d..3ca77fbf3 100644 --- a/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py @@ -18,7 +18,7 @@ import uuid from collections import defaultdict from threading import Lock -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union, Literal from .schema_registry_client import SchemaRegistryClient from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig @@ -73,7 +73,7 @@ def get_registered_schema_by_schema( return rs return None - def get_version(self, subject_name: str, version: int) -> Optional[RegisteredSchema]: + def get_version(self, subject_name: str, version: Union[int, str]) -> Optional[RegisteredSchema]: with self.lock: if subject_name in self.subject_schemas: for rs in self.subject_schemas[subject_name]: @@ -157,7 +157,7 @@ def register_schema( ) -> int: registered_schema = self.register_schema_full_response( subject_name, schema, normalize_schemas=normalize_schemas) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] def register_schema_full_response( self, subject_name: str, schema: 'Schema', @@ -168,7 +168,7 @@ def register_schema_full_response( return registered_schema latest_schema = self._store.get_latest_version(subject_name) - latest_version = 1 if latest_schema is None else latest_schema.version + 1 + latest_version = 1 if latest_schema is None or latest_schema.version is None else latest_schema.version + 1 registered_schema = RegisteredSchema( schema_id=1, @@ -184,7 +184,7 @@ def register_schema_full_response( def get_schema( self, schema_id: int, subject_name: Optional[str] = None, - fmt: Optional[str] = None + fmt: Optional[str] = None, reference_format: Optional[str] = None ) -> 'Schema': schema = self._store.get_schema(schema_id) if schema is not None: @@ -212,7 +212,10 @@ def lookup_schema( raise SchemaRegistryError(404, 40400, "Schema Not Found") - def get_subjects(self) -> List[str]: + def get_subjects( + self, subject_prefix: Optional[str] = None, deleted: bool = False, + deleted_only: bool = False, offset: int = 0, limit: int = -1 + ) -> List[str]: return self._store.get_subjects() def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: @@ -236,30 +239,36 @@ def get_latest_with_metadata( raise SchemaRegistryError(404, 40400, "Schema Not Found") def get_version( - self, subject_name: str, version: int, + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", deleted: bool = False, fmt: Optional[str] = None ) -> 'RegisteredSchema': - registered_schema = self._store.get_version(subject_name, version) + if version == "latest": + registered_schema = self._store.get_latest_version(subject_name) + else: + registered_schema = self._store.get_version(subject_name, version) if registered_schema is not None: return registered_schema raise SchemaRegistryError(404, 40400, "Schema Not Found") - def get_versions(self, subject_name: str) -> List[int]: + def get_versions( + self, subject_name: str, deleted: bool = False, deleted_only: bool = False, + offset: int = 0, limit: int = -1 + ) -> List[int]: return self._store.get_versions(subject_name) def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: registered_schema = self._store.get_version(subject_name, version) if registered_schema is not None: self._store.remove_by_schema(registered_schema) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] raise SchemaRegistryError(404, 40400, "Schema Not Found") def set_config( - self, subject_name: Optional[str] = None, config: 'ServerConfig' = None # noqa F821 + self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None # noqa F821 ) -> 'ServerConfig': # noqa F821 - return None + return None # type: ignore[return-value] def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': # noqa F821 - return None + return None # type: ignore[return-value] diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 501823716..72c936b97 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -16,7 +16,7 @@ # limitations under the License. import io -from typing import Set, List, Union, Optional, Tuple +from typing import Any, Coroutine, Set, List, Union, Optional, Tuple, Callable, cast from google.protobuf import json_format, descriptor_pb2 from google.protobuf.descriptor_pool import DescriptorPool @@ -67,10 +67,18 @@ def _resolve_named_schema( visited = set() if schema.references is not None: for ref in schema.references: + if ref.name is None: + raise ValueError("Name cannot be None") + if _is_builtin(ref.name) or ref.name in visited: continue visited.add(ref.name) + + if ref.subject is None or ref.version is None: + raise ValueError("Subject or version cannot be None") referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') + if referenced_schema.schema.schema_str is None: + raise ValueError("Schema string cannot be None") _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) pool.Add(file_descriptor_proto) @@ -218,50 +226,58 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._auto_register = conf_copy.pop('auto.register.schemas') + self._auto_register = cast(bool, conf_copy.pop('auto.register.schemas')) if not isinstance(self._auto_register, bool): raise ValueError("auto.register.schemas must be a boolean value") - self._normalize_schemas = conf_copy.pop('normalize.schemas') + self._normalize_schemas = cast(bool, conf_copy.pop('normalize.schemas')) if not isinstance(self._normalize_schemas, bool): raise ValueError("normalize.schemas must be a boolean value") - self._use_schema_id = conf_copy.pop('use.schema.id') + self._use_schema_id = cast(Optional[int], conf_copy.pop('use.schema.id')) if (self._use_schema_id is not None and not isinstance(self._use_schema_id, int)): raise ValueError("use.schema.id must be an int value") - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._skip_known_types = conf_copy.pop('skip.known.types') + self._skip_known_types = cast(bool, conf_copy.pop('skip.known.types')) if not isinstance(self._skip_known_types, bool): raise ValueError("skip.known.types must be a boolean value") - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + self._use_deprecated_format = cast(bool, conf_copy.pop('use.deprecated.format')) if not isinstance(self._use_deprecated_format, bool): raise ValueError("use.deprecated.format must be a boolean value") if self._use_deprecated_format: raise ValueError("use.deprecated.format is no longer supported") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._ref_reference_subject_func = conf_copy.pop( - 'reference.subject.name.strategy') + self._ref_reference_subject_func = cast( + Callable[[Optional[SerializationContext], Any], Optional[str]], + conf_copy.pop('reference.subject.name.strategy') + ) if not callable(self._ref_reference_subject_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + self._schema_id_serializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], bytes], + conf_copy.pop('schema.id.serializer') + ) if not callable(self._schema_id_serializer): raise ValueError("schema.id.serializer must be callable") @@ -271,8 +287,8 @@ def __init_impl( self._registry = schema_registry_client self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._schema_id = None - self._known_subjects = set() + self._schema_id: Optional[SchemaId] = None + self._known_subjects: set[str] = set() self._msg_class = msg_type self._parsed_schemas = ParsedSchemaCache() @@ -360,7 +376,7 @@ def _resolve_dependencies( reference.version)) return schema_refs - def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: # type: ignore[override] return self.__serialize(message, ctx) def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: @@ -397,7 +413,7 @@ def __serialize(self, message: Message, ctx: Optional[SerializationContext] = No if latest_schema is not None: self._schema_id = SchemaId(PROTOBUF_TYPE, latest_schema.schema_id, latest_schema.guid) - elif subject not in self._known_subjects and ctx is not None: + elif subject is not None and subject not in self._known_subjects and ctx is not None: references = self._resolve_dependencies(ctx, message.DESCRIPTOR.file) self._schema = Schema( self._schema.schema_str, @@ -422,16 +438,18 @@ def __serialize(self, message: Message, ctx: Optional[SerializationContext] = No desc = fd.message_types_by_name[message.DESCRIPTOR.name] def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, desc, msg, field_transform)) - message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, message, None, - field_transformer) + if ctx is not None and subject is not None: + message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, message, None, + field_transformer) with _ContextStringIO() as fo: fo.write(message.SerializeToString()) - self._schema_id.message_indexes = self._index_array + if self._schema_id is not None: + self._schema_id.message_indexes = self._index_array buffer = fo.getvalue() - if latest_schema is not None: + if latest_schema is not None and ctx is not None and subject is not None: buffer = self._execute_rules_with_phase( ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, None, latest_schema.schema, buffer, None, None) @@ -446,6 +464,8 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip pool = DescriptorPool() _init_pool(pool) _resolve_named_schema(schema, self._registry, pool) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") fd_proto = _str_to_proto("default", schema.schema_str) pool.Add(fd_proto) self._parsed_schemas.set(schema, (fd_proto, pool)) @@ -526,24 +546,30 @@ def __init_impl( if conf is not None: conf_copy.update(conf) - self._use_latest_version = conf_copy.pop('use.latest.version') + self._use_latest_version = cast(bool, conf_copy.pop('use.latest.version')) if not isinstance(self._use_latest_version, bool): raise ValueError("use.latest.version must be a boolean value") - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + self._use_latest_with_metadata = cast(Optional[dict], conf_copy.pop('use.latest.with.metadata')) if (self._use_latest_with_metadata is not None and not isinstance(self._use_latest_with_metadata, dict)): raise ValueError("use.latest.with.metadata must be a dict value") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func = cast( + Callable[[Optional[SerializationContext], Optional[str]], Optional[str]], + conf_copy.pop('subject.name.strategy') + ) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + self._schema_id_deserializer = cast( + Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO], + conf_copy.pop('schema.id.deserializer') + ) if not callable(self._schema_id_deserializer): raise ValueError("schema.id.deserializer must be callable") - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + self._use_deprecated_format = cast(bool, conf_copy.pop('use.deprecated.format')) if not isinstance(self._use_deprecated_format, bool): raise ValueError("use.deprecated.format must be a boolean value") if self._use_deprecated_format: @@ -558,10 +584,10 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__deserialize(data, ctx) - def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a serialized protobuf message with Confluent Schema Registry framing. @@ -597,7 +623,7 @@ def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) writer_schema_raw = self._get_writer_schema(schema_id, subject, fmt='serialized') fd_proto, pool = self._get_parsed_schema(writer_schema_raw) writer_schema = pool.FindFileByName(fd_proto.name) - writer_desc = self._get_message_desc(pool, writer_schema, msg_index) + writer_desc = self._get_message_desc(pool, writer_schema, msg_index if msg_index is not None else []) if subject is None: subject = self._subject_name_func(ctx, writer_desc.full_name) if subject is not None: @@ -606,13 +632,15 @@ def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) writer_schema_raw = None writer_schema = None - payload = self._execute_rules_with_phase( - ctx, subject, RulePhase.ENCODING, RuleMode.READ, - None, writer_schema_raw, payload, None, None) + if ctx is not None and subject is not None: + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) if isinstance(payload, bytes): payload = io.BytesIO(payload) - if latest_schema is not None: + reader_schema_raw: Optional[Schema] = None + if latest_schema is not None and subject is not None and writer_schema_raw is not None: migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema fd_proto, pool = self._get_parsed_schema(latest_schema.schema) @@ -628,7 +656,7 @@ def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) # Attempt to find a reader desc with the same name as the writer reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) - if migrations: + if migrations and ctx is not None and subject is not None: msg = GetMessageClass(writer_desc)() try: msg.ParseFromString(payload.read()) @@ -649,9 +677,10 @@ def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_desc, message, field_transform)) - msg = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, msg, None, - field_transformer) + if ctx is not None and subject is not None: + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, msg, None, + field_transformer) return msg def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: @@ -662,6 +691,8 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip pool = DescriptorPool() _init_pool(pool) _resolve_named_schema(schema, self._registry, pool) + if schema.schema_str is None: + raise ValueError("Schema string cannot be None") fd_proto = _str_to_proto("default", schema.schema_str) pool.Add(fd_proto) self._parsed_schemas.set(schema, (fd_proto, pool)) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 7c4eb7060..8725c3d00 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -26,9 +26,9 @@ from urllib.parse import unquote, urlparse import httpx -from typing import List, Dict, Optional, Union, Any, Callable +from typing import List, Dict, Optional, Union, Any, Callable, Literal -from cachetools import TTLCache, LRUCache +from cachetools import Cache, TTLCache, LRUCache from httpx import Response from authlib.integrations.httpx_client import OAuth2Client @@ -65,10 +65,10 @@ # six: https://pypi.org/project/six/ # compat file : https://github.com/psf/requests/blob/master/requests/compat.py try: - string_type = basestring # noqa + string_type = basestring # type: ignore[name-defined] # noqa def _urlencode(value: str) -> str: - return urllib.quote(value, safe='') + return urllib.quote(value, safe='') # type: ignore[attr-defined] except NameError: string_type = str @@ -108,6 +108,9 @@ def get_bearer_fields(self) -> dict: } def token_expired(self) -> bool: + if self.token is None: + raise ValueError("Token is not set") + expiry_window = self.token['expires_in'] * self.token_expiry_threshold return self.token['expires_at'] < time.time() + expiry_window @@ -115,7 +118,8 @@ def token_expired(self) -> bool: def get_access_token(self) -> str: if not self.token or self.token_expired(): self.generate_access_token() - + if self.token is None: + raise ValueError("Token is not set after the attempt to generate it") return self.token['access_token'] def generate_access_token(self) -> None: @@ -227,7 +231,7 @@ def __init__(self, conf: dict): if cache_capacity is not None: if not isinstance(cache_capacity, (int, float)): raise TypeError("cache.capacity must be a number, not " + str(type(cache_capacity))) - self.cache_capacity = cache_capacity + self.cache_capacity = int(cache_capacity) self.cache_latest_ttl_sec = None cache_latest_ttl_sec = conf_copy.pop('cache.latest.ttl.sec', None) @@ -241,7 +245,7 @@ def __init__(self, conf: dict): if max_retries is not None: if not isinstance(max_retries, (int, float)): raise TypeError("max.retries must be a number, not " + str(type(max_retries))) - self.max_retries = max_retries + self.max_retries = int(max_retries) self.retries_wait_ms = 1000 retries_wait_ms = conf_copy.pop('retries.wait.ms', None) @@ -249,7 +253,7 @@ def __init__(self, conf: dict): if not isinstance(retries_wait_ms, (int, float)): raise TypeError("retries.wait.ms must be a number, not " + str(type(retries_wait_ms))) - self.retries_wait_ms = retries_wait_ms + self.retries_wait_ms = int(retries_wait_ms) self.retries_max_wait_ms = 20000 retries_max_wait_ms = conf_copy.pop('retries.max.wait.ms', None) @@ -257,11 +261,11 @@ def __init__(self, conf: dict): if not isinstance(retries_max_wait_ms, (int, float)): raise TypeError("retries.max.wait.ms must be a number, not " + str(type(retries_max_wait_ms))) - self.retries_max_wait_ms = retries_max_wait_ms + self.retries_max_wait_ms = int(retries_max_wait_ms) - self.bearer_field_provider = None logical_cluster = None identity_pool = None + self.bearer_field_provider: Optional[_BearerFieldProvider] = None self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) if self.bearer_auth_credentials_source is not None: self.auth = None @@ -281,43 +285,43 @@ def __init__(self, conf: dict): if not isinstance(identity_pool, str): raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) - if self.bearer_auth_credentials_source == 'OAUTHBEARER': - properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', - 'bearer.auth.issuer.endpoint.url'] - missing_properties = [prop for prop in properties_list if prop not in conf_copy] - if missing_properties: - raise ValueError("Missing required OAuth configuration properties: {}". - format(", ".join(missing_properties))) - - self.client_id = conf_copy.pop('bearer.auth.client.id') - if not isinstance(self.client_id, string_type): - raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) - - self.client_secret = conf_copy.pop('bearer.auth.client.secret') - if not isinstance(self.client_secret, string_type): - raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) - - self.scope = conf_copy.pop('bearer.auth.scope') - if not isinstance(self.scope, string_type): - raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) - - self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') - if not isinstance(self.token_endpoint, string_type): - raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " - + str(type(self.token_endpoint))) - - self.bearer_field_provider = _OAuthClient( - self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) - elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': - if 'bearer.auth.token' not in conf_copy: - raise ValueError("Missing bearer.auth.token") - static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) - if not isinstance(static_token, string_type): - raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) + if self.bearer_auth_credentials_source == 'OAUTHBEARER': + properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', + 'bearer.auth.issuer.endpoint.url'] + missing_properties = [prop for prop in properties_list if prop not in conf_copy] + if missing_properties: + raise ValueError("Missing required OAuth configuration properties: {}". + format(", ".join(missing_properties))) + + self.client_id = conf_copy.pop('bearer.auth.client.id') + if not isinstance(self.client_id, string_type): + raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) + + self.client_secret = conf_copy.pop('bearer.auth.client.secret') + if not isinstance(self.client_secret, string_type): + raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) + + self.scope = conf_copy.pop('bearer.auth.scope') + if not isinstance(self.scope, string_type): + raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) + + self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') + if not isinstance(self.token_endpoint, string_type): + raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + + str(type(self.token_endpoint))) + + self.bearer_field_provider = _OAuthClient( + self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) + else: # STATIC_TOKEN + if 'bearer.auth.token' not in conf_copy: + raise ValueError("Missing bearer.auth.token") + static_token = conf_copy.pop('bearer.auth.token') + self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) + if not isinstance(static_token, string_type): + raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) elif self.bearer_auth_credentials_source == 'CUSTOM': custom_bearer_properties = ['bearer.auth.custom.provider.function', 'bearer.auth.custom.provider.config'] @@ -379,6 +383,8 @@ def __init__(self, conf: dict): ) def handle_bearer_auth(self, headers: dict) -> None: + if self.bearer_field_provider is None: + raise ValueError("Bearer field provider is not set") bearer_fields = self.bearer_field_provider.get_bearer_fields() required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] @@ -437,9 +443,10 @@ def send_request( " application/vnd.schemaregistry+json," " application/json"} + body_str: Optional[str] = None if body is not None: - body = json.dumps(body) - headers = {'Content-Length': str(len(body)), + body_str = json.dumps(body) + headers = {'Content-Length': str(len(body_str)), 'Content-Type': "application/vnd.schemaregistry.v1+json"} if self.bearer_auth_credentials_source: @@ -449,7 +456,7 @@ def send_request( for i, base_url in enumerate(self.base_urls): try: response = self.send_http_request( - base_url, url, method, headers, body, query) + base_url, url, method, headers, body_str, query) if is_success(response.status_code): return response.json() @@ -461,16 +468,18 @@ def send_request( # Raise the exception since we have no more urls to try raise e - try: - raise SchemaRegistryError(response.status_code, - response.json().get('error_code'), - response.json().get('message')) - # Schema Registry may return malformed output when it hits unexpected errors - except (ValueError, KeyError, AttributeError): - raise SchemaRegistryError(response.status_code, - -1, - "Unknown Schema Registry Error: " - + str(response.content)) + if isinstance(response, Response): + try: + raise SchemaRegistryError(response.status_code, + response.json().get('error_code'), + response.json().get('message')) + except (ValueError, KeyError, AttributeError): + raise SchemaRegistryError(response.status_code, + -1, + "Unknown Schema Registry Error: " + + str(response.content)) + else: + raise TypeError("Unexpected response of unsupported type: " + str(type(response))) def send_http_request( self, base_url: str, url: str, method: str, headers: Optional[dict], @@ -514,7 +523,7 @@ def send_http_request( return response time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) - return response + return response # type: ignore[return-value] class SchemaRegistryClient(object): @@ -597,6 +606,8 @@ def __init__(self, conf: dict): self._cache = _SchemaCache() cache_capacity = self._rest_client.cache_capacity cache_ttl = self._rest_client.cache_latest_ttl_sec + self._latest_version_cache: Cache[Any, Any] + self._latest_with_metadata_cache: Cache[Any, Any] if cache_ttl is not None: self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) @@ -639,7 +650,7 @@ def register_schema( registered_schema = self.register_schema_full_response( subject_name, schema, normalize_schemas=normalize_schemas) - return registered_schema.schema_id + return registered_schema.schema_id # type: ignore[return-value] def register_schema_full_response( self, subject_name: str, schema: 'Schema', @@ -682,14 +693,14 @@ def register_schema_full_response( 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), body=request) - result = RegisteredSchema.from_dict(response) + response_schema = RegisteredSchema.from_dict(response) registered_schema = RegisteredSchema( - schema_id=result.schema_id, - guid=result.guid, - subject=result.subject or subject_name, - version=result.version, - schema=result.schema + schema_id=response_schema.schema_id, + guid=response_schema.guid, + subject=response_schema.subject or subject_name, + version=response_schema.version, + schema=response_schema.schema, ) # The registered schema may not be fully populated @@ -724,7 +735,8 @@ def get_schema( `GET Schema API Reference `_ """ # noqa: E501 - result = self._cache.get_schema_by_id(subject_name, schema_id) + if subject_name is not None: + result = self._cache.get_schema_by_id(subject_name, schema_id) if result is not None: return result[1] @@ -818,7 +830,7 @@ def get_subjects_by_schema_id( Raises: SchemaRegistryError: if subjects can't be found """ - query = {'offset': offset, 'limit': limit} + query: dict[str, Any] = {'offset': offset, 'limit': limit} if subject_name is not None: query['subject'] = subject_name if deleted: @@ -851,7 +863,7 @@ def get_schema_versions( `GET Schema Versions API Reference `_ """ # noqa: E501 - query = {'offset': offset, 'limit': limit} + query: dict[str, Any] = {'offset': offset, 'limit': limit} if subject_name is not None: query['subject'] = subject_name if deleted: @@ -889,7 +901,7 @@ def lookup_schema( request = schema.to_dict() - query_params = { + query_params: dict[str, Any] = { 'normalize': normalize_schemas, 'deleted': deleted } @@ -942,7 +954,7 @@ def get_subjects( `GET subjects API Reference `_ """ # noqa: E501 - query = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} + query: dict[str, Any] = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} if subject_prefix is not None: query['subject'] = subject_prefix return self._rest_client.get('subjects', query) @@ -1039,7 +1051,7 @@ def get_latest_with_metadata( if registered_schema is not None: return registered_schema - query = {'deleted': deleted} + query: dict[str, Any] = {'deleted': deleted} if fmt is not None: query['format'] = fmt keys = metadata.keys() @@ -1058,7 +1070,7 @@ def get_latest_with_metadata( return registered_schema def get_version( - self, subject_name: str, version: Union[int, str] = "latest", + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", deleted: bool = False, fmt: Optional[str] = None ) -> 'RegisteredSchema': """ @@ -1066,7 +1078,7 @@ def get_version( Args: subject_name (str): Subject name. - version (Union[int, str]): Version of the schema or string "latest". Defaults to latest version. + version (Union[int, Literal["latest"]]): Version of the schema or string "latest". Defaults to latest version. deleted (bool): Whether to include deleted schemas. fmt (str): Format of the schema. @@ -1080,11 +1092,12 @@ def get_version( `GET Subject Versions API Reference `_ """ # noqa: E501 - registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) - if registered_schema is not None: - return registered_schema + if version != "latest": + registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) + if registered_schema is not None: + return registered_schema - query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + query: dict[str, Any] = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} response = self._rest_client.get( 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query ) @@ -1096,7 +1109,7 @@ def get_version( return registered_schema def get_referenced_by( - self, subject_name: str, version: Union[int, str] = "latest", + self, subject_name: str, version: Union[int, Literal["latest"]] = "latest", offset: int = 0, limit: int = -1 ) -> List[int]: """ @@ -1104,7 +1117,7 @@ def get_referenced_by( Args: subject_name (str): Subject name - version (int or str): Version number or "latest" + version (Union[int, Literal["latest"]]): Version number or "latest" offset (int): Pagination offset for results. limit (int): Pagination size for results. Ignored if negative. @@ -1118,7 +1131,7 @@ def get_referenced_by( `GET Subject Versions (ReferenceBy) API Reference `_ """ # noqa: E501 - query = {'offset': offset, 'limit': limit} + query: dict[str, Any] = {'offset': offset, 'limit': limit} return self._rest_client.get('subjects/{}/versions/{}/referencedby'.format( _urlencode(subject_name), version), query) @@ -1146,7 +1159,7 @@ def get_versions( `GET Subject All Versions API Reference `_ """ # noqa: E501 - query = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} + query: dict[str, Any] = {'deleted': deleted, 'deleted_only': deleted_only, 'offset': offset, 'limit': limit} return self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name)), query) def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: @@ -1516,6 +1529,6 @@ def clear_caches(self): def new_client(conf: dict) -> 'SchemaRegistryClient': from .mock_schema_registry_client import MockSchemaRegistryClient url = conf.get("url") - if url.startswith("mock://"): + if url and isinstance(url, str) and url.startswith("mock://"): return MockSchemaRegistryClient(conf) return SchemaRegistryClient(conf) diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py index 49957ce5c..bdf4f8b02 100644 --- a/src/confluent_kafka/schema_registry/_sync/serde.py +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -17,7 +17,7 @@ # import logging -from typing import List, Optional, Set, Dict, Any +from typing import List, Optional, Set, Dict, Any, Callable from confluent_kafka.schema_registry import RegisteredSchema from confluent_kafka.schema_registry.common.schema_registry_client import \ @@ -44,6 +44,14 @@ class BaseSerde(object): '_registry', '_rule_registry', '_subject_name_func', '_field_transformer'] + _use_schema_id: Optional[int] + _use_latest_version: bool + _use_latest_with_metadata: Optional[Dict[str, str]] + _registry: Any # SchemaRegistryClient + _rule_registry: Any # RuleRegistry + _subject_name_func: Callable[[Optional['SerializationContext'], Optional[str]], Optional[str]] + _field_transformer: Optional[FieldTransformer] + def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: if self._use_schema_id is not None: schema = self._registry.get_schema(self._use_schema_id, subject, fmt) @@ -114,6 +122,11 @@ def _execute_rules_with_phase( ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule, index, rules, inline_tags, field_transformer) + if rule.type is None: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, + RuleError(f"Rule type is None for rule {rule.name}"), + 'ERROR') + return message rule_executor = self._rule_registry.get_executor(rule.type.upper()) if rule_executor is None: self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, @@ -140,18 +153,24 @@ def _execute_rules_with_phase( return message def _get_on_success(self, rule: Rule) -> Optional[str]: + if rule.type is None: + return rule.on_success override = self._rule_registry.get_override(rule.type) if override is not None and override.on_success is not None: return override.on_success return rule.on_success def _get_on_failure(self, rule: Rule) -> Optional[str]: + if rule.type is None: + return rule.on_failure override = self._rule_registry.get_override(rule.type) if override is not None and override.on_failure is not None: return override.on_failure return rule.on_failure def _is_disabled(self, rule: Rule) -> Optional[bool]: + if rule.type is None: + return rule.disabled override = self._rule_registry.get_override(rule.type) if override is not None and override.disabled is not None: return override.disabled @@ -200,10 +219,16 @@ def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleA class BaseSerializer(BaseSerde, Serializer): __slots__ = ['_auto_register', '_normalize_schemas', '_schema_id_serializer'] + _auto_register: bool + _normalize_schemas: bool + _schema_id_serializer: Callable[[bytes, Any, Any], bytes] + class BaseDeserializer(BaseSerde, Deserializer): __slots__ = ['_schema_id_deserializer'] + _schema_id_deserializer: Callable[[bytes, Any, Any], Any] + def _get_writer_schema( self, schema_id: SchemaId, subject: Optional[str] = None, fmt: Optional[str] = None) -> Schema: @@ -241,7 +266,7 @@ def _get_migrations( ) -> List[Migration]: source = self._registry.lookup_schema( subject, source_info, normalize_schemas=False, deleted=True) - migrations = [] + migrations: List[Migration] = [] if source.version < target.version: migration_mode = RuleMode.UPGRADE first = source @@ -259,13 +284,14 @@ def _get_migrations( if i == 0: previous = version continue - if version.schema.rule_set is not None and self._has_rules( + if version.schema is not None and version.schema.rule_set is not None and self._has_rules( version.schema.rule_set, RulePhase.MIGRATION, migration_mode): - if migration_mode == RuleMode.UPGRADE: - migration = Migration(migration_mode, previous, version) - else: - migration = Migration(migration_mode, version, previous) - migrations.append(migration) + if previous is not None: # previous is always set after first iteration + if migration_mode == RuleMode.UPGRADE: + migration = Migration(migration_mode, previous, version) + else: + migration = Migration(migration_mode, version, previous) + migrations.append(migration) previous = version if migration_mode == RuleMode.DOWNGRADE: migrations.reverse() @@ -275,6 +301,8 @@ def _get_schemas_between( self, subject: str, first: RegisteredSchema, last: RegisteredSchema, fmt: Optional[str] = None ) -> List[RegisteredSchema]: + if first.version is None or last.version is None: + return [first, last] if last.version - first.version <= 1: return [first, last] version1 = first.version @@ -290,8 +318,9 @@ def _execute_migrations( migrations: List[Migration], message: Any ) -> Any: for migration in migrations: - message = self._execute_rules_with_phase( - ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode, - migration.source.schema, migration.target.schema, - message, None, None) + if migration.source is not None and migration.target is not None: + message = self._execute_rules_with_phase( + ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode, + migration.source.schema, migration.target.schema, + message, None, None) return message diff --git a/src/confluent_kafka/schema_registry/common/avro.py b/src/confluent_kafka/schema_registry/common/avro.py index 70eff464c..ab5982f4e 100644 --- a/src/confluent_kafka/schema_registry/common/avro.py +++ b/src/confluent_kafka/schema_registry/common/avro.py @@ -1,4 +1,5 @@ import decimal +import logging import re from collections import defaultdict from copy import deepcopy @@ -44,6 +45,8 @@ ] AvroSchema = Union[str, list, dict] +log = logging.getLogger(__name__) + class _ContextStringIO(BytesIO): """ @@ -113,12 +116,21 @@ def transform( elif isinstance(schema, dict): schema_type = schema.get("type") if schema_type == 'array': + if not isinstance(message, list): + log.warning("Incompatible message type for array schema") + return message return [transform(ctx, schema["items"], item, field_transform) for item in message] elif schema_type == 'map': + if not isinstance(message, dict): + log.warning("Incompatible message type for map schema") + return message return {key: transform(ctx, schema["values"], value, field_transform) for key, value in message.items()} elif schema_type == 'record': + if not isinstance(message, dict): + log.warning("Incompatible message type for record schema") + return message fields = schema["fields"] for field in fields: _transform_field(ctx, schema, field, message, field_transform) @@ -132,8 +144,8 @@ def transform( def _transform_field( - ctx: RuleContext, schema: AvroSchema, field: dict, - message: AvroMessage, field_transform: FieldTransform + ctx: RuleContext, schema: dict, field: dict, + message: dict, field_transform: FieldTransform ): field_type = field["type"] name = field["name"] @@ -216,7 +228,7 @@ def _resolve_union(schema: AvroSchema, message: AvroMessage) -> Optional[AvroSch def get_inline_tags(schema: AvroSchema) -> Dict[str, Set[str]]: - inline_tags = defaultdict(set) + inline_tags: Dict[str, Set[str]] = defaultdict(set) _get_inline_tags_recursively('', '', schema, inline_tags) return inline_tags @@ -246,16 +258,18 @@ def _get_inline_tags_recursively( record_ns = _implied_namespace(name) if record_ns is None: record_ns = ns - if record_ns != '' and not record_name.startswith(record_ns): + # Ensure record_name is not None and doesn't already have namespace prefix + if record_name is not None and record_ns != '' and not record_name.startswith(record_ns): record_name = f"{record_ns}.{record_name}" fields = schema["fields"] for field in fields: field_tags = field.get("confluent:tags") field_name = field.get("name") field_type = field.get("type") - if field_tags is not None and field_name is not None: + # Ensure all required fields are present before building tag key + if field_tags is not None and field_name is not None and record_name is not None: tags[record_name + '.' + field_name].update(field_tags) - if field_type is not None: + if field_type is not None and record_name is not None: _get_inline_tags_recursively(record_ns, record_name, field_type, tags) diff --git a/src/confluent_kafka/schema_registry/common/json_schema.py b/src/confluent_kafka/schema_registry/common/json_schema.py index 0e3564688..ceeacf7bc 100644 --- a/src/confluent_kafka/schema_registry/common/json_schema.py +++ b/src/confluent_kafka/schema_registry/common/json_schema.py @@ -2,6 +2,7 @@ import decimal from io import BytesIO +import logging from typing import Union, Optional, List, Set import httpx @@ -42,8 +43,9 @@ JsonSchema = Union[bool, dict] -DEFAULT_SPEC = referencing.jsonschema.DRAFT7 +DEFAULT_SPEC = referencing.jsonschema.DRAFT7 # type: ignore[attr-defined] +log = logging.getLogger(__name__) class _ContextStringIO(BytesIO): """ @@ -68,8 +70,10 @@ def transform( ctx: RuleContext, schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, path: str, message: JsonMessage, field_transform: FieldTransform ) -> Optional[JsonMessage]: + # Only proceed to transform the message if schema is of dict type if message is None or schema is None or isinstance(schema, bool): return message + field_ctx = ctx.current_field() if field_ctx is not None: field_ctx.field_type = get_type(schema) @@ -104,13 +108,18 @@ def transform( if ref is not None: ref_schema = ref_resolver.lookup(ref) return transform(ctx, ref_schema.contents, ref_registry, ref_resolver, path, message, field_transform) + schema_type = get_type(schema) if schema_type == FieldType.RECORD: props = schema.get("properties") + if not isinstance(message, dict): + log.warning("Incompatible message type for record schema") + return message if props is not None: for prop_name, prop_schema in props.items(): - _transform_field(ctx, path, prop_name, message, - prop_schema, ref_registry, ref_resolver, field_transform) + if isinstance(prop_schema, dict): + _transform_field(ctx, path, prop_name, message, + prop_schema, ref_registry, ref_resolver, field_transform) return message if schema_type in (FieldType.ENUM, FieldType.STRING, FieldType.INT, FieldType.DOUBLE, FieldType.BOOLEAN): if field_ctx is not None: @@ -121,8 +130,8 @@ def transform( def _transform_field( - ctx: RuleContext, path: str, prop_name: str, message: JsonMessage, - prop_schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, field_transform: FieldTransform + ctx: RuleContext, path: str, prop_name: str, message: dict, + prop_schema: dict, ref_registry: Registry, ref_resolver: Resolver, field_transform: FieldTransform ): full_name = path + "." + prop_name try: @@ -146,8 +155,17 @@ def _transform_field( def _validate_subtypes( - schema: JsonSchema, message: JsonMessage, registry: Registry + schema: dict, message: JsonMessage, registry: Registry ) -> Optional[JsonSchema]: + """ + Validate the message against the subtypes. + Args: + schema: The schema to validate the message against. + message: The message to validate. + registry: The registry to use for the validation. + Returns: + The validated schema if the message is valid against the subtypes, otherwise None. + """ schema_type = schema.get("type") if not isinstance(schema_type, list) or len(schema_type) == 0: return None @@ -166,29 +184,35 @@ def _validate_subschemas( message: JsonMessage, registry: Registry, resolver: Resolver, -) -> Optional[JsonSchema]: +)-> Optional[JsonSchema]: + """ + Validate the message against the subschemas. + Args: + subschemas: The list of subschemas to validate the message against. + message: The message to validate. + registry: The registry to use for the validation. + resolver: The resolver to use for the validation. + Returns: + The validated schema if the message is valid against the subschemas, otherwise None. + """ for subschema in subschemas: - try: - ref = subschema.get("$ref") - if ref is not None: - # resolve $ref before validating - subschema = resolver.lookup(ref).contents - validate(instance=message, schema=subschema, registry=registry, resolver=resolver) - return subschema - except ValidationError: - pass + if isinstance(subschema, dict): + try: + ref = subschema.get("$ref") + if ref is not None: + subschema = resolver.lookup(ref).contents + validate(instance=message, schema=subschema, registry=registry, resolver=resolver) + return subschema + except ValidationError: + pass return None def get_type(schema: JsonSchema) -> FieldType: - if isinstance(schema, list): + if isinstance(schema, bool): return FieldType.COMBINED - elif isinstance(schema, dict): - schema_type = schema.get("type") - else: - # string schemas; this could be either a named schema or a primitive type - schema_type = schema + schema_type = schema.get("type") if schema.get("const") is not None or schema.get("enum") is not None: return FieldType.ENUM if schema_type == "object": @@ -223,7 +247,7 @@ def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: return True -def get_inline_tags(schema: JsonSchema) -> Set[str]: +def get_inline_tags(schema: dict) -> Set[str]: tags = schema.get("confluent:tags") if tags is None: return set() diff --git a/src/confluent_kafka/schema_registry/common/protobuf.py b/src/confluent_kafka/schema_registry/common/protobuf.py index 3a55bc5de..889b58020 100644 --- a/src/confluent_kafka/schema_registry/common/protobuf.py +++ b/src/confluent_kafka/schema_registry/common/protobuf.py @@ -3,7 +3,7 @@ import base64 from collections import deque from decimal import Context, Decimal, MAX_PREC -from typing import Set, List, Any +from typing import Set, List, Any, Deque from google.protobuf import descriptor_pb2, any_pb2, api_pb2, empty_pb2, \ duration_pb2, field_mask_pb2, source_context_pb2, struct_pb2, timestamp_pb2, \ @@ -57,7 +57,7 @@ def _bytes(v: int) -> bytes: """ return bytes((v,)) else: - def _bytes(v: int) -> str: + def _bytes(v: int) -> str: # type: ignore[misc] """ Convert int to bytes @@ -97,7 +97,7 @@ def _create_index_array(msg_desc: Descriptor) -> List[int]: ValueError: If the message descriptor is malformed. """ - msg_idx = deque() + msg_idx: Deque[int] = deque() # Walk the nested MessageDescriptor tree up to the root. current = msg_desc @@ -310,7 +310,7 @@ def is_map_field(fd: FieldDescriptor): def get_inline_tags(fd: FieldDescriptor) -> Set[str]: - meta = fd.GetOptions().Extensions[meta_pb2.field_meta] + meta = fd.GetOptions().Extensions[meta_pb2.field_meta] # type: ignore[attr-defined] if meta is None: return set() else: @@ -330,7 +330,7 @@ def _is_builtin(name: str) -> bool: name.startswith('google/type/') -def decimal_to_protobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: +def decimal_to_protobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: # type: ignore[name-defined] """ Converts a Decimal to a Protobuf value. @@ -343,7 +343,7 @@ def decimal_to_protobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: """ sign, digits, exp = value.as_tuple() - delta = exp + scale + delta = exp + scale # type: ignore[operator] if delta < 0: raise ValueError( @@ -362,7 +362,7 @@ def decimal_to_protobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: bytes = unscaled_datum.to_bytes(bytes_req, byteorder="big", signed=True) - result = decimal_pb2.Decimal() + result = decimal_pb2.Decimal() # type: ignore[attr-defined] result.value = bytes result.precision = 0 result.scale = scale @@ -372,7 +372,7 @@ def decimal_to_protobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: decimal_context = Context() -def protobuf_to_decimal(value: decimal_pb2.Decimal) -> Decimal: +def protobuf_to_decimal(value: decimal_pb2.Decimal) -> Decimal: # type: ignore[name-defined] """ Converts a Protobuf value to Decimal. diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index e72237cdd..70511d842 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -22,16 +22,18 @@ from collections import defaultdict from enum import Enum from threading import Lock -from typing import List, Dict, Type, TypeVar, \ +from typing import List, Dict, Type, TypeVar, Union, \ cast, Optional, Any, Tuple __all__ = [ 'VALID_AUTH_PROVIDERS', '_BearerFieldProvider', + '_AsyncBearerFieldProvider', 'is_success', 'is_retriable', 'full_jitter', '_StaticFieldProvider', + '_AsyncStaticFieldProvider', '_SchemaCache', 'RuleKind', 'RuleMode', @@ -52,35 +54,55 @@ class _BearerFieldProvider(metaclass=abc.ABCMeta): + """Base class for synchronous bearer field providers.""" @abc.abstractmethod def get_bearer_fields(self) -> dict: raise NotImplementedError -def is_success(status_code: int) -> bool: - return 200 <= status_code <= 299 - - -def is_retriable(status_code: int) -> bool: - return status_code in (408, 429, 500, 502, 503, 504) +class _AsyncBearerFieldProvider(metaclass=abc.ABCMeta): + """Base class for asynchronous bearer field providers.""" + @abc.abstractmethod + async def get_bearer_fields(self) -> dict: + raise NotImplementedError +class _StaticFieldProvider(_BearerFieldProvider): + """Synchronous static token bearer field provider.""" + def __init__(self, token: str, logical_cluster: str, identity_pool: str): + self.token = token + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool -def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) -> float: - no_jitter_delay = base_delay_ms * (2.0 ** retries_attempted) - return random.random() * min(no_jitter_delay, max_delay_ms) + def get_bearer_fields(self) -> dict: + return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool} -class _StaticFieldProvider(_BearerFieldProvider): +class _AsyncStaticFieldProvider(_AsyncBearerFieldProvider): + """Asynchronous static token bearer field provider.""" def __init__(self, token: str, logical_cluster: str, identity_pool: str): self.token = token self.logical_cluster = logical_cluster self.identity_pool = identity_pool - def get_bearer_fields(self) -> dict: + async def get_bearer_fields(self) -> dict: return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, 'bearer.auth.identity.pool.id': self.identity_pool} +def is_success(status_code: int) -> bool: + return 200 <= status_code <= 299 + + +def is_retriable(status_code: int) -> bool: + return status_code in (408, 429, 500, 502, 503, 504) + + +def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) -> float: + no_jitter_delay = base_delay_ms * (2.0 ** retries_attempted) + return random.random() * min(no_jitter_delay, max_delay_ms) + + class _SchemaCache(object): """ Thread-safe cache for use with the Schema Registry Client. @@ -943,7 +965,7 @@ class RegisteredSchema: version: Optional[int] schema_id: Optional[int] guid: Optional[str] - schema: Optional[Schema] + schema: Schema def to_dict(self) -> Dict[str, Any]: schema = self.schema diff --git a/src/confluent_kafka/schema_registry/rule_registry.py b/src/confluent_kafka/schema_registry/rule_registry.py index f1d90756f..93a41235b 100644 --- a/src/confluent_kafka/schema_registry/rule_registry.py +++ b/src/confluent_kafka/schema_registry/rule_registry.py @@ -1,3 +1,4 @@ + #!/usr/bin/env python # -*- coding: utf-8 -*- # diff --git a/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py b/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py index 0a0542ccd..c9cdaf92f 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py +++ b/src/confluent_kafka/schema_registry/rules/cel/cel_executor.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +import logging import uuid import celpy @@ -33,6 +34,9 @@ from google.protobuf import message + +log = logging.getLogger(__name__) + # A date logical type annotates an Avro int, where the int stores the number # of days from the unix epoch, 1 January 1970 (ISO calendar). DAYS_SHIFT = datetime.date(1970, 1, 1).toordinal() @@ -54,6 +58,9 @@ def transform(self, ctx: RuleContext, msg: Any) -> Any: def execute(self, ctx: RuleContext, msg: Any, args: Any) -> Any: expr = ctx.rule.expr + if expr is None: + log.warning("Expression from rule %s is None", ctx.rule.name) + return msg try: index = expr.index(";") except ValueError: @@ -72,7 +79,11 @@ def execute(self, ctx: RuleContext, msg: Any, args: Any) -> Any: def execute_rule(self, ctx: RuleContext, expr: str, args: Any) -> Any: schema = ctx.target - script_type = ctx.target.schema_type + if schema is None: + raise ValueError("Target schema is None") # TODO: check whether we should raise or return fallback + script_type = schema.schema_type + if script_type is None: + raise ValueError("Target schema type is None") # TODO: check whether we should raise or return fallback prog = self._cache.get_program(expr, script_type, schema) if prog is None: ast = self._env.compile(expr) @@ -158,7 +169,7 @@ def _dict_to_cel(val: dict) -> Dict[str, celtypes.Value]: result = celtypes.MapType() for key, val in val.items(): result[key] = _value_to_cel(val) - return result + return result # type: ignore[return-value] def _array_to_cel(val: list) -> List[celtypes.Value]: diff --git a/src/confluent_kafka/schema_registry/rules/cel/cel_field_presence.py b/src/confluent_kafka/schema_registry/rules/cel/cel_field_presence.py index 329077a3e..597f1b5fb 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/cel_field_presence.py +++ b/src/confluent_kafka/schema_registry/rules/cel/cel_field_presence.py @@ -16,7 +16,7 @@ import threading from typing import Any -import celpy # type: ignore +import celpy _has_state = threading.local() diff --git a/src/confluent_kafka/schema_registry/rules/cel/constraints.py b/src/confluent_kafka/schema_registry/rules/cel/constraints.py index 98b739cd3..32e87fdcf 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/constraints.py +++ b/src/confluent_kafka/schema_registry/rules/cel/constraints.py @@ -15,7 +15,7 @@ import typing -from celpy import celtypes # type: ignore +from celpy import celtypes from google.protobuf import descriptor, message, message_factory from confluent_kafka.schema_registry.rules.cel import string_format @@ -27,18 +27,18 @@ class CompilationError(Exception): def make_key_path(field_name: str, key: celtypes.Value) -> str: - return f"{field_name}[{string_format.format_value(key)}]" + return f"{field_name}[{string_format.format_value(key)}]" # type: ignore[str-bytes-safe] def make_duration(msg: message.Message) -> celtypes.DurationType: return celtypes.DurationType( - seconds=msg.seconds, # type: ignore - nanos=msg.nanos, # type: ignore + seconds=msg.seconds, + nanos=msg.nanos, ) def make_timestamp(msg: message.Message) -> celtypes.TimestampType: - return make_duration(msg) + celtypes.TimestampType(1970, 1, 1) + return make_duration(msg) + celtypes.TimestampType(1970, 1, 1) # type: ignore[return-value] def unwrap(msg: message.Message) -> celtypes.Value: @@ -87,8 +87,8 @@ def __getitem__(self, name): def _msg_to_cel(msg: message.Message) -> typing.Dict[str, celtypes.Value]: ctor = _MSG_TYPE_URL_TO_CTOR.get(msg.DESCRIPTOR.full_name) if ctor is not None: - return ctor(msg) - return MessageType(msg) + return ctor(msg) # type: ignore[return-value] + return MessageType(msg) # type: ignore[return-value] _TYPE_TO_CTOR = { @@ -115,14 +115,14 @@ def _msg_to_cel(msg: message.Message) -> typing.Dict[str, celtypes.Value]: def _proto_message_has_field(msg: message.Message, field: descriptor.FieldDescriptor) -> typing.Any: if field.is_extension: - return msg.HasExtension(field) # type: ignore + return msg.HasExtension(field) else: return msg.HasField(field.name) def _proto_message_get_field(msg: message.Message, field: descriptor.FieldDescriptor) -> typing.Any: if field.is_extension: - return msg.Extensions[field] # type: ignore + return msg.Extensions[field] else: return getattr(msg, field.name) @@ -132,7 +132,7 @@ def _scalar_field_value_to_cel(val: typing.Any, field: descriptor.FieldDescripto if ctor is None: msg = "unknown field type" raise CompilationError(msg) - return ctor(val) + return ctor(val) # type: ignore[operator] def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value: @@ -144,7 +144,7 @@ def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> c def _is_empty_field(msg: message.Message, field: descriptor.FieldDescriptor) -> bool: - if field.has_presence: # type: ignore[attr-defined] + if field.has_presence: return not _proto_message_has_field(msg, field) if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: return len(_proto_message_get_field(msg, field)) == 0 diff --git a/src/confluent_kafka/schema_registry/rules/cel/extra_func.py b/src/confluent_kafka/schema_registry/rules/cel/extra_func.py index 497e9632a..95c016fc2 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/extra_func.py +++ b/src/confluent_kafka/schema_registry/rules/cel/extra_func.py @@ -19,8 +19,8 @@ from ipaddress import IPv4Address, IPv6Address, ip_address from urllib import parse as urlparse -import celpy # type: ignore -from celpy import celtypes # type: ignore +import celpy +from celpy import celtypes from confluent_kafka.schema_registry.rules.cel import string_format @@ -143,6 +143,9 @@ def is_email(string: celtypes.Value) -> celpy.Result: def is_uri(string: celtypes.Value) -> celpy.Result: + if not isinstance(string, celtypes.StringType): + msg = "invalid argument, expected string" + raise celpy.CELEvalError(msg) url = urlparse.urlparse(string) if not all([url.scheme, url.netloc, url.path]): return celtypes.BoolType(False) @@ -150,6 +153,9 @@ def is_uri(string: celtypes.Value) -> celpy.Result: def is_uri_ref(string: celtypes.Value) -> celpy.Result: + if not isinstance(string, celtypes.StringType): + msg = "invalid argument, expected string" + raise celpy.CELEvalError(msg) url = urlparse.urlparse(string) if not all([url.scheme, url.path]) and url.fragment: return celtypes.BoolType(False) diff --git a/src/confluent_kafka/schema_registry/rules/cel/string_format.py b/src/confluent_kafka/schema_registry/rules/cel/string_format.py index 3e295849c..a596c623e 100644 --- a/src/confluent_kafka/schema_registry/rules/cel/string_format.py +++ b/src/confluent_kafka/schema_registry/rules/cel/string_format.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import celpy # type: ignore -from celpy import celtypes # type: ignore +import celpy +from celpy import celtypes QUOTE_TRANS = str.maketrans( { @@ -43,13 +43,13 @@ def __init__(self, locale: str): def format(self, fmt: celtypes.Value, args: celtypes.Value) -> celpy.Result: if not isinstance(fmt, celtypes.StringType): - return celpy.native_to_cel(celpy.new_error("format() requires a string as the first argument")) + return celpy.native_to_cel(celpy.new_error("format() requires a string as the first argument")) # type: ignore[attr-defined] if not isinstance(args, celtypes.ListType): - return celpy.native_to_cel(celpy.new_error("format() requires a list as the second argument")) + return celpy.native_to_cel(celpy.new_error("format() requires a list as the second argument")) # type: ignore[attr-defined] # printf style formatting i = 0 j = 0 - result = "" + result: str = "" while i < len(fmt): if fmt[i] != "%": result += fmt[i] @@ -76,25 +76,41 @@ def format(self, fmt: celtypes.Value, args: celtypes.Value) -> celpy.Result: i += 1 if i >= len(fmt): return celpy.CELEvalError("format() incomplete format specifier") + + # Format the argument and handle errors + formatted: celpy.Result if fmt[i] == "f": - result += self.format_float(arg, precision) - if fmt[i] == "e": - result += self.format_exponential(arg, precision) + formatted = self.format_float(arg, precision) + elif fmt[i] == "e": + formatted = self.format_exponential(arg, precision) elif fmt[i] == "d": - result += self.format_int(arg) + formatted = self.format_int(arg) elif fmt[i] == "s": - result += self.format_string(arg) + formatted = self.format_string(arg) elif fmt[i] == "x": - result += self.format_hex(arg) + formatted = self.format_hex(arg) elif fmt[i] == "X": - result += self.format_hex(arg).upper() + formatted = self.format_hex(arg) + if isinstance(formatted, celpy.CELEvalError): + return formatted + result += str(formatted).upper() + i += 1 + continue elif fmt[i] == "o": - result += self.format_oct(arg) + formatted = self.format_oct(arg) elif fmt[i] == "b": - result += self.format_bin(arg) + formatted = self.format_bin(arg) else: return celpy.CELEvalError("format() unknown format specifier: " + fmt[i]) + + # Check if formatting returned an error + if isinstance(formatted, celpy.CELEvalError): + return formatted + + # Append the formatted string + result += str(formatted) i += 1 + if j < len(args): return celpy.CELEvalError("format() too many arguments for format string") return celtypes.StringType(result) @@ -109,11 +125,12 @@ def format_exponential(self, arg: celtypes.Value, precision: int) -> celpy.Resul return celtypes.StringType(f"{arg:.{precision}e}") return self.format_int(arg) + # TODO: check if celtypes.StringType() supports int conversion def format_int(self, arg: celtypes.Value) -> celpy.Result: if isinstance(arg, celtypes.IntType): - return celtypes.StringType(arg) + return celtypes.StringType(arg) # type: ignore[arg-type] if isinstance(arg, celtypes.UintType): - return celtypes.StringType(arg) + return celtypes.StringType(arg) # type: ignore[arg-type] return celpy.CELEvalError("format_int() requires an integer argument") def format_hex(self, arg: celtypes.Value) -> celpy.Result: @@ -150,21 +167,24 @@ def format_string(self, arg: celtypes.Value) -> celpy.Result: return celtypes.StringType(arg.hex()) if isinstance(arg, celtypes.ListType): return self.format_list(arg) - return celtypes.StringType(arg) + return celtypes.StringType(arg) # type: ignore[arg-type] def format_value(self, arg: celtypes.Value) -> celpy.Result: if isinstance(arg, (celtypes.StringType, str)): return celtypes.StringType(quote(arg)) if isinstance(arg, celtypes.UintType): - return celtypes.StringType(arg) + return celtypes.StringType(arg) # type: ignore[arg-type] return self.format_string(arg) def format_list(self, arg: celtypes.ListType) -> celpy.Result: - result = "[" + result: str = "[" for i in range(len(arg)): if i > 0: result += ", " - result += self.format_value(arg[i]) + formatted = self.format_value(arg[i]) + if isinstance(formatted, celpy.CELEvalError): + return formatted + result += str(formatted) result += "]" return celtypes.StringType(result) diff --git a/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py b/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py index cf33ca2d6..629cb78fd 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/azurekms/azure_client.py @@ -30,27 +30,25 @@ class AzureKmsClient(tink.KmsClient): """Basic Azure client for AEAD.""" def __init__( - self, key_uri: Optional[str], credentials: TokenCredential + self, key_uri: str, credentials: TokenCredential ) -> None: """Creates a new AzureKmsClient that is bound to the key specified in 'key_uri'. Uses the specified credentials when communicating with the KMS. Args: - key_uri: The URI of the key the client should be bound to. If it is None - or empty, then the client is not bound to any particular key. + key_uri: The URI of the key the client should be bound to. credentials: The token credentials. Raises: TinkError: If the key uri is not valid. """ - if not key_uri: - self._key_uri = None - elif key_uri.startswith(AZURE_KEYURI_PREFIX): + if key_uri.startswith(AZURE_KEYURI_PREFIX): self._key_uri = key_uri else: raise tink.TinkError('Invalid key_uri.') + key_id = key_uri[len(AZURE_KEYURI_PREFIX):] self._client = CryptographyClient(key_id, credentials) diff --git a/src/confluent_kafka/schema_registry/rules/encryption/dek_registry/dek_registry_client.py b/src/confluent_kafka/schema_registry/rules/encryption/dek_registry/dek_registry_client.py index 55195407c..1c130a8fa 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/dek_registry/dek_registry_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/dek_registry/dek_registry_client.py @@ -46,7 +46,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: d = src_dict.copy() kek_kms_props = cls() - kek_kms_props.properties = d + kek_kms_props.properties = d # type: ignore[attr-defined] return kek_kms_props @@ -124,7 +124,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: deleted = d.pop("deleted", None) - kek = cls( + kek = cls( # type: ignore[call-arg] name=name, kms_type=kms_type, kms_key_id=kms_key_id, @@ -198,7 +198,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: shared = d.pop("shared", None) - create_kek_request = cls( + create_kek_request = cls( # type: ignore[call-arg] name=name, kms_type=kms_type, kms_key_id=kms_key_id, @@ -321,7 +321,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: deleted = d.pop("deleted", None) - dek = cls( + dek = cls( # type: ignore[call-arg] kek_name=kek_name, subject=subject, version=version, @@ -381,7 +381,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: encrypted_key_material = d.pop("encryptedKeyMaterial", None) - create_dek_request = cls( + create_dek_request = cls( # type: ignore[call-arg] subject=subject, version=version, algorithm=algorithm, diff --git a/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py b/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py index 45a4ab99a..4c9642c4d 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py @@ -58,8 +58,8 @@ def now(self) -> int: class EncryptionExecutor(RuleExecutor): def __init__(self, clock: Clock = Clock()): - self.client = None - self.config = None + self.client: Optional[DekRegistryClient] = None + self.config: Optional[dict] = None self.clock = clock def configure(self, client_conf: dict, rule_conf: dict): @@ -203,7 +203,7 @@ def __init__(self, executor: EncryptionExecutor, cryptor: Cryptor, kek_name: str self._executor = executor self._cryptor = cryptor self._kek_name = kek_name - self._kek = None + self._kek: Optional[Kek] = None self._dek_expiry_days = dek_expiry_days def _is_dek_rotated(self): @@ -242,6 +242,8 @@ def _get_or_create_kek(self, ctx: RuleContext) -> Kek: return kek def _retrieve_kek_from_registry(self, kek_id: KekId) -> Optional[Kek]: + if self._executor.client is None: + raise RuleError("client not configured") try: return self._executor.client.get_kek(kek_id.name, kek_id.deleted) except Exception as e: @@ -253,6 +255,8 @@ def _store_kek_to_registry( self, kek_id: KekId, kms_type: str, kms_key_id: str, shared: bool ) -> Optional[Kek]: + if self._executor.client is None: + raise RuleError("client not configured") try: return self._executor.client.register_kek(kek_id.name, kms_type, kms_key_id, shared) except Exception as e: @@ -265,8 +269,9 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: is_read = ctx.rule_mode == RuleMode.READ if version is None or version == 0: version = 1 + # TODO: fallback value for name? dek_id = DekId( - kek.name, + kek.name, # type: ignore[arg-type] ctx.subject, version, self._cryptor.dek_format, @@ -278,12 +283,19 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: if dek is None or is_expired: if is_read: raise RuleError(f"no dek found for {dek_id.kek_name} during consume") + if self._kek is None: + raise RuleError("no kek found") encrypted_dek = None if not kek.shared: + if self._executor.config is None: + raise RuleError("config not found in executor") primitive = AeadWrapper(self._executor.config, self._kek) raw_dek = self._cryptor.generate_key() encrypted_dek = primitive.encrypt(raw_dek, self._cryptor.EMPTY_AAD) - new_version = dek.version + 1 if is_expired else 1 + if dek is None or dek.version is None: + new_version = 1 + else: + new_version = dek.version + 1 if is_expired else 1 try: dek = self._create_dek(dek_id, new_version, encrypted_dek) except RuleError as e: @@ -294,17 +306,18 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: key_bytes = dek.get_key_material_bytes() if key_bytes is None: if primitive is None: - primitive = AeadWrapper(self._executor.config, self._kek) + primitive = AeadWrapper(self._executor.config, self._kek) # type: ignore[arg-type] encrypted_dek = dek.get_encrypted_key_material_bytes() - raw_dek = primitive.decrypt(encrypted_dek, self._cryptor.EMPTY_AAD) + raw_dek = primitive.decrypt(encrypted_dek, self._cryptor.EMPTY_AAD) # type: ignore[arg-type] dek.set_key_material(raw_dek) return dek def _create_dek(self, dek_id: DekId, new_version: Optional[int], encrypted_dek: Optional[bytes]) -> Dek: + # TODO: fallback value for version? new_dek_id = DekId( dek_id.kek_name, dek_id.subject, - new_version, + new_version, # type: ignore[arg-type] dek_id.algorithm, dek_id.deleted, ) @@ -321,6 +334,8 @@ def _retrieve_dek_from_registry(self, key: DekId) -> Optional[Dek]: version = key.version if not version: version = 1 + if self._executor.client is None: + raise RuleError("client not configured") dek = self._executor.client.get_dek( key.kek_name, key.subject, key.algorithm, version, key.deleted) return dek if dek and dek.encrypted_key_material else None @@ -332,8 +347,14 @@ def _retrieve_dek_from_registry(self, key: DekId) -> Optional[Dek]: def _store_dek_to_registry(self, key: DekId, encrypted_dek: Optional[bytes]) -> Optional[Dek]: try: encrypted_dek_str = base64.b64encode(encrypted_dek).decode("utf-8") if encrypted_dek else None + if self._executor.client is None: + raise RuleError("client not configured") dek = self._executor.client.register_dek( - key.kek_name, key.subject, encrypted_dek_str, key.algorithm, key.version) + key.kek_name, + key.subject, + encrypted_dek_str, # type: ignore[arg-type] + key.algorithm, + key.version) return dek except Exception as e: if isinstance(e, SchemaRegistryError) and e.http_status_code == 409: @@ -345,7 +366,7 @@ def _is_expired(self, ctx: RuleContext, dek: Optional[Dek]) -> bool: return (ctx.rule_mode != RuleMode.READ and self._dek_expiry_days > 0 and dek is not None - and (now - dek.ts) / MILLIS_IN_DAY > self._dek_expiry_days) + and (now - dek.ts) / MILLIS_IN_DAY > self._dek_expiry_days) # type: ignore[operator] def transform(self, ctx: RuleContext, field_type: FieldType, field_value: Any) -> Any: if field_value is None: @@ -359,14 +380,19 @@ def transform(self, ctx: RuleContext, field_type: FieldType, field_value: Any) - version = -1 dek = self._get_or_create_dek(ctx, version) key_material_bytes = dek.get_key_material_bytes() + if key_material_bytes is None: + raise RuleError("no key material bytes found for dek") ciphertext = self._cryptor.encrypt(key_material_bytes, plaintext, Cryptor.EMPTY_AAD) if self._is_dek_rotated(): + if dek.version is None: + raise RuleError("no version found for dek") ciphertext = self._prefix_version(dek.version, ciphertext) if field_type == FieldType.STRING: return base64.b64encode(ciphertext).decode("utf-8") else: return self._to_object(field_type, ciphertext) elif ctx.rule_mode == RuleMode.READ: + ciphertext = None if field_type == FieldType.STRING: ciphertext = base64.b64decode(field_value) else: @@ -381,6 +407,8 @@ def transform(self, ctx: RuleContext, field_type: FieldType, field_value: Any) - raise RuleError("no version found in ciphertext") dek = self._get_or_create_dek(ctx, version) key_material_bytes = dek.get_key_material_bytes() + if key_material_bytes is None: + raise RuleError("no key material bytes found for dek") plaintext = self._cryptor.decrypt(key_material_bytes, ciphertext, Cryptor.EMPTY_AAD) return self._to_object(field_type, plaintext) else: @@ -421,6 +449,8 @@ def __init__(self, config: dict, kek: Kek): def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes: for index, kms_key_id in enumerate(self._kms_key_ids): try: + if self._kek.kms_type is None: + raise RuleError("no kms type found for kek") aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) return aead.encrypt(plaintext, associated_data) except Exception as e: @@ -433,6 +463,8 @@ def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes: def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes: for index, kms_key_id in enumerate(self._kms_key_ids): try: + if self._kek.kms_type is None: + raise RuleError("no kms type found for kek") aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) return aead.decrypt(ciphertext, associated_data) except Exception as e: @@ -452,7 +484,7 @@ def _get_kms_key_ids(self) -> List[str]: if alternate_kms_key_ids is not None: # Split the comma-separated list of alternate KMS key IDs and append to kms_key_ids kms_key_ids.extend([id.strip() for id in alternate_kms_key_ids.split(',') if id.strip()]) - return kms_key_ids + return kms_key_ids # type: ignore[return-value] def _get_aead(self, config: dict, kms_type: str, kms_key_id: str) -> aead.Aead: kek_url = kms_type + "://" + kms_key_id diff --git a/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py b/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py index f8523d73f..adee38241 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/hcvault/hcvault_client.py @@ -29,7 +29,7 @@ class HcVaultKmsClient(tink.KmsClient): """Basic HashiCorp Vault client for AEAD.""" def __init__( - self, key_uri: Optional[str], token: Optional[str], ns: Optional[str] = None, + self, key_uri: str, token: Optional[str], ns: Optional[str] = None, role_id: Optional[str] = None, secret_id: Optional[str] = None ) -> None: """Creates a new HcVaultKmsClient that is bound to the key specified in 'key_uri'. @@ -37,8 +37,7 @@ def __init__( Uses the specified credentials when communicating with the KMS. Args: - key_uri: The URI of the key the client should be bound to. If it is None - or empty, then the client is not bound to any particular key. + key_uri: The URI of the key the client should be bound to. token: The Vault token. ns: The Vault namespace. @@ -46,12 +45,11 @@ def __init__( TinkError: If the key uri is not valid. """ - if not key_uri: - self._key_uri = None - elif key_uri.startswith(VAULT_KEYURI_PREFIX): + if key_uri.startswith(VAULT_KEYURI_PREFIX): self._key_uri = key_uri else: raise tink.TinkError('Invalid key_uri.') + parsed = urlparse(key_uri[len(VAULT_KEYURI_PREFIX):]) vault_url = parsed.scheme + '://' + parsed.netloc self._client = hvac.Client( @@ -60,7 +58,7 @@ def __init__( namespace=ns, verify=False ) - if role_id and secret_id: + if role_id and secret_id and self._client is not None: self._client.auth.approle.login(role_id=role_id, secret_id=secret_id) def does_support(self, key_uri: str) -> bool: diff --git a/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py b/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py index ba92e0478..51f5508cb 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/localkms/local_client.py @@ -23,6 +23,8 @@ class LocalKmsClient(KmsClient): def __init__(self, secret: Optional[str] = None): + if secret is None: + raise TypeError("secret cannot be None") self._aead = self._get_primitive(secret) def _get_primitive(self, secret: str) -> aead.Aead: diff --git a/tools/unasync.py b/tools/unasync.py index feb616c59..68691d889 100644 --- a/tools/unasync.py +++ b/tools/unasync.py @@ -16,6 +16,18 @@ ("tests/schema_registry/_async", "tests/schema_registry/_sync"), ] +# Type hint patterns that should NOT have word boundaries (they contain brackets) +TYPE_HINT_SUBS = [ + (r'Coroutine\[Any, Any, ([^\]]+)\]', r'\1'), + (r'Coroutine\[None, None, ([^\]]+)\]', r'\1'), + (r'Awaitable\[([^\]]+)\]', r'\1'), + (r'AsyncIterator\[([^\]]+)\]', r'Iterator[\1]'), + (r'AsyncIterable\[([^\]]+)\]', r'Iterable[\1]'), + (r'AsyncGenerator\[([^,]+), ([^\]]+)\]', r'Generator[\1, \2, None]'), + (r'AsyncContextManager\[([^\]]+)\]', r'ContextManager[\1]'), +] + +# Regular substitutions that need word boundaries SUBS = [ ('from confluent_kafka.schema_registry.common import asyncinit', ''), ('@asyncinit', ''), @@ -36,6 +48,13 @@ (r'asyncio.run\((.*)\)', r'\2'), ] +# Compile type hint patterns without word boundaries +COMPILED_TYPE_HINT_SUBS = [ + (re.compile(regex), repl) + for regex, repl in TYPE_HINT_SUBS +] + +# Compile regular patterns with word boundaries COMPILED_SUBS = [ (re.compile(r'(^|\b)' + regex + r'($|\b)'), repl) for regex, repl in SUBS @@ -45,6 +64,11 @@ def unasync_line(line): + # First apply type hint transformations (without word boundaries) + for regex, repl in COMPILED_TYPE_HINT_SUBS: + line = re.sub(regex, repl, line) + + # Then apply regular transformations (with word boundaries) for index, (regex, repl) in enumerate(COMPILED_SUBS): old_line = line line = re.sub(regex, repl, line) @@ -189,5 +213,37 @@ def unasync(dir_pairs=None, check=False): '--check', action='store_true', help='Exit with non-zero status if sync directory has any differences') + parser.add_argument( + '--file', + type=str, + help='Convert a single file instead of all directories') args = parser.parse_args() - unasync(check=args.check) + + if args.file: + # Single file mode + async_file = args.file + if not os.path.exists(async_file): + print(f"Error: File {async_file} does not exist") + sys.exit(1) + + # Determine the sync file path + sync_file = None + for async_dir, sync_dir in ASYNC_TO_SYNC: + if async_file.startswith(async_dir): + sync_file = async_file.replace(async_dir, sync_dir, 1) + break + + if not sync_file: + print(f"Error: File {async_file} is not in a known async directory") + print(f"Known async directories: {[d[0] for d in ASYNC_TO_SYNC]}") + sys.exit(1) + + # Create the output directory if needed + os.makedirs(os.path.dirname(sync_file), exist_ok=True) + + print(f"Converting: {async_file} -> {sync_file}") + unasync_file(async_file, sync_file) + print("✅ Done!") + else: + # Directory mode (original behavior) + unasync(check=args.check)