1919"""
2020
2121import enum
22+ import multiprocessing
2223import os
2324import sys
2425import threading
2526import warnings
27+ import shutil
2628from concurrent .futures import thread
2729from contextlib import contextmanager
2830from typing import Any , Callable , Dict , List , Literal , Optional
@@ -285,10 +287,28 @@ def _restore_list(xs, state_dict: Dict[str, Any], mismatch: MismatchMode = 'erro
285287def _dict_state_dict (xs : Dict [str , Any ]) -> Dict [str , Any ]:
286288 if isinstance (xs , brainstate .util .FlattedDict ):
287289 xs = xs .to_nest ()
288- str_keys = set (str (k ) for k in xs .keys ())
290+
291+ try :
292+ str_keys = set (str (k ) for k in xs .keys ())
293+ except TypeError as e :
294+ raise ValueError (f'Dict contains unhashable keys: { e } ' ) from e
295+
289296 if len (str_keys ) != len (xs ):
290- raise ValueError ('Dict keys do not have a unique string representation: '
291- f'{ str_keys } vs given: { xs } ' )
297+ # Provide detailed error showing which keys collide
298+ str_to_keys = {}
299+ for k in xs .keys ():
300+ sk = str (k )
301+ if sk in str_to_keys :
302+ str_to_keys [sk ].append (k )
303+ else :
304+ str_to_keys [sk ] = [k ]
305+
306+ collisions = {sk : keys for sk , keys in str_to_keys .items () if len (keys ) > 1 }
307+ raise ValueError (
308+ f'Dict keys do not have a unique string representation. '
309+ f'Collisions: { collisions } '
310+ )
311+
292312 return {
293313 str (key ): msgpack_to_state_dict (value )
294314 for key , value in xs .items ()
@@ -347,7 +367,13 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any], mismatch: MismatchMode =
347367 # Keep original value if field is missing from state_dict
348368 fields [field ] = getattr (xs , field )
349369
350- return type (xs )(** fields )
370+ try :
371+ return type (xs )(** fields )
372+ except TypeError as e :
373+ raise TypeError (
374+ f"Failed to reconstruct namedtuple { type (xs ).__name__ } at path { current_path ()} : { e } . "
375+ f"Ensure the namedtuple class definition is available."
376+ ) from e
351377
352378
353379msgpack_register_serialization (
@@ -393,17 +419,19 @@ def _brainstate_dict_state(x: brainstate.State) -> Dict[str, Any]:
393419
394420
395421def _restore_brainstate (x : brainstate .State , state_dict : Dict , mismatch : MismatchMode = 'error' ) -> brainstate .State :
396- """Restore brainstate.State from state dict.
422+ """Restore brainstate.State from state dict by mutating its value in-place .
397423
398- Creates a new State object with the restored value instead of mutating the original.
424+ This function mutates the State object's value attribute rather than creating
425+ a new State object, which is consistent with how State objects are used
426+ throughout the codebase.
399427
400428 Args:
401- x: Template State object
429+ x: Template State object to restore (will be mutated)
402430 state_dict: Serialized state dictionary
403431 mismatch: How to handle mismatches
404432
405433 Returns:
406- A new State object with the restored value
434+ The same State object with restored value (for chaining)
407435 """
408436 x .value = msgpack_from_state_dict (x .value , state_dict , mismatch = mismatch )
409437 return x
@@ -538,7 +566,21 @@ def _dict_to_tuple(dct):
538566
539567def _chunk (arr ) -> Dict [str , Any ]:
540568 """Convert array to a canonical dictionary of chunked arrays."""
541- chunksize = max (1 , int (MAX_CHUNK_SIZE / arr .dtype .itemsize ))
569+ itemsize = arr .dtype .itemsize
570+ if itemsize == 0 :
571+ raise ValueError (f"Cannot chunk array with zero itemsize dtype: { arr .dtype } " )
572+
573+ chunksize = max (1 , int (MAX_CHUNK_SIZE / itemsize ))
574+
575+ # Warn if chunking is very inefficient
576+ if chunksize < 1000 and arr .size > 1000000 :
577+ warnings .warn (
578+ f"Array chunking may be inefficient: dtype={ arr .dtype } , "
579+ f"itemsize={ itemsize } , chunksize={ chunksize } . "
580+ f"Consider using a different dtype or smaller arrays." ,
581+ UserWarning
582+ )
583+
542584 data = {'__msgpack_chunked_array__' : True ,
543585 'shape' : _tuple_to_dict (arr .shape )}
544586 flatarr = arr .reshape (- 1 )
@@ -613,22 +655,33 @@ def _msgpack_serialize(pytree, in_place: bool = False) -> bytes:
613655 return msgpack .packb (pytree , default = _msgpack_ext_pack , strict_types = True )
614656
615657
616- def _msgpack_restore (encoded_pytree : bytes ):
658+ def _msgpack_restore (encoded_pytree : bytes , max_size : Optional [ int ] = None ):
617659 """Restore data structure from bytes in msgpack format.
618660
619661 Low-level function that only supports python trees with array leaves,
620662 for custom objects use `from_bytes`.
621663
622664 Args:
623665 encoded_pytree: msgpack-encoded bytes of python tree.
666+ max_size: Maximum allowed size in bytes (default: 10GB)
624667
625668 Returns:
626669 Python tree of dict, list, tuple with python primitive
627670 and array leaves.
628671
629672 Raises:
673+ ValueError: If data exceeds max_size
630674 InvalidCheckpointPath: If the msgpack data is corrupt or invalid.
631675 """
676+ if max_size is None :
677+ max_size = 10 * (1024 ** 3 ) # 10GB default
678+
679+ if len (encoded_pytree ) > max_size :
680+ raise ValueError (
681+ f"Checkpoint data too large: { len (encoded_pytree )} bytes "
682+ f"exceeds maximum { max_size } bytes"
683+ )
684+
632685 try :
633686 state_dict = msgpack .unpackb (encoded_pytree , ext_hook = _msgpack_ext_unpack , raw = False )
634687 except (msgpack .exceptions .ExtraData ,
@@ -684,14 +737,28 @@ class _EmptyNode:
684737def _rename_fn (src , dst , overwrite = False ):
685738 """Rename file from src to dst, with overwrite control.
686739
740+ Uses os.replace() for atomic rename on both Unix and Windows (Python 3.3+).
741+
687742 Args:
688743 src: Source file path
689744 dst: Destination file path
690745 overwrite: If False, raise AlreadyExistsError when dst exists
691746 """
692- if os .path .exists (src ):
693- if os .path .exists (dst ) and not overwrite :
694- raise AlreadyExistsError (dst )
747+ if not os .path .exists (src ):
748+ return
749+
750+ if not overwrite and os .path .exists (dst ):
751+ raise AlreadyExistsError (dst )
752+
753+ try :
754+ os .replace (src , dst ) # Atomic on both platforms
755+ except OSError :
756+ # Fallback for edge cases
757+ if overwrite and os .path .exists (dst ):
758+ try :
759+ os .remove (dst )
760+ except OSError :
761+ pass
695762 os .rename (src , dst )
696763
697764
@@ -712,6 +779,7 @@ def __init__(self, max_workers: int = 1):
712779 self .executor = thread .ThreadPoolExecutor (max_workers = max_workers )
713780 self .save_future = None
714781 self ._closed = False
782+ self ._lock = threading .Lock ()
715783
716784 def __enter__ (self ):
717785 """Enter context manager."""
@@ -762,8 +830,10 @@ def save_async(self, task: Callable[[], Any]):
762830 """
763831 if self ._closed :
764832 raise RuntimeError ("Cannot save with a closed AsyncManager" )
765- self .wait_previous_save ()
766- self .save_future = self .executor .submit (task ) # type: ignore
833+
834+ with self ._lock :
835+ self .wait_previous_save ()
836+ self .save_future = self .executor .submit (task ) # type: ignore
767837
768838
769839def _save_main_ckpt_file (
@@ -855,11 +925,25 @@ def msgpack_save(
855925 if async_manager :
856926 async_manager .wait_previous_save ()
857927
858- if os .path .dirname (filename ):
859- os .makedirs (os .path .dirname (filename ), exist_ok = True )
928+ dirname = os .path .dirname (filename )
929+ if dirname :
930+ try :
931+ os .makedirs (dirname , exist_ok = True )
932+ except OSError as e :
933+ raise OSError (f"Cannot create directory { dirname } : { e } " ) from e
860934 if not overwrite and os .path .exists (filename ):
861935 raise InvalidCheckpointPath (filename )
862936
937+ # Warn on Windows if path exceeds MAX_PATH limitation
938+ if sys .platform == 'win32' :
939+ abs_path = os .path .abspath (filename )
940+ if len (abs_path ) > 260 :
941+ warnings .warn (
942+ f"Path length { len (abs_path )} exceeds Windows MAX_PATH (260). "
943+ "Consider using shorter paths or enabling long path support." ,
944+ UserWarning
945+ )
946+
863947 if isinstance (target , brainstate .util .FlattedDict ):
864948 target = target .to_nest ()
865949 target = _to_bytes (target )
@@ -925,23 +1009,36 @@ def msgpack_load(
9251009 if parallel and fp .seekable ():
9261010 buf_size = 128 << 20 # 128M buffer.
9271011 num_chunks = (file_size + buf_size - 1 ) // buf_size # Ceiling division
928- checkpoint_contents = bytearray (file_size )
929-
930- def read_chunk (i ):
931- # NOTE: We have to re-open the file to read each chunk, otherwise the
932- # parallelism has no effect. But we could reuse the file pointers
933- # within each thread.
934- with open (filename , 'rb' ) as f :
935- f .seek (i * buf_size )
936- buf = f .read (buf_size )
937- if buf :
938- checkpoint_contents [i * buf_size :i * buf_size + len (buf )] = buf
939- return len (buf )
940-
941- pool_size = 32
942- with thread .ThreadPoolExecutor (pool_size ) as pool :
943- # Use context manager for proper resource cleanup
944- wait = list (pool .map (read_chunk , range (num_chunks )))
1012+
1013+ try :
1014+ checkpoint_contents = bytearray (file_size )
1015+ except MemoryError :
1016+ # Fallback to sequential read for very large files
1017+ if verbose :
1018+ warnings .warn (
1019+ f"Insufficient memory for parallel load of { file_size } bytes. "
1020+ "Falling back to sequential read." ,
1021+ UserWarning
1022+ )
1023+ parallel = False
1024+ checkpoint_contents = fp .read ()
1025+
1026+ if parallel :
1027+ def read_chunk (i ):
1028+ # NOTE: We have to re-open the file to read each chunk, otherwise the
1029+ # parallelism has no effect. But we could reuse the file pointers
1030+ # within each thread.
1031+ with open (filename , 'rb' ) as f :
1032+ f .seek (i * buf_size )
1033+ buf = f .read (buf_size )
1034+ if buf :
1035+ checkpoint_contents [i * buf_size :i * buf_size + len (buf )] = buf
1036+ return len (buf )
1037+
1038+ pool_size = min (32 , max (1 , multiprocessing .cpu_count ()))
1039+ with thread .ThreadPoolExecutor (pool_size ) as pool :
1040+ # Use context manager for proper resource cleanup
1041+ wait = list (pool .map (read_chunk , range (num_chunks )))
9451042 else :
9461043 checkpoint_contents = fp .read ()
9471044
0 commit comments