|
35 | 35 | # pytype: skip-file |
36 | 36 |
|
37 | 37 | import logging |
| 38 | +from collections.abc import Iterable |
38 | 39 | from datetime import timedelta |
39 | 40 | from typing import Any |
40 | | -from typing import Dict |
41 | | -from typing import Iterable |
42 | | -from typing import List |
43 | 41 | from typing import Optional |
44 | 42 | from typing import Union |
45 | 43 |
|
|
57 | 55 | from apache_beam.runners.interactive.display.pcoll_visualization import visualize |
58 | 56 | from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll |
59 | 57 | from apache_beam.runners.interactive.options import interactive_options |
| 58 | +from apache_beam.runners.interactive.recording_manager import AsyncComputationResult |
60 | 59 | from apache_beam.runners.interactive.utils import deferred_df_to_pcollection |
61 | 60 | from apache_beam.runners.interactive.utils import elements_to_df |
62 | 61 | from apache_beam.runners.interactive.utils import find_pcoll_name |
@@ -275,7 +274,7 @@ class Recordings(): |
275 | 274 | """ |
276 | 275 | def describe( |
277 | 276 | self, |
278 | | - pipeline: Optional[beam.Pipeline] = None) -> Dict[str, Any]: # noqa: F821 |
| 277 | + pipeline: Optional[beam.Pipeline] = None) -> dict[str, Any]: # noqa: F821 |
279 | 278 | """Returns a description of all the recordings for the given pipeline. |
280 | 279 |
|
281 | 280 | If no pipeline is given then this returns a dictionary of descriptions for |
@@ -417,10 +416,10 @@ class Clusters: |
417 | 416 | # DATAPROC_IMAGE_VERSION = '2.0.XX-debian10' |
418 | 417 |
|
419 | 418 | def __init__(self) -> None: |
420 | | - self.dataproc_cluster_managers: Dict[ClusterMetadata, |
| 419 | + self.dataproc_cluster_managers: dict[ClusterMetadata, |
421 | 420 | DataprocClusterManager] = {} |
422 | | - self.master_urls: Dict[str, ClusterMetadata] = {} |
423 | | - self.pipelines: Dict[beam.Pipeline, DataprocClusterManager] = {} |
| 421 | + self.master_urls: dict[str, ClusterMetadata] = {} |
| 422 | + self.pipelines: dict[beam.Pipeline, DataprocClusterManager] = {} |
424 | 423 | self.default_cluster_metadata: Optional[ClusterMetadata] = None |
425 | 424 |
|
426 | 425 | def create( |
@@ -511,7 +510,7 @@ def cleanup( |
511 | 510 | def describe( |
512 | 511 | self, |
513 | 512 | cluster_identifier: Optional[ClusterIdentifier] = None |
514 | | - ) -> Union[ClusterMetadata, List[ClusterMetadata]]: |
| 513 | + ) -> Union[ClusterMetadata, list[ClusterMetadata]]: |
515 | 514 | """Describes the ClusterMetadata by a ClusterIdentifier. |
516 | 515 |
|
517 | 516 | If no cluster_identifier is given or if the cluster_identifier is unknown, |
@@ -679,7 +678,7 @@ def run_pipeline(self): |
679 | 678 |
|
680 | 679 | @progress_indicated |
681 | 680 | def show( |
682 | | - *pcolls: Union[Dict[Any, PCollection], Iterable[PCollection], PCollection], |
| 681 | + *pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection], |
683 | 682 | include_window_info: bool = False, |
684 | 683 | visualize_data: bool = False, |
685 | 684 | n: Union[int, str] = 'inf', |
@@ -1012,6 +1011,88 @@ def as_pcollection(pcoll_or_df): |
1012 | 1011 | return result_tuple |
1013 | 1012 |
|
1014 | 1013 |
|
| 1014 | +@progress_indicated |
| 1015 | +def compute( |
| 1016 | + *pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection], |
| 1017 | + wait_for_inputs: bool = True, |
| 1018 | + blocking: bool = False, |
| 1019 | + runner=None, |
| 1020 | + options=None, |
| 1021 | + force_compute=False, |
| 1022 | +) -> Optional[AsyncComputationResult]: |
| 1023 | + """Computes the given PCollections, potentially asynchronously. |
| 1024 | +
|
| 1025 | + Args: |
| 1026 | + *pcolls: PCollections to compute. Can be a single PCollection, an iterable |
| 1027 | + of PCollections, or a dictionary with PCollections as values. |
| 1028 | + wait_for_inputs: Whether to wait until the asynchronous dependencies are |
| 1029 | + computed. Setting this to False allows to immediately schedule the |
| 1030 | + computation, but also potentially results in running the same pipeline |
| 1031 | + stages multiple times. |
| 1032 | + blocking: If False, the computation will run in non-blocking fashion. In |
| 1033 | + Colab/IPython environment this mode will also provide the controls for the |
| 1034 | + running pipeline. If True, the computation will block until the pipeline |
| 1035 | + is done. |
| 1036 | + runner: (optional) the runner with which to compute the results. |
| 1037 | + options: (optional) any additional pipeline options to use to compute the |
| 1038 | + results. |
| 1039 | + force_compute: (optional) if True, forces recomputation rather than using |
| 1040 | + cached PCollections. |
| 1041 | +
|
| 1042 | + Returns: |
| 1043 | + An AsyncComputationResult object if blocking is False, otherwise None. |
| 1044 | + """ |
| 1045 | + flatten_pcolls = [] |
| 1046 | + for pcoll_container in pcolls: |
| 1047 | + if isinstance(pcoll_container, dict): |
| 1048 | + flatten_pcolls.extend(pcoll_container.values()) |
| 1049 | + elif isinstance(pcoll_container, (beam.pvalue.PCollection, DeferredBase)): |
| 1050 | + flatten_pcolls.append(pcoll_container) |
| 1051 | + else: |
| 1052 | + try: |
| 1053 | + flatten_pcolls.extend(iter(pcoll_container)) |
| 1054 | + except TypeError: |
| 1055 | + raise ValueError( |
| 1056 | + f'The given pcoll {pcoll_container} is not a dict, an iterable or ' |
| 1057 | + 'a PCollection.') |
| 1058 | + |
| 1059 | + pcolls_set = set() |
| 1060 | + for pcoll in flatten_pcolls: |
| 1061 | + if isinstance(pcoll, DeferredBase): |
| 1062 | + pcoll, _ = deferred_df_to_pcollection(pcoll) |
| 1063 | + watch({f'anonymous_pcollection_{id(pcoll)}': pcoll}) |
| 1064 | + assert isinstance( |
| 1065 | + pcoll, beam.pvalue.PCollection |
| 1066 | + ), f'{pcoll} is not an apache_beam.pvalue.PCollection.' |
| 1067 | + pcolls_set.add(pcoll) |
| 1068 | + |
| 1069 | + if not pcolls_set: |
| 1070 | + _LOGGER.info('No PCollections to compute.') |
| 1071 | + return None |
| 1072 | + |
| 1073 | + pcoll_pipeline = next(iter(pcolls_set)).pipeline |
| 1074 | + user_pipeline = ie.current_env().user_pipeline(pcoll_pipeline) |
| 1075 | + if not user_pipeline: |
| 1076 | + watch({f'anonymous_pipeline_{id(pcoll_pipeline)}': pcoll_pipeline}) |
| 1077 | + user_pipeline = pcoll_pipeline |
| 1078 | + |
| 1079 | + for pcoll in pcolls_set: |
| 1080 | + if pcoll.pipeline is not user_pipeline: |
| 1081 | + raise ValueError('All PCollections must belong to the same pipeline.') |
| 1082 | + |
| 1083 | + recording_manager = ie.current_env().get_recording_manager( |
| 1084 | + user_pipeline, create_if_absent=True) |
| 1085 | + |
| 1086 | + return recording_manager.compute_async( |
| 1087 | + pcolls_set, |
| 1088 | + wait_for_inputs=wait_for_inputs, |
| 1089 | + blocking=blocking, |
| 1090 | + runner=runner, |
| 1091 | + options=options, |
| 1092 | + force_compute=force_compute, |
| 1093 | + ) |
| 1094 | + |
| 1095 | + |
1015 | 1096 | @progress_indicated |
1016 | 1097 | def show_graph(pipeline): |
1017 | 1098 | """Shows the current pipeline shape of a given Beam pipeline as a DAG. |
|
0 commit comments