Skip to content

Commit 8c699b3

Browse files
authored
Upgrade orbax checkpointer to 0.11.15 (#1255)
* Support Orbax 0.11.14 and newer * upgrade orbax to 0.11.14 * move orbax dependency to core group * Revert "move orbax dependency to core group" This reverts commit ef2566a. * upgrade orbax to 0.11.15
1 parent e683576 commit 8c699b3

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

axlearn/common/checkpointer_orbax_emergency.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def save(
751751
# including step time in total blocking time.
752752
start_t = time.perf_counter()
753753
self._get_tensor_manager(state_with_tensors).save(
754-
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors)
754+
step=step, args=ocp.args.Composite(state=ocp.args.PyTreeSave(item=state_with_tensors))
755755
)
756756
time_diff = time.perf_counter() - start_t
757757
if self._composite_save_policy(step=step, evaler_summaries=self._eval_summaries):
@@ -808,7 +808,9 @@ def restore(
808808

809809
restored_state_with_tensors = tensor_manager.restore(
810810
step=step,
811-
args=ocp.args.PyTreeRestore(item=self._get_abstract_state(state_with_tensors)),
811+
args=ocp.args.Composite(
812+
state=ocp.args.PyTreeRestore(item=self._get_abstract_state(state_with_tensors))
813+
),
812814
)
813815
# Merge non-tensor and tensor states by replacing leaves of the non-tensor Pytree with the
814816
# not-None leaves of the tensor Pytree.
@@ -826,7 +828,8 @@ def wait_until_finished(self):
826828
self._non_tensor_manager.wait_until_finished()
827829
self._tensor_manager.wait_until_finished()
828830

829-
def stop(self):
831+
def stop(self, *, has_exception: bool = False):
830832
"""See `BaseCheckpointer.stop` for details."""
831-
self._non_tensor_manager.stop()
832-
self._tensor_manager.close()
833+
self._non_tensor_manager.stop(has_exception=has_exception)
834+
if self._tensor_manager:
835+
self._tensor_manager.close()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ mmau = [
154154
# Orbax checkpointing.
155155
orbax = [
156156
"humanize==4.10.0",
157-
"orbax-checkpoint==0.11.1",
157+
"orbax-checkpoint==0.11.15",
158158
]
159159
# Audio dependencies.
160160
audio = [

0 commit comments

Comments
 (0)