Skip to content

Commit 1aff212

Browse files
authored
Move a lot of checkpointing code into Memoizer, from DFK (#3986)
This should not change behaviour and only move code around and add in more wiring. This is working towards all of checkpointing/memoizer code being plugable, with the DFK having no knowledge of the specifics - for example, the DFK no longer imports `pickle`, because that is now the business of the memoizer. Because of that goal, the parameters for Memoizer are arranged in two ways: Constructor parameters for Memoizer are parameters that the future user will supply at configuration when Memoizer is exposed as a user plugin, like other plugins. DFK internals that should be injected into the (eventually pluggable) memoizer (small m) are set as startup-time attributes, as for other plugins. Right now these two things happen right next to each other, but the Memoizer constructor call will move into user configuration space in a subsequent PR. As before this PR, there is still a separation between "checkpointing" and "memoization" that is slightly artificial. Subsequent PRs will merge these actions together more in this Memoizer implementation. This PR preserves a user-facing dfk.checkpoint() call, that becomes a passthrough to Memoizer.checkpoint(). I think that is likely to disappear in a future PR: manual checkpoint triggering will become a feature (or non-feature) of a specific memoizer implementation and so a method directly on that memoizer (or not). To go alongside the existing update_memo call, called by the DFK when a task is ready for memoization, this PR adds update_checkpoint, which captures the slightly different notion of a task being ready for checkpointing -- the current implementation can memoize a task that is not yet complete because it memoizes the future, not the result, while a checkpoint needs the actual result to be ready and so must be called later. This latter call happens in the race-condition-prone handle_app_update. A later PR will remove this distinction and move everything to around the same place as update_memo. See PR #3979 for description of fixing a related race condition related to update_memo. See PR #3535 for broader context. # Changed Behaviour none ## Type of change - New feature
1 parent 8641f1a commit 1aff212

File tree

2 files changed

+137
-108
lines changed

2 files changed

+137
-108
lines changed

parsl/dataflow/dflow.py

Lines changed: 11 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import inspect
77
import logging
88
import os
9-
import pickle
109
import random
1110
import sys
1211
import threading
@@ -50,7 +49,7 @@
5049
from parsl.monitoring.remote import monitor_wrapper
5150
from parsl.process_loggers import wrap_with_logs
5251
from parsl.usage_tracking.usage import UsageTracker
53-
from parsl.utils import Timer, get_all_checkpoints, get_std_fname_mode, get_version
52+
from parsl.utils import get_std_fname_mode, get_version
5453

5554
logger = logging.getLogger(__name__)
5655

@@ -101,8 +100,6 @@ def __init__(self, config: Config) -> None:
101100

102101
logger.info("Parsl version: {}".format(get_version()))
103102

104-
self.checkpoint_lock = threading.Lock()
105-
106103
self.usage_tracker = UsageTracker(self)
107104
self.usage_tracker.send_start_message()
108105

@@ -168,18 +165,12 @@ def __init__(self, config: Config) -> None:
168165
self.monitoring_radio.send((MessageType.WORKFLOW_INFO,
169166
workflow_info))
170167

171-
if config.checkpoint_files is not None:
172-
checkpoint_files = config.checkpoint_files
173-
elif config.checkpoint_files is None and config.checkpoint_mode is not None:
174-
checkpoint_files = get_all_checkpoints(self.run_dir)
175-
else:
176-
checkpoint_files = []
177-
178-
self.memoizer = Memoizer(memoize=config.app_cache, checkpoint_files=checkpoint_files)
179-
self.checkpointed_tasks = 0
180-
self._checkpoint_timer = None
181-
self.checkpoint_mode = config.checkpoint_mode
182-
self.checkpointable_tasks: List[TaskRecord] = []
168+
self.memoizer = Memoizer(memoize=config.app_cache,
169+
checkpoint_mode=config.checkpoint_mode,
170+
checkpoint_files=config.checkpoint_files,
171+
checkpoint_period=config.checkpoint_period)
172+
self.memoizer.run_dir = self.run_dir
173+
self.memoizer.start()
183174

184175
# this must be set before executors are added since add_executors calls
185176
# job_status_poller.add_executors.
@@ -195,17 +186,6 @@ def __init__(self, config: Config) -> None:
195186
self.add_executors(config.executors)
196187
self.add_executors([parsl_internal_executor])
197188

198-
if self.checkpoint_mode == "periodic":
199-
if config.checkpoint_period is None:
200-
raise ConfigurationError("Checkpoint period must be specified with periodic checkpoint mode")
201-
else:
202-
try:
203-
h, m, s = map(int, config.checkpoint_period.split(':'))
204-
except Exception:
205-
raise ConfigurationError("invalid checkpoint_period provided: {0} expected HH:MM:SS".format(config.checkpoint_period))
206-
checkpoint_period = (h * 3600) + (m * 60) + s
207-
self._checkpoint_timer = Timer(self.checkpoint, interval=checkpoint_period, name="Checkpoint")
208-
209189
self.task_count = 0
210190
self.tasks: Dict[int, TaskRecord] = {}
211191
self.submitter_lock = threading.Lock()
@@ -565,18 +545,7 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None:
565545
if not task_record['app_fu'] == future:
566546
logger.error("Internal consistency error: callback future is not the app_fu in task structure, for task {}".format(task_id))
567547

568-
# Cover all checkpointing cases here:
569-
# Do we need to checkpoint now, or queue for later,
570-
# or do nothing?
571-
if self.checkpoint_mode == 'task_exit':
572-
self.checkpoint(tasks=[task_record])
573-
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
574-
with self.checkpoint_lock:
575-
self.checkpointable_tasks.append(task_record)
576-
elif self.checkpoint_mode is None:
577-
pass
578-
else:
579-
raise InternalConsistencyError(f"Invalid checkpoint mode {self.checkpoint_mode}")
548+
self.memoizer.update_checkpoint(task_record)
580549

581550
self.wipe_task(task_id)
582551
return
@@ -1205,13 +1174,7 @@ def cleanup(self) -> None:
12051174

12061175
self.log_task_states()
12071176

1208-
# checkpoint if any valid checkpoint method is specified
1209-
if self.checkpoint_mode is not None:
1210-
self.checkpoint()
1211-
1212-
if self._checkpoint_timer:
1213-
logger.info("Stopping checkpoint timer")
1214-
self._checkpoint_timer.close()
1177+
self.memoizer.close()
12151178

12161179
# Send final stats
12171180
self.usage_tracker.send_end_message()
@@ -1269,66 +1232,8 @@ def cleanup(self) -> None:
12691232
# should still see it.
12701233
logger.info("DFK cleanup complete")
12711234

1272-
def checkpoint(self, tasks: Optional[Sequence[TaskRecord]] = None) -> None:
1273-
"""Checkpoint the dfk incrementally to a checkpoint file.
1274-
1275-
When called, every task that has been completed yet not
1276-
checkpointed is checkpointed to a file.
1277-
1278-
Kwargs:
1279-
- tasks (List of task records) : List of task ids to checkpoint. Default=None
1280-
if set to None, we iterate over all tasks held by the DFK.
1281-
1282-
.. note::
1283-
Checkpointing only works if memoization is enabled
1284-
1285-
Returns:
1286-
Checkpoint dir if checkpoints were written successfully.
1287-
By default the checkpoints are written to the RUNDIR of the current
1288-
run under RUNDIR/checkpoints/tasks.pkl
1289-
"""
1290-
with self.checkpoint_lock:
1291-
if tasks:
1292-
checkpoint_queue = tasks
1293-
else:
1294-
checkpoint_queue = self.checkpointable_tasks
1295-
self.checkpointable_tasks = []
1296-
1297-
checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
1298-
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
1299-
1300-
if not os.path.exists(checkpoint_dir):
1301-
os.makedirs(checkpoint_dir, exist_ok=True)
1302-
1303-
count = 0
1304-
1305-
with open(checkpoint_tasks, 'ab') as f:
1306-
for task_record in checkpoint_queue:
1307-
task_id = task_record['id']
1308-
1309-
app_fu = task_record['app_fu']
1310-
1311-
if app_fu.done() and app_fu.exception() is None:
1312-
hashsum = task_record['hashsum']
1313-
if not hashsum:
1314-
continue
1315-
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
1316-
1317-
# We are using pickle here since pickle dumps to a file in 'ab'
1318-
# mode behave like a incremental log.
1319-
pickle.dump(t, f)
1320-
count += 1
1321-
logger.debug("Task {} checkpointed".format(task_id))
1322-
1323-
self.checkpointed_tasks += count
1324-
1325-
if count == 0:
1326-
if self.checkpointed_tasks == 0:
1327-
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
1328-
else:
1329-
logger.debug("No tasks checkpointed in this pass.")
1330-
else:
1331-
logger.info("Done checkpointing {} tasks".format(count))
1235+
def checkpoint(self) -> None:
1236+
self.memoizer.checkpoint()
13321237

13331238
@staticmethod
13341239
def _log_std_streams(task_record: TaskRecord) -> None:

parsl/dataflow/memoization.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44
import logging
55
import os
66
import pickle
7+
import threading
78
import types
89
from concurrent.futures import Future
910
from functools import lru_cache, singledispatch
10-
from typing import Any, Dict, List, Optional, Sequence
11+
from typing import Any, Dict, List, Literal, Optional, Sequence
1112

1213
import typeguard
1314

1415
from parsl.dataflow.errors import BadCheckpoint
1516
from parsl.dataflow.taskrecord import TaskRecord
17+
from parsl.errors import ConfigurationError, InternalConsistencyError
18+
from parsl.utils import Timer, get_all_checkpoints
1619

1720
logger = logging.getLogger(__name__)
1821

@@ -146,7 +149,13 @@ class Memoizer:
146149
147150
"""
148151

149-
def __init__(self, *, memoize: bool = True, checkpoint_files: Sequence[str]):
152+
run_dir: str
153+
154+
def __init__(self, *,
155+
memoize: bool = True,
156+
checkpoint_files: Sequence[str] | None,
157+
checkpoint_period: Optional[str],
158+
checkpoint_mode: Literal['task_exit', 'periodic', 'dfk_exit', 'manual'] | None):
150159
"""Initialize the memoizer.
151160
152161
KWargs:
@@ -155,6 +164,26 @@ def __init__(self, *, memoize: bool = True, checkpoint_files: Sequence[str]):
155164
"""
156165
self.memoize = memoize
157166

167+
self.checkpointed_tasks = 0
168+
169+
self.checkpoint_lock = threading.Lock()
170+
171+
self.checkpoint_files = checkpoint_files
172+
self.checkpoint_mode = checkpoint_mode
173+
self.checkpoint_period = checkpoint_period
174+
175+
self.checkpointable_tasks: List[TaskRecord] = []
176+
177+
self._checkpoint_timer: Timer | None = None
178+
179+
def start(self) -> None:
180+
if self.checkpoint_files is not None:
181+
checkpoint_files = self.checkpoint_files
182+
elif self.checkpoint_files is None and self.checkpoint_mode is not None:
183+
checkpoint_files = get_all_checkpoints(self.run_dir)
184+
else:
185+
checkpoint_files = []
186+
158187
checkpoint = self.load_checkpoints(checkpoint_files)
159188

160189
if self.memoize:
@@ -164,6 +193,26 @@ def __init__(self, *, memoize: bool = True, checkpoint_files: Sequence[str]):
164193
logger.info("App caching disabled for all apps")
165194
self.memo_lookup_table = {}
166195

196+
if self.checkpoint_mode == "periodic":
197+
if self.checkpoint_period is None:
198+
raise ConfigurationError("Checkpoint period must be specified with periodic checkpoint mode")
199+
else:
200+
try:
201+
h, m, s = map(int, self.checkpoint_period.split(':'))
202+
except Exception:
203+
raise ConfigurationError("invalid checkpoint_period provided: {0} expected HH:MM:SS".format(self.checkpoint_period))
204+
checkpoint_period = (h * 3600) + (m * 60) + s
205+
self._checkpoint_timer = Timer(self.checkpoint, interval=checkpoint_period, name="Checkpoint")
206+
207+
def close(self) -> None:
208+
if self.checkpoint_mode is not None:
209+
logger.info("Making final checkpoint")
210+
self.checkpoint()
211+
212+
if self._checkpoint_timer:
213+
logger.info("Stopping checkpoint timer")
214+
self._checkpoint_timer.close()
215+
167216
def make_hash(self, task: TaskRecord) -> str:
168217
"""Create a hash of the task inputs.
169218
@@ -324,3 +373,78 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str,
324373
return self._load_checkpoints(checkpointDirs)
325374
else:
326375
return {}
376+
377+
def update_checkpoint(self, task_record: TaskRecord) -> None:
378+
if self.checkpoint_mode == 'task_exit':
379+
self.checkpoint(task=task_record)
380+
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
381+
with self.checkpoint_lock:
382+
self.checkpointable_tasks.append(task_record)
383+
elif self.checkpoint_mode is None:
384+
pass
385+
else:
386+
raise InternalConsistencyError(f"Invalid checkpoint mode {self.checkpoint_mode}")
387+
388+
def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
389+
"""Checkpoint the dfk incrementally to a checkpoint file.
390+
391+
When called with no argument, all tasks registered in self.checkpointable_tasks
392+
will be checkpointed. When called with a single TaskRecord argument, that task will be
393+
checkpointed.
394+
395+
By default the checkpoints are written to the RUNDIR of the current
396+
run under RUNDIR/checkpoints/tasks.pkl
397+
398+
Kwargs:
399+
- task (Optional task records) : A task to checkpoint. Default=None, meaning all
400+
tasks registered for checkpointing.
401+
402+
.. note::
403+
Checkpointing only works if memoization is enabled
404+
405+
"""
406+
with self.checkpoint_lock:
407+
408+
if task:
409+
checkpoint_queue = [task]
410+
else:
411+
checkpoint_queue = self.checkpointable_tasks
412+
413+
checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
414+
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
415+
416+
if not os.path.exists(checkpoint_dir):
417+
os.makedirs(checkpoint_dir, exist_ok=True)
418+
419+
count = 0
420+
421+
with open(checkpoint_tasks, 'ab') as f:
422+
for task_record in checkpoint_queue:
423+
task_id = task_record['id']
424+
425+
app_fu = task_record['app_fu']
426+
427+
if app_fu.done() and app_fu.exception() is None:
428+
hashsum = task_record['hashsum']
429+
if not hashsum:
430+
continue
431+
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
432+
433+
# We are using pickle here since pickle dumps to a file in 'ab'
434+
# mode behave like a incremental log.
435+
pickle.dump(t, f)
436+
count += 1
437+
logger.debug("Task {} checkpointed".format(task_id))
438+
439+
self.checkpointed_tasks += count
440+
441+
if count == 0:
442+
if self.checkpointed_tasks == 0:
443+
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
444+
else:
445+
logger.debug("No tasks checkpointed in this pass.")
446+
else:
447+
logger.info("Done checkpointing {} tasks".format(count))
448+
449+
if not task:
450+
self.checkpointable_tasks = []

0 commit comments

Comments
 (0)