@@ -19,12 +19,14 @@ class IndexInterpolator:
1919 """
2020
2121 def __init__ (self , parameter_list ):
22- parameter_list = np .asarray (parameter_list )
23- self .npars = parameter_list .shape [- 1 ]
24- self .parameter_list = np .unique (parameter_list )
25- self .index_interpolator = interp1d (
26- self .parameter_list , np .arange (len (self .parameter_list )), kind = "linear"
27- )
22+ parameter_list = list (parameter_list )
23+ self .npars = len (parameter_list )
24+ self .parameter_list = [np .unique (pars ) for pars in parameter_list ]
25+ idxs = [np .arange (len (pars )) for pars in self .parameter_list ]
26+ self .index_interpolators = [
27+ interp1d (pars , idx , kind = "linear" )
28+ for pars , idx in zip (self .parameter_list , idxs )
29+ ]
2830
2931 def __call__ (self , param ):
3032 """
@@ -34,26 +36,33 @@ def __call__(self, param):
3436 :type param: list
3537 :raises ValueError: if *value* is out of bounds.
3638
37- :returns: ((low_val, high_val), (frac_low, frac_high )), the lower and higher bounding points in the grid
38- and the fractional distance (0 - 1) between them and the value .
39+ :returns: ((low_val, high_val), (low_dist, high_dist )), the lower and higher bounding points in the grid
40+ and the fractional distance (0 - 1) from the two points .
3941 """
4042 if len (param ) != self .npars :
4143 raise ValueError (
4244 "Incorrect number of parameters. Expected {} but got {}" .format (
4345 self .npars , len (param )
4446 )
4547 )
46- try :
47- index = self .index_interpolator (param )
48- except ValueError :
49- raise ValueError ("Requested param {} is out of bounds." .format (param ))
50- high = np .ceil (index ).astype (int )
51- low = np .floor (index ).astype (int )
52- frac_index = index - low
53- return (
54- (self .parameter_list [low ], self .parameter_list [high ]),
55- ((1 - frac_index ), frac_index ),
56- )
48+ lows = np .empty (self .npars )
49+ highs = np .empty (self .npars )
50+ fracs = np .empty (self .npars )
51+ for i in range (self .npars ):
52+ # get interpolated index
53+ try :
54+ index = self .index_interpolators [i ](param [i ])
55+ except ValueError :
56+ raise ValueError ("Requested param {} is out of bounds." .format (param ))
57+ low = np .floor (index ).astype (int )
58+ high = np .ceil (index ).astype (int )
59+ frac = index - low
60+ # get bounding params
61+ lows [i ] = self .parameter_list [i ][low ]
62+ highs [i ] = self .parameter_list [i ][high ]
63+ fracs [i ] = frac
64+
65+ return (lows , highs ), (1 - fracs , fracs )
5766
5867
5968class Interpolator :
0 commit comments