Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 11 additions & 106 deletions parsl/dataflow/dflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import inspect
import logging
import os
import pickle
import random
import sys
import threading
Expand Down Expand Up @@ -50,7 +49,7 @@
from parsl.monitoring.remote import monitor_wrapper
from parsl.process_loggers import wrap_with_logs
from parsl.usage_tracking.usage import UsageTracker
from parsl.utils import Timer, get_all_checkpoints, get_std_fname_mode, get_version
from parsl.utils import get_std_fname_mode, get_version

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,8 +100,6 @@ def __init__(self, config: Config) -> None:

logger.info("Parsl version: {}".format(get_version()))

self.checkpoint_lock = threading.Lock()

self.usage_tracker = UsageTracker(self)
self.usage_tracker.send_start_message()

Expand Down Expand Up @@ -168,18 +165,12 @@ def __init__(self, config: Config) -> None:
self.monitoring_radio.send((MessageType.WORKFLOW_INFO,
workflow_info))

if config.checkpoint_files is not None:
checkpoint_files = config.checkpoint_files
elif config.checkpoint_files is None and config.checkpoint_mode is not None:
checkpoint_files = get_all_checkpoints(self.run_dir)
else:
checkpoint_files = []

self.memoizer = Memoizer(memoize=config.app_cache, checkpoint_files=checkpoint_files)
self.checkpointed_tasks = 0
self._checkpoint_timer = None
self.checkpoint_mode = config.checkpoint_mode
self.checkpointable_tasks: List[TaskRecord] = []
self.memoizer = Memoizer(memoize=config.app_cache,
checkpoint_mode=config.checkpoint_mode,
checkpoint_files=config.checkpoint_files,
checkpoint_period=config.checkpoint_period)
self.memoizer.run_dir = self.run_dir
self.memoizer.start()

# this must be set before executors are added since add_executors calls
# job_status_poller.add_executors.
Expand All @@ -195,17 +186,6 @@ def __init__(self, config: Config) -> None:
self.add_executors(config.executors)
self.add_executors([parsl_internal_executor])

if self.checkpoint_mode == "periodic":
if config.checkpoint_period is None:
raise ConfigurationError("Checkpoint period must be specified with periodic checkpoint mode")
else:
try:
h, m, s = map(int, config.checkpoint_period.split(':'))
except Exception:
raise ConfigurationError("invalid checkpoint_period provided: {0} expected HH:MM:SS".format(config.checkpoint_period))
checkpoint_period = (h * 3600) + (m * 60) + s
self._checkpoint_timer = Timer(self.checkpoint, interval=checkpoint_period, name="Checkpoint")

self.task_count = 0
self.tasks: Dict[int, TaskRecord] = {}
self.submitter_lock = threading.Lock()
Expand Down Expand Up @@ -565,18 +545,7 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None:
if not task_record['app_fu'] == future:
logger.error("Internal consistency error: callback future is not the app_fu in task structure, for task {}".format(task_id))

# Cover all checkpointing cases here:
# Do we need to checkpoint now, or queue for later,
# or do nothing?
if self.checkpoint_mode == 'task_exit':
self.checkpoint(tasks=[task_record])
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
with self.checkpoint_lock:
self.checkpointable_tasks.append(task_record)
elif self.checkpoint_mode is None:
pass
else:
raise InternalConsistencyError(f"Invalid checkpoint mode {self.checkpoint_mode}")
self.memoizer.update_checkpoint(task_record)

self.wipe_task(task_id)
return
Expand Down Expand Up @@ -1205,13 +1174,7 @@ def cleanup(self) -> None:

self.log_task_states()

# checkpoint if any valid checkpoint method is specified
if self.checkpoint_mode is not None:
self.checkpoint()

if self._checkpoint_timer:
logger.info("Stopping checkpoint timer")
self._checkpoint_timer.close()
self.memoizer.close()

# Send final stats
self.usage_tracker.send_end_message()
Expand Down Expand Up @@ -1269,66 +1232,8 @@ def cleanup(self) -> None:
# should still see it.
logger.info("DFK cleanup complete")

def checkpoint(self, tasks: Optional[Sequence[TaskRecord]] = None) -> None:
"""Checkpoint the dfk incrementally to a checkpoint file.

When called, every task that has been completed yet not
checkpointed is checkpointed to a file.

Kwargs:
- tasks (List of task records) : List of task ids to checkpoint. Default=None
if set to None, we iterate over all tasks held by the DFK.

.. note::
Checkpointing only works if memoization is enabled

Returns:
Checkpoint dir if checkpoints were written successfully.
By default the checkpoints are written to the RUNDIR of the current
run under RUNDIR/checkpoints/tasks.pkl
"""
with self.checkpoint_lock:
if tasks:
checkpoint_queue = tasks
else:
checkpoint_queue = self.checkpointable_tasks
self.checkpointable_tasks = []

checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'

if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)

count = 0

with open(checkpoint_tasks, 'ab') as f:
for task_record in checkpoint_queue:
task_id = task_record['id']

app_fu = task_record['app_fu']

if app_fu.done() and app_fu.exception() is None:
hashsum = task_record['hashsum']
if not hashsum:
continue
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}

# We are using pickle here since pickle dumps to a file in 'ab'
# mode behave like a incremental log.
pickle.dump(t, f)
count += 1
logger.debug("Task {} checkpointed".format(task_id))

self.checkpointed_tasks += count

if count == 0:
if self.checkpointed_tasks == 0:
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
else:
logger.debug("No tasks checkpointed in this pass.")
else:
logger.info("Done checkpointing {} tasks".format(count))
def checkpoint(self) -> None:
self.memoizer.checkpoint()

@staticmethod
def _log_std_streams(task_record: TaskRecord) -> None:
Expand Down
128 changes: 126 additions & 2 deletions parsl/dataflow/memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
import logging
import os
import pickle
import threading
import types
from concurrent.futures import Future
from functools import lru_cache, singledispatch
from typing import Any, Dict, List, Optional, Sequence
from typing import Any, Dict, List, Literal, Optional, Sequence

import typeguard

from parsl.dataflow.errors import BadCheckpoint
from parsl.dataflow.taskrecord import TaskRecord
from parsl.errors import ConfigurationError, InternalConsistencyError
from parsl.utils import Timer, get_all_checkpoints

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -146,7 +149,13 @@ class Memoizer:

"""

def __init__(self, *, memoize: bool = True, checkpoint_files: Sequence[str]):
run_dir: str

Comment on lines +152 to +153
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make run_dir an optional instance member? For example, if checkpoint_files and checkpoint_mode are the right state and the consumer forgets to externally set it before calling .start() ... seems like a potential traceback?

As opposed to, say, choosing a default value, or ensuring that the mistake is caught at instantiation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's two initializations in the parsl style of doing configuration/plugins: the user creates the object, and so user-specified parameters go into __init__. Secondly, there's the framework that is the DFK that invents runtime stuff like the rundir. By the time this happens, __init__ is long in the past - perhaps in a different source file and well before the DFK even begins to exist.

So it's probably right to make this be mandatory parameter of start().

The plugin style in general doesn't do well in differentiating between "uninitialised, configured" (pre-start) plugins and "intialised" (post-start) objects. That differentiation was something that was forced with monitoring plugin radios, because the "configured, uninitialized" state needs to be moved around over the network. And if I was going to re-implement this object model, I'd probably try to make that separation happen everywhere. But that's not how Parsl is right now.

def __init__(self, *,
memoize: bool = True,
checkpoint_files: Sequence[str] | None,
checkpoint_period: Optional[str],
checkpoint_mode: Literal['task_exit', 'periodic', 'dfk_exit', 'manual'] | None):
"""Initialize the memoizer.

KWargs:
Expand All @@ -155,6 +164,26 @@ def __init__(self, *, memoize: bool = True, checkpoint_files: Sequence[str]):
"""
self.memoize = memoize

self.checkpointed_tasks = 0

self.checkpoint_lock = threading.Lock()

self.checkpoint_files = checkpoint_files
self.checkpoint_mode = checkpoint_mode
self.checkpoint_period = checkpoint_period

self.checkpointable_tasks: List[TaskRecord] = []

self._checkpoint_timer: Timer | None = None

def start(self) -> None:
if self.checkpoint_files is not None:
checkpoint_files = self.checkpoint_files
elif self.checkpoint_files is None and self.checkpoint_mode is not None:
checkpoint_files = get_all_checkpoints(self.run_dir)
else:
checkpoint_files = []

checkpoint = self.load_checkpoints(checkpoint_files)

if self.memoize:
Expand All @@ -164,6 +193,26 @@ def __init__(self, *, memoize: bool = True, checkpoint_files: Sequence[str]):
logger.info("App caching disabled for all apps")
self.memo_lookup_table = {}

if self.checkpoint_mode == "periodic":
if self.checkpoint_period is None:
raise ConfigurationError("Checkpoint period must be specified with periodic checkpoint mode")
else:
Comment on lines +197 to +199
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the checkpoint_period attribute can be changed later? In some sense "too bad," as this check is nominally not needed after instantiation or comes much later than would be ideal (i.e., "immediately").

Meanwhile, stylistically, given the [no-return] semantic of the raise, could remove the else: and unindent the block.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably weird to change it at all, and the validation can go into the __init__ of Memoizer I think. But this PR was trying to move blocks as big as possible and the subsequent timer start can't happen until start time.

try:
h, m, s = map(int, self.checkpoint_period.split(':'))
except Exception:
raise ConfigurationError("invalid checkpoint_period provided: {0} expected HH:MM:SS".format(self.checkpoint_period))
Comment on lines +202 to +203
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any sense in raise ... from ...? Given the not None check and typing, .split() is guaranteed to work, and the only failures would be an invalid int or not enough items for the destructure. (If interested, could also imit the number of splits to 2.)

It would be nice to ensure that the this field is valid much earlier. Perhaps in a setter-style setup?

checkpoint_period = (h * 3600) + (m * 60) + s
self._checkpoint_timer = Timer(self.checkpoint, interval=checkpoint_period, name="Checkpoint")

def close(self) -> None:
if self.checkpoint_mode is not None:
logger.info("Making final checkpoint")
self.checkpoint()

if self._checkpoint_timer:
logger.info("Stopping checkpoint timer")
self._checkpoint_timer.close()

def make_hash(self, task: TaskRecord) -> str:
"""Create a hash of the task inputs.

Expand Down Expand Up @@ -324,3 +373,78 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str,
return self._load_checkpoints(checkpointDirs)
else:
return {}

def update_checkpoint(self, task_record: TaskRecord) -> None:
if self.checkpoint_mode == 'task_exit':
self.checkpoint(task=task_record)
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
with self.checkpoint_lock:
self.checkpointable_tasks.append(task_record)
elif self.checkpoint_mode is None:
pass
else:
raise InternalConsistencyError(f"Invalid checkpoint mode {self.checkpoint_mode}")

def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
"""Checkpoint the dfk incrementally to a checkpoint file.

When called with no argument, all tasks registered in self.checkpointable_tasks
will be checkpointed. When called with a single TaskRecord argument, that task will be
checkpointed.

By default the checkpoints are written to the RUNDIR of the current
run under RUNDIR/checkpoints/tasks.pkl

Kwargs:
- task (Optional task records) : A task to checkpoint. Default=None, meaning all
tasks registered for checkpointing.

.. note::
Checkpointing only works if memoization is enabled

"""
with self.checkpoint_lock:

if task:
checkpoint_queue = [task]
else:
checkpoint_queue = self.checkpointable_tasks

checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'

if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)

count = 0

with open(checkpoint_tasks, 'ab') as f:
for task_record in checkpoint_queue:
task_id = task_record['id']

app_fu = task_record['app_fu']

if app_fu.done() and app_fu.exception() is None:
hashsum = task_record['hashsum']
if not hashsum:
continue
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}

# We are using pickle here since pickle dumps to a file in 'ab'
# mode behave like a incremental log.
pickle.dump(t, f)
count += 1
logger.debug("Task {} checkpointed".format(task_id))

self.checkpointed_tasks += count

if count == 0:
if self.checkpointed_tasks == 0:
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
else:
logger.debug("No tasks checkpointed in this pass.")
else:
logger.info("Done checkpointing {} tasks".format(count))

if not task:
self.checkpointable_tasks = []