@@ -42,64 +42,20 @@ class DotDict(dict):
4242 >>> f(d)
4343 TypeError: Argument 'a' of type <class 'str'> is not a valid JAX type.
4444
45- At this moment, you can label this attribute `names` as not a key in the dictionary
46- by using the syntax::
47-
48- >>> d.add_attr_not_key('names')
49- >>> f(d)
50- {'a': DeviceArray(10, dtype=int32, weak_type=True),
51- 'b': DeviceArray(20, dtype=int32, weak_type=True),
52- 'c': DeviceArray(30, dtype=int32, weak_type=True)}
53-
5445 """
5546
56- '''Used to exclude variables that '''
57- attrs_not_keys = ('attrs_not_keys' , 'var_names' )
58-
5947 def __init__ (self , * args , ** kwargs ):
6048 super ().__init__ (* args , ** kwargs )
6149 self .__dict__ = self
62- self .var_names = ()
6350
6451 def copy (self ) -> 'DotDict' :
6552 return type (self )(super ().copy ())
6653
67- def keys (self ):
68- """Retrieve all keys in the dict, excluding ignored keys."""
69- keys = []
70- for k in super (DotDict , self ).keys ():
71- if k not in self .attrs_not_keys :
72- keys .append (k )
73- return tuple (keys )
74-
75- def values (self ):
76- """Retrieve all values in the dict, excluding values of ignored keys."""
77- values = []
78- for k , v in super (DotDict , self ).items ():
79- if k not in self .attrs_not_keys :
80- values .append (v )
81- return tuple (values )
82-
83- def items (self ):
84- """Retrieve all items in the dict, excluding ignored items."""
85- items = []
86- for k , v in super (DotDict , self ).items ():
87- if k not in self .attrs_not_keys :
88- items .append ((k , v ))
89- return items
90-
9154 def to_numpy (self ):
9255 """Change all values to numpy arrays."""
9356 for key in tuple (self .keys ()):
9457 self [key ] = np .asarray (self [key ])
9558
96- def add_attr_not_key (self , * args ):
97- """Add excluded attribute when retrieving dictionary keys. """
98- for arg in args :
99- if not isinstance (arg , str ):
100- raise TypeError ('Only support string.' )
101- self .attrs_not_keys += args
102-
10359 def update (self , * args , ** kwargs ):
10460 super ().update (* args , ** kwargs )
10561 return self
@@ -179,7 +135,7 @@ def subset(self, var_type):
179135
180136 >>> import brainpy as bp
181137 >>>
182- >>> some_collector = Collector ()
138+ >>> some_collector = DotDict ()
183139 >>>
184140 >>> # get all trainable variables
185141 >>> some_collector.subset(bp.math.TrainVar)
0 commit comments