Skip to content

Commit a3caf39

Browse files
committed
fix
1 parent 8be650a commit a3caf39

File tree

6 files changed

+113
-53
lines changed

6 files changed

+113
-53
lines changed

src/confluent_kafka/schema_registry/_async/avro.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717
import io
1818
import json
19-
from typing import Dict, Union, Optional, Callable
19+
from typing import Any, Coroutine, Dict, Union, Optional, Callable, cast
2020

2121
from fastavro import schemaless_reader, schemaless_writer
2222
from confluent_kafka.schema_registry.common import asyncinit
@@ -206,7 +206,7 @@ async def __init_impl(
206206
self._registry = schema_registry_client
207207
self._schema_id = None
208208
self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance()
209-
self._known_subjects = set()
209+
self._known_subjects: set[str] = set()
210210
self._parsed_schemas = ParsedSchemaCache()
211211

212212
if to_dict is not None and not callable(to_dict):
@@ -243,11 +243,17 @@ async def __init_impl(
243243
not isinstance(self._use_latest_with_metadata, dict)):
244244
raise ValueError("use.latest.with.metadata must be a dict value")
245245

246-
self._subject_name_func = conf_copy.pop('subject.name.strategy')
246+
self._subject_name_func = cast(
247+
Callable[[Optional[SerializationContext], Optional[str]], Optional[str]],
248+
conf_copy.pop('subject.name.strategy')
249+
)
247250
if not callable(self._subject_name_func):
248251
raise ValueError("subject.name.strategy must be callable")
249252

250-
self._schema_id_serializer = conf_copy.pop('schema.id.serializer')
253+
self._schema_id_serializer = cast(
254+
Callable[[bytes, Optional[SerializationContext], Any], bytes],
255+
conf_copy.pop('schema.id.serializer')
256+
)
251257
if not callable(self._schema_id_serializer):
252258
raise ValueError("schema.id.serializer must be callable")
253259

@@ -286,7 +292,7 @@ async def __init_impl(
286292

287293
__init__ = __init_impl
288294

289-
def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
295+
def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]:
290296
return self.__serialize(obj, ctx)
291297

292298
async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
@@ -485,11 +491,17 @@ async def __init_impl(
485491
not isinstance(self._use_latest_with_metadata, dict)):
486492
raise ValueError("use.latest.with.metadata must be a dict value")
487493

488-
self._subject_name_func = conf_copy.pop('subject.name.strategy')
494+
self._subject_name_func = cast(
495+
Callable[[Optional[SerializationContext], Optional[str]], Optional[str]],
496+
conf_copy.pop('subject.name.strategy')
497+
)
489498
if not callable(self._subject_name_func):
490499
raise ValueError("subject.name.strategy must be callable")
491500

492-
self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer')
501+
self._schema_id_deserializer = cast(
502+
Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO],
503+
conf_copy.pop('schema.id.deserializer')
504+
)
493505
if not callable(self._schema_id_deserializer):
494506
raise ValueError("schema.id.deserializer must be callable")
495507

@@ -517,11 +529,11 @@ async def __init_impl(
517529

518530
__init__ = __init_impl
519531

520-
def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]:
532+
def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Union[dict, object, None]]:
521533
return self.__deserialize(data, ctx)
522534

523535
async def __deserialize(
524-
self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]:
536+
self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]:
525537
"""
526538
Deserialize Avro binary encoded data with Confluent Schema Registry framing to
527539
a dict, or object instance according to from_dict, if specified.

src/confluent_kafka/schema_registry/_async/json_schema.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717
import io
1818
import orjson
19-
from typing import Union, Optional, Tuple, Callable
19+
from typing import Any, Coroutine, Union, Optional, Tuple, Callable, cast
2020

2121
from cachetools import LRUCache
2222
from jsonschema import ValidationError
@@ -227,7 +227,7 @@ async def __init_impl(
227227
rule_registry if rule_registry else RuleRegistry.get_global_instance()
228228
)
229229
self._schema_id = None
230-
self._known_subjects = set()
230+
self._known_subjects: set[str] = set()
231231
self._parsed_schemas = ParsedSchemaCache()
232232
self._validators = LRUCache(1000)
233233

@@ -265,11 +265,17 @@ async def __init_impl(
265265
not isinstance(self._use_latest_with_metadata, dict)):
266266
raise ValueError("use.latest.with.metadata must be a dict value")
267267

268-
self._subject_name_func = conf_copy.pop('subject.name.strategy')
268+
self._subject_name_func = cast(
269+
Callable[[Optional[SerializationContext], Optional[str]], Optional[str]],
270+
conf_copy.pop('subject.name.strategy')
271+
)
269272
if not callable(self._subject_name_func):
270273
raise ValueError("subject.name.strategy must be callable")
271274

272-
self._schema_id_serializer = conf_copy.pop('schema.id.serializer')
275+
self._schema_id_serializer = cast(
276+
Callable[[bytes, Optional[SerializationContext], Any], bytes],
277+
conf_copy.pop('schema.id.serializer')
278+
)
273279
if not callable(self._schema_id_serializer):
274280
raise ValueError("schema.id.serializer must be callable")
275281

@@ -297,7 +303,7 @@ async def __init_impl(
297303

298304
__init__ = __init_impl
299305

300-
def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
306+
def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]:
301307
return self.__serialize(obj, ctx)
302308

303309
async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
@@ -526,12 +532,18 @@ async def __init_impl(
526532
not isinstance(self._use_latest_with_metadata, dict)):
527533
raise ValueError("use.latest.with.metadata must be a dict value")
528534

529-
self._subject_name_func = conf_copy.pop('subject.name.strategy')
535+
self._subject_name_func = cast(
536+
Callable[[Optional[SerializationContext], Optional[str]], Optional[str]],
537+
conf_copy.pop('subject.name.strategy')
538+
)
530539
if not callable(self._subject_name_func):
531540
raise ValueError("subject.name.strategy must be callable")
532541

533-
self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer')
534-
if not callable(self._subject_name_func):
542+
self._schema_id_deserializer = cast(
543+
Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO],
544+
conf_copy.pop('schema.id.deserializer')
545+
)
546+
if not callable(self._schema_id_deserializer):
535547
raise ValueError("schema.id.deserializer must be callable")
536548

537549
self._validate = conf_copy.pop('validate')
@@ -559,10 +571,10 @@ async def __init_impl(
559571

560572
__init__ = __init_impl
561573

562-
def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
574+
def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]:
563575
return self.__deserialize(data, ctx)
564576

565-
async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
577+
async def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
566578
"""
567579
Deserialize a JSON encoded record with Confluent Schema Registry framing to
568580
a dict, or object instance according to from_dict if from_dict is specified.

src/confluent_kafka/schema_registry/_async/protobuf.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717

1818
import io
19-
from typing import Set, List, Union, Optional, Tuple
19+
from typing import Any, Coroutine, Set, List, Union, Optional, Tuple, Callable, cast
2020

2121
from google.protobuf import json_format, descriptor_pb2
2222
from google.protobuf.descriptor_pool import DescriptorPool
@@ -271,7 +271,7 @@ async def __init_impl(
271271
self._registry = schema_registry_client
272272
self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance()
273273
self._schema_id = None
274-
self._known_subjects = set()
274+
self._known_subjects: set[str] = set()
275275
self._msg_class = msg_type
276276
self._parsed_schemas = ParsedSchemaCache()
277277

@@ -359,7 +359,7 @@ async def _resolve_dependencies(
359359
reference.version))
360360
return schema_refs
361361

362-
def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
362+
def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]:
363363
return self.__serialize(message, ctx)
364364

365365
async def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
@@ -534,11 +534,17 @@ async def __init_impl(
534534
not isinstance(self._use_latest_with_metadata, dict)):
535535
raise ValueError("use.latest.with.metadata must be a dict value")
536536

537-
self._subject_name_func = conf_copy.pop('subject.name.strategy')
537+
self._subject_name_func = cast(
538+
Callable[[Optional[SerializationContext], Optional[str]], Optional[str]],
539+
conf_copy.pop('subject.name.strategy')
540+
)
538541
if not callable(self._subject_name_func):
539542
raise ValueError("subject.name.strategy must be callable")
540543

541-
self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer')
544+
self._schema_id_deserializer = cast(
545+
Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO],
546+
conf_copy.pop('schema.id.deserializer')
547+
)
542548
if not callable(self._schema_id_deserializer):
543549
raise ValueError("schema.id.deserializer must be callable")
544550

@@ -557,10 +563,10 @@ async def __init_impl(
557563

558564
__init__ = __init_impl
559565

560-
def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
566+
def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]:
561567
return self.__deserialize(data, ctx)
562568

563-
async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
569+
async def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
564570
"""
565571
Deserialize a serialized protobuf message with Confluent Schema Registry
566572
framing.
@@ -600,7 +606,7 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] =
600606
if subject is None:
601607
subject = self._subject_name_func(ctx, writer_desc.full_name)
602608
if subject is not None:
603-
latest_schema = self._get_reader_schema(subject, fmt='serialized')
609+
latest_schema = await self._get_reader_schema(subject, fmt='serialized')
604610
else:
605611
writer_schema_raw = None
606612
writer_schema = None

src/confluent_kafka/schema_registry/_sync/avro.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717
import io
1818
import json
19-
from typing import Dict, Union, Optional, Callable
19+
from typing import Any, Coroutine, Dict, Union, Optional, Callable, cast
2020

2121
from fastavro import schemaless_reader, schemaless_writer
2222

@@ -206,7 +206,7 @@ def __init_impl(
206206
self._registry = schema_registry_client
207207
self._schema_id = None
208208
self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance()
209-
self._known_subjects = set()
209+
self._known_subjects: set[str] = set()
210210
self._parsed_schemas = ParsedSchemaCache()
211211

212212
if to_dict is not None and not callable(to_dict):
@@ -243,11 +243,17 @@ def __init_impl(
243243
not isinstance(self._use_latest_with_metadata, dict)):
244244
raise ValueError("use.latest.with.metadata must be a dict value")
245245

246-
self._subject_name_func = conf_copy.pop('subject.name.strategy')
246+
self._subject_name_func = cast(
247+
Callable[[Optional[SerializationContext], Optional[str]], Optional[str]],
248+
conf_copy.pop('subject.name.strategy')
249+
)
247250
if not callable(self._subject_name_func):
248251
raise ValueError("subject.name.strategy must be callable")
249252

250-
self._schema_id_serializer = conf_copy.pop('schema.id.serializer')
253+
self._schema_id_serializer = cast(
254+
Callable[[bytes, Optional[SerializationContext], Any], bytes],
255+
conf_copy.pop('schema.id.serializer')
256+
)
251257
if not callable(self._schema_id_serializer):
252258
raise ValueError("schema.id.serializer must be callable")
253259

@@ -286,7 +292,7 @@ def __init_impl(
286292

287293
__init__ = __init_impl
288294

289-
def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
295+
def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]:
290296
return self.__serialize(obj, ctx)
291297

292298
def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
@@ -485,11 +491,17 @@ def __init_impl(
485491
not isinstance(self._use_latest_with_metadata, dict)):
486492
raise ValueError("use.latest.with.metadata must be a dict value")
487493

488-
self._subject_name_func = conf_copy.pop('subject.name.strategy')
494+
self._subject_name_func = cast(
495+
Callable[[Optional[SerializationContext], Optional[str]], Optional[str]],
496+
conf_copy.pop('subject.name.strategy')
497+
)
489498
if not callable(self._subject_name_func):
490499
raise ValueError("subject.name.strategy must be callable")
491500

492-
self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer')
501+
self._schema_id_deserializer = cast(
502+
Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO],
503+
conf_copy.pop('schema.id.deserializer')
504+
)
493505
if not callable(self._schema_id_deserializer):
494506
raise ValueError("schema.id.deserializer must be callable")
495507

@@ -517,11 +529,11 @@ def __init_impl(
517529

518530
__init__ = __init_impl
519531

520-
def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]:
532+
def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Union[dict, object, None]]:
521533
return self.__deserialize(data, ctx)
522534

523535
def __deserialize(
524-
self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]:
536+
self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]:
525537
"""
526538
Deserialize Avro binary encoded data with Confluent Schema Registry framing to
527539
a dict, or object instance according to from_dict, if specified.

src/confluent_kafka/schema_registry/_sync/json_schema.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717
import io
1818
import orjson
19-
from typing import Union, Optional, Tuple, Callable
19+
from typing import Any, Coroutine, Union, Optional, Tuple, Callable, cast
2020

2121
from cachetools import LRUCache
2222
from jsonschema import ValidationError
@@ -227,7 +227,7 @@ def __init_impl(
227227
rule_registry if rule_registry else RuleRegistry.get_global_instance()
228228
)
229229
self._schema_id = None
230-
self._known_subjects = set()
230+
self._known_subjects: set[str] = set()
231231
self._parsed_schemas = ParsedSchemaCache()
232232
self._validators = LRUCache(1000)
233233

@@ -265,11 +265,17 @@ def __init_impl(
265265
not isinstance(self._use_latest_with_metadata, dict)):
266266
raise ValueError("use.latest.with.metadata must be a dict value")
267267

268-
self._subject_name_func = conf_copy.pop('subject.name.strategy')
268+
self._subject_name_func = cast(
269+
Callable[[Optional[SerializationContext], Optional[str]], Optional[str]],
270+
conf_copy.pop('subject.name.strategy')
271+
)
269272
if not callable(self._subject_name_func):
270273
raise ValueError("subject.name.strategy must be callable")
271274

272-
self._schema_id_serializer = conf_copy.pop('schema.id.serializer')
275+
self._schema_id_serializer = cast(
276+
Callable[[bytes, Optional[SerializationContext], Any], bytes],
277+
conf_copy.pop('schema.id.serializer')
278+
)
273279
if not callable(self._schema_id_serializer):
274280
raise ValueError("schema.id.serializer must be callable")
275281

@@ -297,7 +303,7 @@ def __init_impl(
297303

298304
__init__ = __init_impl
299305

300-
def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
306+
def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]:
301307
return self.__serialize(obj, ctx)
302308

303309
def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
@@ -526,12 +532,18 @@ def __init_impl(
526532
not isinstance(self._use_latest_with_metadata, dict)):
527533
raise ValueError("use.latest.with.metadata must be a dict value")
528534

529-
self._subject_name_func = conf_copy.pop('subject.name.strategy')
535+
self._subject_name_func = cast(
536+
Callable[[Optional[SerializationContext], Optional[str]], Optional[str]],
537+
conf_copy.pop('subject.name.strategy')
538+
)
530539
if not callable(self._subject_name_func):
531540
raise ValueError("subject.name.strategy must be callable")
532541

533-
self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer')
534-
if not callable(self._subject_name_func):
542+
self._schema_id_deserializer = cast(
543+
Callable[[bytes, Optional[SerializationContext], Any], io.BytesIO],
544+
conf_copy.pop('schema.id.deserializer')
545+
)
546+
if not callable(self._schema_id_deserializer):
535547
raise ValueError("schema.id.deserializer must be callable")
536548

537549
self._validate = conf_copy.pop('validate')
@@ -559,10 +571,10 @@ def __init_impl(
559571

560572
__init__ = __init_impl
561573

562-
def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
574+
def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Coroutine[Any, Any, Optional[bytes]]:
563575
return self.__deserialize(data, ctx)
564576

565-
def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
577+
def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]:
566578
"""
567579
Deserialize a JSON encoded record with Confluent Schema Registry framing to
568580
a dict, or object instance according to from_dict if from_dict is specified.

0 commit comments

Comments
 (0)