Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def encode_special_deterministic(self, value, stream):
stream.write_byte(PROTO_TYPE)
self.encode_type(type(value), stream)
stream.write(value.SerializePartialToString(deterministic=True), True)
elif dataclasses and dataclasses.is_dataclass(value):
elif dataclasses.is_dataclass(value):
if not type(value).__dataclass_params__.frozen:
raise TypeError(
"Unable to deterministically encode non-frozen '%s' of type '%s' "
Expand Down
26 changes: 24 additions & 2 deletions sdks/python/apache_beam/typehints/native_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,30 @@ def match_is_named_tuple(user_type):
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))


def match_is_dataclass(user_type):
return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
def match_dataclass_for_row(user_type):
"""Match whether the type is a dataclass handled by row coder.

for frozen dataclasses, only true when explicitly registered with row coder:

beam.coders.typecoders.registry.register_coder(
MyDataClass, beam.coders.RowCoder)
"""
if not dataclasses.is_dataclass(user_type):
return False

if not user_type.__dataclass_params__.frozen:
return True

# avoid circular import
# pylint: disable=wrong-import-position
from apache_beam.coders.typecoders import registry as coders_registry
from apache_beam.coders import RowCoder

# check _coders (not get_coder) to get the registered coder directly without
# fallback
return (
user_type in coders_registry._coders and
coders_registry._coders[user_type] == RowCoder)


def _match_is_optional(user_type):
Expand Down
7 changes: 5 additions & 2 deletions sdks/python/apache_beam/typehints/row_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import Tuple

from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schema_registry import SchemaTypeRegistry

Expand Down Expand Up @@ -91,6 +91,9 @@ def __init__(
# Currently registration happens when converting to schema protos, in
# apache_beam.typehints.schemas
self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)
if self._schema_id and _BEAM_SCHEMA_ID not in self._user_type.__dict__:
# schema id does not inherit. Unset if schema id is from base class
self._schema_id = None

self._schema_options = schema_options or []
self._field_options = field_options or {}
Expand All @@ -105,7 +108,7 @@ def from_user_type(
if match_is_named_tuple(user_type):
fields = [(name, user_type.__annotations__[name])
for name in user_type._fields]
elif match_is_dataclass(user_type):
elif match_dataclass_for_row(user_type):
fields = [(field.name, field.type)
for field in dataclasses.fields(user_type)]
else:
Expand Down
89 changes: 89 additions & 0 deletions sdks/python/apache_beam/typehints/row_type_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.typehints import row_type
from apache_beam.typehints import schemas


class RowTypeTest(unittest.TestCase):
Expand Down Expand Up @@ -85,6 +86,94 @@ def generate(num: int):
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
assert_that(result, equal_to([10] * 100))

def test_group_by_key_namedtuple_union(self):
Tuple1 = typing.NamedTuple("Tuple1", [("id", int)])

Tuple2 = typing.NamedTuple("Tuple2", [("id", int), ("name", str)])

def generate(num: int):
for i in range(2):
yield (Tuple1(i), num)
yield (Tuple2(i, 'a'), num)

pipeline = TestPipeline(is_integration_test=False)

with pipeline as p:
result = (
p
| 'Create' >> beam.Create([i for i in range(2)])
| 'Generate' >> beam.ParDo(generate).with_output_types(
tuple[(Tuple1 | Tuple2), int])
| 'GBK' >> beam.GroupByKey()
| 'Count' >> beam.Map(lambda x: len(x[1])))
assert_that(result, equal_to([2] * 4))

# Union of dataclasses as type hint currently result in FastPrimitiveCoder
# fails at GBK
@unittest.skip("https://github.com/apache/beam/issues/22085")
def test_group_by_key_inherited_dataclass_union(self):
@dataclass
class DataClassInt:
id: int

@dataclass
class DataClassStr(DataClassInt):
name: str

beam.coders.typecoders.registry.register_coder(
DataClassInt, beam.coders.RowCoder)
beam.coders.typecoders.registry.register_coder(
DataClassStr, beam.coders.RowCoder)

def generate(num: int):
for i in range(10):
yield (DataClassInt(i), num)
yield (DataClassStr(i, 'a'), num)

pipeline = TestPipeline(is_integration_test=False)

with pipeline as p:
result = (
p
| 'Create' >> beam.Create([i for i in range(2)])
| 'Generate' >> beam.ParDo(generate).with_output_types(
tuple[(DataClassInt | DataClassStr), int])
| 'GBK' >> beam.GroupByKey()
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
assert_that(result, equal_to([2] * 4))

def test_derived_dataclass_schema_id(self):
@dataclass
class BaseDataClass:
id: int

@dataclass
class DerivedDataClass(BaseDataClass):
name: str

self.assertFalse(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
schema_for_base = schemas.schema_from_element_type(BaseDataClass)
self.assertTrue(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
self.assertEqual(
schema_for_base.id, getattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))

# Getting the schema for BaseDataClass sets the _beam_schema_id
schemas.typing_to_runner_api(
BaseDataClass, schema_registry=schemas.SchemaTypeRegistry())

# We create a RowTypeConstraint from DerivedDataClass.
# It should not inherit the _beam_schema_id from BaseDataClass!
derived_row_type = row_type.RowTypeConstraint.from_user_type(
DerivedDataClass)
self.assertIsNone(derived_row_type._schema_id)

schema_for_derived = schemas.schema_from_element_type(DerivedDataClass)
self.assertTrue(hasattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
self.assertEqual(
schema_for_derived.id,
getattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
self.assertNotEqual(schema_for_derived.id, schema_for_base.id)


if __name__ == '__main__':
unittest.main()
33 changes: 23 additions & 10 deletions sdks/python/apache_beam/typehints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
from apache_beam.typehints.native_type_compatibility import convert_to_python_type
from apache_beam.typehints.native_type_compatibility import extract_optional_type
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
Expand Down Expand Up @@ -335,19 +335,23 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType:
atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[int])))

elif _safe_issubclass(type_, Sequence) and not _safe_issubclass(type_, str):
element_type = self.typing_to_runner_api(_get_args(type_)[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))
arg_types = _get_args(type_)
if len(arg_types) > 0:
element_type = self.typing_to_runner_api(arg_types[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))

elif _safe_issubclass(type_, Mapping):
key_type, value_type = map(self.typing_to_runner_api, _get_args(type_))
return schema_pb2.FieldType(
map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type))

elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str):
element_type = self.typing_to_runner_api(_get_args(type_)[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))
arg_types = _get_args(type_)
if len(arg_types) > 0:
element_type = self.typing_to_runner_api(arg_types[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))

try:
if LogicalType.is_known_logical_type(type_):
Expand Down Expand Up @@ -630,8 +634,10 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema:
Returns schema as a list of (name, python_type) tuples"""
if isinstance(element_type, row_type.RowTypeConstraint):
return named_fields_to_schema(element_type._fields)
elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
elif match_is_named_tuple(element_type) or match_dataclass_for_row(
element_type):
# schema id does not inherit from base classes
if row_type._BEAM_SCHEMA_ID in element_type.__dict__:
# if the named tuple's schema is in registry, we just use it instead of
# regenerating one.
schema_id = getattr(element_type, row_type._BEAM_SCHEMA_ID)
Expand All @@ -657,8 +663,15 @@ def union_schema_type(element_types):
element_types must be a set of schema-aware types whose fields have the
same naming and ordering.
"""
named_fields_and_types = []
for t in element_types:
n = named_fields_from_element_type(t)
if named_fields_and_types and len(named_fields_and_types[-1]) != len(n):
raise TypeError("element types has different number of fields")
named_fields_and_types.append(n)

union_fields_and_types = []
for field in zip(*[named_fields_from_element_type(t) for t in element_types]):
for field in zip(*named_fields_and_types):
names, types = zip(*field)
name_set = set(names)
if len(name_set) != 1:
Expand Down
Loading