Skip to content

Commit d905f80

Browse files
authored
Add optional field for Unions to track types at runtime (#15)
1 parent 3a705c7 commit d905f80

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

src/py_avro_schema/_schemas.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272

7373
NamesType = List[str]
7474

75+
RUNTIME_TYPE_KEY = "_runtime_type"
76+
7577

7678
class TypeNotSupportedError(TypeError):
7779
"""Error raised when a Avro schema cannot be generated for a given Python type"""
@@ -137,6 +139,9 @@ class Option(enum.Flag):
137139
# the two cases.
138140
MARK_NON_TOTAL_TYPED_DICTS = enum.auto()
139141

142+
#: Adds a _runtime_type field to the record schemas that contains the name of the class
143+
ADD_RUNTIME_TYPE_FIELD = enum.auto()
144+
140145

141146
JSON_OPTIONS = [opt for opt in Option if opt.name and opt.name.startswith("JSON_")]
142147

@@ -1105,6 +1110,13 @@ def _record_field(self, py_field: dataclasses.Field) -> RecordField:
11051110

11061111
return field_obj
11071112

1113+
def data_before_deduplication(self, names: NamesType) -> JSONObj:
1114+
"""Return the schema data"""
1115+
data = super().data_before_deduplication(names)
1116+
if Option.ADD_RUNTIME_TYPE_FIELD in self.options:
1117+
data["fields"].append({"name": RUNTIME_TYPE_KEY, "type": ["null", "string"]})
1118+
return data
1119+
11081120

11091121
@register_schema
11101122
class PydanticSchema(RecordSchema):
@@ -1239,6 +1251,13 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
12391251
)
12401252
return field_obj
12411253

1254+
def data_before_deduplication(self, names: NamesType) -> JSONObj:
1255+
"""Return the schema data"""
1256+
data = super().data_before_deduplication(names)
1257+
if Option.ADD_RUNTIME_TYPE_FIELD in self.options:
1258+
data["fields"].append({"name": RUNTIME_TYPE_KEY, "type": ["null", "string"]})
1259+
return data
1260+
12421261

12431262
@register_schema
12441263
class TypedDictSchema(RecordSchema):

tests/test_avro_schema.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import py_avro_schema as pas
2020
from py_avro_schema._alias import register_type_alias, register_type_aliases
21+
from py_avro_schema._testing import assert_schema
2122

2223

2324
def test_package_has_version():
@@ -64,3 +65,18 @@ class PyTypedDict(TypedDict):
6465
"test_avro_schema.SuperOldDict",
6566
"test_avro_schema.VeryOldDict",
6667
]
68+
69+
70+
def test_add_type_field():
71+
class PyType:
72+
field: str
73+
74+
expected = {
75+
"type": "record",
76+
"name": "PyType",
77+
"fields": [
78+
{"name": "field", "type": "string"},
79+
{"name": "_runtime_type", "type": ["null", "string"]},
80+
],
81+
}
82+
assert_schema(PyType, expected, options=pas.Option.ADD_RUNTIME_TYPE_FIELD)

0 commit comments

Comments
 (0)