Skip to content

Commit 677794b

Browse files
authored
FIX problem with extra dimension that makes linear interpolation failing (#247)
* FIX problem with extra dimension that makes linear interpolation failing
1 parent e57fe51 commit 677794b

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

ezyrb/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@ def predict(self, new_point):
5555
:return: the interpolated values.
5656
:rtype: numpy.ndarray
5757
"""
58-
return self.interpolator(new_point)
58+
return self.interpolator(new_point).squeeze()

tests/test_linear.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from unittest import TestCase
55
from ezyrb import Linear, Database, POD, ReducedOrderModel
66

7-
class TestKNeighbors(TestCase):
7+
class TestLinear(TestCase):
88
def test_params(self):
99
reg = Linear(fill_value=0)
1010
assert reg.fill_value == 0
@@ -52,6 +52,13 @@ def test_with_db_predict(self):
5252
assert rom.predict([2]) == 5
5353
assert rom.predict([3]) == 3
5454

55+
Y = np.random.uniform(size=(3, 3))
56+
db = Database(np.array([1, 2, 3]), Y)
57+
rom = ReducedOrderModel(db, POD(), Linear())
58+
rom.fit()
59+
assert rom.predict([1.]).shape == (3,)
60+
61+
5562
def test_wrong1(self):
5663
# wrong number of params
5764
with warnings.catch_warnings():

0 commit comments

Comments
 (0)