Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 75 additions & 37 deletions axlearn/common/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,51 +147,60 @@ def filter_for_validation(structure):
)


def _upload_dir(src_dir_handle: tempfile.TemporaryDirectory, *, dst_dir: str):
"""Upload a directory (non-recursively) from a temporary dir to dst_dir.

Temporary dir will be deleted after the upload is complete.
"""
src_dir = src_dir_handle.name
src_files = fs.listdir(src_dir)
# src_files will be empty if there are no tf savables (i.e., don't have any tf state to save).
# In this case, do not create empty dst_dirs.
if len(src_files):
fs.makedirs(dst_dir)
for item in src_files:
src_file = os.path.join(src_dir, item)
dst_file = os.path.join(dst_dir, item)
assert not fs.isdir(src_file)
fs.copy(src_file, dst_file, overwrite=True)
src_dir_handle.cleanup()
# def _upload_dir(src_dir_handle: tempfile.TemporaryDirectory, *, dst_dir: str):
# """Upload a directory (non-recursively) from a temporary dir to dst_dir.

# Temporary dir will be deleted after the upload is complete.
# """
# src_dir = src_dir_handle.name
# src_files = fs.listdir(src_dir)
# # src_files will be empty if there are no tf savables (i.e., don't have any tf state to save).
# # In this case, do not create empty dst_dirs.
# if len(src_files):
# fs.makedirs(dst_dir)
# for item in src_files:
# src_file = os.path.join(src_dir, item)
# dst_file = os.path.join(dst_dir, item)
# assert not fs.isdir(src_file)
# fs.copy(src_file, dst_file, overwrite=True)
# src_dir_handle.cleanup()


# pylint: disable=redefined-builtin
def async_save_tf_savables(
value_map: Nested[Any], *, executor: futures.ThreadPoolExecutor, dir: str
) -> futures.Future:
def async_save_tf_savables(value_map: Nested[Any], *, dir: str) -> list[tf.train.Checkpoint]:
"""Asynchronously saves TF savables from `value_map` into `dir`.

When this call returns, `value_map` can be safely mutated, but saving to `dir` will not
complete unless the returned future is set.
"""
# pylint: disable-next=consider-using-with
f = tempfile.TemporaryDirectory()
tf_ckpts = []
for path, value in utils.flatten_items(value_map):
tf_checkpoint = tf.train.Checkpoint(value)
tf_checkpoint.write(os.path.join(f.name, path))
return executor.submit(_upload_dir, f, dst_dir=dir)
tf_checkpoint.write(os.path.join(dir, path), tf.train.CheckpointOptions(enable_async=True))
tf_ckpts.append(tf_checkpoint)
return tf_ckpts


# # pylint: disable-next=redefined-builtin
# def restore_tf_savables(value_map: Nested[Any], *, dir: str) -> Nested[Any]:
# """Restores TF savables from `dir` into `value_map` in-place."""

# for path, value in utils.flatten_items(value_map):
# tf_checkpoint = tf.train.Checkpoint(value)
# tf_checkpoint.read(os.path.join(dir, path))

# return value_map

# pylint: disable-next=redefined-builtin
def restore_tf_savables(value_map: Nested[Any], *, dir: str) -> Nested[Any]:
def async_restore_tf_savables(value_map: Nested[Any], *, dir: str) -> list[tf.train.Checkpoint]:
"""Restores TF savables from `dir` into `value_map` in-place."""

ckpts = []
for path, value in utils.flatten_items(value_map):
tf_checkpoint = tf.train.Checkpoint(value)
tf_checkpoint.read(os.path.join(dir, path))
tf_checkpoint.read(os.path.join(dir, path), tf.train.CheckpointOptions(enable_async=True))
ckpts.append(tf_checkpoint)

return value_map
return ckpts


@runtime_checkable
Expand Down Expand Up @@ -517,20 +526,28 @@ def save_to_dir(
# Wait for directory and index creation.
multihost_utils.sync_global_devices(ckpt_dir)
# Each worker writes its tf checkpoints under a different path.
save_tf_future = async_save_tf_savables(

# save_tf_future = async_save_tf_savables(
# spec.tf_ckpt_map,
# executor=self._executor,
# dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}"),
# )

tf_ckpts = async_save_tf_savables(
spec.tf_ckpt_map,
executor=self._executor,
dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}"),
)

maybe_save_python_savables(
spec.python_ckpt_map, dir=os.path.join(ckpt_dir, f"python_{jax.process_index()}")
)

def commit():
on_commit_callback(ckpt_dir=ckpt_dir, index=spec.index)
logging.info(
"Serialization of %s completed in %s seconds.",
ckpt_dir,
"Checkpointer[%s] saving total time at step %s: %s seconds.",
os.path.basename(os.path.dirname(ckpt_dir)),
step,
time.perf_counter() - start_time,
)

Expand All @@ -542,7 +559,7 @@ def commit():
spec.gda_values,
spec.tensorstore_specs,
on_commit_callback=commit,
additional_futures=[save_tf_future],
additional_futures=[self._executor.submit(ckpt.sync) for ckpt in tf_ckpts],
)

def wait_until_finished(self):
Expand All @@ -561,13 +578,23 @@ def restore_from_dir(
check_state_structure(
read_index_file(ckpt_dir), target_structure=spec.index, validation=validation
)
restore_tf_savables(
# restore_tf_savables(
# spec.tf_ckpt_map, dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}")
# )
tf_ckpts = async_restore_tf_savables(
spec.tf_ckpt_map, dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}")
)
maybe_restore_python_savables(
spec.python_ckpt_map, dir=os.path.join(ckpt_dir, f"python_{jax.process_index()}")
)
return self._restore_tensorstore_state(state, ckpt_dir=ckpt_dir, spec=spec)
#return self._restore_tensorstore_state(state, ckpt_dir=ckpt_dir, spec=spec)

restored_state = self._restore_tensorstore_state(state, ckpt_dir=ckpt_dir, spec=spec)

for tf_ckpt in tf_ckpts:
tf_ckpt.sync()

return restored_state

def _restore_tensorstore_state(
self, state, *, ckpt_dir: str, spec: CheckpointSpec, sync: bool = True
Expand Down Expand Up @@ -822,7 +849,8 @@ def __enter__(self):
This is typically invoked prior to the training loop.
"""
if self._within_context:
raise ValueError("Already in a context.")
logging.warn("Already in a context.")
#raise ValueError("Already in a context.")
self._within_context = True

def __exit__(
Expand Down Expand Up @@ -1042,6 +1070,7 @@ def save(
In addition to behavior in `BaseCheckpointer`, saving only happens if the configured
checkpoint policy returns True for the given step and evaler summaries.
"""
start_time = time.perf_counter()
if not self._save_policy(step=step, evaler_summaries=(evaler_summaries or {})):
return
if step < 0 or step >= 10**8:
Expand All @@ -1057,6 +1086,14 @@ def save(
ckpt_dir=ckpt_dir,
action=CheckpointerAction.SAVE,
)

if jax.process_index() == 0:
logging.info(
"Checkpointer[%s] saving stall time at step %s: %s seconds.",
os.path.basename(os.path.dirname(ckpt_dir)),
step,
time.perf_counter() - start_time,
)

def _run_garbage_collection(self):
"""Runs one round of garbage collection of past checkpoints.
Expand Down Expand Up @@ -1100,7 +1137,8 @@ def _run_garbage_collection(self):

# For subsequent dirs, non-committed dirs are gc'ed, and committed dirs are kept according
# to keep_n_steps. Note that we iterate in order of oldest to newest.
last_kept_step = float("-inf")
#last_kept_step = float("-inf")
last_kept_step = 0
for saved_dir in reversed(dirs[len(remaining_dirs) + len(gc_dirs) :]):
saved_step = parse_step_from_dir(saved_dir)
if not (
Expand Down
Loading