Skip to content

Commit 4ae836c

Browse files
committed
Merge branch 'regression' into spack
2 parents 5f7fc7f + 8e87243 commit 4ae836c

File tree

2 files changed

+32
-17
lines changed

2 files changed

+32
-17
lines changed

include/plssvm/svm/csvr.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,16 @@
2929

3030
#include "igor/igor.hpp" // igor::parser
3131

32-
#include <algorithm> // std::all_of
33-
#include <chrono> // std::chrono::{time_point, steady_clock, duration_cast, milliseconds}
34-
#include <cstddef> // std::size_t
35-
#include <memory> // std::addressof
36-
#include <optional> // std::make_optional
37-
#include <tuple> // std::tie
38-
#include <utility> // std::move
39-
#include <vector> // std::vector
32+
#include <algorithm> // std::all_of
33+
#include <chrono> // std::chrono::{time_point, steady_clock, duration_cast, milliseconds}
34+
#include <cmath> // std::round
35+
#include <cstddef> // std::size_t
36+
#include <memory> // std::addressof
37+
#include <optional> // std::make_optional
38+
#include <tuple> // std::tie
39+
#include <type_traits> // std::is_floating_point_v
40+
#include <utility> // std::move
41+
#include <vector> // std::vector
4042

4143
namespace plssvm {
4244

@@ -224,7 +226,11 @@ class csvr : virtual public csvm {
224226

225227
for (std::size_t i = 0; i < data.num_data_points(); ++i) {
226228
// TODO: is there multiclass regression? https://en.wikipedia.org/wiki/Multinomial_logistic_regression
227-
predicted_labels[i] = static_cast<label_type>(votes(i, 0));
229+
if constexpr (std::is_floating_point_v<label_type>) {
230+
predicted_labels[i] = static_cast<label_type>(votes(i, 0));
231+
} else {
232+
predicted_labels[i] = static_cast<label_type>(std::round(votes(i, 0)));
233+
}
228234
}
229235

230236
PLSSVM_DETAIL_TRACKING_PERFORMANCE_TRACKER_ADD_EVENT("predict end");

tests/backends/generic_base_csvr_tests.hpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,13 @@ TYPED_TEST_P(GenericCSVRKernelFunction, predict) {
9797
constexpr plssvm::kernel_function_type kernel = util::test_parameter_value_at_v<0, TypeParam>;
9898

9999
// create parameter struct
100-
plssvm::parameter params{ plssvm::kernel_type = kernel };
101-
if constexpr (kernel != plssvm::kernel_function_type::linear) {
102-
params.gamma = plssvm::real_type{ 1.0 };
100+
plssvm::parameter params{ plssvm::cost = 1000.0, plssvm::kernel_type = kernel };
101+
if constexpr (kernel == plssvm::kernel_function_type::polynomial) {
102+
params.degree = 1;
103+
params.gamma = 1.0;
104+
}
105+
if constexpr (kernel == plssvm::kernel_function_type::sigmoid) {
106+
params.gamma = 0.01;
103107
}
104108

105109
// create data set that is always classifiable
@@ -118,11 +122,16 @@ TYPED_TEST_P(GenericCSVRKernelFunction, predict) {
118122
const plssvm::regression_model<label_type> model = svr.fit(test_data, plssvm::epsilon = 1e-16);
119123

120124
// actual TEST: predict label
121-
[[maybe_unused]] const std::vector<label_type> calculated = svr.predict(model, test_data);
125+
std::vector<label_type> calculated = svr.predict(model, test_data);
122126

123127
// check the calculated result for correctness
124-
GTEST_SKIP() << "not yet implemented for C-SVR";
125-
// EXPECT_EQ(calculated, test_data.labels().value().get());
128+
if constexpr (std::is_floating_point_v<label_type>) {
129+
// convert a floating point label_type back to a plain integer
130+
for (label_type &val : calculated) {
131+
val = static_cast<label_type>(std::round(val));
132+
}
133+
}
134+
EXPECT_EQ(calculated, test_data.labels().value().get());
126135
}
127136

128137
TYPED_TEST_P(GenericCSVRKernelFunction, score_model) {
@@ -158,7 +167,7 @@ TYPED_TEST_P(GenericCSVRKernelFunction, score_model) {
158167
// check the calculated result for correctness
159168
// 1.0 is the maximum possible value
160169
// arbitrary small (negative) values are possible, but the "easy" data set shouldn't result in values smaller 0.0
161-
EXPECT_EXCLUSIVE_RANGE(calculated, plssvm::real_type{ 0.0 }, plssvm::real_type{ 1.0 });
170+
EXPECT_INCLUSIVE_RANGE(calculated, plssvm::real_type{ 0.0 }, plssvm::real_type{ 1.0 });
162171
}
163172

164173
TYPED_TEST_P(GenericCSVRKernelFunction, score) {
@@ -194,7 +203,7 @@ TYPED_TEST_P(GenericCSVRKernelFunction, score) {
194203
// check the calculated result for correctness
195204
// 1.0 is the maximum possible value
196205
// arbitrary small (negative) values are possible, but the "easy" data set shouldn't result in values smaller 0.0
197-
EXPECT_EXCLUSIVE_RANGE(calculated, plssvm::real_type{ 0.0 }, plssvm::real_type{ 1.0 });
206+
EXPECT_INCLUSIVE_RANGE(calculated, plssvm::real_type{ 0.0 }, plssvm::real_type{ 1.0 });
198207
}
199208

200209
REGISTER_TYPED_TEST_SUITE_P(GenericCSVRKernelFunction,

0 commit comments

Comments
 (0)