From 2f21451e1cb84af909852d4a3b5b64a410cac8ae Mon Sep 17 00:00:00 2001 From: Ian Liao <55819364+ian-Liaozy@users.noreply.github.com> Date: Mon, 5 Jan 2026 22:14:54 +0000 Subject: [PATCH 1/4] ib.collect support wait_for_inputs option --- .../runners/interactive/interactive_beam.py | 10 +- .../interactive/interactive_beam_test.py | 104 ++++++++++++++++++ .../runners/interactive/recording_manager.py | 12 +- 3 files changed, 119 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py index 7b773fda5db8..457401aa29a7 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py @@ -879,7 +879,8 @@ def collect( options=None, force_compute=False, force_tuple=False, - raw_records=False): + raw_records=False, + wait_for_inputs=True): """Materializes the elements from a PCollection into a Dataframe. This reads each element from file and reads only the amount that it needs @@ -903,6 +904,10 @@ def collect( the bare results if only one PCollection is computed raw_records: (optional) if True, return a list of collected records without converting to a DataFrame. Default False. + wait_for_inputs: Whether to wait until the asynchronous dependencies are + computed. Setting this to False allows to immediately schedule the + computation, but also potentially results in running the same pipeline + stages multiple times. For example:: @@ -980,7 +985,8 @@ def as_pcollection(pcoll_or_df): max_duration=duration, runner=runner, options=options, - force_compute=force_compute) + force_compute=force_compute, + wait_for_inputs=wait_for_inputs) try: for pcoll in uncomputed: diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py index 21163fc121c5..d6bbcdae2db8 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py @@ -387,6 +387,110 @@ def test_collect_raw_records_true_force_tuple(self): self.assertIsInstance(result[0], list) self.assertEqual(result[0], data) + def test_collect_wait_for_inputs_true(self, mock_current_env): + mock_env = MagicMock() + mock_current_env.return_value = mock_env + mock_rm = MagicMock() + mock_env.get_recording_manager.return_value = mock_rm + mock_env.computed_pcollections = set() + + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + + # Simulate pcoll1 being computed asynchronously + mock_env.is_pcollection_computing.return_value = True + async_res = MagicMock(spec=AsyncComputationResult) + mock_rm._async_computations = {'id1': async_res} + mock_rm._get_all_dependencies.return_value = {pcoll1} + + # Set up return value for record + mock_recording = MagicMock() + mock_rm.record.return_value = mock_recording + mock_rm._wait_for_dependencies.return_value = True + + ib.collect(pcoll2, wait_for_inputs=True) + + # Check if wait_for_dependencies was called because wait_for_inputs is True + mock_rm._wait_for_dependencies.assert_called_once_with({pcoll2}) + # Check that record was called with wait_for_inputs=True + mock_rm.record.assert_called_once_with({pcoll2}, + max_n=float('inf'), + max_duration=float('inf'), + runner=None, + options=None, + force_compute=False, + wait_for_inputs=True) + + def test_collect_wait_for_inputs_false(self, mock_current_env): + mock_env = MagicMock() + mock_current_env.return_value = mock_env + mock_rm = MagicMock() + mock_env.get_recording_manager.return_value = mock_rm + mock_env.computed_pcollections = set() + + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + + # Simulate pcoll1 being computed asynchronously + mock_env.is_pcollection_computing.return_value = True + async_res = MagicMock(spec=AsyncComputationResult) + mock_rm._async_computations = {'id1': async_res} + mock_rm._get_all_dependencies.return_value = {pcoll1} + + # Set up return value for record + mock_recording = MagicMock() + mock_rm.record.return_value = mock_recording + + ib.collect(pcoll2, wait_for_inputs=False) + + # Check that wait_for_dependencies was NOT called + mock_rm._wait_for_dependencies.assert_not_called() + # Check that record was called with wait_for_inputs=False + mock_rm.record.assert_called_once_with({pcoll2}, + max_n=float('inf'), + max_duration=float('inf'), + runner=None, + options=None, + force_compute=False, + wait_for_inputs=False) + + def test_collect_wait_for_inputs_default(self, mock_current_env): + mock_env = MagicMock() + mock_current_env.return_value = mock_env + mock_rm = MagicMock() + mock_env.get_recording_manager.return_value = mock_rm + mock_env.computed_pcollections = set() + + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + + # Simulate pcoll1 being computed asynchronously + mock_env.is_pcollection_computing.return_value = True + async_res = MagicMock(spec=AsyncComputationResult) + mock_rm._async_computations = {'id1': async_res} + mock_rm._get_all_dependencies.return_value = {pcoll1} + mock_rm._wait_for_dependencies.return_value = True + + # Set up return value for record + mock_recording = MagicMock() + mock_rm.record.return_value = mock_recording + + ib.collect(pcoll2) # wait_for_inputs defaults to True + + # Check that wait_for_dependencies was called + mock_rm._wait_for_dependencies.assert_called_once_with({pcoll2}) + # Check that record was called with wait_for_inputs=True + mock_rm.record.assert_called_once_with({pcoll2}, + max_n=float('inf'), + max_duration=float('inf'), + runner=None, + options=None, + force_compute=False, + wait_for_inputs=True) + @unittest.skipIf( not ie.current_env().is_interactive_ready, diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py index c19b60b64fd2..c768e4e6d943 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py @@ -849,7 +849,8 @@ def record( max_duration: Union[int, str], runner: runner.PipelineRunner = None, options: pipeline_options.PipelineOptions = None, - force_compute: bool = False) -> Recording: + force_compute: bool = False, + wait_for_inputs: bool = True) -> Recording: # noqa: F821 """Records the given PCollections.""" @@ -886,10 +887,11 @@ def record( # Start a pipeline fragment to start computing the PCollections. uncomputed_pcolls = set(pcolls).difference(computed_pcolls) if uncomputed_pcolls: - if not self._wait_for_dependencies(uncomputed_pcolls): - raise RuntimeError( - 'Cannot record because a dependency failed to compute' - ' asynchronously.') + if wait_for_inputs: + if not self._wait_for_dependencies(uncomputed_pcolls): + raise RuntimeError( + 'Cannot record because a dependency failed to compute' + ' asynchronously.') self._clear() From 2a515f564a44549032ba991ddb5535f02c496ff5 Mon Sep 17 00:00:00 2001 From: Ian Liao <55819364+ian-Liaozy@users.noreply.github.com> Date: Mon, 5 Jan 2026 23:14:06 +0000 Subject: [PATCH 2/4] Fix unit tests --- .../interactive/interactive_beam_test.py | 217 +++++++++--------- 1 file changed, 114 insertions(+), 103 deletions(-) diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py index d6bbcdae2db8..3d815308e639 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py @@ -387,109 +387,120 @@ def test_collect_raw_records_true_force_tuple(self): self.assertIsInstance(result[0], list) self.assertEqual(result[0], data) - def test_collect_wait_for_inputs_true(self, mock_current_env): - mock_env = MagicMock() - mock_current_env.return_value = mock_env - mock_rm = MagicMock() - mock_env.get_recording_manager.return_value = mock_rm - mock_env.computed_pcollections = set() - - p = beam.Pipeline(ir.InteractiveRunner()) - pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) - pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) - - # Simulate pcoll1 being computed asynchronously - mock_env.is_pcollection_computing.return_value = True - async_res = MagicMock(spec=AsyncComputationResult) - mock_rm._async_computations = {'id1': async_res} - mock_rm._get_all_dependencies.return_value = {pcoll1} - - # Set up return value for record - mock_recording = MagicMock() - mock_rm.record.return_value = mock_recording - mock_rm._wait_for_dependencies.return_value = True - - ib.collect(pcoll2, wait_for_inputs=True) - - # Check if wait_for_dependencies was called because wait_for_inputs is True - mock_rm._wait_for_dependencies.assert_called_once_with({pcoll2}) - # Check that record was called with wait_for_inputs=True - mock_rm.record.assert_called_once_with({pcoll2}, - max_n=float('inf'), - max_duration=float('inf'), - runner=None, - options=None, - force_compute=False, - wait_for_inputs=True) - - def test_collect_wait_for_inputs_false(self, mock_current_env): - mock_env = MagicMock() - mock_current_env.return_value = mock_env - mock_rm = MagicMock() - mock_env.get_recording_manager.return_value = mock_rm - mock_env.computed_pcollections = set() - - p = beam.Pipeline(ir.InteractiveRunner()) - pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) - pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) - - # Simulate pcoll1 being computed asynchronously - mock_env.is_pcollection_computing.return_value = True - async_res = MagicMock(spec=AsyncComputationResult) - mock_rm._async_computations = {'id1': async_res} - mock_rm._get_all_dependencies.return_value = {pcoll1} - - # Set up return value for record - mock_recording = MagicMock() - mock_rm.record.return_value = mock_recording - - ib.collect(pcoll2, wait_for_inputs=False) - - # Check that wait_for_dependencies was NOT called - mock_rm._wait_for_dependencies.assert_not_called() - # Check that record was called with wait_for_inputs=False - mock_rm.record.assert_called_once_with({pcoll2}, - max_n=float('inf'), - max_duration=float('inf'), - runner=None, - options=None, - force_compute=False, - wait_for_inputs=False) - - def test_collect_wait_for_inputs_default(self, mock_current_env): - mock_env = MagicMock() - mock_current_env.return_value = mock_env - mock_rm = MagicMock() - mock_env.get_recording_manager.return_value = mock_rm - mock_env.computed_pcollections = set() - - p = beam.Pipeline(ir.InteractiveRunner()) - pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) - pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) - - # Simulate pcoll1 being computed asynchronously - mock_env.is_pcollection_computing.return_value = True - async_res = MagicMock(spec=AsyncComputationResult) - mock_rm._async_computations = {'id1': async_res} - mock_rm._get_all_dependencies.return_value = {pcoll1} - mock_rm._wait_for_dependencies.return_value = True - - # Set up return value for record - mock_recording = MagicMock() - mock_rm.record.return_value = mock_recording - - ib.collect(pcoll2) # wait_for_inputs defaults to True - - # Check that wait_for_dependencies was called - mock_rm._wait_for_dependencies.assert_called_once_with({pcoll2}) - # Check that record was called with wait_for_inputs=True - mock_rm.record.assert_called_once_with({pcoll2}, - max_n=float('inf'), - max_duration=float('inf'), - runner=None, - options=None, - force_compute=False, - wait_for_inputs=True) + def test_collect_wait_for_inputs_true(self): + with patch('apache_beam.runners.interactive.interactive_beam.ie.current_env' + ) as mock_current_env: + mock_env = MagicMock() + mock_current_env.return_value = mock_env + mock_rm = MagicMock() + mock_env.get_recording_manager.return_value = mock_rm + mock_env.computed_pcollections = set() + mock_env.user_pipeline.return_value = lambda x: x # Mock user_pipeline + + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + + # Simulate pcoll1 being computed asynchronously + mock_env.is_pcollection_computing.return_value = True + async_res = MagicMock(spec=AsyncComputationResult) + mock_rm._async_computations = {'id1': async_res} + mock_rm._get_all_dependencies.return_value = {pcoll1} + mock_rm._wait_for_dependencies.return_value = True + + # Set up return value for record + mock_recording = MagicMock() + mock_rm.record.return_value = mock_recording + + ib.collect(pcoll2, wait_for_inputs=True) + + # Check wait_for_dependencies was called because wait_for_inputs is True + mock_rm._wait_for_dependencies.assert_called_once_with({pcoll2}, + async_result=None) + # Check that record was called with wait_for_inputs=True + mock_rm.record.assert_called_once_with({pcoll2}, + max_n=float('inf'), + max_duration=float('inf'), + runner=None, + options=None, + force_compute=False, + wait_for_inputs=True) + + def test_collect_wait_for_inputs_false(self): + with patch('apache_beam.runners.interactive.interactive_beam.ie.current_env' + ) as mock_current_env: + mock_env = MagicMock() + mock_current_env.return_value = mock_env + mock_rm = MagicMock() + mock_env.get_recording_manager.return_value = mock_rm + mock_env.computed_pcollections = set() + mock_env.user_pipeline.return_value = lambda x: x # Mock user_pipeline + + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + + # Simulate pcoll1 being computed asynchronously + mock_env.is_pcollection_computing.return_value = True + async_res = MagicMock(spec=AsyncComputationResult) + mock_rm._async_computations = {'id1': async_res} + mock_rm._get_all_dependencies.return_value = {pcoll1} + + # Set up return value for record + mock_recording = MagicMock() + mock_rm.record.return_value = mock_recording + + ib.collect(pcoll2, wait_for_inputs=False) + + # Check that wait_for_dependencies was NOT called + mock_rm._wait_for_dependencies.assert_not_called() + # Check that record was called with wait_for_inputs=False + mock_rm.record.assert_called_once_with({pcoll2}, + max_n=float('inf'), + max_duration=float('inf'), + runner=None, + options=None, + force_compute=False, + wait_for_inputs=False) + + def test_collect_wait_for_inputs_default(self): + with patch('apache_beam.runners.interactive.interactive_beam.ie.current_env' + ) as mock_current_env: + mock_env = MagicMock() + mock_current_env.return_value = mock_env + mock_rm = MagicMock() + mock_env.get_recording_manager.return_value = mock_rm + mock_env.computed_pcollections = set() + mock_env.user_pipeline.return_value = lambda x: x # Mock user_pipeline + + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + + # Simulate pcoll1 being computed asynchronously + mock_env.is_pcollection_computing.return_value = True + async_res = MagicMock(spec=AsyncComputationResult) + mock_rm._async_computations = {'id1': async_res} + mock_rm._get_all_dependencies.return_value = {pcoll1} + mock_rm._wait_for_dependencies.return_value = True + + # Set up return value for record + mock_recording = MagicMock() + mock_rm.record.return_value = mock_recording + + ib.collect(pcoll2) # wait_for_inputs defaults to True + + # Check that wait_for_dependencies was called + mock_rm._wait_for_dependencies.assert_called_once_with({pcoll2}, + async_result=None) + # Check that record was called with wait_for_inputs=True + mock_rm.record.assert_called_once_with({pcoll2}, + max_n=float('inf'), + max_duration=float('inf'), + runner=None, + options=None, + force_compute=False, + wait_for_inputs=True) @unittest.skipIf( From 229751c308d9e15d88ecb5870c109ff4b7559ea2 Mon Sep 17 00:00:00 2001 From: Ian Liao <55819364+ian-Liaozy@users.noreply.github.com> Date: Tue, 6 Jan 2026 00:13:46 +0000 Subject: [PATCH 3/4] Fix unit test assertion error --- .../runners/interactive/interactive_beam_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py index 3d815308e639..7531fd242ba3 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py @@ -395,7 +395,7 @@ def test_collect_wait_for_inputs_true(self): mock_rm = MagicMock() mock_env.get_recording_manager.return_value = mock_rm mock_env.computed_pcollections = set() - mock_env.user_pipeline.return_value = lambda x: x # Mock user_pipeline + mock_env.user_pipeline.side_effect = lambda x: x p = beam.Pipeline(ir.InteractiveRunner()) pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) @@ -434,7 +434,7 @@ def test_collect_wait_for_inputs_false(self): mock_rm = MagicMock() mock_env.get_recording_manager.return_value = mock_rm mock_env.computed_pcollections = set() - mock_env.user_pipeline.return_value = lambda x: x # Mock user_pipeline + mock_env.user_pipeline.side_effect = lambda x: x p = beam.Pipeline(ir.InteractiveRunner()) pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) @@ -471,7 +471,7 @@ def test_collect_wait_for_inputs_default(self): mock_rm = MagicMock() mock_env.get_recording_manager.return_value = mock_rm mock_env.computed_pcollections = set() - mock_env.user_pipeline.return_value = lambda x: x # Mock user_pipeline + mock_env.user_pipeline.side_effect = lambda x: x p = beam.Pipeline(ir.InteractiveRunner()) pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) From 75b6ae3bfaa8d2bb1e71058f3fe34207a8970a81 Mon Sep 17 00:00:00 2001 From: Ian Liao <55819364+ian-Liaozy@users.noreply.github.com> Date: Tue, 13 Jan 2026 23:32:59 +0000 Subject: [PATCH 4/4] remove incorrect assert in unit test --- .../runners/interactive/interactive_beam_test.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py index 7531fd242ba3..f0bb69ef249d 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py @@ -414,9 +414,6 @@ def test_collect_wait_for_inputs_true(self): ib.collect(pcoll2, wait_for_inputs=True) - # Check wait_for_dependencies was called because wait_for_inputs is True - mock_rm._wait_for_dependencies.assert_called_once_with({pcoll2}, - async_result=None) # Check that record was called with wait_for_inputs=True mock_rm.record.assert_called_once_with({pcoll2}, max_n=float('inf'), @@ -490,9 +487,6 @@ def test_collect_wait_for_inputs_default(self): ib.collect(pcoll2) # wait_for_inputs defaults to True - # Check that wait_for_dependencies was called - mock_rm._wait_for_dependencies.assert_called_once_with({pcoll2}, - async_result=None) # Check that record was called with wait_for_inputs=True mock_rm.record.assert_called_once_with({pcoll2}, max_n=float('inf'),