Skip to content

Commit 701c97d

Browse files
committed
Handle some cases during infer schema from dataclass
* Make sure Beam schema ID does not inherit * Fix IndexOutofBoundError trying to infer type from custom Iterable without type hint
1 parent ab56619 commit 701c97d

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

sdks/python/apache_beam/typehints/row_type.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def __init__(
9191
# Currently registration happens when converting to schema protos, in
9292
# apache_beam.typehints.schemas
9393
self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)
94+
if self._schema_id and _BEAM_SCHEMA_ID not in self._user_type.__dict__:
95+
# schema id does not inherit. Unset if schema id is from base class
96+
self._schema_id = None
9497

9598
self._schema_options = schema_options or []
9699
self._field_options = field_options or {}

sdks/python/apache_beam/typehints/row_type_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from apache_beam.testing.util import assert_that
2727
from apache_beam.testing.util import equal_to
2828
from apache_beam.typehints import row_type
29+
from apache_beam.typehints import schemas
2930

3031

3132
class RowTypeTest(unittest.TestCase):
@@ -85,6 +86,35 @@ def generate(num: int):
8586
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
8687
assert_that(result, equal_to([10] * 100))
8788

89+
def test_derived_dataclass_schema_id(self):
90+
@dataclass
91+
class BaseDataClass:
92+
id: int
93+
94+
@dataclass
95+
class DerivedDataClass(BaseDataClass):
96+
name: str
97+
98+
self.assertFalse(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
99+
schema_for_base = schemas.schema_from_element_type(BaseDataClass)
100+
self.assertTrue(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
101+
102+
# Getting the schema for BaseDataClass sets the _beam_schema_id
103+
schemas.typing_to_runner_api(
104+
BaseDataClass, schema_registry=schemas.SchemaTypeRegistry())
105+
106+
# We create a RowTypeConstraint from DerivedDataClass.
107+
# It should not inherit the _beam_schema_id from BaseDataClass!
108+
derived_row_type = row_type.RowTypeConstraint.from_user_type(
109+
DerivedDataClass)
110+
self.assertIsNone(derived_row_type._schema_id)
111+
112+
schema_for_derived = schemas.schema_from_element_type(DerivedDataClass)
113+
print(schema_for_derived)
114+
self.assertNotEqual(
115+
getattr(BaseDataClass, row_type._BEAM_SCHEMA_ID),
116+
getattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
117+
88118

89119
if __name__ == '__main__':
90120
unittest.main()

sdks/python/apache_beam/typehints/schemas.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,19 +335,23 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType:
335335
atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[int])))
336336

337337
elif _safe_issubclass(type_, Sequence) and not _safe_issubclass(type_, str):
338-
element_type = self.typing_to_runner_api(_get_args(type_)[0])
339-
return schema_pb2.FieldType(
340-
array_type=schema_pb2.ArrayType(element_type=element_type))
338+
arg_types = _get_args(type_)
339+
if len(arg_types) > 0:
340+
element_type = self.typing_to_runner_api(arg_types[0])
341+
return schema_pb2.FieldType(
342+
array_type=schema_pb2.ArrayType(element_type=element_type))
341343

342344
elif _safe_issubclass(type_, Mapping):
343345
key_type, value_type = map(self.typing_to_runner_api, _get_args(type_))
344346
return schema_pb2.FieldType(
345347
map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type))
346348

347349
elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str):
348-
element_type = self.typing_to_runner_api(_get_args(type_)[0])
349-
return schema_pb2.FieldType(
350-
array_type=schema_pb2.ArrayType(element_type=element_type))
350+
arg_types = _get_args(type_)
351+
if len(arg_types) > 0:
352+
element_type = self.typing_to_runner_api(arg_types[0])
353+
return schema_pb2.FieldType(
354+
array_type=schema_pb2.ArrayType(element_type=element_type))
351355

352356
try:
353357
if LogicalType.is_known_logical_type(type_):
@@ -631,7 +635,8 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema:
631635
if isinstance(element_type, row_type.RowTypeConstraint):
632636
return named_fields_to_schema(element_type._fields)
633637
elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
634-
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
638+
# schema id does not inherit from base classes
639+
if row_type._BEAM_SCHEMA_ID in element_type.__dict__:
635640
# if the named tuple's schema is in registry, we just use it instead of
636641
# regenerating one.
637642
schema_id = getattr(element_type, row_type._BEAM_SCHEMA_ID)

0 commit comments

Comments
 (0)