Skip to content

Commit 8383902

Browse files
committed
Track custom_urn set in logical type registry
1 parent fec392d commit 8383902

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

sdks/python/apache_beam/internal/cloudpickle_pickler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def dump_session(file_path):
259259

260260
with _pickle_lock, open(file_path, 'wb') as file:
261261
coder_reg = typecoders.registry.get_custom_type_coder_tuples()
262-
logical_type_reg = schemas.LogicalType._known_logical_types.copy()
262+
logical_type_reg = schemas.LogicalType._known_logical_types.copy_custom()
263263

264264
pickler = cloudpickle.CloudPickler(file)
265265
# TODO(https://github.com/apache/beam/issues/18500) add file system registry

sdks/python/apache_beam/typehints/schemas.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -684,12 +684,17 @@ def __init__(self):
684684
self.by_urn = {}
685685
self.by_logical_type = {}
686686
self.by_language_type = {}
687+
self._custom_urns = set()
687688

688-
def add(self, urn, logical_type):
689+
def _add_internal(self, urn, logical_type):
689690
self.by_urn[urn] = logical_type
690691
self.by_logical_type[logical_type] = urn
691692
self.by_language_type[logical_type.language_type()] = logical_type
692693

694+
def add(self, urn, logical_type):
695+
self._add_internal(urn, logical_type)
696+
self._custom_urns.add(urn)
697+
693698
def get_logical_type_by_urn(self, urn):
694699
return self.by_urn.get(urn, None)
695700

@@ -704,12 +709,23 @@ def copy(self):
704709
copy.by_urn.update(self.by_urn)
705710
copy.by_logical_type.update(self.by_logical_type)
706711
copy.by_language_type.update(self.by_language_type)
712+
copy._custom_urns.update(self._custom_urns)
707713
return copy
708714

715+
def copy_custom(self):
716+
copy = LogicalTypeRegistry()
717+
for urn in self._custom_urns:
718+
logical_type = self.by_urn[urn]
719+
copy.by_urn[urn] = logical_type
720+
copy.by_logical_type[logical_type] = urn
721+
copy.by_language_type[logical_type.language_type()] = logical_type
722+
copy._custom_urns.add(urn)
723+
709724
def load(self, another):
710725
self.by_urn.update(another.by_urn)
711726
self.by_logical_type.update(another.by_logical_type)
712727
self.by_language_type.update(another.by_language_type)
728+
self._custom_urns.update(another._custom_urns)
713729

714730

715731
LanguageT = TypeVar('LanguageT')
@@ -773,6 +789,19 @@ def to_language_type(self, value):
773789
"""Convert an instance of RepresentationT to LanguageT."""
774790
raise NotImplementedError()
775791

792+
@classmethod
793+
def _register_internal(cls, logical_type_cls):
794+
"""
795+
Register an implementation of LogicalType.
796+
797+
The types registered using this decorator are not pickled on pipeline
798+
submission, as it relies module import to be registered on worker
799+
initialization. Should be used within schemas module and static context.
800+
"""
801+
cls._known_logical_types._add_internal(
802+
logical_type_cls.urn(), logical_type_cls)
803+
return logical_type_cls
804+
776805
@classmethod
777806
def register_logical_type(cls, logical_type_cls):
778807
"""Register an implementation of LogicalType."""
@@ -889,7 +918,7 @@ def _from_typing(cls, typ):
889918
('micros', np.int64)])
890919

891920

892-
@LogicalType.register_logical_type
921+
@LogicalType._register_internal
893922
class MillisInstant(NoArgumentLogicalType[Timestamp, np.int64]):
894923
"""Millisecond-precision instant logical type handles values consistent with
895924
that encoded by ``InstantCoder`` in the Java SDK.
@@ -933,7 +962,7 @@ def to_language_type(self, value):
933962
# Make sure MicrosInstant is registered after MillisInstant so that it
934963
# overwrites the mapping of Timestamp language type representation choice and
935964
# thus does not lose microsecond precision inside python sdk.
936-
@LogicalType.register_logical_type
965+
@LogicalType._register_internal
937966
class MicrosInstant(NoArgumentLogicalType[Timestamp,
938967
MicrosInstantRepresentation]):
939968
"""Microsecond-precision instant logical type that handles ``Timestamp``."""
@@ -960,7 +989,7 @@ def to_language_type(self, value):
960989
return Timestamp(seconds=int(value.seconds), micros=int(value.micros))
961990

962991

963-
@LogicalType.register_logical_type
992+
@LogicalType._register_internal
964993
class PythonCallable(NoArgumentLogicalType[PythonCallableWithSource, str]):
965994
"""A logical type for PythonCallableSource objects."""
966995
@classmethod
@@ -1016,7 +1045,7 @@ def to_language_type(self, value):
10161045
return decimal.Decimal(value.decode())
10171046

10181047

1019-
@LogicalType.register_logical_type
1048+
@LogicalType._register_internal
10201049
class FixedPrecisionDecimalLogicalType(
10211050
LogicalType[decimal.Decimal,
10221051
DecimalLogicalType,
@@ -1068,10 +1097,10 @@ def _from_typing(cls, typ):
10681097

10691098
# TODO(yathu,BEAM-10722): Investigate and resolve conflicts in logical type
10701099
# registration when more than one logical types sharing the same language type
1071-
LogicalType.register_logical_type(DecimalLogicalType)
1100+
LogicalType._register_internal(DecimalLogicalType)
10721101

10731102

1074-
@LogicalType.register_logical_type
1103+
@LogicalType._register_internal
10751104
class FixedBytes(PassThroughLogicalType[bytes, np.int32]):
10761105
"""A logical type for fixed-length bytes."""
10771106
@classmethod
@@ -1104,7 +1133,7 @@ def argument(self):
11041133
return self.length
11051134

11061135

1107-
@LogicalType.register_logical_type
1136+
@LogicalType._register_internal
11081137
class VariableBytes(PassThroughLogicalType[bytes, np.int32]):
11091138
"""A logical type for variable-length bytes with specified maximum length."""
11101139
@classmethod
@@ -1134,7 +1163,7 @@ def argument(self):
11341163
return self.max_length
11351164

11361165

1137-
@LogicalType.register_logical_type
1166+
@LogicalType._register_internal
11381167
class FixedString(PassThroughLogicalType[str, np.int32]):
11391168
"""A logical type for fixed-length string."""
11401169
@classmethod
@@ -1167,7 +1196,7 @@ def argument(self):
11671196
return self.length
11681197

11691198

1170-
@LogicalType.register_logical_type
1199+
@LogicalType._register_internal
11711200
class VariableString(PassThroughLogicalType[str, np.int32]):
11721201
"""A logical type for variable-length string with specified maximum length."""
11731202
@classmethod
@@ -1200,7 +1229,7 @@ def argument(self):
12001229
# TODO: A temporary fix for missing jdbc logical types.
12011230
# See the discussion in https://github.com/apache/beam/issues/35738 for
12021231
# more detail.
1203-
@LogicalType.register_logical_type
1232+
@LogicalType._register_internal
12041233
class JdbcDateType(LogicalType[datetime.date, MillisInstant, str]):
12051234
"""
12061235
For internal use only; no backwards-compatibility guarantees.
@@ -1243,7 +1272,7 @@ def _from_typing(cls, typ):
12431272
return cls()
12441273

12451274

1246-
@LogicalType.register_logical_type
1275+
@LogicalType._register_internal
12471276
class JdbcTimeType(LogicalType[datetime.time, MillisInstant, str]):
12481277
"""
12491278
For internal use only; no backwards-compatibility guarantees.

0 commit comments

Comments
 (0)