Skip to content

Commit 19bc7d4

Browse files
Fix Protobuf with schema registry usage for kafka-actions (#22265)
* Fix Protobuf with schema registry usage for kafka-actions * lint * changelog
1 parent 662ab49 commit 19bc7d4

File tree

6 files changed

+175
-30
lines changed

6 files changed

+175
-30
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix Kafka message decoding when using Protobuf with schema registry.

kafka_actions/datadog_checks/kafka_actions/message_deserializer.py

Lines changed: 114 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,83 @@
1717
SCHEMA_REGISTRY_MAGIC_BYTE = 0x00
1818

1919

20+
def _read_varint(data):
21+
shift = 0
22+
result = 0
23+
bytes_read = 0
24+
25+
for byte in data:
26+
bytes_read += 1
27+
result |= (byte & 0x7F) << shift
28+
if (byte & 0x80) == 0:
29+
return result, bytes_read
30+
shift += 7
31+
32+
raise ValueError("Incomplete varint")
33+
34+
35+
def _read_protobuf_message_indices(payload):
36+
"""
37+
Read the Confluent Protobuf message indices array.
38+
39+
The Confluent Protobuf wire format includes message indices after the schema ID:
40+
[message_indices_length:varint][message_indices:varint...]
41+
42+
The indices indicate which message type to use from the .proto schema.
43+
For example, [0] = first message, [1] = second message, [0, 0] = nested message.
44+
45+
Args:
46+
payload: bytes after the schema ID
47+
48+
Returns:
49+
tuple: (message_indices list, remaining payload bytes)
50+
"""
51+
array_len, bytes_read = _read_varint(payload)
52+
payload = payload[bytes_read:]
53+
54+
indices = []
55+
for _ in range(array_len):
56+
index, bytes_read = _read_varint(payload)
57+
indices.append(index)
58+
payload = payload[bytes_read:]
59+
60+
return indices, payload
61+
62+
63+
def _get_protobuf_message_class(schema_info, message_indices):
64+
"""Get the protobuf message class based on schema info and message indices.
65+
66+
Args:
67+
schema_info: Tuple of (descriptor_pool, file_descriptor_set)
68+
message_indices: List of indices (e.g., [0], [1], [2, 0] for nested)
69+
70+
Returns:
71+
Message class for the specified type
72+
"""
73+
pool, descriptor_set = schema_info
74+
75+
# First index is the message type in the file
76+
file_descriptor = descriptor_set.file[0]
77+
message_descriptor_proto = file_descriptor.message_type[message_indices[0]]
78+
79+
package = file_descriptor.package
80+
name_parts = [message_descriptor_proto.name]
81+
82+
# Handle nested messages if there are more indices
83+
current_proto = message_descriptor_proto
84+
for idx in message_indices[1:]:
85+
current_proto = current_proto.nested_type[idx]
86+
name_parts.append(current_proto.name)
87+
88+
if package:
89+
full_name = f"{package}.{'.'.join(name_parts)}"
90+
else:
91+
full_name = '.'.join(name_parts)
92+
93+
message_descriptor = pool.FindMessageTypeByName(full_name)
94+
return message_factory.GetMessageClass(message_descriptor)
95+
96+
2097
class MessageDeserializer:
2198
"""Handles deserialization of Kafka messages with support for JSON, BSON, Protobuf, and Avro."""
2299

@@ -78,26 +155,29 @@ def _deserialize_bytes_maybe_schema_registry(
78155
)
79156
schema_id = int.from_bytes(message[1:5], 'big')
80157
message = message[5:] # Skip the magic byte and schema ID bytes
81-
return self._deserialize_bytes(message, message_format, schema), schema_id
158+
return self._deserialize_bytes(message, message_format, schema, uses_schema_registry=True), schema_id
82159
else:
83160
# Fallback behavior: try without schema registry format first, then with it
84161
try:
85-
return self._deserialize_bytes(message, message_format, schema), None
162+
return self._deserialize_bytes(message, message_format, schema, uses_schema_registry=False), None
86163
except (UnicodeDecodeError, json.JSONDecodeError, ValueError) as e:
87164
# If the message is not valid, it might be a schema registry message
88165
if len(message) < 5 or message[0] != SCHEMA_REGISTRY_MAGIC_BYTE:
89166
raise e
90167
schema_id = int.from_bytes(message[1:5], 'big')
91168
message = message[5:] # Skip the magic byte and schema ID bytes
92-
return self._deserialize_bytes(message, message_format, schema), schema_id
169+
return self._deserialize_bytes(message, message_format, schema, uses_schema_registry=True), schema_id
93170

94-
def _deserialize_bytes(self, message: bytes, message_format: str, schema) -> str | None:
171+
def _deserialize_bytes(
172+
self, message: bytes, message_format: str, schema, uses_schema_registry: bool = False
173+
) -> str | None:
95174
"""Deserialize message bytes to JSON string.
96175
97176
Args:
98177
message: Raw message bytes
99178
message_format: 'json', 'bson', 'protobuf', 'avro', or 'string'
100179
schema: Schema object (for protobuf/avro)
180+
uses_schema_registry: Whether to extract Confluent message indices from the message
101181
102182
Returns:
103183
JSON string representation, or None if message is empty
@@ -106,7 +186,7 @@ def _deserialize_bytes(self, message: bytes, message_format: str, schema) -> str
106186
return None
107187

108188
if message_format == 'protobuf':
109-
return self._deserialize_protobuf(message, schema)
189+
return self._deserialize_protobuf(message, schema, uses_schema_registry)
110190
elif message_format == 'avro':
111191
return self._deserialize_avro(message, schema)
112192
elif message_format == 'bson':
@@ -158,13 +238,30 @@ def _deserialize_bson(self, message: bytes) -> str | None:
158238
except Exception as e:
159239
raise ValueError(f"Failed to deserialize BSON message: {e}")
160240

161-
def _deserialize_protobuf(self, message: bytes, schema) -> str:
162-
"""Deserialize Protobuf message."""
163-
if schema is None:
241+
def _deserialize_protobuf(self, message: bytes, schema_info, uses_schema_registry: bool) -> str:
242+
"""Deserialize Protobuf message using google.protobuf with strict validation.
243+
244+
Args:
245+
message: Raw protobuf bytes
246+
schema_info: Tuple of (descriptor_pool, file_descriptor_set) from _build_protobuf_schema
247+
uses_schema_registry: Whether to extract Confluent message indices from the message
248+
"""
249+
if schema_info is None:
164250
raise ValueError("Protobuf schema is required")
165251

166252
try:
167-
bytes_consumed = schema.ParseFromString(message)
253+
if uses_schema_registry:
254+
message_indices, message = _read_protobuf_message_indices(message)
255+
# Empty indices array means use the first message type (index 0)
256+
if not message_indices:
257+
message_indices = [0]
258+
else:
259+
message_indices = [0]
260+
261+
message_class = _get_protobuf_message_class(schema_info, message_indices)
262+
schema_instance = message_class()
263+
264+
bytes_consumed = schema_instance.ParseFromString(message)
168265

169266
# Strict validation: ensure all bytes consumed
170267
if bytes_consumed != len(message):
@@ -173,7 +270,7 @@ def _deserialize_protobuf(self, message: bytes, schema) -> str:
173270
f"Read {bytes_consumed} bytes, but message has {len(message)} bytes."
174271
)
175272

176-
return MessageToJson(schema)
273+
return MessageToJson(schema_instance)
177274
except Exception as e:
178275
raise ValueError(f"Failed to deserialize Protobuf message: {e}")
179276

@@ -251,7 +348,12 @@ def _build_avro_schema(self, schema_str: str):
251348
return schema
252349

253350
def _build_protobuf_schema(self, schema_str: str):
254-
"""Build a Protobuf schema from base64-encoded FileDescriptorSet."""
351+
"""Build a Protobuf schema from base64-encoded FileDescriptorSet.
352+
353+
Returns:
354+
Tuple of (descriptor_pool, file_descriptor_set) for use with
355+
_get_protobuf_message_class to select the correct message type.
356+
"""
255357
schema_bytes = base64.b64decode(schema_str)
256358
descriptor_set = descriptor_pb2.FileDescriptorSet()
257359
descriptor_set.ParseFromString(schema_bytes)
@@ -260,23 +362,7 @@ def _build_protobuf_schema(self, schema_str: str):
260362
for fd_proto in descriptor_set.file:
261363
pool.Add(fd_proto)
262364

263-
first_fd = descriptor_set.file[0]
264-
first_message_proto = first_fd.message_type[0]
265-
266-
package = first_fd.package
267-
message_name = first_message_proto.name
268-
if package:
269-
full_name = f"{package}.{message_name}"
270-
else:
271-
full_name = message_name
272-
273-
message_descriptor = pool.FindMessageTypeByName(full_name)
274-
schema = message_factory.GetMessageClass(message_descriptor)()
275-
276-
if schema is None:
277-
raise ValueError("Protobuf schema cannot be None")
278-
279-
return schema
365+
return (pool, descriptor_set)
280366

281367

282368
class DeserializedMessage:

kafka_actions/tests/test_message_deserializer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,12 @@ def test_protobuf_explicit_schema_registry_configuration(self):
341341
b'\x1a\x0c\x41\x6c\x61\x6e\x20\x44\x6f\x6e\x6f\x76\x61\x6e'
342342
)
343343

344-
# Protobuf message WITH Schema Registry format (magic byte 0x00 + schema ID 350 = 0x015E)
344+
# Protobuf message WITH Schema Registry format (Confluent wire format)
345+
# - magic byte 0x00 + schema ID 350 = 0x015E
346+
# - message indices: [0] encoded as varint array (0x01 0x00 = 1 element, value 0)
345347
protobuf_message_with_sr = (
346-
b'\x00\x00\x00\x01\x5e'
348+
b'\x00\x00\x00\x01\x5e' # Schema Registry header
349+
b'\x01\x00' # Message indices: array length 1, index [0]
347350
b'\x08\xe8\xba\xb2\xeb\xd1\x9c\x02\x12\x1b\x54\x68\x65\x20\x47\x6f\x20\x50\x72\x6f\x67\x72\x61\x6d\x6d\x69\x6e\x67\x20\x4c\x61\x6e\x67\x75\x61\x67\x65'
348351
b'\x1a\x0c\x41\x6c\x61\x6e\x20\x44\x6f\x6e\x6f\x76\x61\x6e'
349352
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix Kafka message decoding when using Protobuf with schema registry.

kafka_consumer/datadog_checks/kafka_consumer/kafka_consumer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,9 @@ def _deserialize_protobuf(message, schema_info, uses_schema_registry):
769769
try:
770770
if uses_schema_registry:
771771
message_indices, message = _read_protobuf_message_indices(message)
772+
# Empty indices array means use the first message type (index 0)
773+
if not message_indices:
774+
message_indices = [0]
772775
else:
773776
message_indices = [0]
774777

kafka_consumer/tests/test_unit.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,57 @@ def test_protobuf_message_indices_with_schema_registry():
973973
assert result[0] and 'Fiction' in result[0]
974974

975975

976+
def test_protobuf_empty_message_indices_with_schema_registry():
977+
"""Test Confluent Protobuf wire format with empty message indices array.
978+
979+
When message indices array is empty (encoded as varint 0x00), it should
980+
default to using the first message type (index 0).
981+
982+
This test uses real message bytes from a Kafka topic to ensure the
983+
deserialization handles the Confluent wire format correctly.
984+
"""
985+
key = b'null'
986+
987+
# Schema from real Kafka topic - Purchase message
988+
# message Purchase { string order_id = 1; string customer_id = 2; int64 order_date = 3;
989+
# string city = 6; string country = 7; }
990+
protobuf_schema = (
991+
'CrkDCgxzY2hlbWEucHJvdG8SCHB1cmNoYXNlIpMBCghQdXJjaGFzZRIZCghvcmRlcl9pZBgBIAEoCVIH'
992+
'b3JkZXJJZBIfCgtjdXN0b21lcl9pZBgCIAEoCVIKY3VzdG9tZXJJZBIdCgpvcmRlcl9kYXRlGAMgASgD'
993+
'UglvcmRlckRhdGUSEgoEY2l0eRgGIAEoCVIEY2l0eRIYCgdjb3VudHJ5GAcgASgJUgdjb3VudHJ5ItIB'
994+
'CgpQdXJjaGFzZVYyEiUKDnRyYW5zYWN0aW9uX2lkGAEgASgJUg10cmFuc2FjdGlvbklkEhcKB3VzZXJf'
995+
'aWQYAiABKAlSBnVzZXJJZBIcCgl0aW1lc3RhbXAYAyABKANSCXRpbWVzdGFtcBIaCghsb2NhdGlvbhgE'
996+
'IAEoCVIIbG9jYXRpb24SFgoGcmVnaW9uGAUgASgJUgZyZWdpb24SFgoGYW1vdW50GAYgASgBUgZhbW91'
997+
'bnQSGgoIY3VycmVuY3kYByABKAlSCGN1cnJlbmN5QiwKG2RhdGFkb2cua2Fma2EuZXhhbXBsZS5wcm90'
998+
'b0INUHVyY2hhc2VQcm90b2IGcHJvdG8z'
999+
)
1000+
parsed_schema = build_schema('protobuf', protobuf_schema)
1001+
1002+
# Real message from Kafka topic "human-orders"
1003+
# Hex breakdown:
1004+
# 00 00 00 00 01 - Schema Registry header (magic byte + schema ID 1)
1005+
# 00 - Empty message indices array (varint 0 = 0 elements)
1006+
# 0a 05 31 32 33 34 35 ... - Protobuf payload (Purchase message)
1007+
message_hex = '0000000001000a0531323334351205363738393018f4eae0c4b8333a064d657869636f'
1008+
message_bytes = bytes.fromhex(message_hex)
1009+
1010+
# Test with uses_schema_registry=True (explicit)
1011+
result = deserialize_message(MockedMessage(message_bytes, key), 'protobuf', parsed_schema, True, 'json', '', False)
1012+
assert result[0], "Deserialization should succeed"
1013+
assert '12345' in result[0], "Should contain order_id"
1014+
assert '67890' in result[0], "Should contain customer_id"
1015+
assert 'Mexico' in result[0], "Should contain country"
1016+
assert result[1] == 1, "Should detect schema ID 1"
1017+
1018+
# Test with uses_schema_registry=False (fallback mode)
1019+
result_fallback = deserialize_message(
1020+
MockedMessage(message_bytes, key), 'protobuf', parsed_schema, False, 'json', '', False
1021+
)
1022+
assert result_fallback[0], "Fallback mode should also succeed"
1023+
assert '12345' in result_fallback[0], "Fallback should contain order_id"
1024+
assert result_fallback[1] == 1, "Fallback should detect schema ID 1"
1025+
1026+
9761027
def mocked_time():
9771028
return 400
9781029

0 commit comments

Comments
 (0)