5454import re
5555import shutil
5656import tempfile
57+ import threading
5758import unicodedata
5859import uuid
5960from collections import defaultdict
109110 from apache_beam .runners .runner import PipelineResult
110111 from apache_beam .transforms import environments
111112
112- __all__ = ['Pipeline' , 'PTransformOverride ' ]
113+ __all__ = ['Pipeline' , 'transform_annotations ' ]
113114
114115
115116class Pipeline (HasDisplayData ):
@@ -226,7 +227,9 @@ def __init__(
226227 self .runner = runner
227228 # Stack of transforms generated by nested apply() calls. The stack will
228229 # contain a root node as an enclosing (parent) node for top transforms.
229- self .transforms_stack = [AppliedPTransform (None , None , '' , None )]
230+ self .transforms_stack = [
231+ AppliedPTransform (None , None , '' , None , None , None )
232+ ]
230233 # Set of transform labels (full labels) applied to the pipeline.
231234 # If a transform is applied and the full label is already in the set
232235 # then the transform will have to be cloned with a new label.
@@ -244,6 +247,7 @@ def __init__(
244247
245248 self ._display_data = display_data or {}
246249 self ._error_handlers = []
250+ self ._annotations_stack = [{}]
247251
248252 def display_data (self ):
249253 # type: () -> Dict[str, Any]
@@ -268,6 +272,24 @@ def _current_transform(self):
268272 """Returns the transform currently on the top of the stack."""
269273 return self .transforms_stack [- 1 ]
270274
275+ @contextlib .contextmanager
276+ def transform_annotations (self , ** annotations ):
277+ """A context manager for attaching annotations to a set of transforms.
278+
279+ All transforms applied while this context is active will have these
280+ annotations attached. This includes sub-transforms applied within
281+ composite transforms.
282+ """
283+ self ._annotations_stack .append ({
284+ ** self ._annotations_stack [- 1 ], ** encode_annotations (annotations )
285+ })
286+ yield
287+ self ._annotations_stack .pop ()
288+
289+ def _current_annotations (self ):
290+ """Returns the set of annotations that should be used on apply."""
291+ return {** _global_annotations_stack ()[- 1 ], ** self ._annotations_stack [- 1 ]}
292+
271293 def _root_transform (self ):
272294 # type: () -> AppliedPTransform
273295
@@ -316,7 +338,9 @@ def _replace_if_needed(self, original_transform_node):
316338 original_transform_node .parent ,
317339 replacement_transform ,
318340 original_transform_node .full_label ,
319- original_transform_node .main_inputs )
341+ original_transform_node .main_inputs ,
342+ None ,
343+ annotations = original_transform_node .annotations )
320344
321345 # TODO(https://github.com/apache/beam/issues/21178): Merge rather
322346 # than override.
@@ -741,7 +765,12 @@ def apply(
741765 'returned %s from %s' % (transform , inputs , pvalueish ))
742766
743767 current = AppliedPTransform (
744- self ._current_transform (), transform , full_label , inputs )
768+ self ._current_transform (),
769+ transform ,
770+ full_label ,
771+ inputs ,
772+ None ,
773+ annotations = self ._current_annotations ())
745774 self ._current_transform ().add_part (current )
746775
747776 try :
@@ -1014,7 +1043,7 @@ def from_runner_api(
10141043 root_transform_id , = proto .root_transform_ids
10151044 p .transforms_stack = [context .transforms .get_by_id (root_transform_id )]
10161045 else :
1017- p .transforms_stack = [AppliedPTransform (None , None , '' , None )]
1046+ p .transforms_stack = [AppliedPTransform (None , None , '' , None , None , None )]
10181047 # TODO(robertwb): These are only needed to continue construction. Omit?
10191048 p .applied_labels = {
10201049 t .unique_name
@@ -1124,8 +1153,8 @@ def __init__(
11241153 transform , # type: Optional[ptransform.PTransform]
11251154 full_label , # type: str
11261155 main_inputs , # type: Optional[Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
1127- environment_id = None , # type: Optional[str]
1128- annotations = None , # type: Optional[Dict[str, bytes]]
1156+ environment_id , # type: Optional[str]
1157+ annotations , # type: Optional[Dict[str, bytes]]
11291158 ):
11301159 # type: (...) -> None
11311160 self .parent = parent
@@ -1149,24 +1178,11 @@ def __init__(
11491178 transform .get_resource_hints ()) if transform else {
11501179 } # type: Dict[str, bytes]
11511180
1152- if annotations is None and transform :
1153-
1154- def annotation_to_bytes (key , a : Any ) -> bytes :
1155- if isinstance (a , bytes ):
1156- return a
1157- elif isinstance (a , str ):
1158- return a .encode ('ascii' )
1159- elif isinstance (a , message .Message ):
1160- return a .SerializeToString ()
1161- else :
1162- raise TypeError (
1163- 'Unknown annotation type %r (type %s) for %s' % (a , type (a ), key ))
1164-
1181+ if transform :
11651182 annotations = {
1166- key : annotation_to_bytes (key , a )
1167- for key ,
1168- a in transform .annotations ().items ()
1183+ ** (annotations or {}), ** encode_annotations (transform .annotations ())
11691184 }
1185+
11701186 self .annotations = annotations
11711187
11721188 @property
@@ -1478,6 +1494,50 @@ def _merge_outer_resource_hints(self):
14781494 part ._merge_outer_resource_hints ()
14791495
14801496
1497+ def encode_annotations (annotations : Optional [Dict [str , Any ]]):
1498+ """Encodes non-byte annotation values as bytes."""
1499+ if not annotations :
1500+ return {}
1501+
1502+ def annotation_to_bytes (key , a : Any ) -> bytes :
1503+ if isinstance (a , bytes ):
1504+ return a
1505+ elif isinstance (a , str ):
1506+ return a .encode ('ascii' )
1507+ elif isinstance (a , message .Message ):
1508+ return a .SerializeToString ()
1509+ else :
1510+ raise TypeError (
1511+ 'Unknown annotation type %r (type %s) for %s' % (a , type (a ), key ))
1512+
1513+ return {key : annotation_to_bytes (key , a ) for (key , a ) in annotations .items ()}
1514+
1515+
1516+ _global_annotations_stack_data = threading .local ()
1517+
1518+
1519+ def _global_annotations_stack ():
1520+ try :
1521+ return _global_annotations_stack_data .stack
1522+ except AttributeError :
1523+ _global_annotations_stack_data .stack = [{}]
1524+ return _global_annotations_stack_data .stack
1525+
1526+
1527+ @contextlib .contextmanager
1528+ def transform_annotations (** annotations ):
1529+ """A context manager for attaching annotations to a set of transforms.
1530+
1531+ All transforms applied while this context is active will have these
1532+ annotations attached. This includes sub-transforms applied within
1533+ composite transforms.
1534+ """
1535+ cur_stack = _global_annotations_stack ()
1536+ cur_stack .append ({** cur_stack [- 1 ], ** encode_annotations (annotations )})
1537+ yield
1538+ cur_stack .pop ()
1539+
1540+
14811541class PTransformOverride (metaclass = abc .ABCMeta ):
14821542 """For internal use only; no backwards-compatibility guarantees.
14831543
0 commit comments