Skip to content

Commit a2f2b2b

Browse files
committed
Skip some tests for non-interactve_env and fix errors in unit tests
1 parent a9a2d2b commit a2f2b2b

File tree

5 files changed

+66
-16
lines changed

5 files changed

+66
-16
lines changed

sdks/python/apache_beam/runners/interactive/interactive_beam_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def _get_watched_pcollections_with_variable_names():
7171
return watched_pcollections
7272

7373

74+
@unittest.skipIf(
75+
not ie.current_env().is_interactive_ready,
76+
'[interactive] dependency is not installed.')
7477
@isolated_env
7578
class InteractiveBeamTest(unittest.TestCase):
7679
def setUp(self):
@@ -677,6 +680,9 @@ def test_default_value_for_invalid_worker_number(self):
677680
self.assertEqual(meta.num_workers, 2)
678681

679682

683+
@unittest.skipIf(
684+
not ie.current_env().is_interactive_ready,
685+
'[interactive] dependency is not installed.')
680686
@isolated_env
681687
class InteractiveBeamComputeTest(unittest.TestCase):
682688
def setUp(self):

sdks/python/apache_beam/runners/interactive/interactive_environment_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
_module_name = 'apache_beam.runners.interactive.interactive_environment_test'
3535

3636

37+
@unittest.skipIf(
38+
not ie.current_env().is_interactive_ready,
39+
'[interactive] dependency is not installed.')
3740
@isolated_env
3841
class InteractiveEnvironmentTest(unittest.TestCase):
3942
def setUp(self):

sdks/python/apache_beam/runners/interactive/recording_manager.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def __init__(
438438
self._executor = ThreadPoolExecutor(max_workers=os.cpu_count())
439439
self._env = ie.current_env()
440440
self._async_computations: dict[str, AsyncComputationResult] = {}
441-
self._pipeline_graph = PipelineGraph(self.user_pipeline)
441+
self._pipeline_graph = None
442442

443443
def _execute_pipeline_fragment(
444444
self,
@@ -492,7 +492,6 @@ def _run_async_computation(
492492
if not self._wait_for_dependencies(pcolls_to_compute, async_result):
493493
raise RuntimeError('Dependency computation failed or was cancelled.')
494494

495-
self._env.mark_pcollection_computing(pcolls_to_compute)
496495
_LOGGER.info(
497496
'Starting asynchronous computation for %d PCollections.',
498497
len(pcolls_to_compute))
@@ -696,6 +695,7 @@ def compute_async(
696695
async_result = AsyncComputationResult(
697696
future, pcolls_to_compute, self.user_pipeline, self)
698697
self._async_computations[async_result._display_id] = async_result
698+
self._env.mark_pcollection_computing(pcolls_to_compute)
699699

700700
def task():
701701
try:
@@ -709,22 +709,39 @@ def task():
709709
self._executor.submit(task)
710710
return async_result
711711

712+
def _get_pipeline_graph(self):
713+
"""Lazily initializes and returns the PipelineGraph."""
714+
if self._pipeline_graph is None:
715+
try:
716+
# Try to create the graph.
717+
self._pipeline_graph = PipelineGraph(self.user_pipeline)
718+
except (ImportError, NameError, AttributeError):
719+
# If pydot is missing, PipelineGraph() might crash.
720+
_LOGGER.warning(
721+
"Could not create PipelineGraph (pydot missing?). " \
722+
"Async features disabled."
723+
)
724+
self._pipeline_graph = None
725+
return self._pipeline_graph
726+
712727
def _get_pcoll_id_map(self):
713728
"""Creates a map from PCollection object to its ID in the proto."""
714729
pcoll_to_id = {}
715-
if self._pipeline_graph._pipeline_instrument:
716-
pcoll_to_id = self._pipeline_graph._pipeline_instrument._pcoll_to_pcoll_id
730+
graph = self._get_pipeline_graph()
731+
if graph and graph._pipeline_instrument:
732+
pcoll_to_id = graph._pipeline_instrument._pcoll_to_pcoll_id
717733
return {v: k for k, v in pcoll_to_id.items()}
718734

719735
def _get_all_dependencies(
720736
self,
721737
pcolls: set[beam.pvalue.PCollection]) -> set[beam.pvalue.PCollection]:
722738
"""Gets all upstream PCollection dependencies
723739
for the given set of PCollections."""
724-
if not self._pipeline_graph:
740+
graph = self._get_pipeline_graph()
741+
if not graph:
725742
return set()
726743

727-
analyzer = self._pipeline_graph._pipeline_instrument
744+
analyzer = graph._pipeline_instrument
728745
if not analyzer:
729746
return set()
730747

@@ -751,8 +768,8 @@ def _get_all_dependencies(
751768
queue = collections.deque(target_pcoll_ids)
752769
visited_pcoll_ids = set(target_pcoll_ids)
753770

754-
producers = self._pipeline_graph._producers
755-
transforms = self._pipeline_graph._pipeline_proto.components.transforms
771+
producers = graph._producers
772+
transforms = graph._pipeline_proto.components.transforms
756773

757774
while queue:
758775
pcoll_id = queue.popleft()

sdks/python/apache_beam/runners/interactive/recording_manager_test.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
from apache_beam.utils.windowed_value import WindowedValue
4848

4949

50+
@unittest.skipIf(
51+
not ie.current_env().is_interactive_ready,
52+
'[interactive] dependency is not installed.')
5053
class AsyncComputationResultTest(unittest.TestCase):
5154
def setUp(self):
5255
self.mock_future = MagicMock(spec=Future)
@@ -664,6 +667,9 @@ def test_describe(self):
664667
cache_manager.size('full', letters_stream.cache_key))
665668

666669

670+
@unittest.skipIf(
671+
not ie.current_env().is_interactive_ready,
672+
'[interactive] dependency is not installed.')
667673
class RecordingManagerTest(unittest.TestCase):
668674
def test_basic_execution(self):
669675
"""A basic pipeline to be used as a smoke test."""
@@ -987,12 +993,6 @@ def capture_task(task):
987993

988994
mock_submit.side_effect = capture_task
989995

990-
res = rm.compute_async({pcoll}, blocking=False)
991-
self.assertIs(res, mock_async_res_instance)
992-
mock_submit.assert_called_once()
993-
self.assertIsNotNone(task_submitted)
994-
995-
# Patch dependencies of _run_async_computation
996996
with patch.object(
997997
rm, '_wait_for_dependencies', return_value=True
998998
), patch.object(
@@ -1002,10 +1002,22 @@ def capture_task(task):
10021002
'mark_pcollection_computing',
10031003
wraps=ie.current_env().mark_pcollection_computing,
10041004
) as wrapped_mark:
1005-
# Run the task to trigger the marks
1006-
task_submitted()
1005+
1006+
res = rm.compute_async({pcoll}, blocking=False)
10071007
wrapped_mark.assert_called_once_with({pcoll})
10081008

1009+
# Run the task to trigger the marks
1010+
self.assertIs(res, mock_async_res_instance)
1011+
mock_submit.assert_called_once()
1012+
self.assertIsNotNone(task_submitted)
1013+
1014+
with patch.object(
1015+
rm, '_wait_for_dependencies', return_value=True
1016+
), patch.object(
1017+
rm, '_execute_pipeline_fragment'
1018+
) as _:
1019+
task_submitted()
1020+
10091021
self.assertTrue(pcoll in ie.current_env().computing_pcollections)
10101022

10111023
def test_get_all_dependencies(self):

sdks/python/apache_beam/runners/interactive/utils_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ def test_child_module_logger_can_override_logging_level(self, mock_emit):
244244
reason='[interactive] dependency is not installed.')
245245
class ProgressIndicatorTest(unittest.TestCase):
246246
def setUp(self):
247+
self.gcs_patcher = patch(
248+
'apache_beam.io.gcp.gcsfilesystem.GCSFileSystem.delete')
249+
self.gcs_patcher.start()
247250
ie.new_env()
248251

249252
@patch('IPython.get_ipython', new_callable=mock_get_ipython)
@@ -279,6 +282,9 @@ def test_progress_in_HTML_JS_when_in_notebook(
279282
mocked_html.assert_called()
280283
mocked_js.assert_called()
281284

285+
def tearDown(self):
286+
self.gcs_patcher.stop()
287+
282288

283289
@unittest.skipIf(
284290
not ie.current_env().is_interactive_ready,
@@ -287,6 +293,9 @@ class MessagingUtilTest(unittest.TestCase):
287293
SAMPLE_DATA = {'a': [1, 2, 3], 'b': 4, 'c': '5', 'd': {'e': 'f'}}
288294

289295
def setUp(self):
296+
self.gcs_patcher = patch(
297+
'apache_beam.io.gcp.gcsfilesystem.GCSFileSystem.delete')
298+
self.gcs_patcher.start()
290299
ie.new_env()
291300

292301
def test_as_json_decorator(self):
@@ -298,6 +307,9 @@ def dummy():
298307
# dictionaries remember the order of items inserted.
299308
self.assertEqual(json.loads(dummy()), MessagingUtilTest.SAMPLE_DATA)
300309

310+
def tearDown(self):
311+
self.gcs_patcher.stop()
312+
301313

302314
class GeneralUtilTest(unittest.TestCase):
303315
def test_pcoll_by_name(self):

0 commit comments

Comments
 (0)