Skip to content

Commit 388e5e5

Browse files
authored
Supports Asynchronous Runs in Interactive Beam (#36853)
* Supports Asynchronous Runs in Interactive Beam * use PEP-585 generics * Skip some tests for non-interactve_env and fix errors in unit tests
1 parent 543056a commit 388e5e5

File tree

7 files changed

+1521
-19
lines changed

7 files changed

+1521
-19
lines changed

sdks/python/apache_beam/runners/interactive/interactive_beam.py

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@
3535
# pytype: skip-file
3636

3737
import logging
38+
from collections.abc import Iterable
3839
from datetime import timedelta
3940
from typing import Any
40-
from typing import Dict
41-
from typing import Iterable
42-
from typing import List
4341
from typing import Optional
4442
from typing import Union
4543

@@ -57,6 +55,7 @@
5755
from apache_beam.runners.interactive.display.pcoll_visualization import visualize
5856
from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll
5957
from apache_beam.runners.interactive.options import interactive_options
58+
from apache_beam.runners.interactive.recording_manager import AsyncComputationResult
6059
from apache_beam.runners.interactive.utils import deferred_df_to_pcollection
6160
from apache_beam.runners.interactive.utils import elements_to_df
6261
from apache_beam.runners.interactive.utils import find_pcoll_name
@@ -275,7 +274,7 @@ class Recordings():
275274
"""
276275
def describe(
277276
self,
278-
pipeline: Optional[beam.Pipeline] = None) -> Dict[str, Any]: # noqa: F821
277+
pipeline: Optional[beam.Pipeline] = None) -> dict[str, Any]: # noqa: F821
279278
"""Returns a description of all the recordings for the given pipeline.
280279
281280
If no pipeline is given then this returns a dictionary of descriptions for
@@ -417,10 +416,10 @@ class Clusters:
417416
# DATAPROC_IMAGE_VERSION = '2.0.XX-debian10'
418417

419418
def __init__(self) -> None:
420-
self.dataproc_cluster_managers: Dict[ClusterMetadata,
419+
self.dataproc_cluster_managers: dict[ClusterMetadata,
421420
DataprocClusterManager] = {}
422-
self.master_urls: Dict[str, ClusterMetadata] = {}
423-
self.pipelines: Dict[beam.Pipeline, DataprocClusterManager] = {}
421+
self.master_urls: dict[str, ClusterMetadata] = {}
422+
self.pipelines: dict[beam.Pipeline, DataprocClusterManager] = {}
424423
self.default_cluster_metadata: Optional[ClusterMetadata] = None
425424

426425
def create(
@@ -511,7 +510,7 @@ def cleanup(
511510
def describe(
512511
self,
513512
cluster_identifier: Optional[ClusterIdentifier] = None
514-
) -> Union[ClusterMetadata, List[ClusterMetadata]]:
513+
) -> Union[ClusterMetadata, list[ClusterMetadata]]:
515514
"""Describes the ClusterMetadata by a ClusterIdentifier.
516515
517516
If no cluster_identifier is given or if the cluster_identifier is unknown,
@@ -679,7 +678,7 @@ def run_pipeline(self):
679678

680679
@progress_indicated
681680
def show(
682-
*pcolls: Union[Dict[Any, PCollection], Iterable[PCollection], PCollection],
681+
*pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection],
683682
include_window_info: bool = False,
684683
visualize_data: bool = False,
685684
n: Union[int, str] = 'inf',
@@ -1012,6 +1011,88 @@ def as_pcollection(pcoll_or_df):
10121011
return result_tuple
10131012

10141013

1014+
@progress_indicated
1015+
def compute(
1016+
*pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection],
1017+
wait_for_inputs: bool = True,
1018+
blocking: bool = False,
1019+
runner=None,
1020+
options=None,
1021+
force_compute=False,
1022+
) -> Optional[AsyncComputationResult]:
1023+
"""Computes the given PCollections, potentially asynchronously.
1024+
1025+
Args:
1026+
*pcolls: PCollections to compute. Can be a single PCollection, an iterable
1027+
of PCollections, or a dictionary with PCollections as values.
1028+
wait_for_inputs: Whether to wait until the asynchronous dependencies are
1029+
computed. Setting this to False allows to immediately schedule the
1030+
computation, but also potentially results in running the same pipeline
1031+
stages multiple times.
1032+
blocking: If False, the computation will run in non-blocking fashion. In
1033+
Colab/IPython environment this mode will also provide the controls for the
1034+
running pipeline. If True, the computation will block until the pipeline
1035+
is done.
1036+
runner: (optional) the runner with which to compute the results.
1037+
options: (optional) any additional pipeline options to use to compute the
1038+
results.
1039+
force_compute: (optional) if True, forces recomputation rather than using
1040+
cached PCollections.
1041+
1042+
Returns:
1043+
An AsyncComputationResult object if blocking is False, otherwise None.
1044+
"""
1045+
flatten_pcolls = []
1046+
for pcoll_container in pcolls:
1047+
if isinstance(pcoll_container, dict):
1048+
flatten_pcolls.extend(pcoll_container.values())
1049+
elif isinstance(pcoll_container, (beam.pvalue.PCollection, DeferredBase)):
1050+
flatten_pcolls.append(pcoll_container)
1051+
else:
1052+
try:
1053+
flatten_pcolls.extend(iter(pcoll_container))
1054+
except TypeError:
1055+
raise ValueError(
1056+
f'The given pcoll {pcoll_container} is not a dict, an iterable or '
1057+
'a PCollection.')
1058+
1059+
pcolls_set = set()
1060+
for pcoll in flatten_pcolls:
1061+
if isinstance(pcoll, DeferredBase):
1062+
pcoll, _ = deferred_df_to_pcollection(pcoll)
1063+
watch({f'anonymous_pcollection_{id(pcoll)}': pcoll})
1064+
assert isinstance(
1065+
pcoll, beam.pvalue.PCollection
1066+
), f'{pcoll} is not an apache_beam.pvalue.PCollection.'
1067+
pcolls_set.add(pcoll)
1068+
1069+
if not pcolls_set:
1070+
_LOGGER.info('No PCollections to compute.')
1071+
return None
1072+
1073+
pcoll_pipeline = next(iter(pcolls_set)).pipeline
1074+
user_pipeline = ie.current_env().user_pipeline(pcoll_pipeline)
1075+
if not user_pipeline:
1076+
watch({f'anonymous_pipeline_{id(pcoll_pipeline)}': pcoll_pipeline})
1077+
user_pipeline = pcoll_pipeline
1078+
1079+
for pcoll in pcolls_set:
1080+
if pcoll.pipeline is not user_pipeline:
1081+
raise ValueError('All PCollections must belong to the same pipeline.')
1082+
1083+
recording_manager = ie.current_env().get_recording_manager(
1084+
user_pipeline, create_if_absent=True)
1085+
1086+
return recording_manager.compute_async(
1087+
pcolls_set,
1088+
wait_for_inputs=wait_for_inputs,
1089+
blocking=blocking,
1090+
runner=runner,
1091+
options=options,
1092+
force_compute=force_compute,
1093+
)
1094+
1095+
10151096
@progress_indicated
10161097
def show_graph(pipeline):
10171098
"""Shows the current pipeline shape of a given Beam pipeline as a DAG.

0 commit comments

Comments
 (0)