Skip to content

Commit 32dc4b2

Browse files
Adding support for Kafka Consumer - first commit
1 parent f998fdf commit 32dc4b2

File tree

8 files changed

+387
-26
lines changed

8 files changed

+387
-26
lines changed

aws_lambda_powertools/utilities/data_classes/kafka_event.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections.abc import Iterator
1111

1212

13-
class KafkaEventRecord(DictWrapper):
13+
class KafkaEventBase(DictWrapper):
1414
@property
1515
def topic(self) -> str:
1616
"""The Kafka topic."""
@@ -36,6 +36,8 @@ def timestamp_type(self) -> str:
3636
"""The Kafka record timestamp type."""
3737
return self["timestampType"]
3838

39+
40+
class KafkaEventRecord(KafkaEventBase):
3941
@property
4042
def key(self) -> str | None:
4143
"""
Lines changed: 113 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,122 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
44

5-
from aws_lambda_powertools.utilities.data_classes.kafka_event import KafkaEventRecord
5+
from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict
6+
from aws_lambda_powertools.utilities.data_classes.kafka_event import KafkaEvent, KafkaEventBase
7+
from aws_lambda_powertools.utilities.kafka_consumer.functions import (
8+
deserialize_avro,
9+
deserialize_protobuf_with_compiled_classes,
10+
)
611

12+
if TYPE_CHECKING:
13+
from collections.abc import Iterator
714

8-
class ConsumerRecord(KafkaEventRecord):
15+
from aws_lambda_powertools.utilities.kafka_consumer.schema_config import SchemaConfig
16+
17+
18+
class ConsumerRecordRecords(KafkaEventBase):
919
"""
1020
A Kafka Consumer Record
1121
"""
1222

13-
def __init__(self, data: dict[str, Any], json_deserializer=None):
14-
super().__init__(data, json_deserializer=json_deserializer)
15-
self._json_deserializer = json_deserializer
23+
def __init__(self, data: dict[str, Any], deserialize: SchemaConfig | None = None):
24+
super().__init__(data)
25+
self.deserialize = deserialize
26+
27+
@property
28+
def key(self) -> str | None:
29+
key = self.get("key")
30+
if key and self.deserialize.key_schema_type:
31+
if self.deserialize.value_schema_type == "AVRO":
32+
return deserialize_avro(key, self.deserialize.value_schema_str)
33+
elif self.deserialize.value_schema_type == "PROTOBUF":
34+
return deserialize_protobuf_with_compiled_classes(key, self.deserialize.value_schema_str)
35+
elif self.deserialize.value_schema_type == "JSON":
36+
return self._json_deserializer(key)
37+
else:
38+
raise ValueError("Invalid value_schema_type")
39+
40+
return key
41+
42+
@property
43+
def value(self) -> str:
44+
value = self["value"]
45+
if self.deserialize.value_schema_type:
46+
if self.deserialize.value_schema_type == "AVRO":
47+
return deserialize_avro(value, self.deserialize.value_schema_str)
48+
elif self.deserialize.value_schema_type == "PROTOBUF":
49+
return deserialize_protobuf_with_compiled_classes(value, self.deserialize.value_schema_str)
50+
elif self.deserialize.value_schema_type == "JSON":
51+
return self._json_deserializer(value)
52+
else:
53+
raise ValueError("Invalid value_schema_type")
54+
return value
55+
56+
@property
57+
def original_value(self) -> str:
58+
"""The original (base64 encoded) Kafka record value."""
59+
return self["value"]
60+
61+
@property
62+
def original_key(self) -> str | None:
63+
"""
64+
The original (base64 encoded) Kafka record key.
65+
66+
This key is optional; if not provided,
67+
a round-robin algorithm will be used to determine
68+
the partition for the message.
69+
"""
70+
71+
return self.get("key")
72+
73+
@property
74+
def headers(self) -> list[dict[str, list[int]]]:
75+
"""The raw Kafka record headers."""
76+
return CaseInsensitiveDict((k, bytes(v)) for chunk in self.headers for k, v in chunk.items())
77+
78+
@property
79+
def original_headers(self) -> dict[str, bytes]:
80+
"""Decodes the headers as a single dictionary."""
81+
return self["headers"]
82+
83+
84+
class ConsumerRecord(KafkaEvent):
85+
"""Self-managed or MSK Apache Kafka event trigger
86+
Documentation:
87+
--------------
88+
- https://docs.aws.amazon.com/lambda/latest/dg/with-kafka.html
89+
- https://docs.aws.amazon.com/lambda/latest/dg/with-msk.html
90+
"""
91+
92+
def __init__(self, data: dict[str, Any], deserialize: SchemaConfig | None = None):
93+
super().__init__(data)
94+
self._records: Iterator[ConsumerRecordRecords] | None = None
95+
self.deserialize = deserialize
96+
97+
@property
98+
def records(self) -> Iterator[ConsumerRecordRecords]:
99+
"""The Kafka records."""
100+
for chunk in self["records"].values():
101+
for record in chunk:
102+
yield ConsumerRecordRecords(data=record, deserialize=self.deserialize)
103+
104+
@property
105+
def record(self) -> ConsumerRecordRecords:
106+
"""
107+
Returns the next Kafka record using an iterator.
108+
109+
Returns
110+
-------
111+
ConsumerRecordRecords
112+
The next Kafka record.
113+
114+
Raises
115+
------
116+
StopIteration
117+
If there are no more records available.
118+
119+
"""
120+
if self._records is None:
121+
self._records = self.records
122+
return next(self._records)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
class KafkaConsumerAvroSchemaMismatchError(Exception):
2+
"""
3+
Avro schema mismatch
4+
"""
5+
6+
7+
class KafkaConsumerDeserializationError(Exception):
8+
"""
9+
Avro schema impossible to deserialize
10+
"""
11+
12+
13+
class KafkaConsumerAvroMissingSchemaError(Exception):
14+
"""
15+
Avro schema mismatch
16+
"""
Lines changed: 128 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,150 @@
11
from __future__ import annotations
22

3+
import base64
34
import io
5+
from typing import Any
46

5-
from avro.errors import SchemaResolutionException
67
from avro.io import BinaryDecoder, DatumReader
8+
from avro.schema import parse as parse_schema
9+
from google.protobuf.json_format import MessageToDict
710

11+
from aws_lambda_powertools.utilities.kafka_consumer.exceptions import (
12+
KafkaConsumerAvroMissingSchemaError,
13+
KafkaConsumerAvroSchemaMismatchError,
14+
KafkaConsumerDeserializationError,
15+
)
816

9-
def deserialize_avro(avro_bytes, reader_schema: str | None = None):
17+
18+
def deserialize_avro(avro_bytes: bytes | str, value_schema_str: str) -> dict:
1019
"""
11-
Deserialize Avro binary data to Python objects
20+
Deserialize Avro binary data to Python dictionary objects.
21+
22+
This function handles the deserialization of Avro-formatted binary data
23+
using a specified schema string. It supports both raw bytes and
24+
base64-encoded string inputs.
1225
1326
Parameters
1427
----------
15-
avro_bytes: bytes
16-
Avro binary data
17-
reader_schema: str, Optional
18-
Schema to use for reading
28+
avro_bytes : bytes or str
29+
Avro binary data, either as raw bytes or base64-encoded string.
30+
If a string is provided, it will be treated as base64-encoded.
31+
value_schema_str : str
32+
Avro schema definition in JSON string format to use for reading.
33+
Must be a valid Avro schema definition.
1934
2035
Returns
2136
-------
22-
dict
23-
Deserialized Python object
37+
Any
38+
Deserialized Python dictionary representing the Avro data.
2439
2540
Raises
2641
------
27-
ValueError
28-
If reader_schema schema is None or if deserialization fails
42+
KafkaConsumerAvroMissingSchemaError
43+
If the schema is not provided
44+
KafkaConsumerAvroSchemaMismatchError
45+
If there's a schema mismatch
46+
KafkaConsumerDeserializationError
47+
If deserialization fails due to data corruption.
48+
TypeError
49+
If avro_bytes is neither bytes nor a base64-encoded string.
50+
51+
Examples
52+
--------
53+
>>> schema_str = '{"type": "record", "name": "User", "fields": [{"name": "name", "type": "string"}]}'
54+
>>> encoded_data = base64.b64encode(b'some-avro-binary-data')
55+
>>> user_dict = deserialize_avro(encoded_data, schema_str)
2956
"""
57+
if not value_schema_str:
58+
raise KafkaConsumerAvroMissingSchemaError("Schema string must be provided for Avro deserialization")
59+
3060
try:
31-
reader = DatumReader(reader_schema)
61+
# Parse the provided schema
62+
parsed_schema = parse_schema(value_schema_str)
63+
reader = DatumReader(parsed_schema)
64+
65+
# Handle different input types
66+
if isinstance(avro_bytes, str):
67+
# Assume base64 encoded string
68+
value = base64.b64decode(avro_bytes)
69+
elif isinstance(avro_bytes, bytes):
70+
# Already raw bytes
71+
value = avro_bytes
72+
else:
73+
# Try base64 decoding as a fallback
74+
try:
75+
value = base64.b64decode(avro_bytes)
76+
except Exception as e:
77+
raise TypeError(
78+
f"Expected bytes or base64-encoded string, got {type(avro_bytes).__name__}. Error: {str(e)}",
79+
) from e
3280

33-
decoder = BinaryDecoder(io.BytesIO(avro_bytes))
81+
# Create binary decoder and read data
82+
bytes_reader = io.BytesIO(value)
83+
decoder = BinaryDecoder(bytes_reader)
3484
return reader.read(decoder)
35-
except SchemaResolutionException as e:
36-
raise ValueError(f"Schema mismatch: {e}") from e
85+
86+
except KafkaConsumerAvroSchemaMismatchError as e:
87+
raise ValueError(
88+
f"Schema mismatch detected: Message schema doesn't match expected schema. "
89+
f"Details: {str(e)}. Verify schema registry configuration and message format.",
90+
) from e
91+
except KafkaConsumerDeserializationError as e:
92+
raise ValueError(
93+
f"Deserialization failed: Unable to decode message data using Avro schema. "
94+
f"Error: {str(e)}. Check for data corruption or schema evolution issues.",
95+
) from e
96+
97+
98+
def deserialize_protobuf_with_compiled_classes(
99+
protobuf_bytes: bytes | str,
100+
message_class: Any,
101+
) -> dict[str, Any]:
102+
"""
103+
A deserialize that works with pre-compiled protobuf classes.
104+
105+
Parameters
106+
----------
107+
protobuf_bytes : Union[bytes, str]
108+
Protocol Buffer binary data, either as raw bytes or base64-encoded string.
109+
message_class : Any
110+
The pre-compiled Protocol Buffer message class.
111+
112+
Returns
113+
-------
114+
Dict[str, Any]
115+
Deserialized Python dictionary representing the Protocol Buffer data.
116+
117+
Example
118+
-------
119+
>>> from my_proto_package.user_pb2 import User
120+
>>> user_dict = deserialize_protobuf_with_compiled_classes(encoded_data, User)
121+
"""
122+
123+
try:
124+
# Handle different input types for the binary data
125+
if isinstance(protobuf_bytes, str):
126+
# Assume base64 encoded string
127+
value = base64.b64decode(protobuf_bytes)
128+
elif isinstance(protobuf_bytes, bytes):
129+
# Already raw bytes
130+
value = protobuf_bytes
131+
else:
132+
# Try base64 decoding as a fallback
133+
try:
134+
value = base64.b64decode(protobuf_bytes)
135+
except Exception as e:
136+
raise TypeError(
137+
f"Expected bytes or base64-encoded string, got {type(protobuf_bytes).__name__}. Error: {str(e)}",
138+
) from e
139+
140+
# Create message instance and deserialize
141+
message = message_class()
142+
message.ParseFromString(value)
143+
144+
# Convert to dictionary
145+
return MessageToDict(message, preserving_proto_field_name=True)
146+
37147
except Exception as e:
38-
raise ValueError(f"Failed to deserialize Avro data: {e}") from e
148+
raise KafkaConsumerDeserializationError(
149+
f"Protocol Buffer deserialization error: {type(e).__name__}: {str(e)}",
150+
) from e
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
from aws_lambda_powertools.middleware_factory import lambda_handler_decorator
6+
from aws_lambda_powertools.utilities.kafka_consumer.consumer_record import ConsumerRecord
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Callable
10+
11+
from aws_lambda_powertools.utilities.typing import LambdaContext
12+
13+
14+
@lambda_handler_decorator
15+
def kafka_consumer(
16+
handler: Callable[[Any, LambdaContext], Any],
17+
event: dict[str, Any],
18+
context: LambdaContext,
19+
schema_registry_config: Any | None = None,
20+
):
21+
"""Middleware to create an instance of the passed in event source data class
22+
23+
Parameters
24+
----------
25+
handler: Callable
26+
Lambda's handler
27+
event: dict[str, Any]
28+
Lambda's Event
29+
context: LambdaContext
30+
Lambda's Context
31+
data_class: type[DictWrapper]
32+
Data class type to instantiate
33+
34+
Example
35+
--------
36+
37+
**Sample usage**
38+
39+
from aws_lambda_powertools.utilities.data_classes import S3Event, event_source
40+
41+
@event_source(data_class=S3Event)
42+
def handler(event: S3Event, context):
43+
return {"key": event.object_key}
44+
"""
45+
return handler(ConsumerRecord(event, schema_registry_config), context)

0 commit comments

Comments
 (0)