Skip to content

Commit 70b64f2

Browse files
committed
fix predict output on cmd
1 parent 6a62fdd commit 70b64f2

File tree

3 files changed

+29
-13
lines changed

3 files changed

+29
-13
lines changed

include/thundersvm/model/svmmodel.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ class SvmModel {
117117

118118
//return label
119119
const vector<int> &get_label() const;
120+
121+
//return param.probability
122+
const bool is_prob() const;
120123
protected:
121124

122125
/**

src/thundersvm/model/svmmodel.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,15 @@ const SyncArray<float_type> &SvmModel::get_dec_value() const {
339339
int SvmModel::get_n_binary_models() const{
340340
return n_binary_models;
341341
}
342+
342343
const vector<float> &SvmModel::get_prob_predict() const{
343344
return prob_predict;
344345
}
345346

347+
const bool SvmModel::is_prob() const{
348+
return param.probability;
349+
}
350+
346351
int SvmModel::get_working_set_size(int n_instances, int n_features) {
347352
size_t free_mem = param.max_mem_size - SyncMem::get_total_memory_size();
348353
int ws_size = min(max2power(n_instances),

src/thundersvm/thundersvm-predict.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,29 @@ int main(int argc, char **argv) {
6161

6262
vector<float_type> predict_y;
6363
predict_y = model->predict(predict_dataset.instances(), -1);
64-
vector<int> label;
65-
label = model->get_label();
66-
vector<float> prob_predict;
67-
prob_predict = model->get_prob_predict();
68-
file<<"labels ";
69-
for (int i = 0; i < label.size(); i++){
70-
file<<label[i]<<" ";
71-
}
72-
file<<std::endl;
73-
for (int i = 0; i < predict_y.size(); i++) {
74-
file << predict_y[i]<<" ";
75-
for(int j = 0; j < label.size(); j++){
76-
file<<prob_predict[label.size() * i + j]<<" ";
64+
65+
if(model->is_prob() == 1) {
66+
vector<int> label;
67+
label = model->get_label();
68+
vector<float> prob_predict;
69+
prob_predict = model->get_prob_predict();
70+
file << "labels ";
71+
for (int i = 0; i < label.size(); i++) {
72+
file << label[i] << " ";
7773
}
7874
file << std::endl;
75+
for (int i = 0; i < predict_y.size(); i++) {
76+
file << predict_y[i] << " ";
77+
for (int j = 0; j < label.size(); j++) {
78+
file << prob_predict[label.size() * i + j] << " ";
79+
}
80+
file << std::endl;
81+
}
82+
}
83+
else{
84+
for (int i = 0; i < predict_y.size(); ++i) {
85+
file << predict_y[i] << std::endl;
86+
}
7987
}
8088
file.close();
8189

0 commit comments

Comments
 (0)