22
33import logging
44import os .path
5+ import warnings
6+ from collections import namedtuple
7+ from typing import Dict , Any , Tuple
58
69from brainpy import errors
710from brainpy .base import io , naming
811from brainpy .base .collector import Collector , ArrayCollector
912
1013math = None
14+ StateLoadResult = namedtuple ('StateLoadResult' , ['missing_keys' , 'unexpected_keys' ])
1115
1216__all__ = [
1317 'BrainPyObject' ,
@@ -57,25 +61,27 @@ def name(self, name: str = None):
5761 naming .check_name_uniqueness (name = self ._name , obj = self )
5862
5963 def register_implicit_vars (self , * variables , ** named_variables ):
60- from brainpy .math import Variable
64+ global math
65+ if math is None : from brainpy import math
66+
6167 for variable in variables :
62- if isinstance (variable , Variable ):
68+ if isinstance (variable , math . Variable ):
6369 self .implicit_vars [f'var{ id (variable )} ' ] = variable
6470 elif isinstance (variable , (tuple , list )):
6571 for v in variable :
66- if not isinstance (v , Variable ):
67- raise ValueError (f'Must be instance of { Variable .__name__ } , but we got { type (v )} ' )
72+ if not isinstance (v , math . Variable ):
73+ raise ValueError (f'Must be instance of { math . Variable .__name__ } , but we got { type (v )} ' )
6874 self .implicit_vars [f'var{ id (v )} ' ] = v
6975 elif isinstance (variable , dict ):
7076 for k , v in variable .items ():
71- if not isinstance (v , Variable ):
72- raise ValueError (f'Must be instance of { Variable .__name__ } , but we got { type (v )} ' )
77+ if not isinstance (v , math . Variable ):
78+ raise ValueError (f'Must be instance of { math . Variable .__name__ } , but we got { type (v )} ' )
7379 self .implicit_vars [k ] = v
7480 else :
7581 raise ValueError (f'Unknown type: { type (variable )} ' )
7682 for key , variable in named_variables .items ():
77- if not isinstance (variable , Variable ):
78- raise ValueError (f'Must be instance of { Variable .__name__ } , but we got { type (variable )} ' )
83+ if not isinstance (variable , math . Variable ):
84+ raise ValueError (f'Must be instance of { math . Variable .__name__ } , but we got { type (variable )} ' )
7985 self .implicit_vars [key ] = variable
8086
8187 def register_implicit_nodes (self , * nodes , node_cls : type = None , ** named_nodes ):
@@ -101,7 +107,11 @@ def register_implicit_nodes(self, *nodes, node_cls: type = None, **named_nodes):
101107 raise ValueError (f'Must be instance of { node_cls .__name__ } , but we got { type (node )} ' )
102108 self .implicit_nodes [key ] = node
103109
104- def vars (self , method = 'absolute' , level = - 1 , include_self = True ):
110+ def vars (self ,
111+ method : str = 'absolute' ,
112+ level : int = - 1 ,
113+ include_self : bool = True ,
114+ exclude_types : Tuple [type , ...] = None ):
105115 """Collect all variables in this node and the children nodes.
106116
107117 Parameters
@@ -112,6 +122,8 @@ def vars(self, method='absolute', level=-1, include_self=True):
112122 The hierarchy level to find variables.
113123 include_self: bool
114124 Whether include the variables in the self.
125+ exclude_types: tuple of type
126+ The type to exclude.
115127
116128 Returns
117129 -------
@@ -121,12 +133,19 @@ def vars(self, method='absolute', level=-1, include_self=True):
121133 global math
122134 if math is None : from brainpy import math
123135
136+ if exclude_types is None :
137+ exclude_types = (math .VariableView , )
124138 nodes = self .nodes (method = method , level = level , include_self = include_self )
125139 gather = ArrayCollector ()
126140 for node_path , node in nodes .items ():
127141 for k in dir (node ):
128142 v = getattr (node , k )
143+ include = False
129144 if isinstance (v , math .Variable ):
145+ include = True
146+ if len (exclude_types ) and isinstance (v , exclude_types ):
147+ include = False
148+ if include :
130149 if k not in node ._excluded_vars :
131150 gather [f'{ node_path } .{ k } ' if node_path else k ] = v
132151 gather .update ({f'{ node_path } .{ k } ' : v for k , v in node .implicit_vars .items ()})
@@ -306,6 +325,49 @@ def save_states(self, filename, variables=None, **setting):
306325 else :
307326 raise errors .BrainPyError (f'Unknown file format: { filename } . We only supports { io .SUPPORTED_FORMATS } ' )
308327
328+ def state_dict (self ):
329+ """Returns a dictionary containing a whole state of the module.
330+
331+ Returns
332+ -------
333+ out: dict
334+ A dictionary containing a whole state of the module.
335+ """
336+ return self .vars ().unique ().dict ()
337+
338+ def load_state_dict (self , state_dict : Dict [str , Any ], warn : bool = True ):
339+ """Copy parameters and buffers from :attr:`state_dict` into
340+ this module and its descendants.
341+
342+ Parameters
343+ ----------
344+ state_dict: dict
345+ A dict containing parameters and persistent buffers.
346+ warn: bool
347+ Warnings when there are missing keys or unexpected keys in the external ``state_dict``.
348+
349+ Returns
350+ -------
351+ out: StateLoadResult
352+ ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
353+
354+ * **missing_keys** is a list of str containing the missing keys
355+ * **unexpected_keys** is a list of str containing the unexpected keys
356+ """
357+ variables = self .vars ().unique ()
358+ keys1 = set (state_dict .keys ())
359+ keys2 = set (variables .keys ())
360+ unexpected_keys = list (keys1 - keys2 )
361+ missing_keys = list (keys2 - keys1 )
362+ for key in keys2 .intersection (keys1 ):
363+ variables [key ].value = state_dict [key ]
364+ if warn :
365+ if len (unexpected_keys ):
366+ warnings .warn (f'Unexpected keys in state_dict: { unexpected_keys } ' , UserWarning )
367+ if len (missing_keys ):
368+ warnings .warn (f'Missing keys in state_dict: { missing_keys } ' , UserWarning )
369+ return StateLoadResult (missing_keys , unexpected_keys )
370+
309371 # def to(self, devices):
310372 # global math
311373 # if math is None: from brainpy import math
@@ -317,7 +379,6 @@ def save_states(self, filename, variables=None, **setting):
317379 # all_vars = self.vars().unique()
318380 # for data in all_vars.values():
319381 # data[:] = math.asarray(data.value)
320- # # TODO
321382 #
322383 # def cuda(self):
323384 # global math
0 commit comments