diff --git a/parsl/dataflow/dflow.py b/parsl/dataflow/dflow.py index a9a54fe5bc..01ec60f493 100644 --- a/parsl/dataflow/dflow.py +++ b/parsl/dataflow/dflow.py @@ -6,7 +6,6 @@ import inspect import logging import os -import pickle import random import sys import threading @@ -50,7 +49,7 @@ from parsl.monitoring.remote import monitor_wrapper from parsl.process_loggers import wrap_with_logs from parsl.usage_tracking.usage import UsageTracker -from parsl.utils import Timer, get_all_checkpoints, get_std_fname_mode, get_version +from parsl.utils import get_std_fname_mode, get_version logger = logging.getLogger(__name__) @@ -101,8 +100,6 @@ def __init__(self, config: Config) -> None: logger.info("Parsl version: {}".format(get_version())) - self.checkpoint_lock = threading.Lock() - self.usage_tracker = UsageTracker(self) self.usage_tracker.send_start_message() @@ -168,18 +165,12 @@ def __init__(self, config: Config) -> None: self.monitoring_radio.send((MessageType.WORKFLOW_INFO, workflow_info)) - if config.checkpoint_files is not None: - checkpoint_files = config.checkpoint_files - elif config.checkpoint_files is None and config.checkpoint_mode is not None: - checkpoint_files = get_all_checkpoints(self.run_dir) - else: - checkpoint_files = [] - - self.memoizer = Memoizer(memoize=config.app_cache, checkpoint_files=checkpoint_files) - self.checkpointed_tasks = 0 - self._checkpoint_timer = None - self.checkpoint_mode = config.checkpoint_mode - self.checkpointable_tasks: List[TaskRecord] = [] + self.memoizer = Memoizer(memoize=config.app_cache, + checkpoint_mode=config.checkpoint_mode, + checkpoint_files=config.checkpoint_files, + checkpoint_period=config.checkpoint_period) + self.memoizer.run_dir = self.run_dir + self.memoizer.start() # this must be set before executors are added since add_executors calls # job_status_poller.add_executors. @@ -195,17 +186,6 @@ def __init__(self, config: Config) -> None: self.add_executors(config.executors) self.add_executors([parsl_internal_executor]) - if self.checkpoint_mode == "periodic": - if config.checkpoint_period is None: - raise ConfigurationError("Checkpoint period must be specified with periodic checkpoint mode") - else: - try: - h, m, s = map(int, config.checkpoint_period.split(':')) - except Exception: - raise ConfigurationError("invalid checkpoint_period provided: {0} expected HH:MM:SS".format(config.checkpoint_period)) - checkpoint_period = (h * 3600) + (m * 60) + s - self._checkpoint_timer = Timer(self.checkpoint, interval=checkpoint_period, name="Checkpoint") - self.task_count = 0 self.tasks: Dict[int, TaskRecord] = {} self.submitter_lock = threading.Lock() @@ -565,18 +545,7 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None: if not task_record['app_fu'] == future: logger.error("Internal consistency error: callback future is not the app_fu in task structure, for task {}".format(task_id)) - # Cover all checkpointing cases here: - # Do we need to checkpoint now, or queue for later, - # or do nothing? - if self.checkpoint_mode == 'task_exit': - self.checkpoint(tasks=[task_record]) - elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'): - with self.checkpoint_lock: - self.checkpointable_tasks.append(task_record) - elif self.checkpoint_mode is None: - pass - else: - raise InternalConsistencyError(f"Invalid checkpoint mode {self.checkpoint_mode}") + self.memoizer.update_checkpoint(task_record) self.wipe_task(task_id) return @@ -1205,13 +1174,7 @@ def cleanup(self) -> None: self.log_task_states() - # checkpoint if any valid checkpoint method is specified - if self.checkpoint_mode is not None: - self.checkpoint() - - if self._checkpoint_timer: - logger.info("Stopping checkpoint timer") - self._checkpoint_timer.close() + self.memoizer.close() # Send final stats self.usage_tracker.send_end_message() @@ -1269,66 +1232,8 @@ def cleanup(self) -> None: # should still see it. logger.info("DFK cleanup complete") - def checkpoint(self, tasks: Optional[Sequence[TaskRecord]] = None) -> None: - """Checkpoint the dfk incrementally to a checkpoint file. - - When called, every task that has been completed yet not - checkpointed is checkpointed to a file. - - Kwargs: - - tasks (List of task records) : List of task ids to checkpoint. Default=None - if set to None, we iterate over all tasks held by the DFK. - - .. note:: - Checkpointing only works if memoization is enabled - - Returns: - Checkpoint dir if checkpoints were written successfully. - By default the checkpoints are written to the RUNDIR of the current - run under RUNDIR/checkpoints/tasks.pkl - """ - with self.checkpoint_lock: - if tasks: - checkpoint_queue = tasks - else: - checkpoint_queue = self.checkpointable_tasks - self.checkpointable_tasks = [] - - checkpoint_dir = '{0}/checkpoint'.format(self.run_dir) - checkpoint_tasks = checkpoint_dir + '/tasks.pkl' - - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir, exist_ok=True) - - count = 0 - - with open(checkpoint_tasks, 'ab') as f: - for task_record in checkpoint_queue: - task_id = task_record['id'] - - app_fu = task_record['app_fu'] - - if app_fu.done() and app_fu.exception() is None: - hashsum = task_record['hashsum'] - if not hashsum: - continue - t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()} - - # We are using pickle here since pickle dumps to a file in 'ab' - # mode behave like a incremental log. - pickle.dump(t, f) - count += 1 - logger.debug("Task {} checkpointed".format(task_id)) - - self.checkpointed_tasks += count - - if count == 0: - if self.checkpointed_tasks == 0: - logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled") - else: - logger.debug("No tasks checkpointed in this pass.") - else: - logger.info("Done checkpointing {} tasks".format(count)) + def checkpoint(self) -> None: + self.memoizer.checkpoint() @staticmethod def _log_std_streams(task_record: TaskRecord) -> None: diff --git a/parsl/dataflow/memoization.py b/parsl/dataflow/memoization.py index 1abb0d53fa..ae7486d239 100644 --- a/parsl/dataflow/memoization.py +++ b/parsl/dataflow/memoization.py @@ -4,15 +4,18 @@ import logging import os import pickle +import threading import types from concurrent.futures import Future from functools import lru_cache, singledispatch -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Literal, Optional, Sequence import typeguard from parsl.dataflow.errors import BadCheckpoint from parsl.dataflow.taskrecord import TaskRecord +from parsl.errors import ConfigurationError, InternalConsistencyError +from parsl.utils import Timer, get_all_checkpoints logger = logging.getLogger(__name__) @@ -146,7 +149,13 @@ class Memoizer: """ - def __init__(self, *, memoize: bool = True, checkpoint_files: Sequence[str]): + run_dir: str + + def __init__(self, *, + memoize: bool = True, + checkpoint_files: Sequence[str] | None, + checkpoint_period: Optional[str], + checkpoint_mode: Literal['task_exit', 'periodic', 'dfk_exit', 'manual'] | None): """Initialize the memoizer. KWargs: @@ -155,6 +164,26 @@ def __init__(self, *, memoize: bool = True, checkpoint_files: Sequence[str]): """ self.memoize = memoize + self.checkpointed_tasks = 0 + + self.checkpoint_lock = threading.Lock() + + self.checkpoint_files = checkpoint_files + self.checkpoint_mode = checkpoint_mode + self.checkpoint_period = checkpoint_period + + self.checkpointable_tasks: List[TaskRecord] = [] + + self._checkpoint_timer: Timer | None = None + + def start(self) -> None: + if self.checkpoint_files is not None: + checkpoint_files = self.checkpoint_files + elif self.checkpoint_files is None and self.checkpoint_mode is not None: + checkpoint_files = get_all_checkpoints(self.run_dir) + else: + checkpoint_files = [] + checkpoint = self.load_checkpoints(checkpoint_files) if self.memoize: @@ -164,6 +193,26 @@ def __init__(self, *, memoize: bool = True, checkpoint_files: Sequence[str]): logger.info("App caching disabled for all apps") self.memo_lookup_table = {} + if self.checkpoint_mode == "periodic": + if self.checkpoint_period is None: + raise ConfigurationError("Checkpoint period must be specified with periodic checkpoint mode") + else: + try: + h, m, s = map(int, self.checkpoint_period.split(':')) + except Exception: + raise ConfigurationError("invalid checkpoint_period provided: {0} expected HH:MM:SS".format(self.checkpoint_period)) + checkpoint_period = (h * 3600) + (m * 60) + s + self._checkpoint_timer = Timer(self.checkpoint, interval=checkpoint_period, name="Checkpoint") + + def close(self) -> None: + if self.checkpoint_mode is not None: + logger.info("Making final checkpoint") + self.checkpoint() + + if self._checkpoint_timer: + logger.info("Stopping checkpoint timer") + self._checkpoint_timer.close() + def make_hash(self, task: TaskRecord) -> str: """Create a hash of the task inputs. @@ -324,3 +373,78 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, return self._load_checkpoints(checkpointDirs) else: return {} + + def update_checkpoint(self, task_record: TaskRecord) -> None: + if self.checkpoint_mode == 'task_exit': + self.checkpoint(task=task_record) + elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'): + with self.checkpoint_lock: + self.checkpointable_tasks.append(task_record) + elif self.checkpoint_mode is None: + pass + else: + raise InternalConsistencyError(f"Invalid checkpoint mode {self.checkpoint_mode}") + + def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None: + """Checkpoint the dfk incrementally to a checkpoint file. + + When called with no argument, all tasks registered in self.checkpointable_tasks + will be checkpointed. When called with a single TaskRecord argument, that task will be + checkpointed. + + By default the checkpoints are written to the RUNDIR of the current + run under RUNDIR/checkpoints/tasks.pkl + + Kwargs: + - task (Optional task records) : A task to checkpoint. Default=None, meaning all + tasks registered for checkpointing. + + .. note:: + Checkpointing only works if memoization is enabled + + """ + with self.checkpoint_lock: + + if task: + checkpoint_queue = [task] + else: + checkpoint_queue = self.checkpointable_tasks + + checkpoint_dir = '{0}/checkpoint'.format(self.run_dir) + checkpoint_tasks = checkpoint_dir + '/tasks.pkl' + + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir, exist_ok=True) + + count = 0 + + with open(checkpoint_tasks, 'ab') as f: + for task_record in checkpoint_queue: + task_id = task_record['id'] + + app_fu = task_record['app_fu'] + + if app_fu.done() and app_fu.exception() is None: + hashsum = task_record['hashsum'] + if not hashsum: + continue + t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()} + + # We are using pickle here since pickle dumps to a file in 'ab' + # mode behave like a incremental log. + pickle.dump(t, f) + count += 1 + logger.debug("Task {} checkpointed".format(task_id)) + + self.checkpointed_tasks += count + + if count == 0: + if self.checkpointed_tasks == 0: + logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled") + else: + logger.debug("No tasks checkpointed in this pass.") + else: + logger.info("Done checkpointing {} tasks".format(count)) + + if not task: + self.checkpointable_tasks = []