Skip to content

Commit 4891ef2

Browse files
committed
[impl] Implement a test for multi_svr
1 parent 4ea07f1 commit 4891ef2

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

src/multi_svr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def predict(self, X):
7070
regressor.fit(X, y)
7171

7272
pred_y = regressor.predict(X)
73-
err = metrics.mean_squared_error(y, pred_y, multioutput='raw_values')
73+
errs = metrics.mean_squared_error(y, pred_y, multioutput='raw_values')
7474

7575
print('pred_y:', pred_y)
76-
print('err:', err)
76+
print('errs:', errs)

tests/multi_svr_test.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,41 @@
11
import unittest
2+
from sklearn import metrics
3+
24
import multi_svr
35

46
class MultiSVRTest(unittest.TestCase):
57

6-
def test_dummy(self):
7-
self.assertEqual(1, 1) # TODO impl
8+
def test_prediction(self):
9+
X = [
10+
[0, 0],
11+
[0, 10],
12+
[1, 10],
13+
[1, 20],
14+
[1, 30],
15+
[1, 40]
16+
]
17+
18+
y = [
19+
[0, 0],
20+
[0, 10],
21+
[2, 10],
22+
[2, 20],
23+
[2, 30],
24+
[2, 40]
25+
]
26+
27+
# Create SVR
28+
regressor = multi_svr.MutilSVR(kernel='linear')
29+
# Fit
30+
regressor.fit(X, y)
31+
# Predict
32+
pred_y = regressor.predict(X)
33+
# Calc errors
34+
errs = metrics.mean_squared_error(y, pred_y, multioutput='raw_values')
35+
36+
# Errors should be small
37+
assert(errs[0] < 0.05)
38+
assert(errs[1] < 0.05)
839

940

1041
def suite():

0 commit comments

Comments
 (0)