Skip to content

Commit a7442a3

Browse files
committed
fix a bug for linear coef
1 parent e6d365a commit a7442a3

File tree

8 files changed

+32
-14
lines changed

8 files changed

+32
-14
lines changed

include/thundersvm/dataset.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class DataSet {
8080

8181
const vector<int> original_index(int y_i, int y_j) const;
8282

83+
const bool is_zero_based() const;
8384
private:
8485
vector<float_type> y_;
8586
node2d instances_;
@@ -89,5 +90,6 @@ class DataSet {
8990
vector<int> count_; //the number of instances of each class
9091
vector<int> label_;
9192
vector<int> perm_;
93+
bool zero_based = 0; //is zero_based format dataset?
9294
};
9395
#endif //THUNDERSVM_DATASET_H

include/thundersvm/model/svmmodel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class SvmModel {
105105
//return prob_predict
106106
const vector<float> &get_prob_predict() const;
107107

108-
void compute_linear_coef_single_model(size_t n_feature);
108+
void compute_linear_coef_single_model(size_t n_feature, const bool zero_based);
109109
//get the params, for scikit load params
110110
void get_param(char* kernel_type, int* degree, float* gamma, float* coef0, int* probability);
111111

src/thundersvm/dataset.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ void DataSet::load_from_file(string file_name) {
7777
float v;
7878
CHECK_EQ(sscanf(tuple.c_str(), "%d:%f", &i, &v), 2) << "read error, using [index]:[value] format";
7979
instances_thread[tid].back().emplace_back(i, v);
80+
if(i == 0 && zero_based == 0) zero_based = 1;
8081
if (i > local_feature[tid]) local_feature[tid] = i;
8182
};
8283

@@ -158,7 +159,7 @@ void DataSet::load_from_dense(int row_size, int features, float* data, float* la
158159
if(label != NULL)
159160
y_.push_back(label[i]);
160161
instances_.emplace_back();
161-
for(int j = 0; j < features; j++){
162+
for(int j = 1; j <= features; j++){
162163
ind = j;
163164
v = data[off];
164165
off++;
@@ -285,4 +286,6 @@ const vector<float_type> &DataSet::y() const {
285286
return y_;
286287
}
287288

288-
289+
const bool DataSet::is_zero_based() const{
290+
return zero_based;
291+
}

src/thundersvm/model/nusvr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void NuSVR::train(const DataSet &dataset, SvmParam param) {
3636
save_svr_coef(alpha_2, dataset.instances());
3737

3838
if(param.kernel_type == SvmParam::LINEAR){
39-
compute_linear_coef_single_model(dataset.n_features());
39+
compute_linear_coef_single_model(dataset.n_features(), dataset.is_zero_based());
4040
}
4141
}
4242

src/thundersvm/model/oneclass_svc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ void OneClassSVC::train(const DataSet &dataset, SvmParam param) {
5151
coef.copy_from(coef_vec.data(), coef_vec.size());
5252

5353
if(param.kernel_type == SvmParam::LINEAR){
54-
compute_linear_coef_single_model(dataset.n_features());
54+
compute_linear_coef_single_model(dataset.n_features(), dataset.is_zero_based());
5555
}
5656
}
5757

src/thundersvm/model/svc.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,24 @@ void SVC::train(const DataSet &dataset, SvmParam param) {
109109
///TODO: Use coef instead of alpha_data to compute linear_coef_data
110110
if(param.kernel_type == SvmParam::LINEAR){
111111
int k = 0;
112-
linear_coef.resize(n_binary_models * dataset_.n_features());
112+
if(dataset_.is_zero_based())
113+
linear_coef.resize(n_binary_models * (dataset_.n_features()+1));
114+
else
115+
linear_coef.resize(n_binary_models * dataset_.n_features());
113116
float_type *linear_coef_data = linear_coef.host_data();
114117
for (int i = 0; i < n_classes; i++){
115118
for (int j = i + 1; j < n_classes; j++){
116119
const float_type *alpha_data = alpha[k].host_data();
117120
DataSet::node2d ins = dataset_.instances(i, j);//get instances of class i and j
118121
for(int iid = 0; iid < ins.size(); iid++) {
119122
for (int fid = 0; fid < ins[iid].size(); fid++) {
120-
if(alpha_data[iid] != 0)
121-
linear_coef_data[k * dataset_.n_features() + ins[iid][fid].index - 1] += alpha_data[iid] * ins[iid][fid].value;
122-
}
123+
if(alpha_data[iid] != 0){
124+
if(dataset_.is_zero_based())
125+
linear_coef_data[k * dataset_.n_features() + ins[iid][fid].index] += alpha_data[iid] * ins[iid][fid].value;
126+
else
127+
linear_coef_data[k * dataset_.n_features() + ins[iid][fid].index - 1] += alpha_data[iid] * ins[iid][fid].value;
128+
}
129+
}
123130
}
124131
k++;
125132
}

src/thundersvm/model/svmmodel.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,17 +381,23 @@ void SvmModel::get_param(char* kernel_type, int* degree, float* gamma, float* co
381381
*probability = param.probability;
382382
}
383383

384-
void SvmModel::compute_linear_coef_single_model(size_t n_feature){
385-
linear_coef.resize(n_feature);
384+
void SvmModel::compute_linear_coef_single_model(size_t n_feature, const bool zero_based){
385+
if(zero_based)
386+
linear_coef.resize(n_feature+1);
387+
else
388+
linear_coef.resize(n_feature);
386389
float_type* linear_coef_data = linear_coef.host_data();
387390
float_type* coef_data = coef.host_data();
388391
for(int i = 0; i < n_total_sv; i++){
389392
for(int j = 0; j < sv[i].size(); j++){
390-
linear_coef_data[sv[i][j].index - 1] += coef_data[i] * sv[i][j].value;
393+
if(zero_based)
394+
linear_coef_data[sv[i][j].index] += coef_data[i] * sv[i][j].value;
395+
else
396+
linear_coef_data[sv[i][j].index - 1] += coef_data[i] * sv[i][j].value;
391397
}
392398
}
393399
}
394400

395401
int SvmModel::get_sv_max_index() const{
396402
return sv_max_index;
397-
}
403+
}

src/thundersvm/model/svr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void SVR::train(const DataSet &dataset, SvmParam param) {
3636
save_svr_coef(alpha_2, dataset.instances());
3737

3838
if(param.kernel_type == SvmParam::LINEAR){
39-
compute_linear_coef_single_model(dataset.n_features());
39+
compute_linear_coef_single_model(dataset.n_features(), dataset.is_zero_based());
4040
}
4141
}
4242

0 commit comments

Comments
 (0)