@@ -97,9 +97,13 @@ TYPED_TEST_P(GenericCSVRKernelFunction, predict) {
9797 constexpr plssvm::kernel_function_type kernel = util::test_parameter_value_at_v<0 , TypeParam>;
9898
9999 // create parameter struct
100- plssvm::parameter params{ plssvm::kernel_type = kernel };
101- if constexpr (kernel != plssvm::kernel_function_type::linear) {
102- params.gamma = plssvm::real_type{ 1.0 };
100+ plssvm::parameter params{ plssvm::cost = 1000.0 , plssvm::kernel_type = kernel };
101+ if constexpr (kernel == plssvm::kernel_function_type::polynomial) {
102+ params.degree = 1 ;
103+ params.gamma = 1.0 ;
104+ }
105+ if constexpr (kernel == plssvm::kernel_function_type::sigmoid) {
106+ params.gamma = 0.01 ;
103107 }
104108
105109 // create data set that is always classifiable
@@ -118,11 +122,16 @@ TYPED_TEST_P(GenericCSVRKernelFunction, predict) {
118122 const plssvm::regression_model<label_type> model = svr.fit (test_data, plssvm::epsilon = 1e-16 );
119123
120124 // actual TEST: predict label
121- [[maybe_unused]] const std::vector<label_type> calculated = svr.predict (model, test_data);
125+ std::vector<label_type> calculated = svr.predict (model, test_data);
122126
123127 // check the calculated result for correctness
124- GTEST_SKIP () << " not yet implemented for C-SVR" ;
125- // EXPECT_EQ(calculated, test_data.labels().value().get());
128+ if constexpr (std::is_floating_point_v<label_type>) {
129+ // convert a floating point label_type back to a plain integer
130+ for (label_type &val : calculated) {
131+ val = static_cast <label_type>(std::round (val));
132+ }
133+ }
134+ EXPECT_EQ (calculated, test_data.labels ().value ().get ());
126135}
127136
128137TYPED_TEST_P (GenericCSVRKernelFunction, score_model) {
@@ -158,7 +167,7 @@ TYPED_TEST_P(GenericCSVRKernelFunction, score_model) {
158167 // check the calculated result for correctness
159168 // 1.0 is the maximum possible value
160169 // arbitrary small (negative) values are possible, but the "easy" data set shouldn't result in values smaller 0.0
161- EXPECT_EXCLUSIVE_RANGE (calculated, plssvm::real_type{ 0.0 }, plssvm::real_type{ 1.0 });
170+ EXPECT_INCLUSIVE_RANGE (calculated, plssvm::real_type{ 0.0 }, plssvm::real_type{ 1.0 });
162171}
163172
164173TYPED_TEST_P (GenericCSVRKernelFunction, score) {
@@ -194,7 +203,7 @@ TYPED_TEST_P(GenericCSVRKernelFunction, score) {
194203 // check the calculated result for correctness
195204 // 1.0 is the maximum possible value
196205 // arbitrary small (negative) values are possible, but the "easy" data set shouldn't result in values smaller 0.0
197- EXPECT_EXCLUSIVE_RANGE (calculated, plssvm::real_type{ 0.0 }, plssvm::real_type{ 1.0 });
206+ EXPECT_INCLUSIVE_RANGE (calculated, plssvm::real_type{ 0.0 }, plssvm::real_type{ 1.0 });
198207}
199208
200209REGISTER_TYPED_TEST_SUITE_P (GenericCSVRKernelFunction,
0 commit comments