@@ -684,12 +684,17 @@ def __init__(self):
684684 self .by_urn = {}
685685 self .by_logical_type = {}
686686 self .by_language_type = {}
687+ self ._custom_urns = set ()
687688
688- def add (self , urn , logical_type ):
689+ def _add_internal (self , urn , logical_type ):
689690 self .by_urn [urn ] = logical_type
690691 self .by_logical_type [logical_type ] = urn
691692 self .by_language_type [logical_type .language_type ()] = logical_type
692693
694+ def add (self , urn , logical_type ):
695+ self ._add_internal (urn , logical_type )
696+ self ._custom_urns .add (urn )
697+
693698 def get_logical_type_by_urn (self , urn ):
694699 return self .by_urn .get (urn , None )
695700
@@ -704,12 +709,23 @@ def copy(self):
704709 copy .by_urn .update (self .by_urn )
705710 copy .by_logical_type .update (self .by_logical_type )
706711 copy .by_language_type .update (self .by_language_type )
712+ copy ._custom_urns .update (self ._custom_urns )
707713 return copy
708714
715+ def copy_custom (self ):
716+ copy = LogicalTypeRegistry ()
717+ for urn in self ._custom_urns :
718+ logical_type = self .by_urn [urn ]
719+ copy .by_urn [urn ] = logical_type
720+ copy .by_logical_type [logical_type ] = urn
721+ copy .by_language_type [logical_type .language_type ()] = logical_type
722+ copy ._custom_urns .add (urn )
723+
709724 def load (self , another ):
710725 self .by_urn .update (another .by_urn )
711726 self .by_logical_type .update (another .by_logical_type )
712727 self .by_language_type .update (another .by_language_type )
728+ self ._custom_urns .update (another ._custom_urns )
713729
714730
715731LanguageT = TypeVar ('LanguageT' )
@@ -773,6 +789,19 @@ def to_language_type(self, value):
773789 """Convert an instance of RepresentationT to LanguageT."""
774790 raise NotImplementedError ()
775791
792+ @classmethod
793+ def _register_internal (cls , logical_type_cls ):
794+ """
795+ Register an implementation of LogicalType.
796+
797+ The types registered using this decorator are not pickled on pipeline
798+ submission, as it relies module import to be registered on worker
799+ initialization. Should be used within schemas module and static context.
800+ """
801+ cls ._known_logical_types ._add_internal (
802+ logical_type_cls .urn (), logical_type_cls )
803+ return logical_type_cls
804+
776805 @classmethod
777806 def register_logical_type (cls , logical_type_cls ):
778807 """Register an implementation of LogicalType."""
@@ -889,7 +918,7 @@ def _from_typing(cls, typ):
889918 ('micros' , np .int64 )])
890919
891920
892- @LogicalType .register_logical_type
921+ @LogicalType ._register_internal
893922class MillisInstant (NoArgumentLogicalType [Timestamp , np .int64 ]):
894923 """Millisecond-precision instant logical type handles values consistent with
895924 that encoded by ``InstantCoder`` in the Java SDK.
@@ -933,7 +962,7 @@ def to_language_type(self, value):
933962# Make sure MicrosInstant is registered after MillisInstant so that it
934963# overwrites the mapping of Timestamp language type representation choice and
935964# thus does not lose microsecond precision inside python sdk.
936- @LogicalType .register_logical_type
965+ @LogicalType ._register_internal
937966class MicrosInstant (NoArgumentLogicalType [Timestamp ,
938967 MicrosInstantRepresentation ]):
939968 """Microsecond-precision instant logical type that handles ``Timestamp``."""
@@ -960,7 +989,7 @@ def to_language_type(self, value):
960989 return Timestamp (seconds = int (value .seconds ), micros = int (value .micros ))
961990
962991
963- @LogicalType .register_logical_type
992+ @LogicalType ._register_internal
964993class PythonCallable (NoArgumentLogicalType [PythonCallableWithSource , str ]):
965994 """A logical type for PythonCallableSource objects."""
966995 @classmethod
@@ -1016,7 +1045,7 @@ def to_language_type(self, value):
10161045 return decimal .Decimal (value .decode ())
10171046
10181047
1019- @LogicalType .register_logical_type
1048+ @LogicalType ._register_internal
10201049class FixedPrecisionDecimalLogicalType (
10211050 LogicalType [decimal .Decimal ,
10221051 DecimalLogicalType ,
@@ -1068,10 +1097,10 @@ def _from_typing(cls, typ):
10681097
10691098# TODO(yathu,BEAM-10722): Investigate and resolve conflicts in logical type
10701099# registration when more than one logical types sharing the same language type
1071- LogicalType .register_logical_type (DecimalLogicalType )
1100+ LogicalType ._register_internal (DecimalLogicalType )
10721101
10731102
1074- @LogicalType .register_logical_type
1103+ @LogicalType ._register_internal
10751104class FixedBytes (PassThroughLogicalType [bytes , np .int32 ]):
10761105 """A logical type for fixed-length bytes."""
10771106 @classmethod
@@ -1104,7 +1133,7 @@ def argument(self):
11041133 return self .length
11051134
11061135
1107- @LogicalType .register_logical_type
1136+ @LogicalType ._register_internal
11081137class VariableBytes (PassThroughLogicalType [bytes , np .int32 ]):
11091138 """A logical type for variable-length bytes with specified maximum length."""
11101139 @classmethod
@@ -1134,7 +1163,7 @@ def argument(self):
11341163 return self .max_length
11351164
11361165
1137- @LogicalType .register_logical_type
1166+ @LogicalType ._register_internal
11381167class FixedString (PassThroughLogicalType [str , np .int32 ]):
11391168 """A logical type for fixed-length string."""
11401169 @classmethod
@@ -1167,7 +1196,7 @@ def argument(self):
11671196 return self .length
11681197
11691198
1170- @LogicalType .register_logical_type
1199+ @LogicalType ._register_internal
11711200class VariableString (PassThroughLogicalType [str , np .int32 ]):
11721201 """A logical type for variable-length string with specified maximum length."""
11731202 @classmethod
@@ -1200,7 +1229,7 @@ def argument(self):
12001229# TODO: A temporary fix for missing jdbc logical types.
12011230# See the discussion in https://github.com/apache/beam/issues/35738 for
12021231# more detail.
1203- @LogicalType .register_logical_type
1232+ @LogicalType ._register_internal
12041233class JdbcDateType (LogicalType [datetime .date , MillisInstant , str ]):
12051234 """
12061235 For internal use only; no backwards-compatibility guarantees.
@@ -1243,7 +1272,7 @@ def _from_typing(cls, typ):
12431272 return cls ()
12441273
12451274
1246- @LogicalType .register_logical_type
1275+ @LogicalType ._register_internal
12471276class JdbcTimeType (LogicalType [datetime .time , MillisInstant , str ]):
12481277 """
12491278 For internal use only; no backwards-compatibility guarantees.
0 commit comments