Skip to content

Commit c002770

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 837296133
1 parent 9d36192 commit c002770

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

grain/_src/python/dataset/dataset_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,32 @@ def test_execution_summary_with_no_logging(self):
13921392
log_value = "Grain Dataset Execution Summary"
13931393
self.assertNotIn(log_value, "".join(logs.output))
13941394

1395+
@flagsaver.flagsaver(grain_py_debug_mode=True)
1396+
def test_execution_summary_with_mp_prefetch(self):
1397+
ds = dataset.MapDataset.range(10000).map(MapTransformAddingOne())
1398+
ds = ds.to_iter_dataset()
1399+
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1))
1400+
it = ds.__iter__()
1401+
_ = list(it)
1402+
all_nodes_present = False
1403+
while not all_nodes_present:
1404+
time.sleep(1)
1405+
all_nodes_present = True
1406+
summary = dataset.get_execution_summary(it)
1407+
node_names = {node.name for node in summary.nodes.values()}
1408+
all_nodes_present = all_nodes_present and any(
1409+
"RangeMapDataset" in name for name in node_names
1410+
)
1411+
all_nodes_present = all_nodes_present and any(
1412+
"MapMapDataset" in name for name in node_names
1413+
)
1414+
all_nodes_present = all_nodes_present and any(
1415+
"PrefetchDatasetIterator" in name for name in node_names
1416+
)
1417+
all_nodes_present = all_nodes_present and any(
1418+
"MultiprocessPrefetchDatasetIterator" in name for name in node_names
1419+
)
1420+
13951421

13961422
class GetElementSpecTest(parameterized.TestCase):
13971423

grain/_src/python/dataset/stats.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -818,10 +818,6 @@ def __init__(self, config: StatsConfig, parents: Sequence[Stats]):
818818
self._last_update_time = 0
819819
self._last_report_time = 0
820820
self._summary_dispatcher = None
821-
if self._config.stats_out_queue:
822-
self._summary_dispatcher = self._send_stats_to_main_process_loop
823-
elif self._config.log_summary:
824-
self._summary_dispatcher = self._logging_execution_summary_loop
825821

826822
def __reduce__(self):
827823
return _ExecutionStats, (self._config, self._parents)
@@ -939,14 +935,18 @@ def record_self_time(self, num_elements: int = 1, offset_ns: int = 0):
939935
target=self._reporting_loop, daemon=True
940936
)
941937
self._reporting_thread.start()
942-
if (
943-
self._summary_dispatcher_thread is None
944-
and self._summary_dispatcher is not None
938+
if self._summary_dispatcher_thread is None and (
939+
self._config.stats_out_queue or self._config.log_summary
945940
):
946941
with self._summary_dispatcher_thread_init_lock:
947942
# Check above together with update would not be atomic -- another
948943
# thread may have started the logging thread.
949944
if self._summary_dispatcher_thread is None:
945+
self._summary_dispatcher = (
946+
self._send_stats_to_main_process_loop
947+
if self._config.stats_out_queue
948+
else self._logging_execution_summary_loop
949+
)
950950
self._summary_dispatcher_thread = threading.Thread(
951951
target=self._summary_dispatcher, daemon=True
952952
)

0 commit comments

Comments
 (0)