Skip to content

Commit 425ce26

Browse files
ydjin0602Ydjin0602ods
authored
fix batch serializers (#887)
* key and value serialization for producer batch builder * fixes * Add test for serialization in batch * Fix linting errors * Add changelog entry --------- Co-authored-by: Ydjin0602 <[email protected]> Co-authored-by: Denis Otkidach <[email protected]>
1 parent 5e0e882 commit 425ce26

File tree

4 files changed

+91
-12
lines changed

4 files changed

+91
-12
lines changed

CHANGES.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
Changelog
33
=========
44

5+
Unreleased
6+
==========
7+
8+
Bugfixes:
9+
10+
* Fix serialization for batch (issue #886, pr #887 by @ydjin0602)
11+
12+
513
0.10.0 (2023-12-15)
614
===================
715

aiokafka/producer/message_accumulator.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515

1616
class BatchBuilder:
17-
def __init__(self, magic, batch_size, compression_type,
18-
*, is_transactional):
17+
def __init__(
18+
self, magic, batch_size, compression_type,
19+
*, is_transactional, key_serializer=None, value_serializer=None
20+
):
1921
if magic < 2:
2022
assert not is_transactional
2123
self._builder = LegacyRecordBatchBuilder(
@@ -28,6 +30,20 @@ def __init__(self, magic, batch_size, compression_type,
2830
self._relative_offset = 0
2931
self._buffer = None
3032
self._closed = False
33+
self._key_serializer = key_serializer
34+
self._value_serializer = value_serializer
35+
36+
def _serialize(self, key, value):
37+
if self._key_serializer is None:
38+
serialized_key = key
39+
else:
40+
serialized_key = self._key_serializer(key)
41+
if self._value_serializer is None:
42+
serialized_value = value
43+
else:
44+
serialized_value = self._value_serializer(value)
45+
46+
return serialized_key, serialized_value
3147

3248
def append(self, *, timestamp, key, value, headers=[]):
3349
"""Add a message to the batch.
@@ -49,8 +65,9 @@ def append(self, *, timestamp, key, value, headers=[]):
4965
if self._closed:
5066
return None
5167

68+
key_bytes, value_bytes = self._serialize(key, value)
5269
metadata = self._builder.append(
53-
self._relative_offset, timestamp, key, value,
70+
self._relative_offset, timestamp, key=key_bytes, value=value_bytes,
5471
headers=headers)
5572

5673
# Check if we could add the message
@@ -422,7 +439,7 @@ def drain_by_nodes(self, ignore_nodes, muted_partitions=set()):
422439

423440
return nodes, unknown_leaders_exist
424441

425-
def create_builder(self):
442+
def create_builder(self, key_serializer=None, value_serializer=None):
426443
if self._api_version >= (0, 11):
427444
magic = 2
428445
elif self._api_version >= (0, 10):
@@ -435,8 +452,13 @@ def create_builder(self):
435452
self._txn_manager.transactional_id is not None:
436453
is_transactional = True
437454
return BatchBuilder(
438-
magic, self._batch_size, self._compression_type,
439-
is_transactional=is_transactional)
455+
magic,
456+
self._batch_size,
457+
self._compression_type,
458+
is_transactional=is_transactional,
459+
key_serializer=key_serializer,
460+
value_serializer=value_serializer
461+
)
440462

441463
def _append_batch(self, builder, tp):
442464
# We must do this before actual add takes place to check for errors.

aiokafka/producer/producer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,14 +348,14 @@ async def partitions_for(self, topic):
348348
return (await self.client._wait_on_metadata(topic))
349349

350350
def _serialize(self, topic, key, value):
351-
if self._key_serializer:
352-
serialized_key = self._key_serializer(key)
353-
else:
351+
if self._key_serializer is None:
354352
serialized_key = key
355-
if self._value_serializer:
356-
serialized_value = self._value_serializer(value)
357353
else:
354+
serialized_key = self._key_serializer(key)
355+
if self._value_serializer is None:
358356
serialized_value = value
357+
else:
358+
serialized_value = self._value_serializer(value)
359359

360360
message_size = LegacyRecordBatchBuilder.record_overhead(
361361
self._producer_magic)
@@ -484,7 +484,9 @@ def create_batch(self):
484484
Returns:
485485
BatchBuilder: empty batch to be filled and submitted by the caller.
486486
"""
487-
return self._message_accumulator.create_builder()
487+
return self._message_accumulator.create_builder(
488+
key_serializer=self._key_serializer, value_serializer=self._value_serializer
489+
)
488490

489491
async def send_batch(self, batch, topic, *, partition):
490492
"""Submit a BatchBuilder for publication.

tests/test_producer.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,53 @@ async def test_producer_send_batch(self):
380380
await producer.send_batch(
381381
batch, self.topic, partition=partition)
382382

383+
@run_until_complete
384+
async def test_producer_send_batch_with_serializer(self):
385+
def key_serializer(val):
386+
return val.upper().encode()
387+
388+
def value_serializer(val):
389+
return json.dumps(val, separators=(',', ':')).encode()
390+
391+
producer = AIOKafkaProducer(
392+
bootstrap_servers=self.hosts,
393+
key_serializer=key_serializer,
394+
value_serializer=value_serializer,
395+
)
396+
await producer.start()
397+
398+
partitions = await producer.partitions_for(self.topic)
399+
partition = partitions.pop()
400+
401+
batch = producer.create_batch()
402+
batch.append(key="key1", value={"value": 111}, timestamp=None)
403+
batch.append(key="key2", value={"value": 222}, timestamp=None)
404+
self.assertEqual(batch.record_count(), 2)
405+
406+
# batch gets properly sent
407+
future = await producer.send_batch(
408+
batch, self.topic, partition=partition)
409+
resp = await future
410+
await producer.stop()
411+
self.assertEqual(resp.partition, partition)
412+
413+
consumer = AIOKafkaConsumer(
414+
self.topic,
415+
bootstrap_servers=self.hosts,
416+
enable_auto_commit=True,
417+
auto_offset_reset="earliest")
418+
await consumer.start()
419+
420+
msg = await consumer.getone()
421+
self.assertEqual(msg.key, b"KEY1")
422+
self.assertEqual(msg.value, b"{\"value\":111}")
423+
424+
msg = await consumer.getone()
425+
self.assertEqual(msg.key, b"KEY2")
426+
self.assertEqual(msg.value, b"{\"value\":222}")
427+
428+
await consumer.stop()
429+
383430
@pytest.mark.ssl
384431
@run_until_complete
385432
async def test_producer_ssl(self):

0 commit comments

Comments
 (0)