Skip to content

Commit b5884e8

Browse files
authored
Allow annotations to be attached to transforms via a context. (#33319)
This API is offered both on the pipeline object itself, and also as a (thread-local) top-level function as one many not always have an easy reference to the pipeline.
1 parent 35f9392 commit b5884e8

File tree

4 files changed

+138
-29
lines changed

4 files changed

+138
-29
lines changed

sdks/python/apache_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
from apache_beam import metrics
9090
from apache_beam import typehints
9191
from apache_beam import version
92-
from apache_beam.pipeline import Pipeline
92+
from apache_beam.pipeline import *
9393
from apache_beam.transforms import *
9494
from apache_beam.pvalue import PCollection
9595
from apache_beam.pvalue import Row

sdks/python/apache_beam/pipeline.py

Lines changed: 83 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import re
5555
import shutil
5656
import tempfile
57+
import threading
5758
import unicodedata
5859
import uuid
5960
from collections import defaultdict
@@ -109,7 +110,7 @@
109110
from apache_beam.runners.runner import PipelineResult
110111
from apache_beam.transforms import environments
111112

112-
__all__ = ['Pipeline', 'PTransformOverride']
113+
__all__ = ['Pipeline', 'transform_annotations']
113114

114115

115116
class Pipeline(HasDisplayData):
@@ -226,7 +227,9 @@ def __init__(
226227
self.runner = runner
227228
# Stack of transforms generated by nested apply() calls. The stack will
228229
# contain a root node as an enclosing (parent) node for top transforms.
229-
self.transforms_stack = [AppliedPTransform(None, None, '', None)]
230+
self.transforms_stack = [
231+
AppliedPTransform(None, None, '', None, None, None)
232+
]
230233
# Set of transform labels (full labels) applied to the pipeline.
231234
# If a transform is applied and the full label is already in the set
232235
# then the transform will have to be cloned with a new label.
@@ -244,6 +247,7 @@ def __init__(
244247

245248
self._display_data = display_data or {}
246249
self._error_handlers = []
250+
self._annotations_stack = [{}]
247251

248252
def display_data(self):
249253
# type: () -> Dict[str, Any]
@@ -268,6 +272,24 @@ def _current_transform(self):
268272
"""Returns the transform currently on the top of the stack."""
269273
return self.transforms_stack[-1]
270274

275+
@contextlib.contextmanager
276+
def transform_annotations(self, **annotations):
277+
"""A context manager for attaching annotations to a set of transforms.
278+
279+
All transforms applied while this context is active will have these
280+
annotations attached. This includes sub-transforms applied within
281+
composite transforms.
282+
"""
283+
self._annotations_stack.append({
284+
**self._annotations_stack[-1], **encode_annotations(annotations)
285+
})
286+
yield
287+
self._annotations_stack.pop()
288+
289+
def _current_annotations(self):
290+
"""Returns the set of annotations that should be used on apply."""
291+
return {**_global_annotations_stack()[-1], **self._annotations_stack[-1]}
292+
271293
def _root_transform(self):
272294
# type: () -> AppliedPTransform
273295

@@ -316,7 +338,9 @@ def _replace_if_needed(self, original_transform_node):
316338
original_transform_node.parent,
317339
replacement_transform,
318340
original_transform_node.full_label,
319-
original_transform_node.main_inputs)
341+
original_transform_node.main_inputs,
342+
None,
343+
annotations=original_transform_node.annotations)
320344

321345
# TODO(https://github.com/apache/beam/issues/21178): Merge rather
322346
# than override.
@@ -741,7 +765,12 @@ def apply(
741765
'returned %s from %s' % (transform, inputs, pvalueish))
742766

743767
current = AppliedPTransform(
744-
self._current_transform(), transform, full_label, inputs)
768+
self._current_transform(),
769+
transform,
770+
full_label,
771+
inputs,
772+
None,
773+
annotations=self._current_annotations())
745774
self._current_transform().add_part(current)
746775

747776
try:
@@ -1014,7 +1043,7 @@ def from_runner_api(
10141043
root_transform_id, = proto.root_transform_ids
10151044
p.transforms_stack = [context.transforms.get_by_id(root_transform_id)]
10161045
else:
1017-
p.transforms_stack = [AppliedPTransform(None, None, '', None)]
1046+
p.transforms_stack = [AppliedPTransform(None, None, '', None, None, None)]
10181047
# TODO(robertwb): These are only needed to continue construction. Omit?
10191048
p.applied_labels = {
10201049
t.unique_name
@@ -1124,8 +1153,8 @@ def __init__(
11241153
transform, # type: Optional[ptransform.PTransform]
11251154
full_label, # type: str
11261155
main_inputs, # type: Optional[Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
1127-
environment_id=None, # type: Optional[str]
1128-
annotations=None, # type: Optional[Dict[str, bytes]]
1156+
environment_id, # type: Optional[str]
1157+
annotations, # type: Optional[Dict[str, bytes]]
11291158
):
11301159
# type: (...) -> None
11311160
self.parent = parent
@@ -1149,24 +1178,11 @@ def __init__(
11491178
transform.get_resource_hints()) if transform else {
11501179
} # type: Dict[str, bytes]
11511180

1152-
if annotations is None and transform:
1153-
1154-
def annotation_to_bytes(key, a: Any) -> bytes:
1155-
if isinstance(a, bytes):
1156-
return a
1157-
elif isinstance(a, str):
1158-
return a.encode('ascii')
1159-
elif isinstance(a, message.Message):
1160-
return a.SerializeToString()
1161-
else:
1162-
raise TypeError(
1163-
'Unknown annotation type %r (type %s) for %s' % (a, type(a), key))
1164-
1181+
if transform:
11651182
annotations = {
1166-
key: annotation_to_bytes(key, a)
1167-
for key,
1168-
a in transform.annotations().items()
1183+
**(annotations or {}), **encode_annotations(transform.annotations())
11691184
}
1185+
11701186
self.annotations = annotations
11711187

11721188
@property
@@ -1478,6 +1494,50 @@ def _merge_outer_resource_hints(self):
14781494
part._merge_outer_resource_hints()
14791495

14801496

1497+
def encode_annotations(annotations: Optional[Dict[str, Any]]):
1498+
"""Encodes non-byte annotation values as bytes."""
1499+
if not annotations:
1500+
return {}
1501+
1502+
def annotation_to_bytes(key, a: Any) -> bytes:
1503+
if isinstance(a, bytes):
1504+
return a
1505+
elif isinstance(a, str):
1506+
return a.encode('ascii')
1507+
elif isinstance(a, message.Message):
1508+
return a.SerializeToString()
1509+
else:
1510+
raise TypeError(
1511+
'Unknown annotation type %r (type %s) for %s' % (a, type(a), key))
1512+
1513+
return {key: annotation_to_bytes(key, a) for (key, a) in annotations.items()}
1514+
1515+
1516+
_global_annotations_stack_data = threading.local()
1517+
1518+
1519+
def _global_annotations_stack():
1520+
try:
1521+
return _global_annotations_stack_data.stack
1522+
except AttributeError:
1523+
_global_annotations_stack_data.stack = [{}]
1524+
return _global_annotations_stack_data.stack
1525+
1526+
1527+
@contextlib.contextmanager
1528+
def transform_annotations(**annotations):
1529+
"""A context manager for attaching annotations to a set of transforms.
1530+
1531+
All transforms applied while this context is active will have these
1532+
annotations attached. This includes sub-transforms applied within
1533+
composite transforms.
1534+
"""
1535+
cur_stack = _global_annotations_stack()
1536+
cur_stack.append({**cur_stack[-1], **encode_annotations(annotations)})
1537+
yield
1538+
cur_stack.pop()
1539+
1540+
14811541
class PTransformOverride(metaclass=abc.ABCMeta):
14821542
"""For internal use only; no backwards-compatibility guarantees.
14831543

sdks/python/apache_beam/pipeline_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,51 @@ def annotations(self):
10161016
transform.annotations['proto'], some_proto.SerializeToString())
10171017
self.assertEqual(seen, 2)
10181018

1019+
def assertHasAnnotation(self, pipeline_proto, transform, key, value):
1020+
for transform_proto in pipeline_proto.components.transforms.values():
1021+
if transform_proto.unique_name == transform:
1022+
self.assertIn(key, transform_proto.annotations.keys())
1023+
self.assertEqual(transform_proto.annotations[key], value)
1024+
break
1025+
else:
1026+
self.fail(
1027+
"Unknown transform: %r not in %s" % (
1028+
transform,
1029+
sorted([
1030+
t.unique_name
1031+
for t in pipeline_proto.components.transforms.values()
1032+
])))
1033+
1034+
def test_pipeline_context_annotations(self):
1035+
p = beam.Pipeline()
1036+
with p.transform_annotations(foo='first'):
1037+
pcoll = p | beam.Create([1, 2, 3]) | 'First' >> beam.Map(lambda x: x + 1)
1038+
with p.transform_annotations(foo='second'):
1039+
_ = pcoll | 'Second' >> beam.Map(lambda x: x * 2)
1040+
with p.transform_annotations(foo='nested', another='more'):
1041+
_ = pcoll | 'Nested' >> beam.Map(lambda x: x * 3)
1042+
1043+
proto = p.to_runner_api()
1044+
self.assertHasAnnotation(proto, 'First', 'foo', b'first')
1045+
self.assertHasAnnotation(proto, 'Second', 'foo', b'second')
1046+
self.assertHasAnnotation(proto, 'Nested', 'foo', b'nested')
1047+
self.assertHasAnnotation(proto, 'Nested', 'another', b'more')
1048+
1049+
def test_beam_context_annotations(self):
1050+
p = beam.Pipeline()
1051+
with beam.transform_annotations(foo='first'):
1052+
pcoll = p | beam.Create([1, 2, 3]) | 'First' >> beam.Map(lambda x: x + 1)
1053+
with beam.transform_annotations(foo='second'):
1054+
_ = pcoll | 'Second' >> beam.Map(lambda x: x * 2)
1055+
with beam.transform_annotations(foo='nested', another='more'):
1056+
_ = pcoll | 'Nested' >> beam.Map(lambda x: x * 3)
1057+
1058+
proto = p.to_runner_api()
1059+
self.assertHasAnnotation(proto, 'First', 'foo', b'first')
1060+
self.assertHasAnnotation(proto, 'Second', 'foo', b'second')
1061+
self.assertHasAnnotation(proto, 'Nested', 'foo', b'nested')
1062+
self.assertHasAnnotation(proto, 'Nested', 'another', b'more')
1063+
10191064
def test_transform_ids(self):
10201065
class MyPTransform(beam.PTransform):
10211066
def expand(self, p):

sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def test_group_by_key_input_visitor_with_valid_inputs(self):
272272
pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any]
273273
for pcoll in [pcoll1, pcoll2, pcoll3]:
274274
applied = AppliedPTransform(
275-
None, beam.GroupByKey(), "label", {'pcoll': pcoll})
275+
None, beam.GroupByKey(), "label", {'pcoll': pcoll}, None, None)
276276
applied.outputs[None] = PCollection(None)
277277
common.group_by_key_input_visitor().visit_transform(applied)
278278
self.assertEqual(
@@ -291,15 +291,17 @@ def test_group_by_key_input_visitor_with_invalid_inputs(self):
291291
for pcoll in [pcoll1, pcoll2]:
292292
with self.assertRaisesRegex(ValueError, err_msg):
293293
common.group_by_key_input_visitor().visit_transform(
294-
AppliedPTransform(None, beam.GroupByKey(), "label", {'in': pcoll}))
294+
AppliedPTransform(
295+
None, beam.GroupByKey(), "label", {'in': pcoll}, None, None))
295296

296297
def test_group_by_key_input_visitor_for_non_gbk_transforms(self):
297298
p = TestPipeline()
298299
pcoll = PCollection(p)
299300
for transform in [beam.Flatten(), beam.Map(lambda x: x)]:
300301
pcoll.element_type = typehints.Any
301302
common.group_by_key_input_visitor().visit_transform(
302-
AppliedPTransform(None, transform, "label", {'in': pcoll}))
303+
AppliedPTransform(
304+
None, transform, "label", {'in': pcoll}, None, None))
303305
self.assertEqual(pcoll.element_type, typehints.Any)
304306

305307
def test_flatten_input_with_visitor_with_single_input(self):
@@ -319,7 +321,8 @@ def _test_flatten_input_visitor(self, input_type, output_type, num_inputs):
319321
output_pcoll = PCollection(p)
320322
output_pcoll.element_type = output_type
321323

322-
flatten = AppliedPTransform(None, beam.Flatten(), "label", inputs)
324+
flatten = AppliedPTransform(
325+
None, beam.Flatten(), "label", inputs, None, None)
323326
flatten.add_output(output_pcoll, None)
324327
DataflowRunner.flatten_input_visitor().visit_transform(flatten)
325328
for _ in range(num_inputs):
@@ -357,7 +360,8 @@ def test_side_input_visitor(self):
357360
z: (x, y, z),
358361
beam.pvalue.AsSingleton(pc),
359362
beam.pvalue.AsMultiMap(pc))
360-
applied_transform = AppliedPTransform(None, transform, "label", {'pc': pc})
363+
applied_transform = AppliedPTransform(
364+
None, transform, "label", {'pc': pc}, None, None)
361365
DataflowRunner.side_input_visitor().visit_transform(applied_transform)
362366
self.assertEqual(2, len(applied_transform.side_inputs))
363367
self.assertEqual(

0 commit comments

Comments
 (0)