Skip to content

Commit 94322a8

Browse files
committed
[impl] Implement MutilSVR
1 parent 30f8f4d commit 94322a8

File tree

1 file changed

+76
-2
lines changed

1 file changed

+76
-2
lines changed

src/multi_svr.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,76 @@
1-
def hello():
2-
return "hello, world from multi_svr"
1+
import numpy as np
2+
import sklearn
3+
from sklearn import svm
4+
5+
6+
class MutilSVR(sklearn.base.BaseEstimator, sklearn.base.RegressorMixin):
7+
def __init__(self, **kwargs):
8+
self.__init_kwargs = kwargs
9+
10+
def fit(self, X, y, **kwargs):
11+
X = np.array(X)
12+
y = np.array(y)
13+
14+
# Get dimension of y
15+
y_dim = np.ndim(y)
16+
if(y_dim == 2):
17+
# Feature dimension
18+
feature_dim = len(y[0])
19+
# Create SVRs
20+
self.svrs = [svm.SVR(**self.__init_kwargs) for _ in range(feature_dim)]
21+
22+
# For each SVR
23+
for curr_feature_dim, svr in enumerate(self.svrs): # (curr=Current)
24+
# Select y
25+
selected_y = y[:,curr_feature_dim]
26+
# Fit
27+
svr.fit(X, selected_y, **kwargs)
28+
else:
29+
raise Exception("Dimension of y must be 2, but found %d" % y_dim)
30+
31+
32+
def predict(self, X):
33+
# Init predict list
34+
preds = []
35+
# For each SVR
36+
for curr_feature_dim, svr in enumerate(self.svrs): # (curr=Current)
37+
# Predict
38+
pred = svr.predict(X)
39+
# Append to preds
40+
preds.append(pred)
41+
42+
pred = np.column_stack(tuple(preds))
43+
return pred
44+
45+
46+
47+
48+
if __name__ == '__main__':
49+
from sklearn import metrics
50+
X = [
51+
[0, 0],
52+
[0, 10],
53+
[1, 10],
54+
[1, 20],
55+
[1, 30],
56+
[1, 40]
57+
]
58+
59+
y = [
60+
[0, 0],
61+
[0, 10],
62+
[2, 10],
63+
[2, 20],
64+
[2, 30],
65+
[2, 40]
66+
]
67+
68+
regressor = MutilSVR(kernel='linear')
69+
70+
regressor.fit(X, y)
71+
72+
pred_y = regressor.predict(X)
73+
err = metrics.mean_squared_error(y, pred_y, multioutput='raw_values')
74+
75+
print('pred_y:', pred_y)
76+
print('err:', err)

0 commit comments

Comments
 (0)