Skip to content

Commit 35a8e1f

Browse files
authored
Improve the somewhat flaky test_loss_multibinary_log by avoiding samples very close to class boundaries (#3112)
1 parent 5ce4891 commit 35a8e1f

File tree

1 file changed

+41
-4
lines changed

1 file changed

+41
-4
lines changed

dlib/test/dnn.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4457,13 +4457,50 @@ void test_multm_prev()
44574457

44584458
for (size_t i = 0; i < labels.size(); ++i)
44594459
{
4460-
matrix<float, 0, 1> x = matrix_cast<float>(randm(dims, 1)) * rnd.get_double_in_range(1, 9);
4461-
const auto norm = sqrt(sum(squared(x)));
4462-
if (norm < 3)
4460+
const double class_boundary_1 = 3.0;
4461+
const double class_boundary_2 = 6.0;
4462+
4463+
const double desired_margin = 0.1;
4464+
4465+
const auto get_random_matrix = [&rnd, dims]()
4466+
{
4467+
return matrix<float, 0, 1>(matrix_cast<float>(randm(dims, 1)) * rnd.get_double_in_range(1, 9));
4468+
};
4469+
4470+
const auto get_distance_from_nearest_class_boundary = [class_boundary_1, class_boundary_2](double norm)
4471+
{
4472+
return std::min(
4473+
std::abs(norm - class_boundary_1),
4474+
std::abs(norm - class_boundary_2)
4475+
);
4476+
};
4477+
4478+
auto x = get_random_matrix();
4479+
auto norm = sqrt(sum(squared(x)));
4480+
auto distance_from_nearest_class_boundary = get_distance_from_nearest_class_boundary(norm);
4481+
4482+
// Try again if the newly generated sample is very close to either of the class boundaries
4483+
int retry_counter = 0;
4484+
const int max_retry_counter = 10;
4485+
while (distance_from_nearest_class_boundary < desired_margin && ++retry_counter <= max_retry_counter)
4486+
{
4487+
const auto new_x = get_random_matrix();
4488+
const auto new_norm = sqrt(sum(squared(new_x)));
4489+
const auto new_distance_from_nearest_class_boundary = get_distance_from_nearest_class_boundary(new_norm);
4490+
4491+
if (new_distance_from_nearest_class_boundary > distance_from_nearest_class_boundary)
4492+
{
4493+
x = new_x;
4494+
norm = new_norm;
4495+
distance_from_nearest_class_boundary = new_distance_from_nearest_class_boundary;
4496+
}
4497+
}
4498+
4499+
if (norm < class_boundary_1)
44634500
{
44644501
labels[i][0] = 1.f;
44654502
}
4466-
else if (3 <= norm && norm < 6)
4503+
else if (class_boundary_1 <= norm && norm < class_boundary_2)
44674504
{
44684505
labels[i][0] = 1.f;
44694506
labels[i][1] = 1.f;

0 commit comments

Comments
 (0)