diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 03514bb50db0..1e3bb2ece92a 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -32,6 +32,7 @@ import decimal import enum +import functools import itertools import json import logging @@ -354,6 +355,7 @@ def decode(self, value): NAMED_TUPLE_TYPE = 102 ENUM_TYPE = 103 NESTED_STATE_TYPE = 104 +DATACLASS_KW_ONLY_TYPE = 105 # Types that can be encoded as iterables, but are not literally # lists, etc. due to being lazy. The actual type is not preserved @@ -374,6 +376,18 @@ def _verify_dill_compat(): raise RuntimeError(base_error + f". Found dill version '{dill.__version__}") +dataclass_uses_kw_only: Callable[[Any], bool] +if dataclasses: + # Cache the result to avoid multiple checks for the same dataclass type. + @functools.cache + def dataclass_uses_kw_only(cls) -> bool: + return any( + field.init and field.kw_only for field in dataclasses.fields(cls)) + +else: + dataclass_uses_kw_only = lambda cls: False + + class FastPrimitivesCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" def __init__( @@ -497,18 +511,25 @@ def encode_special_deterministic(self, value, stream): self.encode_type(type(value), stream) stream.write(value.SerializePartialToString(deterministic=True), True) elif dataclasses and dataclasses.is_dataclass(value): - stream.write_byte(DATACLASS_TYPE) if not type(value).__dataclass_params__.frozen: raise TypeError( "Unable to deterministically encode non-frozen '%s' of type '%s' " "for the input of '%s'" % (value, type(value), self.requires_deterministic_step_label)) - self.encode_type(type(value), stream) - values = [ - getattr(value, field.name) for field in dataclasses.fields(value) - ] + init_fields = [field for field in dataclasses.fields(value) if field.init] try: - self.iterable_coder_impl.encode_to_stream(values, stream, True) + if dataclass_uses_kw_only(type(value)): + stream.write_byte(DATACLASS_KW_ONLY_TYPE) + self.encode_type(type(value), stream) + stream.write_var_int64(len(init_fields)) + for field in init_fields: + stream.write(field.name.encode("utf-8"), True) + self.encode_to_stream(getattr(value, field.name), stream, True) + else: # Not using kw_only, we can pass parameters by position. + stream.write_byte(DATACLASS_TYPE) + self.encode_type(type(value), stream) + values = [getattr(value, field.name) for field in init_fields] + self.iterable_coder_impl.encode_to_stream(values, stream, True) except Exception as e: raise TypeError(self._deterministic_encoding_error_msg(value)) from e elif isinstance(value, tuple) and hasattr(type(value), '_fields'): @@ -616,6 +637,14 @@ def decode_from_stream(self, stream, nested): msg = cls() msg.ParseFromString(stream.read_all(True)) return msg + elif t == DATACLASS_KW_ONLY_TYPE: + cls = self.decode_type(stream) + vlen = stream.read_var_int64() + fields = {} + for _ in range(vlen): + field_name = stream.read_all(True).decode('utf-8') + fields[field_name] = self.decode_from_stream(stream, True) + return cls(**fields) elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE: cls = self.decode_type(stream) return cls(*self.iterable_coder_impl.decode_from_stream(stream, True)) diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 8a57d1e63e2c..8f89ab9602c1 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -113,6 +113,11 @@ class FrozenDataClass: a: Any b: int + @dataclasses.dataclass(frozen=True, kw_only=True) + class FrozenKwOnlyDataClass: + c: int + d: int + @dataclasses.dataclass class UnFrozenDataClass: x: int @@ -303,9 +308,11 @@ def test_deterministic_coder(self, compat_version): if dataclasses is not None: self.check_coder(deterministic_coder, FrozenDataClass(1, 2)) + self.check_coder(deterministic_coder, FrozenKwOnlyDataClass(c=1, d=2)) with self.assertRaises(TypeError): self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2)) + with self.assertRaises(TypeError): self.check_coder( deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3)) @@ -742,6 +749,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic( from apache_beam.coders.coders_test_common import DefinesGetState from apache_beam.coders.coders_test_common import DefinesGetAndSetState from apache_beam.coders.coders_test_common import FrozenDataClass + from apache_beam.coders.coders_test_common import FrozenKwOnlyDataClass from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message @@ -777,6 +785,8 @@ def test_cross_process_encoding_of_special_types_is_deterministic( test_cases.extend([ ("frozen_dataclass", FrozenDataClass(1, 2)), ("frozen_dataclass_list", [FrozenDataClass(1, 2), FrozenDataClass(3, 4)]), + ("frozen_kwonly_dataclass", FrozenKwOnlyDataClass(c=1, d=2)), + ("frozen_kwonly_dataclass_list", [FrozenKwOnlyDataClass(c=1, d=2), FrozenKwOnlyDataClass(c=3, d=4)]), ]) compat_version = {'"'+ compat_version +'"' if compat_version else None}