Skip to content

Commit 5e3fa47

Browse files
Refactoring functions
1 parent e3ce2a8 commit 5e3fa47

File tree

24 files changed

+6231
-198
lines changed

24 files changed

+6231
-198
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from aws_lambda_powertools.utilities.kafka_consumer.consumer_record import ConsumerRecord
1+
from aws_lambda_powertools.utilities.kafka_consumer.consumer_records import ConsumerRecords
22
from aws_lambda_powertools.utilities.kafka_consumer.kafka_consumer import kafka_consumer
33
from aws_lambda_powertools.utilities.kafka_consumer.schema_config import SchemaConfig
44

55
__all__ = [
66
"kafka_consumer",
7-
"ConsumerRecord",
7+
"ConsumerRecords",
88
"SchemaConfig",
99
]

aws_lambda_powertools/utilities/kafka_consumer/consumer_record.py renamed to aws_lambda_powertools/utilities/kafka_consumer/consumer_records.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44

55
from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict
66
from aws_lambda_powertools.utilities.data_classes.kafka_event import KafkaEvent, KafkaEventBase
7-
from aws_lambda_powertools.utilities.kafka_consumer.functions import (
8-
deserialize_avro,
9-
deserialize_protobuf_with_compiled_classes,
10-
)
7+
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.deserializer import get_deserializer
8+
from aws_lambda_powertools.utilities.kafka_consumer.serialization.serialization import serialize_to_output_type
119

1210
if TYPE_CHECKING:
1311
from collections.abc import Iterator
@@ -25,32 +23,43 @@ def __init__(self, data: dict[str, Any], deserialize: SchemaConfig | None = None
2523
self.deserialize = deserialize
2624

2725
@property
28-
def key(self) -> str | None:
26+
def key(self) -> Any:
2927
key = self.get("key")
30-
if key and self.deserialize.key_schema_type:
31-
if self.deserialize.value_schema_type == "AVRO":
32-
return deserialize_avro(key, self.deserialize.value_schema_str)
33-
elif self.deserialize.value_schema_type == "PROTOBUF":
34-
return deserialize_protobuf_with_compiled_classes(key, self.deserialize.value_schema_str)
35-
elif self.deserialize.value_schema_type == "JSON":
36-
return self._json_deserializer(key)
37-
else:
38-
raise ValueError("Invalid value_schema_type")
28+
if key and (self.deserialize and self.deserialize.key_schema_type):
29+
deserializer = get_deserializer(
30+
self.deserialize.value_schema_type,
31+
self.deserialize.value_schema_str,
32+
)
33+
deserialized_key = deserializer.deserialize(key)
34+
35+
if self.deserialize.key_output_serializer:
36+
return serialize_to_output_type(
37+
deserialized_key,
38+
self.deserialize.key_output_serializer,
39+
)
40+
41+
return deserialized_key
3942

4043
return key
4144

4245
@property
43-
def value(self) -> str:
46+
def value(self) -> Any:
4447
value = self["value"]
45-
if self.deserialize.value_schema_type:
46-
if self.deserialize.value_schema_type == "AVRO":
47-
return deserialize_avro(value, self.deserialize.value_schema_str)
48-
elif self.deserialize.value_schema_type == "PROTOBUF":
49-
return deserialize_protobuf_with_compiled_classes(value, self.deserialize.value_schema_str)
50-
elif self.deserialize.value_schema_type == "JSON":
51-
return self._json_deserializer(value)
52-
else:
53-
raise ValueError("Invalid value_schema_type")
48+
if value and (self.deserialize and self.deserialize.value_schema_type):
49+
deserializer = get_deserializer(
50+
self.deserialize.value_schema_type,
51+
self.deserialize.value_schema_str,
52+
)
53+
deserialized_value = deserializer.deserialize(value)
54+
55+
if self.deserialize.value_output_serializer:
56+
return serialize_to_output_type(
57+
deserialized_value,
58+
self.deserialize.value_output_serializer,
59+
)
60+
61+
return deserialized_value
62+
5463
return value
5564

5665
@property
@@ -81,7 +90,7 @@ def original_headers(self) -> dict[str, bytes]:
8190
return self["headers"]
8291

8392

84-
class ConsumerRecord(KafkaEvent):
93+
class ConsumerRecords(KafkaEvent):
8594
"""Self-managed or MSK Apache Kafka event trigger
8695
Documentation:
8796
--------------

aws_lambda_powertools/utilities/kafka_consumer/deserializer/__init__.py

Whitespace-only changes.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
import io
4+
from typing import Any
5+
6+
from avro.io import BinaryDecoder, DatumReader
7+
from avro.schema import parse as parse_schema
8+
9+
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.base import DeserializerBase
10+
from aws_lambda_powertools.utilities.kafka_consumer.exceptions import (
11+
KafkaConsumerAvroMissingSchemaError,
12+
KafkaConsumerDeserializationError,
13+
)
14+
15+
16+
class AvroDeserializer(DeserializerBase):
17+
def __init__(self, schema_str: str):
18+
if not schema_str:
19+
raise KafkaConsumerAvroMissingSchemaError("Schema string must be provided for Avro deserialization")
20+
self.parsed_schema = parse_schema(schema_str)
21+
self.reader = DatumReader(self.parsed_schema)
22+
23+
def deserialize(self, data: bytes | str) -> dict[str, Any]:
24+
try:
25+
value = self._decode_input(data)
26+
bytes_reader = io.BytesIO(value)
27+
decoder = BinaryDecoder(bytes_reader)
28+
return self.reader.read(decoder)
29+
except (TypeError, ValueError) as e:
30+
raise KafkaConsumerDeserializationError(
31+
f"Avro deserialization error: {type(e).__name__}: {str(e)}",
32+
) from e
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
from abc import ABC, abstractmethod
5+
from typing import Any
6+
7+
8+
class DeserializerBase(ABC):
9+
@abstractmethod
10+
def deserialize(self, data: bytes | str) -> dict[str, Any]:
11+
pass
12+
13+
def _decode_input(self, data: bytes | str) -> bytes:
14+
if isinstance(data, str):
15+
return base64.b64decode(data)
16+
elif isinstance(data, bytes):
17+
return data
18+
else:
19+
try:
20+
return base64.b64decode(data)
21+
except Exception as e:
22+
raise TypeError(
23+
f"Expected bytes or base64-encoded string, got {type(data).__name__}. Error: {str(e)}",
24+
) from e
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
if TYPE_CHECKING:
6+
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.base import DeserializerBase
7+
8+
9+
def get_deserializer(schema_type: str, schema_value: Any) -> DeserializerBase:
10+
if schema_type == "AVRO":
11+
# Import here to avoid dependency if not used
12+
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.avro import AvroDeserializer
13+
14+
return AvroDeserializer(schema_value)
15+
elif schema_type == "PROTOBUF":
16+
# Import here to avoid dependency if not used
17+
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.protobuf import ProtobufDeserializer
18+
19+
return ProtobufDeserializer(schema_value)
20+
elif schema_type == "JSON":
21+
# Import here to avoid dependency if not used
22+
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.json import JsonDeserializer
23+
24+
return JsonDeserializer()
25+
else:
26+
raise ValueError(f"Invalid schema_type: {schema_type}")
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import annotations
2+
3+
import json
4+
5+
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.base import DeserializerBase
6+
from aws_lambda_powertools.utilities.kafka_consumer.exceptions import KafkaConsumerDeserializationError
7+
8+
9+
class JsonDeserializer(DeserializerBase):
10+
def deserialize(self, data: bytes | str) -> dict:
11+
try:
12+
value = self._decode_input(data)
13+
return json.loads(value.decode("utf-8"))
14+
except Exception as e:
15+
raise KafkaConsumerDeserializationError(
16+
f"JSON deserialization error: {type(e).__name__}: {str(e)}",
17+
) from e
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from google.protobuf.json_format import MessageToDict
6+
7+
from aws_lambda_powertools.utilities.kafka_consumer.deserializer.base import DeserializerBase
8+
from aws_lambda_powertools.utilities.kafka_consumer.exceptions import (
9+
KafkaConsumerDeserializationError,
10+
)
11+
12+
13+
class ProtobufDeserializer(DeserializerBase):
14+
def __init__(self, message_class: Any):
15+
self.message_class = message_class
16+
17+
def deserialize(self, data: bytes | str) -> dict:
18+
try:
19+
value = self._decode_input(data)
20+
message = self.message_class()
21+
message.ParseFromString(value)
22+
return MessageToDict(message, preserving_proto_field_name=True)
23+
except Exception as e:
24+
raise KafkaConsumerDeserializationError(
25+
f"Protocol Buffer deserialization error: {type(e).__name__}: {str(e)}",
26+
) from e

aws_lambda_powertools/utilities/kafka_consumer/functions.py

Lines changed: 0 additions & 150 deletions
This file was deleted.

0 commit comments

Comments
 (0)