@@ -16,7 +16,6 @@ def __init__ (self,
1616 rcut ,
1717 set_prefix = 'set' ,
1818 shuffle_test = True ,
19- run_opt = None ,
2019 type_map = None ,
2120 modifier = None ) :
2221 # init data
@@ -53,7 +52,7 @@ def __init__ (self,
5352 # natoms, nbatches
5453 ntypes = []
5554 for ii in self .data_systems :
56- ntypes .append (np . max ( ii .get_atom_type ()) + 1 )
55+ ntypes .append (ii .get_ntypes () )
5756 self .sys_ntypes = max (ntypes )
5857 self .natoms = []
5958 self .natoms_vec = []
@@ -81,10 +80,6 @@ def __init__ (self,
8180 warnings .warn ("system %s required test size is larger than the size of the dataset %s (%d > %d)" % \
8281 (self .system_dirs [ii ], chk_ret [0 ], test_size , chk_ret [1 ]))
8382
84- # print summary
85- if run_opt is not None :
86- self .print_summary (run_opt )
87-
8883
8984 def _load_test (self , ntests = - 1 ):
9085 self .test_data = collections .defaultdict (list )
@@ -155,24 +150,57 @@ def reduce(self,
155150 def get_data_dict (self ) :
156151 return self .data_systems [0 ].get_data_dict ()
157152
153+
154+ def _get_sys_probs (self ,
155+ sys_probs ,
156+ auto_prob_style ) :
157+ if sys_probs is None :
158+ if auto_prob_style == "prob_uniform" :
159+ prob = None
160+ elif auto_prob_style == "prob_sys_size" :
161+ prob = self .prob_nbatches
162+ elif auto_prob_style [:14 ] == "prob_sys_size;" :
163+ prob = self ._prob_sys_size_ext (auto_prob_style )
164+ else :
165+ raise RuntimeError ("unkown style " + auto_prob_style )
166+ else :
167+ prob = self ._process_sys_probs (sys_probs )
168+ return prob
169+
170+
158171 def get_batch (self ,
159172 sys_idx = None ,
160- sys_weights = None ,
161- style = "prob_sys_size" ) :
173+ sys_probs = None ,
174+ auto_prob_style = "prob_sys_size" ) :
175+ """
176+ Get a batch of data from the data system
177+
178+ Parameters
179+ ----------
180+ sys_idx: int
181+ The index of system from which the batch is get.
182+ If sys_idx is not None, `sys_probs` and `auto_prob_style` are ignored
183+ If sys_idx is None, automatically determine the system according to `sys_probs` or `auto_prob_style`, see the following.
184+ sys_probs: list of float
185+ The probabilitis of systems to get the batch.
186+ Summation of positive elements of this list should be no greater than 1.
187+ Element of this list can be negative, the probability of the corresponding system is determined automatically by the number of batches in the system.
188+ auto_prob_style: float
189+ Determine the probability of systems automatically. The method is assigned by this key and can be
190+ - "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()
191+ - "prob_sys_size" : the probability of a system is proportional to the number of batches in the system
192+ - "prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;..." :
193+ the list of systems is devided into blocks. A block is specified by `stt_idx:end_idx:weight`,
194+ where `stt_idx` is the starting index of the system, `end_idx` is then ending (not including) index of the system,
195+ the probabilities of the systems in this block sums up to `weight`, and the relatively probabilities within this block is proportional
196+ to the number of batches in the system.
197+ """
162198 if not hasattr (self , 'default_mesh' ) :
163199 self ._make_default_mesh ()
164200 if sys_idx is not None :
165201 self .pick_idx = sys_idx
166202 else :
167- if sys_weights is None :
168- if style == "prob_sys_size" :
169- prob = self .prob_nbatches
170- elif style == "prob_uniform" :
171- prob = None
172- else :
173- raise RuntimeError ("unkown get_batch style" )
174- else :
175- prob = self .process_sys_weights (sys_weights )
203+ prob = self ._get_sys_probs (sys_probs , auto_prob_style )
176204 self .pick_idx = np .random .choice (np .arange (self .nsystems ), p = prob )
177205 b_data = self .data_systems [self .pick_idx ].get_batch (self .batch_size [self .pick_idx ])
178206 b_data ["natoms_vec" ] = self .natoms_vec [self .pick_idx ]
@@ -224,21 +252,26 @@ def _format_name_length(self, name, width) :
224252 name = '-- ' + name
225253 return name
226254
227- def print_summary (self , run_opt ) :
255+ def print_summary (self ,
256+ run_opt ,
257+ sys_probs = None ,
258+ auto_prob_style = "prob_sys_size" ) :
259+ prob = self ._get_sys_probs (sys_probs , auto_prob_style )
228260 tmp_msg = ""
229261 # width 65
230262 sys_width = 42
231- tmp_msg += "---Summary of DataSystem-----------------------------------------\n "
263+ tmp_msg += "---Summary of DataSystem------------------------------------------------ \n "
232264 tmp_msg += "find %d system(s):\n " % self .nsystems
233265 tmp_msg += "%s " % self ._format_name_length ('system' , sys_width )
234- tmp_msg += "%s %s %s\n " % ('natoms' , 'bch_sz' , 'n_bch' )
266+ tmp_msg += "%s %s %s %5s \n " % ('natoms' , 'bch_sz' , 'n_bch' , 'prob ' )
235267 for ii in range (self .nsystems ) :
236- tmp_msg += ("%s %6d %6d %5d\n " %
268+ tmp_msg += ("%s %6d %6d %5d %5.3f \n " %
237269 (self ._format_name_length (self .system_dirs [ii ], sys_width ),
238270 self .natoms [ii ],
239271 self .batch_size [ii ],
240- self .nbatches [ii ]) )
241- tmp_msg += "-----------------------------------------------------------------\n "
272+ self .nbatches [ii ],
273+ prob [ii ]) )
274+ tmp_msg += "------------------------------------------------------------------------\n "
242275 run_opt .message (tmp_msg )
243276
244277
@@ -264,18 +297,39 @@ def _check_type_map_consistency(self, type_map_list):
264297 ret = ii
265298 return ret
266299
267- def _process_sys_weights (self , sys_weights ) :
268- sys_weights = np .array (sys_weights )
269- type_filter = sys_weights >= 0
270- assigned_sum_prob = np .sum (type_filter * sys_weights )
300+ def _process_sys_probs (self , sys_probs ) :
301+ sys_probs = np .array (sys_probs )
302+ type_filter = sys_probs >= 0
303+ assigned_sum_prob = np .sum (type_filter * sys_probs )
271304 assert assigned_sum_prob <= 1 , "the sum of assigned probability should be less than 1"
272305 rest_sum_prob = 1. - assigned_sum_prob
273306 rest_nbatch = (1 - type_filter ) * self .nbatches
274307 rest_prob = rest_sum_prob * rest_nbatch / np .sum (rest_nbatch )
275- ret_prob = rest_prob + type_filter * sys_weights
308+ ret_prob = rest_prob + type_filter * sys_probs
276309 assert np .sum (ret_prob ) == 1 , "sum of probs should be 1"
277310 return ret_prob
278-
311+
312+ def _prob_sys_size_ext (self , keywords ):
313+ block_str = keywords .split (';' )[1 :]
314+ block_stt = []
315+ block_end = []
316+ block_weights = []
317+ for ii in block_str :
318+ stt = int (ii .split (':' )[0 ])
319+ end = int (ii .split (':' )[1 ])
320+ weight = float (ii .split (':' )[2 ])
321+ assert (weight >= 0 ), "the weight of a block should be no less than 0"
322+ block_stt .append (stt )
323+ block_end .append (end )
324+ block_weights .append (weight )
325+ nblocks = len (block_str )
326+ block_probs = np .array (block_weights ) / np .sum (block_weights )
327+ sys_probs = np .zeros ([self .get_nsystems ()])
328+ for ii in range (nblocks ):
329+ nbatch_block = self .nbatches [block_stt [ii ]:block_end [ii ]]
330+ tmp_prob = [float (i ) for i in nbatch_block ] / np .sum (nbatch_block )
331+ sys_probs [block_stt [ii ]:block_end [ii ]] = tmp_prob * block_probs [ii ]
332+ return sys_probs
279333
280334
281335
0 commit comments