Skip to content

Commit 813ba85

Browse files
REF: Use string-based kernel identifiers in SVOREX and REDSVM (#28)
* REF: Replace kernel integer codes with string identifiers in SVOREX * DOC: Update docstrings in SVOREX * REF: Replace kernel integer codes with string identifiers in REDSVM * DOC: Update docstrings in REDSVM
1 parent 21d4037 commit 813ba85

File tree

4 files changed

+51
-33
lines changed

4 files changed

+51
-33
lines changed

orca_python/classifiers/REDSVM.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@ class REDSVM(BaseEstimator, ClassifierMixin):
2222
C : float, default=1
2323
Set the parameter C.
2424
25-
kernel : int, default=2
25+
kernel : str, default="rbf"
2626
Set type of kernel function.
27-
0 -- linear: u'*v
28-
1 -- polynomial: (gamma*u'*v + coef0)^degree
29-
2 -- radial basis function: exp(-gamma*|u-v|^2)
30-
3 -- sigmoid: tanh(gamma*u'*v + coef0)
31-
4 -- stump: -|u-v|_1 + coef0
32-
5 -- perceptron: -|u-v|_2 + coef0
33-
6 -- laplacian: exp(-gamma*|u-v|_1)
34-
7 -- exponential: exp(-gamma*|u-v|_2)
35-
8 -- precomputed kernel (kernel values in training_instance_matrix)
27+
- linear: u'*v
28+
- polynomial: (gamma*u'*v + coef0)^degree
29+
- rbf: exp(-gamma*|u-v|^2)
30+
- sigmoid: tanh(gamma*u'*v + coef0)
31+
- stump: -|u-v|_1 + coef0
32+
- perceptron: -|u-v|_2 + coef0
33+
- laplacian: exp(-gamma*|u-v|_1)
34+
- exponential: exp(-gamma*|u-v|_2)
35+
- precomputed: kernel values in training_instance_matrix
3636
3737
degree : int, default=3
3838
Set degree in kernel function.
@@ -77,7 +77,7 @@ class REDSVM(BaseEstimator, ClassifierMixin):
7777
def __init__(
7878
self,
7979
C=1,
80-
kernel=2,
80+
kernel="rbf",
8181
degree=3,
8282
gamma=None,
8383
coef0=0,
@@ -126,9 +126,23 @@ def fit(self, X, y):
126126
if self.gamma is None:
127127
self.gamma = 1 / np.size(X, 1)
128128

129+
# Map kernel type
130+
kernel_type_mapping = {
131+
"linear": 0,
132+
"poly": 1,
133+
"rbf": 2,
134+
"sigmoid": 3,
135+
"stump": 4,
136+
"perceptron": 5,
137+
"laplacian": 6,
138+
"exponential": 7,
139+
"precomputed": 8,
140+
}
141+
kernel_type = kernel_type_mapping.get(self.kernel, -1)
142+
129143
# Fit the model
130144
options = "-s 5 -t {} -d {} -g {} -r {} -c {} -m {} -e {} -h {} -q".format(
131-
str(self.kernel),
145+
str(kernel_type),
132146
str(self.degree),
133147
str(self.gamma),
134148
str(self.coef0),

orca_python/classifiers/SVOREX.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ class SVOREX(BaseEstimator, ClassifierMixin):
2020
C : float, default=1
2121
Set the parameter C.
2222
23-
kernel : int, default=0
23+
kernel : str, default="gaussian"
2424
Set type of kernel function.
25-
0 -- gaussian: use gaussian kernel
26-
1 -- linear: use imbalanced Linear kernel
27-
2 -- polynomial: use Polynomial kernel with order p
25+
- gaussian: use gaussian kernel
26+
- linear: use imbalanced Linear kernel
27+
- poly: use Polynomial kernel with order p
2828
2929
degree : int, default=2
3030
Set degree in kernel function.
@@ -56,7 +56,7 @@ class SVOREX(BaseEstimator, ClassifierMixin):
5656
5757
"""
5858

59-
def __init__(self, C=1.0, kernel=0, degree=2, tol=0.001, kappa=1):
59+
def __init__(self, C=1.0, kernel="gaussian", degree=2, tol=0.001, kappa=1):
6060
self.C = C
6161
self.kernel = kernel
6262
self.degree = degree
@@ -93,9 +93,9 @@ def fit(self, X, y):
9393

9494
arg = ""
9595
# Prepare the kernel type arguments
96-
if self.kernel == 1:
96+
if self.kernel == "linear":
9797
arg = "-L"
98-
elif self.kernel == 2:
98+
elif self.kernel == "poly":
9999
arg = "-P {}".format(self.degree)
100100

101101
# Fit the model

orca_python/classifiers/tests/test_redsvm.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ def y():
2424
@pytest.mark.parametrize(
2525
"kernel, expected_file",
2626
[
27-
(0, "predictions_linear_0.csv"),
28-
(1, "predictions_poly_0.csv"),
29-
(2, "predictions_rbf_0.csv"),
30-
(3, "predictions_sigmoid_0.csv"),
31-
(4, "predictions_stump_0.csv"),
32-
(5, "predictions_perceptron_0.csv"),
33-
(6, "predictions_laplacian_0.csv"),
34-
(7, "predictions_exponential_0.csv"),
27+
("linear", "predictions_linear_0.csv"),
28+
("poly", "predictions_poly_0.csv"),
29+
("rbf", "predictions_rbf_0.csv"),
30+
("sigmoid", "predictions_sigmoid_0.csv"),
31+
("stump", "predictions_stump_0.csv"),
32+
("perceptron", "predictions_perceptron_0.csv"),
33+
("laplacian", "predictions_laplacian_0.csv"),
34+
("exponential", "predictions_exponential_0.csv"),
3535
],
3636
)
3737
def test_redsvm_predict_matches_expected(kernel, expected_file):
@@ -65,11 +65,15 @@ def test_redsvm_predict_matches_expected(kernel, expected_file):
6565
@pytest.mark.parametrize(
6666
"param_name, invalid_value, error_msg",
6767
[
68-
("kernel", -1, "unknown kernel type"),
68+
("kernel", "unknown", "unknown kernel type"),
6969
("cache_size", -1, "cache_size <= 0"),
7070
("tol", -1, "eps <= 0"),
7171
("shrinking", 2, "shrinking != 0 and shrinking != 1"),
72-
("kernel", 8, "Wrong input format: sample_serial_number out of range"),
72+
(
73+
"kernel",
74+
"precomputed",
75+
"Wrong input format: sample_serial_number out of range",
76+
),
7377
],
7478
)
7579
def test_redsvm_fit_hyperparameters_validation(

orca_python/classifiers/tests/test_svorex.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def y():
2424
@pytest.mark.parametrize(
2525
"kernel, expected_file",
2626
[
27-
(0, "predictions_gaussian_0.csv"),
28-
(1, "predictions_linear_0.csv"),
29-
(2, "predictions_poly_0.csv"),
27+
("gaussian", "predictions_gaussian_0.csv"),
28+
("linear", "predictions_linear_0.csv"),
29+
("poly", "predictions_poly_0.csv"),
3030
],
3131
)
3232
def test_svorex_predict_matches_expected(kernel, expected_file):
@@ -53,7 +53,7 @@ def test_svorex_predict_matches_expected(kernel, expected_file):
5353
({"tol": 0}, "- T is invalid"),
5454
({"C": 0}, "- C is invalid"),
5555
({"kappa": 0}, "- K is invalid"),
56-
({"kernel": 2, "degree": 0}, "- P is invalid"),
56+
({"kernel": "poly", "degree": 0}, "- P is invalid"),
5757
({"kappa": -1}, "-1 is invalid"),
5858
],
5959
)

0 commit comments

Comments
 (0)