Skip to content

Commit bc5f5d4

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 837296133
1 parent 9d36192 commit bc5f5d4

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

grain/_src/python/dataset/dataset_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,40 @@ def test_execution_summary_with_no_logging(self):
13921392
log_value = "Grain Dataset Execution Summary"
13931393
self.assertNotIn(log_value, "".join(logs.output))
13941394

1395+
@flagsaver.flagsaver(grain_py_debug_mode=True)
1396+
@mock.patch.object(dataset_stats, "_REPORTING_PERIOD_SEC", 0.05)
1397+
def test_execution_summary_with_mp_prefetch(self):
1398+
def worker_init_fn_wrapper(worker_index, worker_count):
1399+
del worker_index, worker_count
1400+
dataset_stats._REPORTING_PERIOD_SEC = 0.05
1401+
1402+
ds = dataset.MapDataset.range(10000).map(MapTransformAddingOne())
1403+
ds = ds.to_iter_dataset()
1404+
ds = ds.mp_prefetch(
1405+
options.MultiprocessingOptions(num_workers=1),
1406+
worker_init_fn=worker_init_fn_wrapper,
1407+
)
1408+
it = ds.__iter__()
1409+
_ = list(it)
1410+
all_nodes_present = False
1411+
while not all_nodes_present:
1412+
time.sleep(1)
1413+
all_nodes_present = True
1414+
summary = dataset.get_execution_summary(it)
1415+
node_names = {node.name for node in summary.nodes.values()}
1416+
all_nodes_present = all_nodes_present and any(
1417+
"RangeMapDataset" in name for name in node_names
1418+
)
1419+
all_nodes_present = all_nodes_present and any(
1420+
"MapMapDataset" in name for name in node_names
1421+
)
1422+
all_nodes_present = all_nodes_present and any(
1423+
"PrefetchDatasetIterator" in name for name in node_names
1424+
)
1425+
all_nodes_present = all_nodes_present and any(
1426+
"MultiprocessPrefetchDatasetIterator" in name for name in node_names
1427+
)
1428+
13951429

13961430
class GetElementSpecTest(parameterized.TestCase):
13971431

grain/_src/python/dataset/stats.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -818,10 +818,6 @@ def __init__(self, config: StatsConfig, parents: Sequence[Stats]):
818818
self._last_update_time = 0
819819
self._last_report_time = 0
820820
self._summary_dispatcher = None
821-
if self._config.stats_out_queue:
822-
self._summary_dispatcher = self._send_stats_to_main_process_loop
823-
elif self._config.log_summary:
824-
self._summary_dispatcher = self._logging_execution_summary_loop
825821

826822
def __reduce__(self):
827823
return _ExecutionStats, (self._config, self._parents)
@@ -939,14 +935,18 @@ def record_self_time(self, num_elements: int = 1, offset_ns: int = 0):
939935
target=self._reporting_loop, daemon=True
940936
)
941937
self._reporting_thread.start()
942-
if (
943-
self._summary_dispatcher_thread is None
944-
and self._summary_dispatcher is not None
938+
if self._summary_dispatcher_thread is None and (
939+
self._config.stats_out_queue or self._config.log_summary
945940
):
946941
with self._summary_dispatcher_thread_init_lock:
947942
# Check above together with update would not be atomic -- another
948943
# thread may have started the logging thread.
949944
if self._summary_dispatcher_thread is None:
945+
self._summary_dispatcher = (
946+
self._send_stats_to_main_process_loop
947+
if self._config.stats_out_queue
948+
else self._logging_execution_summary_loop
949+
)
950950
self._summary_dispatcher_thread = threading.Thread(
951951
target=self._summary_dispatcher, daemon=True
952952
)

0 commit comments

Comments
 (0)