Skip to content

Commit 44c02d5

Browse files
committed
Merge branch 'release/0.1.0'
2 parents 2a62872 + e737c17 commit 44c02d5

File tree

9 files changed

+237
-0
lines changed

9 files changed

+237
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
/.idea
2+
*.egg-info/
3+
__pycache__/

.travis.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# (from: https://qiita.com/masashi127/items/5bfcba5cad8e82958844)
2+
3+
language: python
4+
5+
python:
6+
- 3.4
7+
- 3.5
8+
- 3.6
9+
10+
addons:
11+
apt:
12+
packages:
13+
# (from: https://github.com/dnouri/nolearn/blob/master/.travis.yml)
14+
- libblas-dev
15+
- liblapack-dev
16+
- gfortran
17+
18+
19+
before_install:
20+
- pip install -U pip setuptools wheel # (from: https://github.com/dnouri/nolearn/blob/master/.travis.yml)
21+
22+
install:
23+
- travis_wait travis_retry pip install -r requirements.txt # (from: https://github.com/dnouri/nolearn/blob/master/.travis.yml)
24+
- pip install coveralls
25+
26+
script:
27+
- coverage run --source=nwtgck_hello_test setup.py test
28+
29+
after_success:
30+
- coveralls
31+
32+
cache:
33+
- apt
34+
- directories:
35+
- $HOME/.cache/pip

README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# multi_svr [![Build Status](https://travis-ci.org/nwtgck/multi-svr-python.svg?branch=develop)](https://travis-ci.org/nwtgck/multi-svr-python) [![Coverage Status](https://coveralls.io/repos/github/nwtgck/multi-svr-python/badge.svg?branch=develop)](https://coveralls.io/github/nwtgck/multi-svr-python?branch=develop)
2+
3+
Support Vector Regression (SVR) for multidimensional labels
4+
5+
6+
## Installation
7+
8+
```bash
9+
pip3 install git+https://github.com/nwtgck/multi-svr-python
10+
```
11+
12+
13+
## Usage
14+
15+
```python
16+
X = [
17+
[0, 0],
18+
[0, 10],
19+
[1, 10],
20+
[1, 20],
21+
[1, 30],
22+
[1, 40]
23+
]
24+
25+
y = [
26+
[0, 0],
27+
[0, 10],
28+
[2, 10],
29+
[2, 20],
30+
[2, 30],
31+
[2, 40]
32+
]
33+
34+
# Create SVR
35+
regressor = multi_svr.MutilSVR(kernel='linear')
36+
# Fit
37+
regressor.fit(X, y)
38+
# Predict
39+
pred_y = regressor.predict(X)
40+
# Calc errors
41+
errs = metrics.mean_squared_error(y, pred_y)
42+
```
43+
44+
## How to test
45+
46+
```bash
47+
cd <this repo
48+
python setup.py test
49+
```

requirements.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
certifi==2017.11.5
2+
chardet==3.0.4
3+
docopt==0.6.2
4+
idna==2.6
5+
numpy==1.13.3
6+
requests==2.18.4
7+
scikit-learn==0.19.1
8+
scipy==1.0.0
9+
urllib3==1.22

setup.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# (from: https://github.com/masaponto/Python-MLP/blob/master/setup.py)
2+
# (from: https://qiita.com/masashi127/items/5bfcba5cad8e82958844)
3+
# (from: https://qiita.com/hotoku/items/4789533f5e497f3dc6e0)
4+
5+
from setuptools import setup, find_packages
6+
import sys
7+
8+
sys.path.append('./src')
9+
sys.path.append('./tests')
10+
11+
setup(
12+
name='multi_svr',
13+
version='0.1.0',
14+
description='SVR for multidimensional label',
15+
author='Ryo Ota',
16+
author_email='[email protected]',
17+
install_requires=['scikit-learn', 'numpy', 'SciPy'],
18+
py_modules=["multi_svr"],
19+
packages=find_packages(),
20+
test_suite='tests'
21+
)

src/__init__.py

Whitespace-only changes.

src/multi_svr.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import numpy as np
2+
import sklearn
3+
from sklearn import svm
4+
5+
6+
class MutilSVR(sklearn.base.BaseEstimator, sklearn.base.RegressorMixin):
7+
def __init__(self, **kwargs):
8+
self.__init_kwargs = kwargs
9+
10+
def fit(self, X, y, **kwargs):
11+
X = np.array(X)
12+
y = np.array(y)
13+
14+
# Get dimension of y
15+
y_dim = np.ndim(y)
16+
if(y_dim == 2):
17+
# Feature dimension
18+
feature_dim = len(y[0])
19+
# Create SVRs
20+
self.svrs = [svm.SVR(**self.__init_kwargs) for _ in range(feature_dim)]
21+
22+
# For each SVR
23+
for curr_feature_dim, svr in enumerate(self.svrs): # (curr=Current)
24+
# Select y
25+
selected_y = y[:,curr_feature_dim]
26+
# Fit
27+
svr.fit(X, selected_y, **kwargs)
28+
else:
29+
raise Exception("Dimension of y must be 2, but found %d" % y_dim)
30+
31+
32+
def predict(self, X):
33+
# Init predict list
34+
preds = []
35+
# For each SVR
36+
for curr_feature_dim, svr in enumerate(self.svrs): # (curr=Current)
37+
# Predict
38+
pred = svr.predict(X)
39+
# Append to preds
40+
preds.append(pred)
41+
42+
pred = np.column_stack(tuple(preds))
43+
return pred
44+
45+
46+
47+
48+
if __name__ == '__main__':
49+
from sklearn import metrics
50+
X = [
51+
[0, 0],
52+
[0, 10],
53+
[1, 10],
54+
[1, 20],
55+
[1, 30],
56+
[1, 40]
57+
]
58+
59+
y = [
60+
[0, 0],
61+
[0, 10],
62+
[2, 10],
63+
[2, 20],
64+
[2, 30],
65+
[2, 40]
66+
]
67+
68+
regressor = MutilSVR(kernel='linear')
69+
70+
regressor.fit(X, y)
71+
72+
pred_y = regressor.predict(X)
73+
errs = metrics.mean_squared_error(y, pred_y, multioutput='raw_values')
74+
75+
print('pred_y:', pred_y)
76+
print('errs:', errs)

tests/__init__.py

Whitespace-only changes.

tests/multi_svr_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import unittest
2+
from sklearn import metrics
3+
4+
import multi_svr
5+
6+
class MultiSVRTest(unittest.TestCase):
7+
8+
def test_prediction(self):
9+
X = [
10+
[0, 0],
11+
[0, 10],
12+
[1, 10],
13+
[1, 20],
14+
[1, 30],
15+
[1, 40]
16+
]
17+
18+
y = [
19+
[0, 0],
20+
[0, 10],
21+
[2, 10],
22+
[2, 20],
23+
[2, 30],
24+
[2, 40]
25+
]
26+
27+
# Create SVR
28+
regressor = multi_svr.MutilSVR(kernel='linear')
29+
# Fit
30+
regressor.fit(X, y)
31+
# Predict
32+
pred_y = regressor.predict(X)
33+
# Calc errors
34+
errs = metrics.mean_squared_error(y, pred_y, multioutput='raw_values')
35+
36+
# Errors should be small
37+
assert(errs[0] < 0.05)
38+
assert(errs[1] < 0.05)
39+
40+
41+
def suite():
42+
suite = unittest.TestSuite()
43+
suite.addTest(unittest.makeSuite(MultiSVRTest))
44+
return suite

0 commit comments

Comments
 (0)