diff --git a/kafka_actions/changelog.d/22265.fixed b/kafka_actions/changelog.d/22265.fixed new file mode 100644 index 0000000000000..bd76d8982cbe3 --- /dev/null +++ b/kafka_actions/changelog.d/22265.fixed @@ -0,0 +1 @@ +Fix Kafka message decoding when using Protobuf with schema registry. diff --git a/kafka_actions/datadog_checks/kafka_actions/message_deserializer.py b/kafka_actions/datadog_checks/kafka_actions/message_deserializer.py index 6b39bf3dfab82..f33df3cf532eb 100644 --- a/kafka_actions/datadog_checks/kafka_actions/message_deserializer.py +++ b/kafka_actions/datadog_checks/kafka_actions/message_deserializer.py @@ -17,6 +17,83 @@ SCHEMA_REGISTRY_MAGIC_BYTE = 0x00 +def _read_varint(data): + shift = 0 + result = 0 + bytes_read = 0 + + for byte in data: + bytes_read += 1 + result |= (byte & 0x7F) << shift + if (byte & 0x80) == 0: + return result, bytes_read + shift += 7 + + raise ValueError("Incomplete varint") + + +def _read_protobuf_message_indices(payload): + """ + Read the Confluent Protobuf message indices array. + + The Confluent Protobuf wire format includes message indices after the schema ID: + [message_indices_length:varint][message_indices:varint...] + + The indices indicate which message type to use from the .proto schema. + For example, [0] = first message, [1] = second message, [0, 0] = nested message. + + Args: + payload: bytes after the schema ID + + Returns: + tuple: (message_indices list, remaining payload bytes) + """ + array_len, bytes_read = _read_varint(payload) + payload = payload[bytes_read:] + + indices = [] + for _ in range(array_len): + index, bytes_read = _read_varint(payload) + indices.append(index) + payload = payload[bytes_read:] + + return indices, payload + + +def _get_protobuf_message_class(schema_info, message_indices): + """Get the protobuf message class based on schema info and message indices. + + Args: + schema_info: Tuple of (descriptor_pool, file_descriptor_set) + message_indices: List of indices (e.g., [0], [1], [2, 0] for nested) + + Returns: + Message class for the specified type + """ + pool, descriptor_set = schema_info + + # First index is the message type in the file + file_descriptor = descriptor_set.file[0] + message_descriptor_proto = file_descriptor.message_type[message_indices[0]] + + package = file_descriptor.package + name_parts = [message_descriptor_proto.name] + + # Handle nested messages if there are more indices + current_proto = message_descriptor_proto + for idx in message_indices[1:]: + current_proto = current_proto.nested_type[idx] + name_parts.append(current_proto.name) + + if package: + full_name = f"{package}.{'.'.join(name_parts)}" + else: + full_name = '.'.join(name_parts) + + message_descriptor = pool.FindMessageTypeByName(full_name) + return message_factory.GetMessageClass(message_descriptor) + + class MessageDeserializer: """Handles deserialization of Kafka messages with support for JSON, BSON, Protobuf, and Avro.""" @@ -78,26 +155,29 @@ def _deserialize_bytes_maybe_schema_registry( ) schema_id = int.from_bytes(message[1:5], 'big') message = message[5:] # Skip the magic byte and schema ID bytes - return self._deserialize_bytes(message, message_format, schema), schema_id + return self._deserialize_bytes(message, message_format, schema, uses_schema_registry=True), schema_id else: # Fallback behavior: try without schema registry format first, then with it try: - return self._deserialize_bytes(message, message_format, schema), None + return self._deserialize_bytes(message, message_format, schema, uses_schema_registry=False), None except (UnicodeDecodeError, json.JSONDecodeError, ValueError) as e: # If the message is not valid, it might be a schema registry message if len(message) < 5 or message[0] != SCHEMA_REGISTRY_MAGIC_BYTE: raise e schema_id = int.from_bytes(message[1:5], 'big') message = message[5:] # Skip the magic byte and schema ID bytes - return self._deserialize_bytes(message, message_format, schema), schema_id + return self._deserialize_bytes(message, message_format, schema, uses_schema_registry=True), schema_id - def _deserialize_bytes(self, message: bytes, message_format: str, schema) -> str | None: + def _deserialize_bytes( + self, message: bytes, message_format: str, schema, uses_schema_registry: bool = False + ) -> str | None: """Deserialize message bytes to JSON string. Args: message: Raw message bytes message_format: 'json', 'bson', 'protobuf', 'avro', or 'string' schema: Schema object (for protobuf/avro) + uses_schema_registry: Whether to extract Confluent message indices from the message Returns: JSON string representation, or None if message is empty @@ -106,7 +186,7 @@ def _deserialize_bytes(self, message: bytes, message_format: str, schema) -> str return None if message_format == 'protobuf': - return self._deserialize_protobuf(message, schema) + return self._deserialize_protobuf(message, schema, uses_schema_registry) elif message_format == 'avro': return self._deserialize_avro(message, schema) elif message_format == 'bson': @@ -158,13 +238,30 @@ def _deserialize_bson(self, message: bytes) -> str | None: except Exception as e: raise ValueError(f"Failed to deserialize BSON message: {e}") - def _deserialize_protobuf(self, message: bytes, schema) -> str: - """Deserialize Protobuf message.""" - if schema is None: + def _deserialize_protobuf(self, message: bytes, schema_info, uses_schema_registry: bool) -> str: + """Deserialize Protobuf message using google.protobuf with strict validation. + + Args: + message: Raw protobuf bytes + schema_info: Tuple of (descriptor_pool, file_descriptor_set) from _build_protobuf_schema + uses_schema_registry: Whether to extract Confluent message indices from the message + """ + if schema_info is None: raise ValueError("Protobuf schema is required") try: - bytes_consumed = schema.ParseFromString(message) + if uses_schema_registry: + message_indices, message = _read_protobuf_message_indices(message) + # Empty indices array means use the first message type (index 0) + if not message_indices: + message_indices = [0] + else: + message_indices = [0] + + message_class = _get_protobuf_message_class(schema_info, message_indices) + schema_instance = message_class() + + bytes_consumed = schema_instance.ParseFromString(message) # Strict validation: ensure all bytes consumed if bytes_consumed != len(message): @@ -173,7 +270,7 @@ def _deserialize_protobuf(self, message: bytes, schema) -> str: f"Read {bytes_consumed} bytes, but message has {len(message)} bytes." ) - return MessageToJson(schema) + return MessageToJson(schema_instance) except Exception as e: raise ValueError(f"Failed to deserialize Protobuf message: {e}") @@ -251,7 +348,12 @@ def _build_avro_schema(self, schema_str: str): return schema def _build_protobuf_schema(self, schema_str: str): - """Build a Protobuf schema from base64-encoded FileDescriptorSet.""" + """Build a Protobuf schema from base64-encoded FileDescriptorSet. + + Returns: + Tuple of (descriptor_pool, file_descriptor_set) for use with + _get_protobuf_message_class to select the correct message type. + """ schema_bytes = base64.b64decode(schema_str) descriptor_set = descriptor_pb2.FileDescriptorSet() descriptor_set.ParseFromString(schema_bytes) @@ -260,23 +362,7 @@ def _build_protobuf_schema(self, schema_str: str): for fd_proto in descriptor_set.file: pool.Add(fd_proto) - first_fd = descriptor_set.file[0] - first_message_proto = first_fd.message_type[0] - - package = first_fd.package - message_name = first_message_proto.name - if package: - full_name = f"{package}.{message_name}" - else: - full_name = message_name - - message_descriptor = pool.FindMessageTypeByName(full_name) - schema = message_factory.GetMessageClass(message_descriptor)() - - if schema is None: - raise ValueError("Protobuf schema cannot be None") - - return schema + return (pool, descriptor_set) class DeserializedMessage: diff --git a/kafka_actions/tests/test_message_deserializer.py b/kafka_actions/tests/test_message_deserializer.py index a9e803f611542..9de2a4f71e814 100644 --- a/kafka_actions/tests/test_message_deserializer.py +++ b/kafka_actions/tests/test_message_deserializer.py @@ -341,9 +341,12 @@ def test_protobuf_explicit_schema_registry_configuration(self): b'\x1a\x0c\x41\x6c\x61\x6e\x20\x44\x6f\x6e\x6f\x76\x61\x6e' ) - # Protobuf message WITH Schema Registry format (magic byte 0x00 + schema ID 350 = 0x015E) + # Protobuf message WITH Schema Registry format (Confluent wire format) + # - magic byte 0x00 + schema ID 350 = 0x015E + # - message indices: [0] encoded as varint array (0x01 0x00 = 1 element, value 0) protobuf_message_with_sr = ( - b'\x00\x00\x00\x01\x5e' + b'\x00\x00\x00\x01\x5e' # Schema Registry header + b'\x01\x00' # Message indices: array length 1, index [0] b'\x08\xe8\xba\xb2\xeb\xd1\x9c\x02\x12\x1b\x54\x68\x65\x20\x47\x6f\x20\x50\x72\x6f\x67\x72\x61\x6d\x6d\x69\x6e\x67\x20\x4c\x61\x6e\x67\x75\x61\x67\x65' b'\x1a\x0c\x41\x6c\x61\x6e\x20\x44\x6f\x6e\x6f\x76\x61\x6e' ) diff --git a/kafka_consumer/changelog.d/22265.fixed b/kafka_consumer/changelog.d/22265.fixed new file mode 100644 index 0000000000000..bd76d8982cbe3 --- /dev/null +++ b/kafka_consumer/changelog.d/22265.fixed @@ -0,0 +1 @@ +Fix Kafka message decoding when using Protobuf with schema registry. diff --git a/kafka_consumer/datadog_checks/kafka_consumer/kafka_consumer.py b/kafka_consumer/datadog_checks/kafka_consumer/kafka_consumer.py index e0fac95ee663b..96a0f38d86af0 100644 --- a/kafka_consumer/datadog_checks/kafka_consumer/kafka_consumer.py +++ b/kafka_consumer/datadog_checks/kafka_consumer/kafka_consumer.py @@ -769,6 +769,9 @@ def _deserialize_protobuf(message, schema_info, uses_schema_registry): try: if uses_schema_registry: message_indices, message = _read_protobuf_message_indices(message) + # Empty indices array means use the first message type (index 0) + if not message_indices: + message_indices = [0] else: message_indices = [0] diff --git a/kafka_consumer/tests/test_unit.py b/kafka_consumer/tests/test_unit.py index 597f5f32024e2..c3da3a0751407 100644 --- a/kafka_consumer/tests/test_unit.py +++ b/kafka_consumer/tests/test_unit.py @@ -973,6 +973,57 @@ def test_protobuf_message_indices_with_schema_registry(): assert result[0] and 'Fiction' in result[0] +def test_protobuf_empty_message_indices_with_schema_registry(): + """Test Confluent Protobuf wire format with empty message indices array. + + When message indices array is empty (encoded as varint 0x00), it should + default to using the first message type (index 0). + + This test uses real message bytes from a Kafka topic to ensure the + deserialization handles the Confluent wire format correctly. + """ + key = b'null' + + # Schema from real Kafka topic - Purchase message + # message Purchase { string order_id = 1; string customer_id = 2; int64 order_date = 3; + # string city = 6; string country = 7; } + protobuf_schema = ( + 'CrkDCgxzY2hlbWEucHJvdG8SCHB1cmNoYXNlIpMBCghQdXJjaGFzZRIZCghvcmRlcl9pZBgBIAEoCVIH' + 'b3JkZXJJZBIfCgtjdXN0b21lcl9pZBgCIAEoCVIKY3VzdG9tZXJJZBIdCgpvcmRlcl9kYXRlGAMgASgD' + 'UglvcmRlckRhdGUSEgoEY2l0eRgGIAEoCVIEY2l0eRIYCgdjb3VudHJ5GAcgASgJUgdjb3VudHJ5ItIB' + 'CgpQdXJjaGFzZVYyEiUKDnRyYW5zYWN0aW9uX2lkGAEgASgJUg10cmFuc2FjdGlvbklkEhcKB3VzZXJf' + 'aWQYAiABKAlSBnVzZXJJZBIcCgl0aW1lc3RhbXAYAyABKANSCXRpbWVzdGFtcBIaCghsb2NhdGlvbhgE' + 'IAEoCVIIbG9jYXRpb24SFgoGcmVnaW9uGAUgASgJUgZyZWdpb24SFgoGYW1vdW50GAYgASgBUgZhbW91' + 'bnQSGgoIY3VycmVuY3kYByABKAlSCGN1cnJlbmN5QiwKG2RhdGFkb2cua2Fma2EuZXhhbXBsZS5wcm90' + 'b0INUHVyY2hhc2VQcm90b2IGcHJvdG8z' + ) + parsed_schema = build_schema('protobuf', protobuf_schema) + + # Real message from Kafka topic "human-orders" + # Hex breakdown: + # 00 00 00 00 01 - Schema Registry header (magic byte + schema ID 1) + # 00 - Empty message indices array (varint 0 = 0 elements) + # 0a 05 31 32 33 34 35 ... - Protobuf payload (Purchase message) + message_hex = '0000000001000a0531323334351205363738393018f4eae0c4b8333a064d657869636f' + message_bytes = bytes.fromhex(message_hex) + + # Test with uses_schema_registry=True (explicit) + result = deserialize_message(MockedMessage(message_bytes, key), 'protobuf', parsed_schema, True, 'json', '', False) + assert result[0], "Deserialization should succeed" + assert '12345' in result[0], "Should contain order_id" + assert '67890' in result[0], "Should contain customer_id" + assert 'Mexico' in result[0], "Should contain country" + assert result[1] == 1, "Should detect schema ID 1" + + # Test with uses_schema_registry=False (fallback mode) + result_fallback = deserialize_message( + MockedMessage(message_bytes, key), 'protobuf', parsed_schema, False, 'json', '', False + ) + assert result_fallback[0], "Fallback mode should also succeed" + assert '12345' in result_fallback[0], "Fallback should contain order_id" + assert result_fallback[1] == 1, "Fallback should detect schema ID 1" + + def mocked_time(): return 400