File tree Expand file tree Collapse file tree 2 files changed +35
-4
lines changed Expand file tree Collapse file tree 2 files changed +35
-4
lines changed Original file line number Diff line number Diff line change @@ -70,7 +70,7 @@ def predict(self, X):
7070 regressor .fit (X , y )
7171
7272 pred_y = regressor .predict (X )
73- err = metrics .mean_squared_error (y , pred_y , multioutput = 'raw_values' )
73+ errs = metrics .mean_squared_error (y , pred_y , multioutput = 'raw_values' )
7474
7575 print ('pred_y:' , pred_y )
76- print ('err :' , err )
76+ print ('errs :' , errs )
Original file line number Diff line number Diff line change 11import unittest
2+ from sklearn import metrics
3+
24import multi_svr
35
46class MultiSVRTest (unittest .TestCase ):
57
6- def test_dummy (self ):
7- self .assertEqual (1 , 1 ) # TODO impl
8+ def test_prediction (self ):
9+ X = [
10+ [0 , 0 ],
11+ [0 , 10 ],
12+ [1 , 10 ],
13+ [1 , 20 ],
14+ [1 , 30 ],
15+ [1 , 40 ]
16+ ]
17+
18+ y = [
19+ [0 , 0 ],
20+ [0 , 10 ],
21+ [2 , 10 ],
22+ [2 , 20 ],
23+ [2 , 30 ],
24+ [2 , 40 ]
25+ ]
26+
27+ # Create SVR
28+ regressor = multi_svr .MutilSVR (kernel = 'linear' )
29+ # Fit
30+ regressor .fit (X , y )
31+ # Predict
32+ pred_y = regressor .predict (X )
33+ # Calc errors
34+ errs = metrics .mean_squared_error (y , pred_y , multioutput = 'raw_values' )
35+
36+ # Errors should be small
37+ assert (errs [0 ] < 0.05 )
38+ assert (errs [1 ] < 0.05 )
839
940
1041def suite ():
You can’t perform that action at this time.
0 commit comments