File tree Expand file tree Collapse file tree 3 files changed +29
-13
lines changed
Expand file tree Collapse file tree 3 files changed +29
-13
lines changed Original file line number Diff line number Diff line change @@ -117,6 +117,9 @@ class SvmModel {
117117
118118 // return label
119119 const vector<int > &get_label () const ;
120+
121+ // return param.probability
122+ const bool is_prob () const ;
120123protected:
121124
122125 /* *
Original file line number Diff line number Diff line change @@ -339,10 +339,15 @@ const SyncArray<float_type> &SvmModel::get_dec_value() const {
339339int SvmModel::get_n_binary_models () const {
340340 return n_binary_models;
341341}
342+
342343const vector<float > &SvmModel::get_prob_predict () const {
343344 return prob_predict;
344345}
345346
347+ const bool SvmModel::is_prob () const {
348+ return param.probability ;
349+ }
350+
346351int SvmModel::get_working_set_size (int n_instances, int n_features) {
347352 size_t free_mem = param.max_mem_size - SyncMem::get_total_memory_size ();
348353 int ws_size = min (max2power (n_instances),
Original file line number Diff line number Diff line change @@ -61,21 +61,29 @@ int main(int argc, char **argv) {
6161
6262 vector<float_type> predict_y;
6363 predict_y = model->predict (predict_dataset.instances (), -1 );
64- vector<int > label;
65- label = model->get_label ();
66- vector<float > prob_predict;
67- prob_predict = model->get_prob_predict ();
68- file<<" labels " ;
69- for (int i = 0 ; i < label.size (); i++){
70- file<<label[i]<<" " ;
71- }
72- file<<std::endl;
73- for (int i = 0 ; i < predict_y.size (); i++) {
74- file << predict_y[i]<<" " ;
75- for (int j = 0 ; j < label.size (); j++){
76- file<<prob_predict[label.size () * i + j]<<" " ;
64+
65+ if (model->is_prob () == 1 ) {
66+ vector<int > label;
67+ label = model->get_label ();
68+ vector<float > prob_predict;
69+ prob_predict = model->get_prob_predict ();
70+ file << " labels " ;
71+ for (int i = 0 ; i < label.size (); i++) {
72+ file << label[i] << " " ;
7773 }
7874 file << std::endl;
75+ for (int i = 0 ; i < predict_y.size (); i++) {
76+ file << predict_y[i] << " " ;
77+ for (int j = 0 ; j < label.size (); j++) {
78+ file << prob_predict[label.size () * i + j] << " " ;
79+ }
80+ file << std::endl;
81+ }
82+ }
83+ else {
84+ for (int i = 0 ; i < predict_y.size (); ++i) {
85+ file << predict_y[i] << std::endl;
86+ }
7987 }
8088 file.close ();
8189
You can’t perform that action at this time.
0 commit comments