Skip to content

Commit b89ac23

Browse files
committed
allow for named custom array types
1 parent 49c52f4 commit b89ac23

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

schema_salad/avro/schema.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,47 @@ def items(self) -> Schema:
429429
return cast(Schema, self.get_prop("items"))
430430

431431

432+
class NamedArraySchema(NamedSchema):
433+
"""Avro named array schema class."""
434+
435+
def __init__(
436+
self,
437+
items: JsonDataType,
438+
names: Names,
439+
name: str,
440+
namespace: Optional[str] = None,
441+
doc: Optional[Union[str, list[str]]] = None,
442+
other_props: Optional[PropsType] = None,
443+
) -> None:
444+
"""Create a NamedArraySchema object."""
445+
# Call parent ctor
446+
NamedSchema.__init__(self, "array", name, namespace, names, other_props)
447+
# Add class members
448+
449+
if names is None:
450+
raise SchemaParseException("Must provide Names.")
451+
if isinstance(items, str) and names.has_name(items, None):
452+
items_schema = cast(Schema, names.get_name(items, None))
453+
else:
454+
try:
455+
items_schema = make_avsc_object(items, names)
456+
except Exception as err:
457+
raise SchemaParseException(
458+
f"Items schema ({items}) not a valid Avro schema: {err}. "
459+
f"Known names: {list(names.names.keys())})."
460+
) from err
461+
462+
self.set_prop("items", items_schema)
463+
if doc is not None:
464+
self.set_prop("doc", doc)
465+
466+
# read-only properties
467+
@property
468+
def items(self) -> Schema:
469+
"""Avro schema describing the array items' type."""
470+
return cast(Schema, self.get_prop("items"))
471+
472+
432473
class MapSchema(Schema):
433474
"""Avro map schema class."""
434475

@@ -740,6 +781,11 @@ def make_avsc_object(json_data: JsonDataType, names: Optional[Names] = None) ->
740781
if atype in VALID_TYPES:
741782
if atype == "array":
742783
items = json_data.get("items")
784+
if "name" in json_data and json_data["name"]:
785+
name = json_data["name"]
786+
namespace = json_data.get("namespace", names.default_namespace)
787+
doc = json_data.get("doc")
788+
return NamedArraySchema(items, names, name, namespace, doc, other_props)
743789
return ArraySchema(items, names, other_props)
744790
elif atype == "map":
745791
values = json_data.get("values")
@@ -748,8 +794,7 @@ def make_avsc_object(json_data: JsonDataType, names: Optional[Names] = None) ->
748794
namespace = json_data.get("namespace", names.default_namespace)
749795
doc = json_data.get("doc")
750796
return NamedMapSchema(values, names, name, namespace, doc, other_props)
751-
else:
752-
return MapSchema(values, names, other_props)
797+
return MapSchema(values, names, other_props)
753798
elif atype == "union":
754799
schemas = json_data.get("names")
755800
if not isinstance(schemas, list):
@@ -761,8 +806,7 @@ def make_avsc_object(json_data: JsonDataType, names: Optional[Names] = None) ->
761806
namespace = json_data.get("namespace", names.default_namespace)
762807
doc = json_data.get("doc")
763808
return NamedUnionSchema(schemas, names, name, namespace, doc)
764-
else:
765-
return UnionSchema(schemas, names)
809+
return UnionSchema(schemas, names)
766810
if atype is None:
767811
raise SchemaParseException(f'No "type" property: {json_data}')
768812
raise SchemaParseException(f"Undefined type: {atype}")

schema_salad/validate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def friendly(v: Any) -> Any:
9292
"""Format an Avro schema into a pretty-printed representation."""
9393
if isinstance(v, avro.schema.NamedSchema):
9494
return avro_shortname(v.name)
95-
if isinstance(v, avro.schema.ArraySchema):
95+
if isinstance(v, (avro.schema.ArraySchema, avro.schema.NamedArraySchema)):
9696
return f"array of <{friendly(v.items)}>"
9797
if isinstance(v, (avro.schema.MapSchema, avro.schema.NamedMapSchema)):
9898
return f"map of <{friendly(v.values)}>"
@@ -207,7 +207,7 @@ def validate_ex(
207207
)
208208
)
209209
return False
210-
if isinstance(expected_schema, avro.schema.ArraySchema):
210+
if isinstance(expected_schema, (avro.schema.ArraySchema, avro.schema.NamedArraySchema)):
211211
if isinstance(datum, MutableSequence):
212212
for i, d in enumerate(datum):
213213
try:
@@ -258,7 +258,9 @@ def validate_ex(
258258
errors: list[SchemaSaladException] = []
259259
checked = []
260260
for s in expected_schema.schemas:
261-
if isinstance(datum, MutableSequence) and not isinstance(s, avro.schema.ArraySchema):
261+
if isinstance(datum, MutableSequence) and not isinstance(
262+
s, (avro.schema.ArraySchema, avro.schema.NamedArraySchema)
263+
):
262264
continue
263265
if isinstance(datum, MutableMapping) and not isinstance(
264266
s, (avro.schema.RecordSchema, avro.schema.MapSchema, avro.schema.NamedMapSchema)
@@ -268,6 +270,7 @@ def validate_ex(
268270
s,
269271
(
270272
avro.schema.ArraySchema,
273+
avro.schema.NamedArraySchema,
271274
avro.schema.RecordSchema,
272275
avro.schema.MapSchema,
273276
avro.schema.NamedMapSchema,

0 commit comments

Comments
 (0)