Skip to content

Commit b5e1f40

Browse files
Refactoring tests
1 parent 7ec47f1 commit b5e1f40

File tree

11 files changed

+235
-702
lines changed

11 files changed

+235
-702
lines changed

aws_lambda_powertools/utilities/kafka/consumer_records.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def key(self) -> Any:
5050

5151
if self.schema_config and self.schema_config.key_schema_type:
5252
schema_type = self.schema_config.key_schema_type
53-
schema_str = self.schema_config.key_schema_str
53+
schema_str = self.schema_config.key_schema
5454
output_serializer = self.schema_config.key_output_serializer
5555

5656
# Always use get_deserializer if None it will default to DEFAULT
@@ -74,7 +74,7 @@ def value(self) -> Any:
7474

7575
if self.schema_config and self.schema_config.value_schema_type:
7676
schema_type = self.schema_config.value_schema_type
77-
schema_str = self.schema_config.value_schema_str
77+
schema_str = self.schema_config.value_schema
7878
output_serializer = self.schema_config.value_output_serializer
7979

8080
# Always use get_deserializer if None it will default to DEFAULT

aws_lambda_powertools/utilities/kafka/schema_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def __init__(
7777
)
7878

7979
self.value_schema_type = value_schema_type
80-
self.value_schema_str = value_schema
80+
self.value_schema = value_schema
8181
self.value_output_serializer = value_output_serializer
8282
self.key_schema_type = key_schema_type
83-
self.key_schema_str = key_schema
83+
self.key_schema = key_schema
8484
self.key_output_serializer = key_output_serializer

aws_lambda_powertools/utilities/kafka/serialization/base.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import TYPE_CHECKING, Any
55

66
if TYPE_CHECKING:
7+
from collections.abc import Callable
8+
79
from aws_lambda_powertools.utilities.kafka.serialization.types import T
810

911

@@ -16,21 +18,21 @@ class OutputSerializerBase(ABC):
1618

1719
Methods
1820
-------
19-
serialize(data, output_class)
21+
serialize(data, output)
2022
Abstract method that must be implemented by subclasses to serialize data.
2123

2224
Examples
2325
--------
2426
>>> class MyOutputSerializer(OutputSerializerBase):
25-
... def serialize(self, data: dict[str, Any], output_class=None):
26-
... if output_class:
27+
... def serialize(self, data: dict[str, Any], output=None):
28+
... if output:
2729
... # Convert dictionary to class instance
28-
... return output_class(**data)
30+
... return output(**data)
2931
... return data # Return as is if no output class provided
3032
"""
3133

3234
@abstractmethod
33-
def serialize(self, data: dict[str, Any], output_class: type[T] | None = None) -> T | dict[str, Any]:
35+
def serialize(self, data: dict[str, Any], output: type[T] | Callable | None = None) -> T | dict[str, Any]:
3436
"""
3537
Serialize dictionary data into a specific output format or class instance.
3638

@@ -41,14 +43,14 @@ def serialize(self, data: dict[str, Any], output_class: type[T] | None = None) -
4143
----------
4244
data : dict[str, Any]
4345
The dictionary data to serialize.
44-
output_class : type[T] or None, optional
46+
output : type[T] or None, optional
4547
Optional class type to convert the dictionary into. If provided,
4648
the method should return an instance of this class.
4749

4850
Returns
4951
-------
5052
T or dict[str, Any]
51-
An instance of output_class if provided, otherwise a processed dictionary.
52-
The generic type T represents the type of the output_class.
53+
An instance of output if provided, otherwise a processed dictionary.
54+
The generic type T represents the type of the output.
5355
"""
5456
raise NotImplementedError("Subclasses must implement this method")

aws_lambda_powertools/utilities/kafka/serialization/custom_dict.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,14 @@
22

33
from typing import TYPE_CHECKING, Any
44

5-
from aws_lambda_powertools.utilities.kafka.exceptions import KafkaConsumerOutputSerializerError
65
from aws_lambda_powertools.utilities.kafka.serialization.base import OutputSerializerBase
76

87
if TYPE_CHECKING:
8+
from collections.abc import Callable
9+
910
from aws_lambda_powertools.utilities.kafka.serialization.types import T
1011

1112

1213
class CustomDictOutputSerializer(OutputSerializerBase):
13-
def serialize(self, data: dict[str, Any], output_class: type[T] | None = None) -> T | dict[str, Any]:
14-
if output_class is None:
15-
return data
16-
17-
if not hasattr(output_class, "to_dict"):
18-
raise KafkaConsumerOutputSerializerError("The output serialization class must have to_dict method")
19-
20-
# Instantiate and then populate
21-
instance = output_class
22-
for key, value in data.items():
23-
setattr(instance, key, value)
24-
return instance
14+
def serialize(self, data: dict[str, Any], output: type[T] | Callable | None = None) -> T | dict[str, Any]:
15+
return data if output is None else output(data)
Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
from __future__ import annotations
22

33
from dataclasses import is_dataclass
4-
from typing import Any, cast
4+
from typing import TYPE_CHECKING, Any, cast
55

66
from aws_lambda_powertools.utilities.kafka.serialization.base import OutputSerializerBase
77
from aws_lambda_powertools.utilities.kafka.serialization.types import T
88

9+
if TYPE_CHECKING:
10+
from collections.abc import Callable
11+
912

1013
class DataclassOutputSerializer(OutputSerializerBase):
11-
def serialize(self, data: dict[str, Any], output_class: type[T] | None = None) -> T | dict[str, Any]:
12-
if output_class is None:
14+
def serialize(self, data: dict[str, Any], output: type[T] | Callable | None = None) -> T | dict[str, Any]:
15+
if output is None:
1316
return data
1417

15-
if not is_dataclass(output_class):
18+
if not is_dataclass(output):
1619
raise ValueError("Output class must be a dataclass")
1720

18-
return cast(T, output_class(**data))
21+
return cast(T, output(**data))

aws_lambda_powertools/utilities/kafka/serialization/pydantic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from aws_lambda_powertools.utilities.kafka.serialization.base import OutputSerializerBase
88

99
if TYPE_CHECKING:
10+
from collections.abc import Callable
11+
1012
from aws_lambda_powertools.utilities.kafka.serialization.types import T
1113

1214

1315
class PydanticOutputSerializer(OutputSerializerBase):
14-
def serialize(self, data: dict[str, Any], output_class: type[T] | None = None) -> T | dict[str, Any]:
15-
if output_class is None:
16+
def serialize(self, data: dict[str, Any], output: type[T] | Callable | None = None) -> T | dict[str, Any]:
17+
if output is None:
1618
return data
1719

1820
# Use TypeAdapter for better support of Union types and other complex types
19-
adapter = TypeAdapter(output_class)
21+
adapter: TypeAdapter = TypeAdapter(output)
2022
return adapter.validate_python(data)

aws_lambda_powertools/utilities/kafka/serialization/serialization.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,25 @@
77
from aws_lambda_powertools.utilities.kafka.serialization.dataclass import DataclassOutputSerializer
88

99
if TYPE_CHECKING:
10+
from collections.abc import Callable
11+
1012
from aws_lambda_powertools.utilities.kafka.serialization.types import T
1113

1214

13-
def _get_output_serializer(output_class: type[T] | None = None) -> Any:
15+
def _get_output_serializer(output: type[T] | Callable | None = None) -> Any:
1416
"""
1517
Returns the appropriate serializer for the given output class.
1618
Uses lazy imports to avoid unnecessary dependencies.
1719
"""
18-
if output_class is None:
20+
if output is None:
1921
# Return a pass-through serializer if no output class is specified
2022
return CustomDictOutputSerializer()
2123

2224
# Check if it's a dataclass
23-
if is_dataclass(output_class):
25+
if is_dataclass(output):
2426
return DataclassOutputSerializer()
2527

26-
if _is_pydantic_model(output_class):
28+
if _is_pydantic_model(output):
2729
from aws_lambda_powertools.utilities.kafka.serialization.pydantic import PydanticOutputSerializer
2830

2931
return PydanticOutputSerializer()
@@ -41,9 +43,12 @@ def _is_pydantic_model(obj: Any) -> bool:
4143
return False
4244

4345

44-
def serialize_to_output_type(data: object | dict[str, Any], output_class: type[T] | None = None) -> T | dict[str, Any]:
46+
def serialize_to_output_type(
47+
data: object | dict[str, Any],
48+
output: type[T] | Callable | None = None,
49+
) -> T | dict[str, Any]:
4550
"""
4651
Helper function to directly serialize data to the specified output class
4752
"""
48-
serializer = _get_output_serializer(output_class)
49-
return serializer.serialize(data, output_class)
53+
serializer = _get_output_serializer(output)
54+
return serializer.serialize(data, output)

tests/functional/kafka_consumer/_avro/test_kafka_consumer_with_avro.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,34 +165,34 @@ def test_kafka_consumer_with_avro_and_custom_object(
165165
kafka_event_with_avro_data,
166166
avro_value_schema,
167167
lambda_context,
168-
user_value_dict,
169168
):
170169
"""Test Kafka consumer with Avro deserialization and custom object serialization."""
171170

171+
def dict_output(data: dict) -> dict:
172+
return data
173+
172174
# Create dict to capture results
173175
result_data = {}
174176

175177
schema_config = SchemaConfig(
176178
value_schema_type="AVRO",
177179
value_schema=avro_value_schema,
178-
value_output_serializer=user_value_dict,
180+
value_output_serializer=dict_output,
179181
)
180182

181183
@kafka_consumer(schema_config=schema_config)
182184
def handler(event: ConsumerRecords, context):
183185
# Capture the results to verify
184186
record = next(event.records)
185-
result_data["value_type"] = type(record.value).__name__
186-
result_data["name"] = record.value.name
187-
result_data["age"] = record.value.age
187+
result_data["name"] = record.value.get("name")
188+
result_data["age"] = record.value.get("age")
188189
return {"processed": True}
189190

190191
# Call the handler
191192
result = handler(kafka_event_with_avro_data, lambda_context)
192193

193194
# Verify the results
194195
assert result == {"processed": True}
195-
assert result_data["value_type"] == "UserValueDict"
196196
assert result_data["name"] == "John Doe"
197197
assert result_data["age"] == 30
198198

tests/functional/kafka_consumer/_protobuf/test_kafka_consumer_with_protobuf.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -217,40 +217,33 @@ def test_kafka_consumer_with_custom_object(
217217
"""Test Kafka consumer with Protobuf deserialization and custom object serialization."""
218218

219219
# Define a custom output object class
220-
class UserCustomObject:
221-
def __init__(self, proto_message):
222-
self.name = proto_message.name
223-
self.age = proto_message.age
224-
self.custom_field = f"{proto_message.name} is {proto_message.age} years old"
220+
def dict_output(data: dict) -> dict:
221+
return data
225222

226223
# Create dict to capture results
227224
result_data = {}
228225

229226
schema_config = SchemaConfig(
230227
value_schema_type="PROTOBUF",
231228
value_schema=User,
232-
value_output_serializer=lambda msg: UserCustomObject(msg),
229+
value_output_serializer=dict_output,
233230
)
234231

235232
@kafka_consumer(schema_config=schema_config)
236233
def handler(event: ConsumerRecords, context):
237234
# Capture the results to verify
238235
record = next(event.records)
239-
result_data["value_type"] = type(record.value).__name__
240-
result_data["name"] = record.value.name
241-
result_data["age"] = record.value.age
242-
result_data["custom_field"] = record.value.custom_field
236+
result_data["name"] = record.value.get("name")
237+
result_data["age"] = record.value.get("age")
243238
return {"processed": True}
244239

245240
# Call the handler
246241
result = handler(kafka_event_with_proto_data, lambda_context)
247242

248243
# Verify the results
249244
assert result == {"processed": True}
250-
assert result_data["value_type"] == "UserCustomObject"
251245
assert result_data["name"] == "John Doe"
252246
assert result_data["age"] == 30
253-
assert result_data["custom_field"] == "John Doe is 30 years old"
254247

255248

256249
def test_kafka_consumer_with_multiple_records(lambda_context):

0 commit comments

Comments
 (0)