3232
3333import decimal
3434import enum
35+ import functools
3536import itertools
3637import json
3738import logging
@@ -354,6 +355,7 @@ def decode(self, value):
354355NAMED_TUPLE_TYPE = 102
355356ENUM_TYPE = 103
356357NESTED_STATE_TYPE = 104
358+ DATACLASS_KW_ONLY_TYPE = 105
357359
358360# Types that can be encoded as iterables, but are not literally
359361# lists, etc. due to being lazy. The actual type is not preserved
@@ -374,6 +376,18 @@ def _verify_dill_compat():
374376 raise RuntimeError (base_error + f". Found dill version '{ dill .__version__ } " )
375377
376378
379+ dataclass_uses_kw_only : Callable [[Any ], bool ]
380+ if dataclasses :
381+ # Cache the result to avoid multiple checks for the same dataclass type.
382+ @functools .cache
383+ def dataclass_uses_kw_only (cls ) -> bool :
384+ return any (
385+ field .init and field .kw_only for field in dataclasses .fields (cls ))
386+
387+ else :
388+ dataclass_uses_kw_only = lambda cls : False
389+
390+
377391class FastPrimitivesCoderImpl (StreamCoderImpl ):
378392 """For internal use only; no backwards-compatibility guarantees."""
379393 def __init__ (
@@ -497,18 +511,25 @@ def encode_special_deterministic(self, value, stream):
497511 self .encode_type (type (value ), stream )
498512 stream .write (value .SerializePartialToString (deterministic = True ), True )
499513 elif dataclasses and dataclasses .is_dataclass (value ):
500- stream .write_byte (DATACLASS_TYPE )
501514 if not type (value ).__dataclass_params__ .frozen :
502515 raise TypeError (
503516 "Unable to deterministically encode non-frozen '%s' of type '%s' "
504517 "for the input of '%s'" %
505518 (value , type (value ), self .requires_deterministic_step_label ))
506- self .encode_type (type (value ), stream )
507- values = [
508- getattr (value , field .name ) for field in dataclasses .fields (value )
509- ]
519+ init_fields = [field for field in dataclasses .fields (value ) if field .init ]
510520 try :
511- self .iterable_coder_impl .encode_to_stream (values , stream , True )
521+ if dataclass_uses_kw_only (type (value )):
522+ stream .write_byte (DATACLASS_KW_ONLY_TYPE )
523+ self .encode_type (type (value ), stream )
524+ stream .write_var_int64 (len (init_fields ))
525+ for field in init_fields :
526+ stream .write (field .name .encode ("utf-8" ), True )
527+ self .encode_to_stream (getattr (value , field .name ), stream , True )
528+ else : # Not using kw_only, we can pass parameters by position.
529+ stream .write_byte (DATACLASS_TYPE )
530+ self .encode_type (type (value ), stream )
531+ values = [getattr (value , field .name ) for field in init_fields ]
532+ self .iterable_coder_impl .encode_to_stream (values , stream , True )
512533 except Exception as e :
513534 raise TypeError (self ._deterministic_encoding_error_msg (value )) from e
514535 elif isinstance (value , tuple ) and hasattr (type (value ), '_fields' ):
@@ -616,6 +637,14 @@ def decode_from_stream(self, stream, nested):
616637 msg = cls ()
617638 msg .ParseFromString (stream .read_all (True ))
618639 return msg
640+ elif t == DATACLASS_KW_ONLY_TYPE :
641+ cls = self .decode_type (stream )
642+ vlen = stream .read_var_int64 ()
643+ fields = {}
644+ for _ in range (vlen ):
645+ field_name = stream .read_all (True ).decode ('utf-8' )
646+ fields [field_name ] = self .decode_from_stream (stream , True )
647+ return cls (** fields )
619648 elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE :
620649 cls = self .decode_type (stream )
621650 return cls (* self .iterable_coder_impl .decode_from_stream (stream , True ))
0 commit comments