Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions src/confluent_kafka/schema_registry/_async/avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,16 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731
parsed_schema = self._parsed_schema

with _ContextStringIO() as fo:
# write the record to the rest of the buffer
schemaless_writer(fo, parsed_schema, value)
buffer = fo.getvalue()
# Check if it's a simple bytes type
is_bytes = (parsed_schema == "bytes" or
(isinstance(parsed_schema, dict) and parsed_schema.get("type") == "bytes"))
if is_bytes:
# For simple bytes type, write value directly
buffer = value if isinstance(value, bytes) else value.encode()
else:
# write the record to the rest of the buffer
schemaless_writer(fo, parsed_schema, value)
buffer = fo.getvalue()

if latest_schema is not None:
buffer = self._execute_rules_with_phase(
Expand Down Expand Up @@ -585,17 +592,29 @@ async def __deserialize(
reader_schema_raw = writer_schema_raw
reader_schema = writer_schema

# Check if it's a simple bytes type
is_bytes = (writer_schema == "bytes" or
(isinstance(writer_schema, dict) and writer_schema.get("type") == "bytes"))

if migrations:
obj_dict = schemaless_reader(payload,
writer_schema,
None,
self._return_record_name)
if is_bytes:
# For simple bytes type, read payload directly
obj_dict = payload.read()
else:
obj_dict = schemaless_reader(payload,
writer_schema,
None,
self._return_record_name)
obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict)
else:
obj_dict = schemaless_reader(payload,
writer_schema,
reader_schema,
self._return_record_name)
if is_bytes:
# For simple bytes type, read payload directly
obj_dict = payload.read()
else:
obj_dict = schemaless_reader(payload,
writer_schema,
reader_schema,
self._return_record_name)

def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731
transform(rule_ctx, reader_schema, message, field_transform))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,8 @@ async def send_request(
if body is not None:
body = json.dumps(body)
headers = {'Content-Length': str(len(body)),
'Content-Type': "application/vnd.schemaregistry.v1+json"}
'Content-Type': "application/vnd.schemaregistry.v1+json",
'Accept-Version': "8.0"}

if self.bearer_auth_credentials_source:
await self.handle_bearer_auth(headers)
Expand Down
41 changes: 30 additions & 11 deletions src/confluent_kafka/schema_registry/_sync/avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,16 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731
parsed_schema = self._parsed_schema

with _ContextStringIO() as fo:
# write the record to the rest of the buffer
schemaless_writer(fo, parsed_schema, value)
buffer = fo.getvalue()
# Check if it's a simple bytes type
is_bytes = (parsed_schema == "bytes" or
(isinstance(parsed_schema, dict) and parsed_schema.get("type") == "bytes"))
if is_bytes:
# For simple bytes type, write value directly
buffer = value if isinstance(value, bytes) else value.encode()
else:
# write the record to the rest of the buffer
schemaless_writer(fo, parsed_schema, value)
buffer = fo.getvalue()

if latest_schema is not None:
buffer = self._execute_rules_with_phase(
Expand Down Expand Up @@ -585,17 +592,29 @@ def __deserialize(
reader_schema_raw = writer_schema_raw
reader_schema = writer_schema

# Check if it's a simple bytes type
is_bytes = (writer_schema == "bytes" or
(isinstance(writer_schema, dict) and writer_schema.get("type") == "bytes"))

if migrations:
obj_dict = schemaless_reader(payload,
writer_schema,
None,
self._return_record_name)
if is_bytes:
# For simple bytes type, read payload directly
obj_dict = payload.read()
else:
obj_dict = schemaless_reader(payload,
writer_schema,
None,
self._return_record_name)
obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict)
else:
obj_dict = schemaless_reader(payload,
writer_schema,
reader_schema,
self._return_record_name)
if is_bytes:
# For simple bytes type, read payload directly
obj_dict = payload.read()
else:
obj_dict = schemaless_reader(payload,
writer_schema,
reader_schema,
self._return_record_name)

def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731
transform(rule_ctx, reader_schema, message, field_transform))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,8 @@ def send_request(
if body is not None:
body = json.dumps(body)
headers = {'Content-Length': str(len(body)),
'Content-Type': "application/vnd.schemaregistry.v1+json"}
'Content-Type': "application/vnd.schemaregistry.v1+json",
'Accept-Version': "8.0"}

if self.bearer_auth_credentials_source:
self.handle_bearer_auth(headers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def new_kms_client(self, conf: Dict[str, Any], key_url: Optional[str]) -> KmsCli
role_external_id = conf.get(_ROLE_EXTERNAL_ID)
if role_external_id is None:
role_external_id = os.getenv("AWS_ROLE_EXTERNAL_ID")
role_web_identity_token_file = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
key = conf.get(_ACCESS_KEY_ID)
secret = conf.get(_SECRET_ACCESS_KEY)
profile = conf.get(_PROFILE)
Expand All @@ -74,7 +75,8 @@ def new_kms_client(self, conf: Dict[str, Any], key_url: Optional[str]) -> KmsCli
)
else:
session = boto3.Session(region_name=region)
if role_arn is not None:
# If role_web_identity_token_file is set, use the DefaultCredentialsProvider
if role_arn is not None and role_web_identity_token_file is None:
sts_client = session.client('sts')
params = {
'RoleArn': role_arn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,7 @@ def register_dek(
encrypted_key_material=encrypted_key_material
)

response = self._rest_client.post('/dek-registry/v1/keks/{}/deks'
.format(urllib.parse.quote(kek_name)),
request.to_dict())
dek = Dek.from_dict(response)
dek = self._create_dek(kek_name, request)

self._dek_cache.set(cache_key, dek)
# Ensure latest dek is invalidated, such as in case of conflict (409)
Expand All @@ -611,6 +608,27 @@ def register_dek(

return dek

def _create_dek(
self, kek_name: str, request: CreateDekRequest
) -> Dek:
from confluent_kafka.schema_registry.error import SchemaRegistryError
try:
# Try newer API with subject in the path
path = '/dek-registry/v1/keks/{}/deks/{}'.format(
urllib.parse.quote(kek_name),
urllib.parse.quote(request.subject, safe='')
)
response = self._rest_client.post(path, request.to_dict())
return Dek.from_dict(response)
except SchemaRegistryError as e:
if e.http_status_code == 405:
# Try fallback to older API that does not have subject in the path
path = '/dek-registry/v1/keks/{}/deks'.format(urllib.parse.quote(kek_name))
response = self._rest_client.post(path, request.to_dict())
return Dek.from_dict(response)
else:
raise

def get_dek(
self, kek_name: str, subject: str, algorithm: DekAlgorithm = DekAlgorithm.AES256_GCM,
version: int = 1, deleted: bool = False
Expand Down
16 changes: 16 additions & 0 deletions tests/schema_registry/_async/test_avro_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,22 @@ async def test_avro_serialize_use_schema_id():
assert obj == obj2


async def test_avro_serialize_bytes():
conf = {'url': _BASE_URL}
client = AsyncSchemaRegistryClient.new_client(conf)
ser_conf = {'auto.register.schemas': True}
obj = b'\x02\x03\x04'
schema = 'bytes'
ser = await AsyncAvroSerializer(client, schema_str=json.dumps(schema), conf=ser_conf)
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
obj_bytes = await ser(obj, ser_ctx)
assert b'\x00\x00\x00\x00\x01\x02\x03\x04' == obj_bytes

deser = await AsyncAvroDeserializer(client)
obj2 = await deser(obj_bytes, ser_ctx)
assert obj == obj2


async def test_avro_serialize_nested():
conf = {'url': _BASE_URL}
client = AsyncSchemaRegistryClient.new_client(conf)
Expand Down
16 changes: 16 additions & 0 deletions tests/schema_registry/_sync/test_avro_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,22 @@ def test_avro_serialize_use_schema_id():
assert obj == obj2


def test_avro_serialize_bytes():
conf = {'url': _BASE_URL}
client = SchemaRegistryClient.new_client(conf)
ser_conf = {'auto.register.schemas': True}
obj = b'\x02\x03\x04'
schema = 'bytes'
ser = AvroSerializer(client, schema_str=json.dumps(schema), conf=ser_conf)
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
obj_bytes = ser(obj, ser_ctx)
assert b'\x00\x00\x00\x00\x01\x02\x03\x04' == obj_bytes

deser = AvroDeserializer(client)
obj2 = deser(obj_bytes, ser_ctx)
assert obj == obj2


def test_avro_serialize_nested():
conf = {'url': _BASE_URL}
client = SchemaRegistryClient.new_client(conf)
Expand Down