Skip to content

Commit 6978f18

Browse files
committed
Fix bad schema caching logic in serializer
Schema versions were cached by schema, which is incorrect when using the topic_record_name naming strategy.
1 parent 708da5e commit 6978f18

File tree

5 files changed

+77
-47
lines changed

5 files changed

+77
-47
lines changed

src/aws_schema_registry/avro.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
import fastavro
77

8-
from aws_schema_registry.schema import DataFormat
8+
from aws_schema_registry.schema import DataFormat, Schema
99

1010

11-
class AvroSchema:
11+
class AvroSchema(Schema):
1212
"""Implementation of the `Schema` protocol for Avro schemas.
1313
1414
Arguments:
@@ -19,29 +19,47 @@ class AvroSchema:
1919
itself
2020
"""
2121

22-
data_format: DataFormat = 'AVRO'
23-
parsed: dict
24-
2522
def __init__(self, string: str, return_record_name: bool = False):
26-
self.string = string
27-
self.parsed = fastavro.parse_schema(json.loads(string))
28-
# https://github.com/fastavro/fastavro/issues/415
29-
self.name = self.parsed.get('name', self.parsed['type'])
23+
self._dict = json.loads(string)
24+
self._parsed = fastavro.parse_schema(self._dict)
3025
self.return_record_name = return_record_name
3126

27+
def __hash__(self):
28+
return hash(str(self))
29+
30+
def __eq__(self, other):
31+
return isinstance(other, AvroSchema) and \
32+
self._parsed == other._parsed and \
33+
self.return_record_name == other.return_record_name
34+
35+
def __str__(self):
36+
return json.dumps(self._dict)
37+
38+
def __repr__(self):
39+
return '<AvroSchema %s>' % self._dict
40+
41+
@property
42+
def data_format(self) -> DataFormat:
43+
return 'AVRO'
44+
45+
@property
46+
def fqn(self) -> str:
47+
# https://github.com/fastavro/fastavro/issues/415
48+
return self._parsed.get('name', self._parsed['type'])
49+
3250
def read(self, bytes_: bytes):
3351
b = BytesIO(bytes_)
3452
value = fastavro.schemaless_reader(
3553
b,
36-
self.parsed,
54+
self._parsed,
3755
return_record_name=self.return_record_name
3856
)
3957
b.close()
4058
return value
4159

4260
def write(self, data) -> bytes:
4361
b = BytesIO()
44-
fastavro.schemaless_writer(b, self.parsed, data)
62+
fastavro.schemaless_writer(b, self._parsed, data)
4563
value = b.getvalue()
4664
b.close()
4765
return value

src/aws_schema_registry/naming.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def record_name_strategy(topic: str, is_key: bool, schema: Schema) -> str:
6161
However, this requires that the fully-qualified record names uniquely
6262
and consistently identify a schema across the entire registry.
6363
"""
64-
return schema.name
64+
return schema.fqn
6565

6666

6767
def topic_record_name_strategy(topic: str, is_key: bool,
@@ -72,4 +72,4 @@ def topic_record_name_strategy(topic: str, is_key: bool,
7272
Additionally allows different topics to use the same record name for
7373
incompatible schemas.
7474
"""
75-
return f'{topic}-{schema.name}'
75+
return f'{topic}-{schema.fqn}'

src/aws_schema_registry/schema.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from __future__ import annotations
22

3+
from abc import ABC, abstractmethod
34
from dataclasses import dataclass
45
import sys
5-
from typing import Any, Optional
6+
from typing import Any, Optional, Hashable
67
from uuid import UUID
78

89
if sys.version_info[1] < 8: # for py37
9-
from typing_extensions import Literal, Protocol
10+
from typing_extensions import Literal
1011
else:
11-
from typing import Literal, Protocol
12+
from typing import Literal
1213

1314
DataFormat = Literal['AVRO', 'JSON']
1415

@@ -33,14 +34,26 @@
3334
SchemaVersionStatus = Literal['AVAILABLE', 'PENDING', 'FAILURE', 'DELETING']
3435

3536

36-
class Schema(Protocol):
37-
data_format: DataFormat
38-
name: str
39-
string: str
37+
class Schema(ABC, Hashable):
38+
"""Abstract base class for a schema implementation."""
39+
40+
@property
41+
@abstractmethod
42+
def data_format(self) -> DataFormat:
43+
"""The data format of this schema."""
44+
45+
@property
46+
@abstractmethod
47+
def fqn(self) -> str:
48+
"""The fully-qualified name of this schema."""
4049

41-
def read(self, bytes_: bytes) -> Any: ...
50+
@abstractmethod
51+
def read(self, bytes_: bytes) -> Any:
52+
"""Read bytes into a record."""
4253

43-
def write(self, data) -> bytes: ...
54+
@abstractmethod
55+
def write(self, data) -> bytes:
56+
"""Write a record into bytes."""
4457

4558

4659
@dataclass

src/aws_schema_registry/serde.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
import logging
45
import sys
56
from typing import Any, Dict, NamedTuple
@@ -58,8 +59,6 @@ class SchemaRegistrySerializer:
5859
alternate strategies.
5960
"""
6061

61-
_cache: Dict[Schema, SchemaVersion]
62-
6362
def __init__(
6463
self,
6564
client: SchemaRegistryClient,
@@ -71,32 +70,29 @@ def __init__(
7170
self.is_key = is_key
7271
self.compatibility_mode: CompatibilityMode = compatibility_mode
7372
self.schema_naming_strategy = schema_naming_strategy
74-
self._cache = {}
7573

76-
def serialize(self, topic, data_and_schema: DataAndSchema):
74+
def serialize(self, topic: str, data_and_schema: DataAndSchema):
7775
if data_and_schema is None:
7876
return None
7977
if not isinstance(data_and_schema, DataAndSchema):
8078
raise TypeError('AvroSerializer can only serialize',
8179
f' {DataAndSchema}, got {type(data_and_schema)}')
8280
data, schema = data_and_schema
83-
schema_version = self._cache.get(schema)
84-
if not schema_version:
85-
schema_name = self.schema_naming_strategy(
86-
topic, self.is_key, schema
87-
)
88-
LOG.info('Schema %s not cached locally, registering...',
89-
schema_name)
90-
schema_version = self.client.get_or_register_schema_version(
91-
definition=schema.string,
92-
schema_name=schema_name,
93-
data_format=schema.data_format,
94-
compatibility_mode=self.compatibility_mode
95-
)
96-
self._cache[schema] = schema_version
81+
schema_version = self._get_schema_version(topic, schema)
9782
serialized = schema.write(data)
9883
return encode(serialized, schema_version.version_id)
9984

85+
@functools.lru_cache(maxsize=None)
86+
def _get_schema_version(self, topic: str, schema: Schema) -> SchemaVersion:
87+
schema_name = self.schema_naming_strategy(topic, self.is_key, schema)
88+
LOG.info('Fetching schema %s...', schema_name)
89+
return self.client.get_or_register_schema_version(
90+
definition=str(schema),
91+
schema_name=schema_name,
92+
data_format=schema.data_format,
93+
compatibility_mode=self.compatibility_mode
94+
)
95+
10096

10197
class SchemaRegistryDeserializer:
10298
"""Kafka serializer that uses the AWS Schema Registry.
@@ -146,10 +142,13 @@ def deserialize(self, topic: str, bytes_: bytes):
146142
schema_version = self.client.get_schema_version(
147143
version_id=schema_version_id
148144
)
149-
if schema_version.data_format == 'AVRO':
150-
writer_schema = AvroSchema(schema_version.definition)
151-
elif schema_version.data_format == 'JSON':
152-
raise NotImplementedError('JSON schema not supported')
145+
writer_schema = self._create_writer_schema(schema_version)
153146
self._writer_schemas[schema_version_id] = writer_schema
154-
LOG.info('Schema %s fetched', schema_version_id)
147+
LOG.info('Schema version %s fetched', schema_version_id)
155148
return writer_schema.read(data_bytes)
149+
150+
def _create_writer_schema(self, schema_version: SchemaVersion) -> Schema:
151+
if schema_version.data_format == 'AVRO':
152+
return AvroSchema(schema_version.definition)
153+
elif schema_version.data_format == 'JSON':
154+
raise NotImplementedError('JSON schema not supported')

tests/test_avro.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
def test_fully_qualified_name():
55
s = AvroSchema('{"type": "record", "namespace": "foo", "name": "Bar"}')
6-
assert s.name == "foo.Bar"
6+
assert s.fqn == "foo.Bar"
77

88

99
def test_primitive_name():
1010
# fastavro does not fulfill this part of the Avro spec
1111
s = AvroSchema('{"type": "string"}')
12-
assert s.name == 'string'
12+
assert s.fqn == 'string'
1313

1414

1515
def test_readwrite():

0 commit comments

Comments
 (0)