Skip to content

Commit 8466021

Browse files
committed
add attribute support_ #138
1 parent cbc08c6 commit 8466021

File tree

8 files changed

+24
-2
lines changed

8 files changed

+24
-2
lines changed

include/thundersvm/model/svmmodel.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ class SvmModel {
111111

112112
//return sv_max_index
113113
int get_sv_max_index() const;
114+
115+
//return sv_indices
116+
const vector<int> &get_sv_ind() const;
114117
protected:
115118

116119
/**
@@ -146,6 +149,8 @@ class SvmModel {
146149
*/
147150

148151
DataSet::node2d sv;
152+
///the indices of support vectors
153+
vector<int> sv_indices;
149154
///the number of support vectors for each class
150155
SyncArray<int> n_sv;
151156

python/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ The usage of thundersvm scikit interface is similar to sklearn.svm.
100100
The seed of the pseudo random number generator to use when shuffling the data. If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random.
101101

102102
### Attributes
103+
*support_*: array-like, shape = [n_SV]\
104+
indices of support vectors.
105+
103106
*support_vectors_*: array-like, shape = [n_SV, n_features]\
104107
support vectors.
105108

python/thundersvm/thundersvm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,16 @@ def fit(self, X, y):
129129
csr_col = (c_int * (self.n_sv * self.n_features))()
130130
csr_data = (c_float * (self.n_sv * self.n_features))()
131131
data_size = (c_int * 1)()
132-
thundersvm.get_sv(csr_row, csr_col, csr_data, data_size, c_void_p(self.model))
132+
sv_indices = (c_int * self.n_sv)()
133+
thundersvm.get_sv(csr_row, csr_col, csr_data, data_size, sv_indices, c_void_p(self.model))
133134
self.row = np.array([csr_row[index] for index in range(0, self.n_sv + 1)])
134135
self.col = np.array([csr_col[index] for index in range(0, data_size[0])])
135136
self.data = np.array([csr_data[index] for index in range(0, data_size[0])])
136137

137138
self.support_vectors_ = sp.csr_matrix((self.data, self.col, self.row))
138139
if self._sparse == False:
139140
self.support_vectors_ = self.support_vectors_.toarray(order = 'C')
141+
self.support_ = np.array([sv_indices[index] for index in range(0, self.n_sv)]).astype(int)
140142

141143
dual_coef = (c_float * ((self.n_classes - 1) * self.n_sv))()
142144
thundersvm.get_coef(dual_coef, self.n_classes, self.n_sv, c_void_p(self.model))

src/thundersvm/model/oneclass_svc.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ void OneClassSVC::train(const DataSet &dataset, SvmParam param) {
4040
for (int i = 0; i < n_instances; ++i) {
4141
if (alpha_data[i] != 0) {
4242
sv.push_back(dataset.instances()[i]);
43+
sv_indices.push_back(i);
4344
coef_vec.push_back(alpha_data[i]);
4445
}
4546
}

src/thundersvm/model/svc.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ void SVC::train(const DataSet &dataset, SvmParam param) {
6868
if (is_sv[original_index[j]]) {
6969
n_sv_data[i]++;
7070
sv.push_back(i_instances[j]);
71+
sv_indices.push_back(original_index[j]);
7172
}
7273
}
7374
}

src/thundersvm/model/svmmodel.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ int SvmModel::get_n_classes() const {
324324
return n_classes;
325325
}
326326

327+
const vector<int> &SvmModel::get_sv_ind() const {
328+
return sv_indices;
329+
}
330+
327331
void SvmModel::set_max_iter(int iter) {
328332
max_iter = iter;
329333
}

src/thundersvm/model/svr.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ void SVR::save_svr_coef(const SyncArray<float_type> &alpha_2, const DataSet::nod
4949
float_type alpha_i = alpha_2_data[i] - alpha_2_data[i + n_instances];
5050
if (alpha_i != 0) {
5151
sv.push_back(instances[i]);
52+
sv_indices.push_back(i);
5253
coef_vec.push_back(alpha_i);
5354
}
5455
}

src/thundersvm/thundersvm-scikit.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ extern "C" {
260260
return;
261261
}
262262

263-
void get_sv(int* row, int* col, float* data, int* data_size, SvmModel* model){
263+
void get_sv(int* row, int* col, float* data, int* data_size, int* sv_indices, SvmModel* model){
264264
DataSet::node2d svs = model->svs();
265265
row[0] = 0;
266266
int data_ind = 0;
@@ -277,6 +277,11 @@ extern "C" {
277277
}
278278
}
279279
data_size[0] = data_ind;
280+
281+
vector<int> sv_index = model->get_sv_ind();
282+
for(int i = 0; i < sv_index.size(); i++){
283+
sv_indices[i] = sv_index[i];
284+
}
280285
return ;
281286
}
282287

0 commit comments

Comments
 (0)