Skip to content

Commit 6a62fdd

Browse files
committed
fix prob_predict on cmd #149
1 parent c7f6882 commit 6a62fdd

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

include/thundersvm/model/svmmodel.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ class SvmModel {
114114

115115
//return sv_indices
116116
const vector<int> &get_sv_ind() const;
117+
118+
//return label
119+
const vector<int> &get_label() const;
117120
protected:
118121

119122
/**

src/thundersvm/model/svmmodel.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,3 +405,7 @@ void SvmModel::compute_linear_coef_single_model(size_t n_feature, const bool zer
405405
int SvmModel::get_sv_max_index() const{
406406
return sv_max_index;
407407
}
408+
409+
const vector<int> &SvmModel::get_label() const{
410+
return label;
411+
}

src/thundersvm/thundersvm-predict.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,21 @@ int main(int argc, char **argv) {
6161

6262
vector<float_type> predict_y;
6363
predict_y = model->predict(predict_dataset.instances(), -1);
64-
for (int i = 0; i < predict_y.size(); ++i) {
65-
file << predict_y[i] << std::endl;
64+
vector<int> label;
65+
label = model->get_label();
66+
vector<float> prob_predict;
67+
prob_predict = model->get_prob_predict();
68+
file<<"labels ";
69+
for (int i = 0; i < label.size(); i++){
70+
file<<label[i]<<" ";
71+
}
72+
file<<std::endl;
73+
for (int i = 0; i < predict_y.size(); i++) {
74+
file << predict_y[i]<<" ";
75+
for(int j = 0; j < label.size(); j++){
76+
file<<prob_predict[label.size() * i + j]<<" ";
77+
}
78+
file << std::endl;
6679
}
6780
file.close();
6881

0 commit comments

Comments
 (0)