@@ -353,28 +353,15 @@ def set_sys_probs(self, sys_probs=None, auto_prob_style: str = "prob_sys_size"):
353353 elif auto_prob_style == "prob_sys_size" :
354354 probs = self .prob_nbatches
355355 elif auto_prob_style [:14 ] == "prob_sys_size;" :
356- probs = self ._prob_sys_size_ext (auto_prob_style )
356+ probs = prob_sys_size_ext (
357+ auto_prob_style , self .get_nsystems (), self .nbatches
358+ )
357359 else :
358360 raise RuntimeError ("Unknown auto prob style: " + auto_prob_style )
359361 else :
360- probs = self . _process_sys_probs (sys_probs )
362+ probs = process_sys_probs (sys_probs , self . nbatches )
361363 self .sys_probs = probs
362364
363- def _get_sys_probs (self , sys_probs , auto_prob_style ): # depreciated
364- if sys_probs is None :
365- if auto_prob_style == "prob_uniform" :
366- prob_v = 1.0 / float (self .nsystems )
367- prob = [prob_v for ii in range (self .nsystems )]
368- elif auto_prob_style == "prob_sys_size" :
369- prob = self .prob_nbatches
370- elif auto_prob_style [:14 ] == "prob_sys_size;" :
371- prob = self ._prob_sys_size_ext (auto_prob_style )
372- else :
373- raise RuntimeError ("unknown style " + auto_prob_style )
374- else :
375- prob = self ._process_sys_probs (sys_probs )
376- return prob
377-
378365 def get_batch (self , sys_idx : Optional [int ] = None ) -> dict :
379366 # batch generation style altered by Ziyao Li:
380367 # one should specify the "sys_prob" and "auto_prob_style" params
@@ -623,42 +610,44 @@ def _check_type_map_consistency(self, type_map_list):
623610 ret = ii
624611 return ret
625612
626- def _process_sys_probs (self , sys_probs ):
627- sys_probs = np .array (sys_probs )
628- type_filter = sys_probs >= 0
629- assigned_sum_prob = np .sum (type_filter * sys_probs )
630- # 1e-8 is to handle floating point error; See #1917
631- assert (
632- assigned_sum_prob <= 1.0 + 1e-8
633- ), "the sum of assigned probability should be less than 1"
634- rest_sum_prob = 1.0 - assigned_sum_prob
635- if not np .isclose (rest_sum_prob , 0 ):
636- rest_nbatch = (1 - type_filter ) * self .nbatches
637- rest_prob = rest_sum_prob * rest_nbatch / np .sum (rest_nbatch )
638- ret_prob = rest_prob + type_filter * sys_probs
639- else :
640- ret_prob = sys_probs
641- assert np .isclose (np .sum (ret_prob ), 1 ), "sum of probs should be 1"
642- return ret_prob
643-
644- def _prob_sys_size_ext (self , keywords ):
645- block_str = keywords .split (";" )[1 :]
646- block_stt = []
647- block_end = []
648- block_weights = []
649- for ii in block_str :
650- stt = int (ii .split (":" )[0 ])
651- end = int (ii .split (":" )[1 ])
652- weight = float (ii .split (":" )[2 ])
653- assert weight >= 0 , "the weight of a block should be no less than 0"
654- block_stt .append (stt )
655- block_end .append (end )
656- block_weights .append (weight )
657- nblocks = len (block_str )
658- block_probs = np .array (block_weights ) / np .sum (block_weights )
659- sys_probs = np .zeros ([self .get_nsystems ()])
660- for ii in range (nblocks ):
661- nbatch_block = self .nbatches [block_stt [ii ] : block_end [ii ]]
662- tmp_prob = [float (i ) for i in nbatch_block ] / np .sum (nbatch_block )
663- sys_probs [block_stt [ii ] : block_end [ii ]] = tmp_prob * block_probs [ii ]
664- return sys_probs
613+
614+ def process_sys_probs (sys_probs , nbatch ):
615+ sys_probs = np .array (sys_probs )
616+ type_filter = sys_probs >= 0
617+ assigned_sum_prob = np .sum (type_filter * sys_probs )
618+ # 1e-8 is to handle floating point error; See #1917
619+ assert (
620+ assigned_sum_prob <= 1.0 + 1e-8
621+ ), "the sum of assigned probability should be less than 1"
622+ rest_sum_prob = 1.0 - assigned_sum_prob
623+ if not np .isclose (rest_sum_prob , 0 ):
624+ rest_nbatch = (1 - type_filter ) * nbatch
625+ rest_prob = rest_sum_prob * rest_nbatch / np .sum (rest_nbatch )
626+ ret_prob = rest_prob + type_filter * sys_probs
627+ else :
628+ ret_prob = sys_probs
629+ assert np .isclose (np .sum (ret_prob ), 1 ), "sum of probs should be 1"
630+ return ret_prob
631+
632+
633+ def prob_sys_size_ext (keywords , nsystems , nbatch ):
634+ block_str = keywords .split (";" )[1 :]
635+ block_stt = []
636+ block_end = []
637+ block_weights = []
638+ for ii in block_str :
639+ stt = int (ii .split (":" )[0 ])
640+ end = int (ii .split (":" )[1 ])
641+ weight = float (ii .split (":" )[2 ])
642+ assert weight >= 0 , "the weight of a block should be no less than 0"
643+ block_stt .append (stt )
644+ block_end .append (end )
645+ block_weights .append (weight )
646+ nblocks = len (block_str )
647+ block_probs = np .array (block_weights ) / np .sum (block_weights )
648+ sys_probs = np .zeros ([nsystems ])
649+ for ii in range (nblocks ):
650+ nbatch_block = nbatch [block_stt [ii ] : block_end [ii ]]
651+ tmp_prob = [float (i ) for i in nbatch_block ] / np .sum (nbatch_block )
652+ sys_probs [block_stt [ii ] : block_end [ii ]] = tmp_prob * block_probs [ii ]
653+ return sys_probs
0 commit comments