@@ -412,7 +412,7 @@ def sample_one_dimension(
412412
413413 def _sample_real_dimension (self , dimension , shape_size , below_points , above_points ):
414414 """Sample values for real dimension"""
415- if dimension .prior_name in ["uniform" , "reciprocal" ]:
415+ if any ( map ( dimension .prior_name . endswith , ["uniform" , "reciprocal" ])) :
416416 return self .sample_one_dimension (
417417 dimension ,
418418 shape_size ,
@@ -421,7 +421,9 @@ def _sample_real_dimension(self, dimension, shape_size, below_points, above_poin
421421 self ._sample_real_point ,
422422 )
423423 else :
424- raise NotImplementedError ()
424+ raise NotImplementedError (
425+ f"Prior { dimension .prior_name } is not supported for real values"
426+ )
425427
426428 def _sample_loguniform_real_point (self , dimension , below_points , above_points ):
427429 """Sample one value for real dimension in a loguniform way"""
@@ -555,10 +557,31 @@ class GMMSampler:
555557 weights: list
556558 Weights for each Gaussian components in the GMM
557559 Default: ``None``
560+ base_attempts: int, optional
561+ Base number of attempts to sample points within `low` and `high` bounds.
562+ Defaults to 10.
563+ attempts_factor: int, optional
564+ If sampling always falls out of bound try again with `attempts` * `attempts_factor`.
565+ Defaults to 10.
566+ max_attempts: int, optional
567+ If sampling always falls out of bound try again with `attempts` * `attempts_factor`
568+ up to `max_attempts` (inclusive).
569+ Defaults to 10000.
558570
559571 """
560572
561- def __init__ (self , tpe , mus , sigmas , low , high , weights = None ):
573+ def __init__ (
574+ self ,
575+ tpe ,
576+ mus ,
577+ sigmas ,
578+ low ,
579+ high ,
580+ weights = None ,
581+ base_attempts = 10 ,
582+ attempts_factor = 10 ,
583+ max_attempts = 10000 ,
584+ ):
562585 self .tpe = tpe
563586
564587 self .mus = mus
@@ -567,6 +590,10 @@ def __init__(self, tpe, mus, sigmas, low, high, weights=None):
567590 self .high = high
568591 self .weights = weights if weights is not None else len (mus ) * [1.0 / len (mus )]
569592
593+ self .base_attempts = base_attempts
594+ self .attempts_factor = attempts_factor
595+ self .max_attempts = max_attempts
596+
570597 self .pdfs = []
571598 self ._build_mixture ()
572599
@@ -575,24 +602,38 @@ def _build_mixture(self):
575602 for mu , sigma in zip (self .mus , self .sigmas ):
576603 self .pdfs .append (norm (mu , sigma ))
577604
578- def sample (self , num = 1 , attempts = 10 ):
605+ def sample (self , num = 1 , attempts = None ):
579606 """Sample required number of points"""
607+ if attempts is None :
608+ attempts = self .base_attempts
609+
580610 point = []
581611 for _ in range (num ):
582612 pdf = numpy .argmax (self .tpe .rng .multinomial (1 , self .weights ))
583- new_points = list (
584- self .pdfs [pdf ].rvs (size = attempts , random_state = self .tpe .rng )
585- )
586- while True :
587- if not new_points :
588- raise RuntimeError (
589- f"Failed to sample in interval ({ self .low } , { self .high } )"
590- )
591- pt = new_points .pop (0 )
592- if self .low <= pt <= self .high :
593- point .append (pt )
613+ attempts_tried = 0
614+ while attempts_tried < attempts :
615+ new_points = self .pdfs [pdf ].rvs (
616+ size = attempts , random_state = self .tpe .rng
617+ )
618+ valid_points = (self .low <= new_points ) * (self .high >= new_points )
619+
620+ if any (valid_points ):
621+ index = numpy .argmax (valid_points )
622+ point .append (float (new_points [index ]))
594623 break
595624
625+ index = None
626+ attempts_tried += 1
627+
628+ if index is None and attempts >= self .max_attempts :
629+ raise RuntimeError (
630+ f"Failed to sample in interval ({ self .low } , { self .high } )"
631+ )
632+ elif index is None :
633+ point .append (
634+ self .sample (num = 1 , attempts = attempts * self .attempts_factor )[0 ]
635+ )
636+
596637 return point
597638
598639 def get_loglikelis (self , points ):
0 commit comments