Skip to content

Commit 1099cba

Browse files
Internal refactoring
1 parent cc70001 commit 1099cba

File tree

5 files changed

+182
-38
lines changed

5 files changed

+182
-38
lines changed

aws_lambda_powertools/utilities/kafka_consumer/consumer_records.py

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from functools import cached_property
44
from typing import TYPE_CHECKING, Any
55

6-
from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict
6+
from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict, DictWrapper
77
from aws_lambda_powertools.utilities.data_classes.kafka_event import KafkaEventBase, KafkaEventRecordBase
88
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.deserializer import get_deserializer
99
from aws_lambda_powertools.utilities.kafka_consumer.serialization.serialization import serialize_to_output_type
@@ -14,6 +14,18 @@
1414
from aws_lambda_powertools.utilities.kafka_consumer.schema_config import SchemaConfig
1515

1616

17+
class ConsumerRecordSchemaMetadata(DictWrapper):
18+
@property
19+
def data_format(self) -> str | None:
20+
"""The data format of the Kafka record."""
21+
return self.get("dataFormat", None)
22+
23+
@property
24+
def schema_id(self) -> str | None:
25+
"""The schema id of the Kafka record."""
26+
return self.get("schemaId", None)
27+
28+
1729
class ConsumerRecordRecords(KafkaEventRecordBase):
1830
"""
1931
A Kafka Consumer Record
@@ -26,42 +38,54 @@ def __init__(self, data: dict[str, Any], schema_config: SchemaConfig | None = No
2638
@cached_property
2739
def key(self) -> Any:
2840
key = self.get("key")
29-
if key and (self.schema_config and self.schema_config.key_schema_type):
30-
deserializer = get_deserializer(
31-
self.schema_config.key_schema_type,
32-
self.schema_config.key_schema_str,
33-
)
34-
deserialized_key = deserializer.deserialize(key)
3541

36-
if self.schema_config.key_output_serializer:
37-
return serialize_to_output_type(
38-
deserialized_key,
39-
self.schema_config.key_output_serializer,
40-
)
42+
# Return None if key doesn't exist
43+
if not key:
44+
return None
45+
46+
# Determine schema type and schema string
47+
schema_type = None
48+
schema_str = None
49+
output_serializer = None
50+
51+
if self.schema_config and self.schema_config.key_schema_type:
52+
schema_type = self.schema_config.key_schema_type
53+
schema_str = self.schema_config.key_schema_str
54+
output_serializer = self.schema_config.key_output_serializer
4155

42-
return deserialized_key
56+
# Always use get_deserializer if None it will default to DEFAULT
57+
deserializer = get_deserializer(schema_type, schema_str)
58+
deserialized_value = deserializer.deserialize(key)
4359

44-
return key # MISSING DESERIALIZER
60+
# Apply output serializer if specified
61+
if output_serializer:
62+
return serialize_to_output_type(deserialized_value, output_serializer)
63+
64+
return deserialized_value
4565

4666
@cached_property
4767
def value(self) -> Any:
4868
value = self["value"]
49-
if value and (self.schema_config and self.schema_config.value_schema_type):
50-
deserializer = get_deserializer(
51-
self.schema_config.value_schema_type,
52-
self.schema_config.value_schema_str,
53-
)
54-
deserialized_value = deserializer.deserialize(value)
5569

56-
if self.schema_config.value_output_serializer:
57-
return serialize_to_output_type(
58-
deserialized_value,
59-
self.schema_config.value_output_serializer,
60-
)
70+
# Determine schema type and schema string
71+
schema_type = None
72+
schema_str = None
73+
output_serializer = None
74+
75+
if self.schema_config and self.schema_config.value_schema_type:
76+
schema_type = self.schema_config.value_schema_type
77+
schema_str = self.schema_config.value_schema_str
78+
output_serializer = self.schema_config.value_output_serializer
6179

62-
return deserialized_value
80+
# Always use get_deserializer if None it will default to DEFAULT
81+
deserializer = get_deserializer(schema_type, schema_str)
82+
deserialized_value = deserializer.deserialize(value)
6383

64-
return value # MISSING DESERIALIZER
84+
# Apply output serializer if specified
85+
if output_serializer:
86+
return serialize_to_output_type(deserialized_value, output_serializer)
87+
88+
return deserialized_value
6589

6690
@property
6791
def original_value(self) -> str:
@@ -90,6 +114,22 @@ def headers(self) -> dict[str, bytes]:
90114
"""Decodes the headers as a single dictionary."""
91115
return CaseInsensitiveDict((k, bytes(v)) for chunk in self.original_headers for k, v in chunk.items())
92116

117+
@property
118+
def key_schema_metadata(self) -> ConsumerRecordSchemaMetadata | None:
119+
"""The metadata of the Key Kafka record."""
120+
return (
121+
None if self.get("keySchemaMetadata") is None else ConsumerRecordSchemaMetadata(self["keySchemaMetadata"])
122+
)
123+
124+
@property
125+
def value_schema_metadata(self) -> ConsumerRecordSchemaMetadata | None:
126+
"""The metadata of the Value Kafka record."""
127+
return (
128+
None
129+
if self.get("valueSchemaMetadata") is None
130+
else ConsumerRecordSchemaMetadata(self["valueSchemaMetadata"])
131+
)
132+
93133

94134
class ConsumerRecords(KafkaEventBase):
95135
"""Self-managed or MSK Apache Kafka event trigger

aws_lambda_powertools/utilities/kafka_consumer/deserializer/deserializer.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import hashlib
34
from typing import TYPE_CHECKING, Any
45

56
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.default import DefaultDeserializer
@@ -8,6 +9,23 @@
89
if TYPE_CHECKING:
910
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.base import DeserializerBase
1011

12+
# Cache for deserializers
13+
_deserializer_cache: dict[str, DeserializerBase] = {}
14+
15+
16+
def _get_cache_key(schema_type: str | object, schema_value: Any) -> str:
17+
if schema_value is None:
18+
return str(schema_type)
19+
20+
if isinstance(schema_value, str):
21+
# For string schemas like Avro, hash the content
22+
schema_hash = hashlib.md5(schema_value.encode("utf-8")).hexdigest()
23+
else:
24+
# For objects like Protobuf, use the object id
25+
schema_hash = str(id(schema_value))
26+
27+
return f"{schema_type}_{schema_hash}"
28+
1129

1230
def get_deserializer(schema_type: str | object, schema_value: Any) -> DeserializerBase:
1331
"""
@@ -55,18 +73,35 @@ def get_deserializer(schema_type: str | object, schema_value: Any) -> Deserializ
5573
>>> # Get a no-op deserializer for raw data
5674
>>> no_op_deserializer = get_deserializer("RAW", None)
5775
"""
76+
77+
# Generate a cache key based on schema type and value
78+
cache_key = _get_cache_key(schema_type, schema_value)
79+
80+
# Check if we already have this deserializer in cache
81+
if cache_key in _deserializer_cache:
82+
return _deserializer_cache[cache_key]
83+
84+
deserializer: DeserializerBase
85+
5886
if schema_type == "AVRO":
5987
# Import here to avoid dependency if not used
6088
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.avro import AvroDeserializer
6189

62-
return AvroDeserializer(schema_value)
90+
deserializer = AvroDeserializer(schema_value)
6391
elif schema_type == "PROTOBUF":
6492
# Import here to avoid dependency if not used
6593
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.protobuf import ProtobufDeserializer
6694

67-
return ProtobufDeserializer(schema_value)
95+
deserializer = ProtobufDeserializer(schema_value)
6896
elif schema_type == "JSON":
69-
return JsonDeserializer()
97+
deserializer = JsonDeserializer()
98+
99+
else:
100+
# Default to no-op deserializer
101+
deserializer = DefaultDeserializer()
102+
103+
# Store in cache for future use
104+
_deserializer_cache[cache_key] = deserializer
70105

71106
# Default to default deserializer that is base64 decode + bytes decoded
72-
return DefaultDeserializer()
107+
return deserializer

tests/functional/kafka_consumer/_avro/test_kafka_consumer_with_avro.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from aws_lambda_powertools.utilities.kafka_consumer.exceptions import (
1212
KafkaConsumerAvroSchemaParserError,
1313
KafkaConsumerDeserializationError,
14+
KafkaConsumerMissingSchemaError,
1415
)
1516
from aws_lambda_powertools.utilities.kafka_consumer.kafka_consumer import kafka_consumer
1617
from aws_lambda_powertools.utilities.kafka_consumer.schema_config import SchemaConfig
@@ -292,3 +293,17 @@ def lambda_handler(event: ConsumerRecords, context):
292293
assert key_value_result["value_type"] == "UserValueDataClass"
293294
assert key_value_result["value_name"] == "John Doe"
294295
assert key_value_result["value_age"] == 30
296+
297+
298+
def test_kafka_consumer_without_avro_value_schema():
299+
"""Test error handling when Avro data is invalid."""
300+
301+
with pytest.raises(KafkaConsumerMissingSchemaError):
302+
SchemaConfig(value_schema_type="AVRO", value_schema=None)
303+
304+
305+
def test_kafka_consumer_without_avro_key_schema():
306+
"""Test error handling when Avro data is invalid."""
307+
308+
with pytest.raises(KafkaConsumerMissingSchemaError):
309+
SchemaConfig(key_schema_type="AVRO", key_schema=None)

tests/functional/kafka_consumer/_protobuf/test_kafka_consumer_with_protobuf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from aws_lambda_powertools.utilities.kafka_consumer.consumer_records import ConsumerRecords
88
from aws_lambda_powertools.utilities.kafka_consumer.exceptions import (
99
KafkaConsumerDeserializationError,
10+
KafkaConsumerMissingSchemaError,
1011
)
1112
from aws_lambda_powertools.utilities.kafka_consumer.kafka_consumer import kafka_consumer
1213
from aws_lambda_powertools.utilities.kafka_consumer.schema_config import SchemaConfig
@@ -316,3 +317,17 @@ def handler(event: ConsumerRecords, context):
316317
assert processed_records[0]["age"] == 30
317318
assert processed_records[1]["name"] == "Jane Smith"
318319
assert processed_records[1]["age"] == 25
320+
321+
322+
def test_kafka_consumer_without_protobuf_value_schema():
323+
"""Test error handling when Avro data is invalid."""
324+
325+
with pytest.raises(KafkaConsumerMissingSchemaError):
326+
SchemaConfig(value_schema_type="PROTOBUF", value_schema=None)
327+
328+
329+
def test_kafka_consumer_without_protobuf_key_schema():
330+
"""Test error handling when Avro data is invalid."""
331+
332+
with pytest.raises(KafkaConsumerMissingSchemaError):
333+
SchemaConfig(key_schema_type="PROTOBUF", key_schema=None)

tests/functional/kafka_consumer/required_dependencies/test_kafka_consumer.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,23 +213,62 @@ def handler(event: ConsumerRecords, context):
213213
assert any(r["name"] == "Bob Johnson" and r["age"] == 40 for r in processed_records)
214214

215215

216-
def test_kafka_consumer_without_schema_config(kafka_event_with_json_data, lambda_context):
216+
def test_kafka_consumer_default_deserializer_value(kafka_event_with_json_data, lambda_context):
217217
"""Test Kafka consumer when no schema config is provided."""
218218

219-
# Create dict to capture results
220-
result_data = {}
219+
base64_data = base64.b64encode(b"data")
220+
kafka_event_with_json_data = deepcopy(kafka_event_with_json_data)
221+
kafka_event_with_json_data["records"]["my-topic-1"][0]["value"] = base64_data
221222

222223
@kafka_consumer()
223224
def handler(event: ConsumerRecords, context):
224225
# Capture the results to verify
225226
record = next(event.records)
226227
# Should get raw base64-encoded data with no deserialization
227-
result_data["value_type"] = type(record.value).__name__
228-
return {"processed": True}
228+
return record.value
229229

230230
# Call the handler
231231
result = handler(kafka_event_with_json_data, lambda_context)
232232

233233
# Verify the results
234-
assert result == {"processed": True}
235-
assert result_data["value_type"] == "str" # Raw base64 string
234+
assert result == "data"
235+
236+
237+
def test_kafka_consumer_default_deserializer_key(kafka_event_with_json_data, lambda_context):
238+
"""Test Kafka consumer when no schema config is provided."""
239+
240+
base64_data = base64.b64encode(b"data")
241+
kafka_event_with_json_data = deepcopy(kafka_event_with_json_data)
242+
kafka_event_with_json_data["records"]["my-topic-1"][0]["key"] = base64_data
243+
244+
@kafka_consumer()
245+
def handler(event: ConsumerRecords, context):
246+
# Capture the results to verify
247+
record = next(event.records)
248+
# Should get raw base64-encoded data with no deserialization
249+
return record.key
250+
251+
# Call the handler
252+
result = handler(kafka_event_with_json_data, lambda_context)
253+
254+
# Verify the results
255+
assert result == "data"
256+
257+
258+
def test_kafka_consumer_default_deserializer_key_is_none(kafka_event_with_json_data, lambda_context):
259+
"""Test Kafka consumer when no schema config is provided."""
260+
261+
kafka_event_with_json_data["records"]["my-topic-1"][0]["key"] = None
262+
263+
@kafka_consumer()
264+
def handler(event: ConsumerRecords, context):
265+
# Capture the results to verify
266+
record = next(event.records)
267+
# Should get raw base64-encoded data with no deserialization
268+
return record.key
269+
270+
# Call the handler
271+
result = handler(kafka_event_with_json_data, lambda_context)
272+
273+
# Verify the results
274+
assert result is None

0 commit comments

Comments
 (0)