Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kafka_actions/changelog.d/22265.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix Kafka message decoding when using Protobuf with schema registry.
142 changes: 114 additions & 28 deletions kafka_actions/datadog_checks/kafka_actions/message_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,83 @@
SCHEMA_REGISTRY_MAGIC_BYTE = 0x00


def _read_varint(data):
shift = 0
result = 0
bytes_read = 0

for byte in data:
bytes_read += 1
result |= (byte & 0x7F) << shift
if (byte & 0x80) == 0:
return result, bytes_read
shift += 7

raise ValueError("Incomplete varint")


def _read_protobuf_message_indices(payload):
"""
Read the Confluent Protobuf message indices array.

The Confluent Protobuf wire format includes message indices after the schema ID:
[message_indices_length:varint][message_indices:varint...]

The indices indicate which message type to use from the .proto schema.
For example, [0] = first message, [1] = second message, [0, 0] = nested message.

Args:
payload: bytes after the schema ID

Returns:
tuple: (message_indices list, remaining payload bytes)
"""
array_len, bytes_read = _read_varint(payload)
payload = payload[bytes_read:]

indices = []
for _ in range(array_len):
index, bytes_read = _read_varint(payload)
indices.append(index)
payload = payload[bytes_read:]

return indices, payload


def _get_protobuf_message_class(schema_info, message_indices):
"""Get the protobuf message class based on schema info and message indices.

Args:
schema_info: Tuple of (descriptor_pool, file_descriptor_set)
message_indices: List of indices (e.g., [0], [1], [2, 0] for nested)

Returns:
Message class for the specified type
"""
pool, descriptor_set = schema_info

# First index is the message type in the file
file_descriptor = descriptor_set.file[0]
message_descriptor_proto = file_descriptor.message_type[message_indices[0]]

package = file_descriptor.package
name_parts = [message_descriptor_proto.name]

# Handle nested messages if there are more indices
current_proto = message_descriptor_proto
for idx in message_indices[1:]:
current_proto = current_proto.nested_type[idx]
name_parts.append(current_proto.name)

if package:
full_name = f"{package}.{'.'.join(name_parts)}"
else:
full_name = '.'.join(name_parts)

message_descriptor = pool.FindMessageTypeByName(full_name)
return message_factory.GetMessageClass(message_descriptor)


class MessageDeserializer:
"""Handles deserialization of Kafka messages with support for JSON, BSON, Protobuf, and Avro."""

Expand Down Expand Up @@ -78,26 +155,29 @@ def _deserialize_bytes_maybe_schema_registry(
)
schema_id = int.from_bytes(message[1:5], 'big')
message = message[5:] # Skip the magic byte and schema ID bytes
return self._deserialize_bytes(message, message_format, schema), schema_id
return self._deserialize_bytes(message, message_format, schema, uses_schema_registry=True), schema_id
else:
# Fallback behavior: try without schema registry format first, then with it
try:
return self._deserialize_bytes(message, message_format, schema), None
return self._deserialize_bytes(message, message_format, schema, uses_schema_registry=False), None
except (UnicodeDecodeError, json.JSONDecodeError, ValueError) as e:
# If the message is not valid, it might be a schema registry message
if len(message) < 5 or message[0] != SCHEMA_REGISTRY_MAGIC_BYTE:
raise e
schema_id = int.from_bytes(message[1:5], 'big')
message = message[5:] # Skip the magic byte and schema ID bytes
return self._deserialize_bytes(message, message_format, schema), schema_id
return self._deserialize_bytes(message, message_format, schema, uses_schema_registry=True), schema_id

def _deserialize_bytes(self, message: bytes, message_format: str, schema) -> str | None:
def _deserialize_bytes(
self, message: bytes, message_format: str, schema, uses_schema_registry: bool = False
) -> str | None:
"""Deserialize message bytes to JSON string.

Args:
message: Raw message bytes
message_format: 'json', 'bson', 'protobuf', 'avro', or 'string'
schema: Schema object (for protobuf/avro)
uses_schema_registry: Whether to extract Confluent message indices from the message

Returns:
JSON string representation, or None if message is empty
Expand All @@ -106,7 +186,7 @@ def _deserialize_bytes(self, message: bytes, message_format: str, schema) -> str
return None

if message_format == 'protobuf':
return self._deserialize_protobuf(message, schema)
return self._deserialize_protobuf(message, schema, uses_schema_registry)
elif message_format == 'avro':
return self._deserialize_avro(message, schema)
elif message_format == 'bson':
Expand Down Expand Up @@ -158,13 +238,30 @@ def _deserialize_bson(self, message: bytes) -> str | None:
except Exception as e:
raise ValueError(f"Failed to deserialize BSON message: {e}")

def _deserialize_protobuf(self, message: bytes, schema) -> str:
"""Deserialize Protobuf message."""
if schema is None:
def _deserialize_protobuf(self, message: bytes, schema_info, uses_schema_registry: bool) -> str:
"""Deserialize Protobuf message using google.protobuf with strict validation.

Args:
message: Raw protobuf bytes
schema_info: Tuple of (descriptor_pool, file_descriptor_set) from _build_protobuf_schema
uses_schema_registry: Whether to extract Confluent message indices from the message
"""
if schema_info is None:
raise ValueError("Protobuf schema is required")

try:
bytes_consumed = schema.ParseFromString(message)
if uses_schema_registry:
message_indices, message = _read_protobuf_message_indices(message)
# Empty indices array means use the first message type (index 0)
if not message_indices:
message_indices = [0]
else:
message_indices = [0]

message_class = _get_protobuf_message_class(schema_info, message_indices)
schema_instance = message_class()

bytes_consumed = schema_instance.ParseFromString(message)

# Strict validation: ensure all bytes consumed
if bytes_consumed != len(message):
Expand All @@ -173,7 +270,7 @@ def _deserialize_protobuf(self, message: bytes, schema) -> str:
f"Read {bytes_consumed} bytes, but message has {len(message)} bytes."
)

return MessageToJson(schema)
return MessageToJson(schema_instance)
except Exception as e:
raise ValueError(f"Failed to deserialize Protobuf message: {e}")

Expand Down Expand Up @@ -251,7 +348,12 @@ def _build_avro_schema(self, schema_str: str):
return schema

def _build_protobuf_schema(self, schema_str: str):
"""Build a Protobuf schema from base64-encoded FileDescriptorSet."""
"""Build a Protobuf schema from base64-encoded FileDescriptorSet.

Returns:
Tuple of (descriptor_pool, file_descriptor_set) for use with
_get_protobuf_message_class to select the correct message type.
"""
schema_bytes = base64.b64decode(schema_str)
descriptor_set = descriptor_pb2.FileDescriptorSet()
descriptor_set.ParseFromString(schema_bytes)
Expand All @@ -260,23 +362,7 @@ def _build_protobuf_schema(self, schema_str: str):
for fd_proto in descriptor_set.file:
pool.Add(fd_proto)

first_fd = descriptor_set.file[0]
first_message_proto = first_fd.message_type[0]

package = first_fd.package
message_name = first_message_proto.name
if package:
full_name = f"{package}.{message_name}"
else:
full_name = message_name

message_descriptor = pool.FindMessageTypeByName(full_name)
schema = message_factory.GetMessageClass(message_descriptor)()

if schema is None:
raise ValueError("Protobuf schema cannot be None")

return schema
return (pool, descriptor_set)


class DeserializedMessage:
Expand Down
7 changes: 5 additions & 2 deletions kafka_actions/tests/test_message_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,12 @@ def test_protobuf_explicit_schema_registry_configuration(self):
b'\x1a\x0c\x41\x6c\x61\x6e\x20\x44\x6f\x6e\x6f\x76\x61\x6e'
)

# Protobuf message WITH Schema Registry format (magic byte 0x00 + schema ID 350 = 0x015E)
# Protobuf message WITH Schema Registry format (Confluent wire format)
# - magic byte 0x00 + schema ID 350 = 0x015E
# - message indices: [0] encoded as varint array (0x01 0x00 = 1 element, value 0)
protobuf_message_with_sr = (
b'\x00\x00\x00\x01\x5e'
b'\x00\x00\x00\x01\x5e' # Schema Registry header
b'\x01\x00' # Message indices: array length 1, index [0]
b'\x08\xe8\xba\xb2\xeb\xd1\x9c\x02\x12\x1b\x54\x68\x65\x20\x47\x6f\x20\x50\x72\x6f\x67\x72\x61\x6d\x6d\x69\x6e\x67\x20\x4c\x61\x6e\x67\x75\x61\x67\x65'
b'\x1a\x0c\x41\x6c\x61\x6e\x20\x44\x6f\x6e\x6f\x76\x61\x6e'
)
Expand Down
1 change: 1 addition & 0 deletions kafka_consumer/changelog.d/22265.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix Kafka message decoding when using Protobuf with schema registry.
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,9 @@ def _deserialize_protobuf(message, schema_info, uses_schema_registry):
try:
if uses_schema_registry:
message_indices, message = _read_protobuf_message_indices(message)
# Empty indices array means use the first message type (index 0)
if not message_indices:
message_indices = [0]
else:
message_indices = [0]

Expand Down
51 changes: 51 additions & 0 deletions kafka_consumer/tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,57 @@ def test_protobuf_message_indices_with_schema_registry():
assert result[0] and 'Fiction' in result[0]


def test_protobuf_empty_message_indices_with_schema_registry():
"""Test Confluent Protobuf wire format with empty message indices array.

When message indices array is empty (encoded as varint 0x00), it should
default to using the first message type (index 0).

This test uses real message bytes from a Kafka topic to ensure the
deserialization handles the Confluent wire format correctly.
"""
key = b'null'

# Schema from real Kafka topic - Purchase message
# message Purchase { string order_id = 1; string customer_id = 2; int64 order_date = 3;
# string city = 6; string country = 7; }
protobuf_schema = (
'CrkDCgxzY2hlbWEucHJvdG8SCHB1cmNoYXNlIpMBCghQdXJjaGFzZRIZCghvcmRlcl9pZBgBIAEoCVIH'
'b3JkZXJJZBIfCgtjdXN0b21lcl9pZBgCIAEoCVIKY3VzdG9tZXJJZBIdCgpvcmRlcl9kYXRlGAMgASgD'
'UglvcmRlckRhdGUSEgoEY2l0eRgGIAEoCVIEY2l0eRIYCgdjb3VudHJ5GAcgASgJUgdjb3VudHJ5ItIB'
'CgpQdXJjaGFzZVYyEiUKDnRyYW5zYWN0aW9uX2lkGAEgASgJUg10cmFuc2FjdGlvbklkEhcKB3VzZXJf'
'aWQYAiABKAlSBnVzZXJJZBIcCgl0aW1lc3RhbXAYAyABKANSCXRpbWVzdGFtcBIaCghsb2NhdGlvbhgE'
'IAEoCVIIbG9jYXRpb24SFgoGcmVnaW9uGAUgASgJUgZyZWdpb24SFgoGYW1vdW50GAYgASgBUgZhbW91'
'bnQSGgoIY3VycmVuY3kYByABKAlSCGN1cnJlbmN5QiwKG2RhdGFkb2cua2Fma2EuZXhhbXBsZS5wcm90'
'b0INUHVyY2hhc2VQcm90b2IGcHJvdG8z'
)
parsed_schema = build_schema('protobuf', protobuf_schema)

# Real message from Kafka topic "human-orders"
# Hex breakdown:
# 00 00 00 00 01 - Schema Registry header (magic byte + schema ID 1)
# 00 - Empty message indices array (varint 0 = 0 elements)
# 0a 05 31 32 33 34 35 ... - Protobuf payload (Purchase message)
message_hex = '0000000001000a0531323334351205363738393018f4eae0c4b8333a064d657869636f'
message_bytes = bytes.fromhex(message_hex)

# Test with uses_schema_registry=True (explicit)
result = deserialize_message(MockedMessage(message_bytes, key), 'protobuf', parsed_schema, True, 'json', '', False)
assert result[0], "Deserialization should succeed"
assert '12345' in result[0], "Should contain order_id"
assert '67890' in result[0], "Should contain customer_id"
assert 'Mexico' in result[0], "Should contain country"
assert result[1] == 1, "Should detect schema ID 1"

# Test with uses_schema_registry=False (fallback mode)
result_fallback = deserialize_message(
MockedMessage(message_bytes, key), 'protobuf', parsed_schema, False, 'json', '', False
)
assert result_fallback[0], "Fallback mode should also succeed"
assert '12345' in result_fallback[0], "Fallback should contain order_id"
assert result_fallback[1] == 1, "Fallback should detect schema ID 1"


def mocked_time():
return 400

Expand Down
Loading