Skip to content

Commit 0ebf84b

Browse files
authored
Add ib.collect support for raw records (#36516)
1 parent 4c08585 commit 0ebf84b

File tree

2 files changed

+105
-10
lines changed

2 files changed

+105
-10
lines changed

sdks/python/apache_beam/runners/interactive/interactive_beam.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,8 @@ def collect(
879879
runner=None,
880880
options=None,
881881
force_compute=False,
882-
force_tuple=False):
882+
force_tuple=False,
883+
raw_records=False):
883884
"""Materializes the elements from a PCollection into a Dataframe.
884885
885886
This reads each element from file and reads only the amount that it needs
@@ -901,6 +902,8 @@ def collect(
901902
cached PCollections
902903
force_tuple: (optional) if True, return a 1-tuple or results rather than
903904
the bare results if only one PCollection is computed
905+
raw_records: (optional) if True, return a list of collected records
906+
without converting to a DataFrame. Default False.
904907
905908
For example::
906909
@@ -910,6 +913,9 @@ def collect(
910913
911914
# Run the pipeline and bring the PCollection into memory as a Dataframe.
912915
in_memory_square = head(square, n=5)
916+
917+
# Run the pipeline and get the raw list of elements.
918+
raw_squares = collect(square, n=5, raw_records=True)
913919
"""
914920
if len(pcolls) == 0:
915921
return ()
@@ -986,15 +992,19 @@ def as_pcollection(pcoll_or_df):
986992
if n == float('inf'):
987993
n = None
988994

989-
# Collecting DataFrames may have a length > n, so slice again to be sure. Note
990-
# that array[:None] returns everything.
991-
empty = pd.DataFrame()
992-
result_tuple = tuple(
993-
elements_to_df(
994-
computed[pcoll],
995-
include_window_info=include_window_info,
996-
element_type=pcolls_to_element_types[pcoll])[:n] if pcoll in
997-
computed else empty for pcoll in pcolls)
995+
if raw_records:
996+
result_tuple = tuple([el.value for el in computed.get(pcoll, [])][:n]
997+
for pcoll in pcolls)
998+
else:
999+
# Collecting DataFrames may have a length > n, so slice again to be sure.
1000+
# Note that array[:None] returns everything.
1001+
empty = pd.DataFrame()
1002+
result_tuple = tuple(
1003+
elements_to_df(
1004+
computed.get(pcoll, []),
1005+
include_window_info=include_window_info,
1006+
element_type=pcolls_to_element_types[pcoll])[:n] if pcoll in
1007+
computed else empty for pcoll in pcolls)
9981008

9991009
if len(result_tuple) == 1 and not force_tuple:
10001010
return result_tuple[0]

sdks/python/apache_beam/runners/interactive/interactive_beam_test.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,91 @@ def is_triggered(self):
293293
self.assertTrue(ib.recordings.record(p))
294294
ib.recordings.stop(p)
295295

296+
def test_collect_raw_records_true(self):
297+
p = beam.Pipeline(ir.InteractiveRunner())
298+
data = list(range(5))
299+
pcoll = p | 'Create' >> beam.Create(data)
300+
ib.watch(locals())
301+
ie.current_env().track_user_pipelines()
302+
303+
result = ib.collect(pcoll, raw_records=True)
304+
self.assertIsInstance(result, list)
305+
self.assertEqual(result, data)
306+
307+
result_n = ib.collect(pcoll, n=3, raw_records=True)
308+
self.assertIsInstance(result_n, list)
309+
self.assertEqual(result_n, data[:3])
310+
311+
def test_collect_raw_records_false(self):
312+
p = beam.Pipeline(ir.InteractiveRunner())
313+
data = list(range(5))
314+
pcoll = p | 'Create' >> beam.Create(data)
315+
ib.watch(locals())
316+
ie.current_env().track_user_pipelines()
317+
318+
result = ib.collect(pcoll)
319+
self.assertNotIsInstance(result, list)
320+
self.assertTrue(
321+
hasattr(result, 'columns'), "Result should have 'columns' attribute")
322+
self.assertTrue(
323+
hasattr(result, 'values'), "Result should have 'values' attribute")
324+
325+
result_n = ib.collect(pcoll, n=3)
326+
self.assertNotIsInstance(result_n, list)
327+
self.assertTrue(
328+
hasattr(result_n, 'columns'),
329+
"Result (n=3) should have 'columns' attribute")
330+
self.assertTrue(
331+
hasattr(result_n, 'values'),
332+
"Result (n=3) should have 'values' attribute")
333+
334+
def test_collect_raw_records_true_multiple_pcolls(self):
335+
p = beam.Pipeline(ir.InteractiveRunner())
336+
data1 = list(range(3))
337+
data2 = [x * x for x in range(3)]
338+
pcoll1 = p | 'Create1' >> beam.Create(data1)
339+
pcoll2 = p | 'Create2' >> beam.Create(data2)
340+
ib.watch(locals())
341+
ie.current_env().track_user_pipelines()
342+
343+
result = ib.collect(pcoll1, pcoll2, raw_records=True)
344+
self.assertIsInstance(result, tuple)
345+
self.assertEqual(len(result), 2)
346+
self.assertIsInstance(result[0], list)
347+
self.assertEqual(result[0], data1)
348+
self.assertIsInstance(result[1], list)
349+
self.assertEqual(result[1], data2)
350+
351+
def test_collect_raw_records_false_multiple_pcolls(self):
352+
p = beam.Pipeline(ir.InteractiveRunner())
353+
data1 = list(range(3))
354+
data2 = [x * x for x in range(3)]
355+
pcoll1 = p | 'Create1' >> beam.Create(data1)
356+
pcoll2 = p | 'Create2' >> beam.Create(data2)
357+
ib.watch(locals())
358+
ie.current_env().track_user_pipelines()
359+
360+
result = ib.collect(pcoll1, pcoll2)
361+
self.assertIsInstance(result, tuple)
362+
self.assertEqual(len(result), 2)
363+
self.assertNotIsInstance(result[0], list)
364+
self.assertTrue(hasattr(result[0], 'columns'))
365+
self.assertNotIsInstance(result[1], list)
366+
self.assertTrue(hasattr(result[1], 'columns'))
367+
368+
def test_collect_raw_records_true_force_tuple(self):
369+
p = beam.Pipeline(ir.InteractiveRunner())
370+
data = list(range(5))
371+
pcoll = p | 'Create' >> beam.Create(data)
372+
ib.watch(locals())
373+
ie.current_env().track_user_pipelines()
374+
375+
result = ib.collect(pcoll, raw_records=True, force_tuple=True)
376+
self.assertIsInstance(result, tuple)
377+
self.assertEqual(len(result), 1)
378+
self.assertIsInstance(result[0], list)
379+
self.assertEqual(result[0], data)
380+
296381

297382
@unittest.skipIf(
298383
not ie.current_env().is_interactive_ready,

0 commit comments

Comments
 (0)