Skip to content

Commit fe2d7be

Browse files
authored
Harden msgpack checkpointing, AsyncManager, and Optax step API (#69)
* Refactor checkpoint handling and optimize step method in optimizer * Enhance checkpoint module with improved error handling, async support, and new tests - Added detailed error messages for dict key collisions and namedtuple reconstruction failures. - Implemented async save functionality with thread safety in AsyncManager. - Introduced tests for AsyncManager, including concurrent saves and memory handling. - Improved chunking logic with warnings for inefficient configurations. - Added validation for maximum data size in msgpack restoration. * Bump version to 0.1.8
1 parent 9f86291 commit fe2d7be

File tree

5 files changed

+471
-53
lines changed

5 files changed

+471
-53
lines changed

braintools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515

1616

17-
__version__ = "0.1.7"
17+
__version__ = "0.1.8"
1818
__version_info__ = tuple(map(int, __version__.split(".")))
1919

2020
from . import conn

braintools/file/_msg_checkpoint.py

Lines changed: 131 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
"""
2020

2121
import enum
22+
import multiprocessing
2223
import os
2324
import sys
2425
import threading
2526
import warnings
27+
import shutil
2628
from concurrent.futures import thread
2729
from contextlib import contextmanager
2830
from typing import Any, Callable, Dict, List, Literal, Optional
@@ -285,10 +287,28 @@ def _restore_list(xs, state_dict: Dict[str, Any], mismatch: MismatchMode = 'erro
285287
def _dict_state_dict(xs: Dict[str, Any]) -> Dict[str, Any]:
286288
if isinstance(xs, brainstate.util.FlattedDict):
287289
xs = xs.to_nest()
288-
str_keys = set(str(k) for k in xs.keys())
290+
291+
try:
292+
str_keys = set(str(k) for k in xs.keys())
293+
except TypeError as e:
294+
raise ValueError(f'Dict contains unhashable keys: {e}') from e
295+
289296
if len(str_keys) != len(xs):
290-
raise ValueError('Dict keys do not have a unique string representation: '
291-
f'{str_keys} vs given: {xs}')
297+
# Provide detailed error showing which keys collide
298+
str_to_keys = {}
299+
for k in xs.keys():
300+
sk = str(k)
301+
if sk in str_to_keys:
302+
str_to_keys[sk].append(k)
303+
else:
304+
str_to_keys[sk] = [k]
305+
306+
collisions = {sk: keys for sk, keys in str_to_keys.items() if len(keys) > 1}
307+
raise ValueError(
308+
f'Dict keys do not have a unique string representation. '
309+
f'Collisions: {collisions}'
310+
)
311+
292312
return {
293313
str(key): msgpack_to_state_dict(value)
294314
for key, value in xs.items()
@@ -347,7 +367,13 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any], mismatch: MismatchMode =
347367
# Keep original value if field is missing from state_dict
348368
fields[field] = getattr(xs, field)
349369

350-
return type(xs)(**fields)
370+
try:
371+
return type(xs)(**fields)
372+
except TypeError as e:
373+
raise TypeError(
374+
f"Failed to reconstruct namedtuple {type(xs).__name__} at path {current_path()}: {e}. "
375+
f"Ensure the namedtuple class definition is available."
376+
) from e
351377

352378

353379
msgpack_register_serialization(
@@ -393,17 +419,19 @@ def _brainstate_dict_state(x: brainstate.State) -> Dict[str, Any]:
393419

394420

395421
def _restore_brainstate(x: brainstate.State, state_dict: Dict, mismatch: MismatchMode = 'error') -> brainstate.State:
396-
"""Restore brainstate.State from state dict.
422+
"""Restore brainstate.State from state dict by mutating its value in-place.
397423
398-
Creates a new State object with the restored value instead of mutating the original.
424+
This function mutates the State object's value attribute rather than creating
425+
a new State object, which is consistent with how State objects are used
426+
throughout the codebase.
399427
400428
Args:
401-
x: Template State object
429+
x: Template State object to restore (will be mutated)
402430
state_dict: Serialized state dictionary
403431
mismatch: How to handle mismatches
404432
405433
Returns:
406-
A new State object with the restored value
434+
The same State object with restored value (for chaining)
407435
"""
408436
x.value = msgpack_from_state_dict(x.value, state_dict, mismatch=mismatch)
409437
return x
@@ -538,7 +566,21 @@ def _dict_to_tuple(dct):
538566

539567
def _chunk(arr) -> Dict[str, Any]:
540568
"""Convert array to a canonical dictionary of chunked arrays."""
541-
chunksize = max(1, int(MAX_CHUNK_SIZE / arr.dtype.itemsize))
569+
itemsize = arr.dtype.itemsize
570+
if itemsize == 0:
571+
raise ValueError(f"Cannot chunk array with zero itemsize dtype: {arr.dtype}")
572+
573+
chunksize = max(1, int(MAX_CHUNK_SIZE / itemsize))
574+
575+
# Warn if chunking is very inefficient
576+
if chunksize < 1000 and arr.size > 1000000:
577+
warnings.warn(
578+
f"Array chunking may be inefficient: dtype={arr.dtype}, "
579+
f"itemsize={itemsize}, chunksize={chunksize}. "
580+
f"Consider using a different dtype or smaller arrays.",
581+
UserWarning
582+
)
583+
542584
data = {'__msgpack_chunked_array__': True,
543585
'shape': _tuple_to_dict(arr.shape)}
544586
flatarr = arr.reshape(-1)
@@ -613,22 +655,33 @@ def _msgpack_serialize(pytree, in_place: bool = False) -> bytes:
613655
return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)
614656

615657

616-
def _msgpack_restore(encoded_pytree: bytes):
658+
def _msgpack_restore(encoded_pytree: bytes, max_size: Optional[int] = None):
617659
"""Restore data structure from bytes in msgpack format.
618660
619661
Low-level function that only supports python trees with array leaves,
620662
for custom objects use `from_bytes`.
621663
622664
Args:
623665
encoded_pytree: msgpack-encoded bytes of python tree.
666+
max_size: Maximum allowed size in bytes (default: 10GB)
624667
625668
Returns:
626669
Python tree of dict, list, tuple with python primitive
627670
and array leaves.
628671
629672
Raises:
673+
ValueError: If data exceeds max_size
630674
InvalidCheckpointPath: If the msgpack data is corrupt or invalid.
631675
"""
676+
if max_size is None:
677+
max_size = 10 * (1024 ** 3) # 10GB default
678+
679+
if len(encoded_pytree) > max_size:
680+
raise ValueError(
681+
f"Checkpoint data too large: {len(encoded_pytree)} bytes "
682+
f"exceeds maximum {max_size} bytes"
683+
)
684+
632685
try:
633686
state_dict = msgpack.unpackb(encoded_pytree, ext_hook=_msgpack_ext_unpack, raw=False)
634687
except (msgpack.exceptions.ExtraData,
@@ -684,14 +737,28 @@ class _EmptyNode:
684737
def _rename_fn(src, dst, overwrite=False):
685738
"""Rename file from src to dst, with overwrite control.
686739
740+
Uses os.replace() for atomic rename on both Unix and Windows (Python 3.3+).
741+
687742
Args:
688743
src: Source file path
689744
dst: Destination file path
690745
overwrite: If False, raise AlreadyExistsError when dst exists
691746
"""
692-
if os.path.exists(src):
693-
if os.path.exists(dst) and not overwrite:
694-
raise AlreadyExistsError(dst)
747+
if not os.path.exists(src):
748+
return
749+
750+
if not overwrite and os.path.exists(dst):
751+
raise AlreadyExistsError(dst)
752+
753+
try:
754+
os.replace(src, dst) # Atomic on both platforms
755+
except OSError:
756+
# Fallback for edge cases
757+
if overwrite and os.path.exists(dst):
758+
try:
759+
os.remove(dst)
760+
except OSError:
761+
pass
695762
os.rename(src, dst)
696763

697764

@@ -712,6 +779,7 @@ def __init__(self, max_workers: int = 1):
712779
self.executor = thread.ThreadPoolExecutor(max_workers=max_workers)
713780
self.save_future = None
714781
self._closed = False
782+
self._lock = threading.Lock()
715783

716784
def __enter__(self):
717785
"""Enter context manager."""
@@ -762,8 +830,10 @@ def save_async(self, task: Callable[[], Any]):
762830
"""
763831
if self._closed:
764832
raise RuntimeError("Cannot save with a closed AsyncManager")
765-
self.wait_previous_save()
766-
self.save_future = self.executor.submit(task) # type: ignore
833+
834+
with self._lock:
835+
self.wait_previous_save()
836+
self.save_future = self.executor.submit(task) # type: ignore
767837

768838

769839
def _save_main_ckpt_file(
@@ -855,11 +925,25 @@ def msgpack_save(
855925
if async_manager:
856926
async_manager.wait_previous_save()
857927

858-
if os.path.dirname(filename):
859-
os.makedirs(os.path.dirname(filename), exist_ok=True)
928+
dirname = os.path.dirname(filename)
929+
if dirname:
930+
try:
931+
os.makedirs(dirname, exist_ok=True)
932+
except OSError as e:
933+
raise OSError(f"Cannot create directory {dirname}: {e}") from e
860934
if not overwrite and os.path.exists(filename):
861935
raise InvalidCheckpointPath(filename)
862936

937+
# Warn on Windows if path exceeds MAX_PATH limitation
938+
if sys.platform == 'win32':
939+
abs_path = os.path.abspath(filename)
940+
if len(abs_path) > 260:
941+
warnings.warn(
942+
f"Path length {len(abs_path)} exceeds Windows MAX_PATH (260). "
943+
"Consider using shorter paths or enabling long path support.",
944+
UserWarning
945+
)
946+
863947
if isinstance(target, brainstate.util.FlattedDict):
864948
target = target.to_nest()
865949
target = _to_bytes(target)
@@ -925,23 +1009,36 @@ def msgpack_load(
9251009
if parallel and fp.seekable():
9261010
buf_size = 128 << 20 # 128M buffer.
9271011
num_chunks = (file_size + buf_size - 1) // buf_size # Ceiling division
928-
checkpoint_contents = bytearray(file_size)
929-
930-
def read_chunk(i):
931-
# NOTE: We have to re-open the file to read each chunk, otherwise the
932-
# parallelism has no effect. But we could reuse the file pointers
933-
# within each thread.
934-
with open(filename, 'rb') as f:
935-
f.seek(i * buf_size)
936-
buf = f.read(buf_size)
937-
if buf:
938-
checkpoint_contents[i * buf_size:i * buf_size + len(buf)] = buf
939-
return len(buf)
940-
941-
pool_size = 32
942-
with thread.ThreadPoolExecutor(pool_size) as pool:
943-
# Use context manager for proper resource cleanup
944-
wait = list(pool.map(read_chunk, range(num_chunks)))
1012+
1013+
try:
1014+
checkpoint_contents = bytearray(file_size)
1015+
except MemoryError:
1016+
# Fallback to sequential read for very large files
1017+
if verbose:
1018+
warnings.warn(
1019+
f"Insufficient memory for parallel load of {file_size} bytes. "
1020+
"Falling back to sequential read.",
1021+
UserWarning
1022+
)
1023+
parallel = False
1024+
checkpoint_contents = fp.read()
1025+
1026+
if parallel:
1027+
def read_chunk(i):
1028+
# NOTE: We have to re-open the file to read each chunk, otherwise the
1029+
# parallelism has no effect. But we could reuse the file pointers
1030+
# within each thread.
1031+
with open(filename, 'rb') as f:
1032+
f.seek(i * buf_size)
1033+
buf = f.read(buf_size)
1034+
if buf:
1035+
checkpoint_contents[i * buf_size:i * buf_size + len(buf)] = buf
1036+
return len(buf)
1037+
1038+
pool_size = min(32, max(1, multiprocessing.cpu_count()))
1039+
with thread.ThreadPoolExecutor(pool_size) as pool:
1040+
# Use context manager for proper resource cleanup
1041+
wait = list(pool.map(read_chunk, range(num_chunks)))
9451042
else:
9461043
checkpoint_contents = fp.read()
9471044

0 commit comments

Comments
 (0)