@@ -99,36 +99,12 @@ def __init__(
9999 # format string monitors
100100 monitors = self ._format_seq_monitors (monitors )
101101 # get monitor targets
102- monitors = self ._find_monitor_targets (monitors )
102+ monitors = self ._find_seq_monitor_targets (monitors )
103103 elif isinstance (monitors , dict ):
104- _monitors = dict ()
105- for key , val in monitors .items ():
106- if not isinstance (key , str ):
107- raise MonitorError ('Expect the key of the dict "monitors" must be a string. But got '
108- f'{ type (key )} : { key } ' )
109- if isinstance (val , bm .Variable ):
110- val = (val , None )
111- if isinstance (val , (tuple , list )):
112- if not isinstance (val [0 ], bm .Variable ):
113- raise MonitorError ('Expect the format of (variable, index) in the monitor setting. '
114- f'But we got { val } ' )
115- if len (val ) == 1 :
116- _monitors [key ] = (val [0 ], None )
117- elif len (val ) == 2 :
118- if isinstance (val [1 ], (int , np .integer )):
119- idx = bm .array ([val [1 ]])
120- else :
121- idx = None if val [1 ] is None else bm .asarray (val [1 ])
122- _monitors [key ] = (val [0 ], idx )
123- else :
124- raise MonitorError ('Expect the format of (variable, index) in the monitor setting. '
125- f'But we got { val } ' )
126- elif callable (val ):
127- _monitors [key ] = val
128- else :
129- raise MonitorError ('The value of dict monitor expect a sequence with (variable, index) '
130- f'or a callable function. But we got { val } ' )
131- monitors = _monitors
104+ # format string monitors
105+ monitors = self ._format_dict_monitors (monitors )
106+ # get monitor targets
107+ monitors = self ._find_dict_monitor_targets (monitors )
132108 else :
133109 raise MonitorError (f'We only supports a format of list/tuple/dict of '
134110 f'"vars", while we got { type (monitors )} .' )
@@ -160,7 +136,7 @@ def __init__(
160136
161137 def _format_seq_monitors (self , monitors ):
162138 if not isinstance (monitors , (tuple , list )):
163- raise TypeError (f'Must be a sequence , but we got { type (monitors )} ' )
139+ raise TypeError (f'Must be a tuple/list , but we got { type (monitors )} ' )
164140 _monitors = []
165141 for mon in monitors :
166142 if isinstance (mon , str ):
@@ -183,7 +159,40 @@ def _format_seq_monitors(self, monitors):
183159 raise MonitorError (f'We do not support monitor with { type (mon )} : { mon } ' )
184160 return _monitors
185161
186- def _find_monitor_targets (self , _monitors ):
162+ def _format_dict_monitors (self , monitors ):
163+ if not isinstance (monitors , dict ):
164+ raise TypeError (f'Must be a dict, but we got { type (monitors )} ' )
165+ _monitors = dict ()
166+ for key , val in monitors .items ():
167+ if not isinstance (key , str ):
168+ raise MonitorError ('Expect the key of the dict "monitors" must be a string. But got '
169+ f'{ type (key )} : { key } ' )
170+ if isinstance (val , (bm .Variable , str )):
171+ val = (val , None )
172+
173+ if isinstance (val , (tuple , list )):
174+ if not isinstance (val [0 ], (bm .Variable , str )):
175+ raise MonitorError ('Expect the format of (variable, index) in the monitor setting. '
176+ f'But we got { val } ' )
177+ if len (val ) == 1 :
178+ _monitors [key ] = (val [0 ], None )
179+ elif len (val ) == 2 :
180+ if isinstance (val [1 ], (int , np .integer )):
181+ idx = bm .array ([val [1 ]])
182+ else :
183+ idx = None if val [1 ] is None else bm .asarray (val [1 ])
184+ _monitors [key ] = (val [0 ], idx )
185+ else :
186+ raise MonitorError ('Expect the format of (variable, index) in the monitor setting. '
187+ f'But we got { val } ' )
188+ elif callable (val ):
189+ _monitors [key ] = val
190+ else :
191+ raise MonitorError ('The value of dict monitor expect a sequence with (variable, index) '
192+ f'or a callable function. But we got { val } ' )
193+ return _monitors
194+
195+ def _find_seq_monitor_targets (self , _monitors ):
187196 if not isinstance (_monitors , (tuple , list )):
188197 raise TypeError (f'Must be a sequence, but we got { type (_monitors )} ' )
189198 # get monitor targets
@@ -214,6 +223,43 @@ def _find_monitor_targets(self, _monitors):
214223 monitors [key ] = (getattr (master , splits [- 1 ]), index )
215224 return monitors
216225
226+ def _find_dict_monitor_targets (self , _monitors ):
227+ if not isinstance (_monitors , dict ):
228+ raise TypeError (f'Must be a dict, but we got { type (_monitors )} ' )
229+ # get monitor targets
230+ monitors = {}
231+ name2node = None
232+ for _key , _mon in _monitors .items ():
233+ if isinstance (_mon , str ):
234+ if name2node is None :
235+ name2node = {node .name : node for node in list (self .target .nodes (level = - 1 ).unique ().values ())}
236+
237+ key , index = _mon [0 ], _mon [1 ]
238+ splits = key .split ('.' )
239+ if len (splits ) == 1 :
240+ if not hasattr (self .target , splits [0 ]):
241+ raise RunningError (f'{ self .target } does not has variable { key } .' )
242+ monitors [key ] = (getattr (self .target , splits [- 1 ]), index )
243+ else :
244+ if not hasattr (self .target , splits [0 ]):
245+ if splits [0 ] not in name2node :
246+ raise MonitorError (f'Cannot find target { key } in monitor of { self .target } , please check.' )
247+ else :
248+ master = name2node [splits [0 ]]
249+ assert len (splits ) == 2
250+ monitors [key ] = (getattr (master , splits [- 1 ]), index )
251+ else :
252+ master = self .target
253+ for s in splits [:- 1 ]:
254+ try :
255+ master = getattr (master , s )
256+ except KeyError :
257+ raise MonitorError (f'Cannot find { key } in { master } , please check.' )
258+ monitors [key ] = (getattr (master , splits [- 1 ]), index )
259+ else :
260+ monitors [_key ] = _mon
261+ return monitors
262+
217263 def __del__ (self ):
218264 if hasattr (self , 'mon' ):
219265 for key in tuple (self .mon .keys ()):
0 commit comments