Skip to content

Commit 87de10f

Browse files
authored
Add support for kw_only dataclasses (addresses #36978) (#36979)
* add support for kw_only dataclasses (#36978) * use a different type for kw_ony dataclasses * allow passing positional parameters to create dataclasses when possible * remove wrong TODO * add function typehint for pylint * minor refactoring
1 parent ff01a52 commit 87de10f

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

sdks/python/apache_beam/coders/coder_impl.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import decimal
3434
import enum
35+
import functools
3536
import itertools
3637
import json
3738
import logging
@@ -354,6 +355,7 @@ def decode(self, value):
354355
NAMED_TUPLE_TYPE = 102
355356
ENUM_TYPE = 103
356357
NESTED_STATE_TYPE = 104
358+
DATACLASS_KW_ONLY_TYPE = 105
357359

358360
# Types that can be encoded as iterables, but are not literally
359361
# lists, etc. due to being lazy. The actual type is not preserved
@@ -374,6 +376,18 @@ def _verify_dill_compat():
374376
raise RuntimeError(base_error + f". Found dill version '{dill.__version__}")
375377

376378

379+
dataclass_uses_kw_only: Callable[[Any], bool]
380+
if dataclasses:
381+
# Cache the result to avoid multiple checks for the same dataclass type.
382+
@functools.cache
383+
def dataclass_uses_kw_only(cls) -> bool:
384+
return any(
385+
field.init and field.kw_only for field in dataclasses.fields(cls))
386+
387+
else:
388+
dataclass_uses_kw_only = lambda cls: False
389+
390+
377391
class FastPrimitivesCoderImpl(StreamCoderImpl):
378392
"""For internal use only; no backwards-compatibility guarantees."""
379393
def __init__(
@@ -497,18 +511,25 @@ def encode_special_deterministic(self, value, stream):
497511
self.encode_type(type(value), stream)
498512
stream.write(value.SerializePartialToString(deterministic=True), True)
499513
elif dataclasses and dataclasses.is_dataclass(value):
500-
stream.write_byte(DATACLASS_TYPE)
501514
if not type(value).__dataclass_params__.frozen:
502515
raise TypeError(
503516
"Unable to deterministically encode non-frozen '%s' of type '%s' "
504517
"for the input of '%s'" %
505518
(value, type(value), self.requires_deterministic_step_label))
506-
self.encode_type(type(value), stream)
507-
values = [
508-
getattr(value, field.name) for field in dataclasses.fields(value)
509-
]
519+
init_fields = [field for field in dataclasses.fields(value) if field.init]
510520
try:
511-
self.iterable_coder_impl.encode_to_stream(values, stream, True)
521+
if dataclass_uses_kw_only(type(value)):
522+
stream.write_byte(DATACLASS_KW_ONLY_TYPE)
523+
self.encode_type(type(value), stream)
524+
stream.write_var_int64(len(init_fields))
525+
for field in init_fields:
526+
stream.write(field.name.encode("utf-8"), True)
527+
self.encode_to_stream(getattr(value, field.name), stream, True)
528+
else: # Not using kw_only, we can pass parameters by position.
529+
stream.write_byte(DATACLASS_TYPE)
530+
self.encode_type(type(value), stream)
531+
values = [getattr(value, field.name) for field in init_fields]
532+
self.iterable_coder_impl.encode_to_stream(values, stream, True)
512533
except Exception as e:
513534
raise TypeError(self._deterministic_encoding_error_msg(value)) from e
514535
elif isinstance(value, tuple) and hasattr(type(value), '_fields'):
@@ -616,6 +637,14 @@ def decode_from_stream(self, stream, nested):
616637
msg = cls()
617638
msg.ParseFromString(stream.read_all(True))
618639
return msg
640+
elif t == DATACLASS_KW_ONLY_TYPE:
641+
cls = self.decode_type(stream)
642+
vlen = stream.read_var_int64()
643+
fields = {}
644+
for _ in range(vlen):
645+
field_name = stream.read_all(True).decode('utf-8')
646+
fields[field_name] = self.decode_from_stream(stream, True)
647+
return cls(**fields)
619648
elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE:
620649
cls = self.decode_type(stream)
621650
return cls(*self.iterable_coder_impl.decode_from_stream(stream, True))

sdks/python/apache_beam/coders/coders_test_common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ class FrozenDataClass:
113113
a: Any
114114
b: int
115115

116+
@dataclasses.dataclass(frozen=True, kw_only=True)
117+
class FrozenKwOnlyDataClass:
118+
c: int
119+
d: int
120+
116121
@dataclasses.dataclass
117122
class UnFrozenDataClass:
118123
x: int
@@ -303,9 +308,11 @@ def test_deterministic_coder(self, compat_version):
303308

304309
if dataclasses is not None:
305310
self.check_coder(deterministic_coder, FrozenDataClass(1, 2))
311+
self.check_coder(deterministic_coder, FrozenKwOnlyDataClass(c=1, d=2))
306312

307313
with self.assertRaises(TypeError):
308314
self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2))
315+
309316
with self.assertRaises(TypeError):
310317
self.check_coder(
311318
deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3))
@@ -742,6 +749,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic(
742749
from apache_beam.coders.coders_test_common import DefinesGetState
743750
from apache_beam.coders.coders_test_common import DefinesGetAndSetState
744751
from apache_beam.coders.coders_test_common import FrozenDataClass
752+
from apache_beam.coders.coders_test_common import FrozenKwOnlyDataClass
745753
746754
747755
from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
@@ -777,6 +785,8 @@ def test_cross_process_encoding_of_special_types_is_deterministic(
777785
test_cases.extend([
778786
("frozen_dataclass", FrozenDataClass(1, 2)),
779787
("frozen_dataclass_list", [FrozenDataClass(1, 2), FrozenDataClass(3, 4)]),
788+
("frozen_kwonly_dataclass", FrozenKwOnlyDataClass(c=1, d=2)),
789+
("frozen_kwonly_dataclass_list", [FrozenKwOnlyDataClass(c=1, d=2), FrozenKwOnlyDataClass(c=3, d=4)]),
780790
])
781791
782792
compat_version = {'"'+ compat_version +'"' if compat_version else None}

0 commit comments

Comments
 (0)