2828from karapace .utils import assert_never , json_decode , json_encode , JSONDecodeError
2929from typing import Any , cast , Collection , Dict , Final , final , Mapping , Sequence
3030
31+ import avro .schema
3132import hashlib
3233import logging
34+ import re
3335
3436LOG = logging .getLogger (__name__ )
3537
@@ -152,6 +154,7 @@ def normalize_schema_str(
152154 except JSONDecodeError as e :
153155 LOG .info ("Schema is not valid JSON" )
154156 raise e
157+
155158 elif schema_type == SchemaType .PROTOBUF :
156159 if schema :
157160 schema_str = str (schema )
@@ -194,6 +197,45 @@ def schema(self) -> Draft7Validator | AvroSchema | ProtobufSchema:
194197 return parsed_typed_schema .schema
195198
196199
200+ class AvroMerge :
201+ def __init__ (self , schema_str : str , dependencies : Mapping [str , Dependency ] | None = None ):
202+ self .schema_str = json_encode (json_decode (schema_str ), compact = True , sort_keys = True )
203+ self .dependencies = dependencies
204+ self .unique_id = 0
205+ self .regex = re .compile (r"^\s*\[" )
206+
207+ def union_safe_schema_str (self , schema_str : str ) -> str :
208+ # in case we meet union - we use it as is
209+
210+ base_schema = (
211+ f'{{"name": "___RESERVED_KARAPACE_WRAPPER_NAME_{ self .unique_id } ___",'
212+ f'"type": "record", "fields": [{{"name": "name", "type":'
213+ )
214+ if self .regex .match (schema_str ):
215+ return f"{ base_schema } { schema_str } }}]}}"
216+ return f"{ base_schema } [{ schema_str } ]}}]}}"
217+
218+ def builder (self , schema_str : str , dependencies : Mapping [str , Dependency ] | None = None ) -> str :
219+ """To support references in AVRO we iteratively merge all referenced schemas with current schema"""
220+ stack : list [tuple [str , Mapping [str , Dependency ] | None ]] = [(schema_str , dependencies )]
221+ merged_schemas = []
222+
223+ while stack :
224+ current_schema_str , current_dependencies = stack .pop ()
225+ if current_dependencies :
226+ stack .append ((current_schema_str , None ))
227+ for dependency in reversed (current_dependencies .values ()):
228+ stack .append ((dependency .schema .schema_str , dependency .schema .dependencies ))
229+ else :
230+ self .unique_id += 1
231+ merged_schemas .append (self .union_safe_schema_str (current_schema_str ))
232+
233+ return ",\n " .join (merged_schemas )
234+
235+ def wrap (self ) -> str :
236+ return "[\n " + self .builder (self .schema_str , self .dependencies ) + "\n ]"
237+
238+
197239def parse (
198240 schema_type : SchemaType ,
199241 schema_str : str ,
@@ -206,18 +248,37 @@ def parse(
206248) -> ParsedTypedSchema :
207249 if schema_type not in [SchemaType .AVRO , SchemaType .JSONSCHEMA , SchemaType .PROTOBUF ]:
208250 raise InvalidSchema (f"Unknown parser { schema_type } for { schema_str } " )
209-
251+ parsed_schema_result : Draft7Validator | AvroSchema | ProtobufSchema
210252 parsed_schema : Draft7Validator | AvroSchema | ProtobufSchema
211253 if schema_type is SchemaType .AVRO :
212254 try :
255+ if dependencies :
256+ wrapped_schema_str = AvroMerge (schema_str , dependencies ).wrap ()
257+ else :
258+ wrapped_schema_str = schema_str
213259 parsed_schema = parse_avro_schema_definition (
214- schema_str ,
260+ wrapped_schema_str ,
215261 validate_enum_symbols = validate_avro_enum_symbols ,
216262 validate_names = validate_avro_names ,
217263 )
264+ if dependencies :
265+ if isinstance (parsed_schema , avro .schema .UnionSchema ):
266+ parsed_schema_result = parsed_schema .schemas [- 1 ].fields [0 ].type .schemas [- 1 ]
267+
268+ else :
269+ raise InvalidSchema
270+ else :
271+ parsed_schema_result = parsed_schema
272+ return ParsedTypedSchema (
273+ schema_type = schema_type ,
274+ schema_str = schema_str ,
275+ schema = parsed_schema_result ,
276+ references = references ,
277+ dependencies = dependencies ,
278+ schema_wrapped = parsed_schema ,
279+ )
218280 except (SchemaParseException , JSONDecodeError , TypeError ) as e :
219281 raise InvalidSchema from e
220-
221282 elif schema_type is SchemaType .JSONSCHEMA :
222283 try :
223284 parsed_schema = parse_jsonschema_definition (schema_str )
@@ -284,9 +345,10 @@ def __init__(
284345 schema : Draft7Validator | AvroSchema | ProtobufSchema ,
285346 references : Sequence [Reference ] | None = None ,
286347 dependencies : Mapping [str , Dependency ] | None = None ,
348+ schema_wrapped : Draft7Validator | AvroSchema | ProtobufSchema | None = None ,
287349 ) -> None :
288350 self ._schema_cached : Draft7Validator | AvroSchema | ProtobufSchema | None = schema
289-
351+ self . schema_wrapped = schema_wrapped
290352 super ().__init__ (
291353 schema_type = schema_type ,
292354 schema_str = schema_str ,
0 commit comments