Skip to content

Commit ab1d004

Browse files
authored
Merge pull request #147 from Starfish-develop/ml/interpolator
2 parents 42217c6 + 8775c03 commit ab1d004

File tree

5 files changed

+34
-24
lines changed

5 files changed

+34
-24
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
repos:
22
- repo: https://github.com/python/black
3-
rev: stable
3+
rev: 22.3.0
44
hooks:
55
- id: black

Starfish/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.4.1"
1+
__version__ = "0.4.2"
22

33
from .spectrum import Spectrum
44

Starfish/grid_tools/interpolators.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ class IndexInterpolator:
1919
"""
2020

2121
def __init__(self, parameter_list):
22-
parameter_list = np.asarray(parameter_list)
23-
self.npars = parameter_list.shape[-1]
24-
self.parameter_list = np.unique(parameter_list)
25-
self.index_interpolator = interp1d(
26-
self.parameter_list, np.arange(len(self.parameter_list)), kind="linear"
27-
)
22+
parameter_list = list(parameter_list)
23+
self.npars = len(parameter_list)
24+
self.parameter_list = [np.unique(pars) for pars in parameter_list]
25+
idxs = [np.arange(len(pars)) for pars in self.parameter_list]
26+
self.index_interpolators = [
27+
interp1d(pars, idx, kind="linear")
28+
for pars, idx in zip(self.parameter_list, idxs)
29+
]
2830

2931
def __call__(self, param):
3032
"""
@@ -34,26 +36,33 @@ def __call__(self, param):
3436
:type param: list
3537
:raises ValueError: if *value* is out of bounds.
3638
37-
:returns: ((low_val, high_val), (frac_low, frac_high)), the lower and higher bounding points in the grid
38-
and the fractional distance (0 - 1) between them and the value.
39+
:returns: ((low_val, high_val), (low_dist, high_dist)), the lower and higher bounding points in the grid
40+
and the fractional distance (0 - 1) from the two points.
3941
"""
4042
if len(param) != self.npars:
4143
raise ValueError(
4244
"Incorrect number of parameters. Expected {} but got {}".format(
4345
self.npars, len(param)
4446
)
4547
)
46-
try:
47-
index = self.index_interpolator(param)
48-
except ValueError:
49-
raise ValueError("Requested param {} is out of bounds.".format(param))
50-
high = np.ceil(index).astype(int)
51-
low = np.floor(index).astype(int)
52-
frac_index = index - low
53-
return (
54-
(self.parameter_list[low], self.parameter_list[high]),
55-
((1 - frac_index), frac_index),
56-
)
48+
lows = np.empty(self.npars)
49+
highs = np.empty(self.npars)
50+
fracs = np.empty(self.npars)
51+
for i in range(self.npars):
52+
# get interpolated index
53+
try:
54+
index = self.index_interpolators[i](param[i])
55+
except ValueError:
56+
raise ValueError("Requested param {} is out of bounds.".format(param))
57+
low = np.floor(index).astype(int)
58+
high = np.ceil(index).astype(int)
59+
frac = index - low
60+
# get bounding params
61+
lows[i] = self.parameter_list[i][low]
62+
highs[i] = self.parameter_list[i][high]
63+
fracs[i] = frac
64+
65+
return (lows, highs), (1 - fracs, fracs)
5766

5867

5968
class Interpolator:

Starfish/grid_tools/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def determine_chunk_log(wl, wl_min, wl_max):
217217
else:
218218
break
219219

220-
assert type(chunk) == np.int, "Chunk is not an integer!. Chunk is {}".format(chunk)
220+
assert type(chunk) == int, "Chunk is not an integer!. Chunk is {}".format(chunk)
221221

222222
if chunk < len_wl:
223223
# Now that we have determined the length of the chunk of the synthetic
@@ -345,4 +345,4 @@ def idl_float(idl_num: str) -> float:
345345
```
346346
"""
347347
idl_str = idl_num.lower()
348-
return np.float(idl_str.replace("d", "e"))
348+
return float(idl_str.replace("d", "e"))

tests/test_grid_tools/test_interpolators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
class TestIndexInterpolator:
88
@pytest.fixture
99
def mock_index_interpolator(self, grid_points):
10-
yield IndexInterpolator(grid_points)
10+
pars = [np.unique(pars) for pars in grid_points.T]
11+
yield IndexInterpolator(pars)
1112

1213
@pytest.mark.parametrize(
1314
"input, expected",

0 commit comments

Comments
 (0)