@@ -4457,13 +4457,50 @@ void test_multm_prev()
44574457
44584458 for (size_t i = 0 ; i < labels.size (); ++i)
44594459 {
4460- matrix<float , 0 , 1 > x = matrix_cast<float >(randm (dims, 1 )) * rnd.get_double_in_range (1 , 9 );
4461- const auto norm = sqrt (sum (squared (x)));
4462- if (norm < 3 )
4460+ const double class_boundary_1 = 3.0 ;
4461+ const double class_boundary_2 = 6.0 ;
4462+
4463+ const double desired_margin = 0.1 ;
4464+
4465+ const auto get_random_matrix = [&rnd, dims]()
4466+ {
4467+ return matrix<float , 0 , 1 >(matrix_cast<float >(randm (dims, 1 )) * rnd.get_double_in_range (1 , 9 ));
4468+ };
4469+
4470+ const auto get_distance_from_nearest_class_boundary = [class_boundary_1, class_boundary_2](double norm)
4471+ {
4472+ return std::min (
4473+ std::abs (norm - class_boundary_1),
4474+ std::abs (norm - class_boundary_2)
4475+ );
4476+ };
4477+
4478+ auto x = get_random_matrix ();
4479+ auto norm = sqrt (sum (squared (x)));
4480+ auto distance_from_nearest_class_boundary = get_distance_from_nearest_class_boundary (norm);
4481+
4482+ // Try again if the newly generated sample is very close to either of the class boundaries
4483+ int retry_counter = 0 ;
4484+ const int max_retry_counter = 10 ;
4485+ while (distance_from_nearest_class_boundary < desired_margin && ++retry_counter <= max_retry_counter)
4486+ {
4487+ const auto new_x = get_random_matrix ();
4488+ const auto new_norm = sqrt (sum (squared (new_x)));
4489+ const auto new_distance_from_nearest_class_boundary = get_distance_from_nearest_class_boundary (new_norm);
4490+
4491+ if (new_distance_from_nearest_class_boundary > distance_from_nearest_class_boundary)
4492+ {
4493+ x = new_x;
4494+ norm = new_norm;
4495+ distance_from_nearest_class_boundary = new_distance_from_nearest_class_boundary;
4496+ }
4497+ }
4498+
4499+ if (norm < class_boundary_1)
44634500 {
44644501 labels[i][0 ] = 1 .f ;
44654502 }
4466- else if (3 <= norm && norm < 6 )
4503+ else if (class_boundary_1 <= norm && norm < class_boundary_2 )
44674504 {
44684505 labels[i][0 ] = 1 .f ;
44694506 labels[i][1 ] = 1 .f ;
0 commit comments