Skip to content

Commit f0124db

Browse files
authored
Merge pull request #2535 from ivikhrev/social-distance-fix
Social distance demo fix
2 parents 9687b11 + 6555c7e commit f0124db

File tree

2 files changed

+31
-43
lines changed

2 files changed

+31
-43
lines changed

demos/social_distance_demo/cpp/include/person_trackers.hpp

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -29,46 +29,42 @@ class PersonTrackers {
2929
PersonTrackers() : trackIdGenerator{0}, similarityThreshold{0.7f}, maxDisappeared{10} {}
3030

3131
void similarity(std::list<TrackableObject> &tos) {
32-
if (trackables.size() > 0) {
33-
for (const auto& to : tos) {
34-
std::deque<std::pair<int, float>> sim;
35-
for (auto &tracker : trackables) {
36-
if (!tracker.second.updated) {
37-
float cosine = cosineSimilarity(to.reid, tracker.second.reid);
38-
if (cosine > similarityThreshold) {
39-
sim.push_back(std::make_pair(tracker.first, cosine));
40-
}
32+
for (const auto& to : tos) {
33+
std::deque<std::pair<int, float>> sim;
34+
for (auto &tracker : trackables) {
35+
if (!tracker.second.updated) {
36+
float cosine = cosineSimilarity(to.reid, tracker.second.reid);
37+
if (cosine > similarityThreshold) {
38+
sim.push_back(std::make_pair(tracker.first, cosine));
4139
}
4240
}
41+
}
4342

44-
if (sim.empty()) {
45-
trackables.insert({ trackIdGenerator, to });
46-
trackables.at(trackIdGenerator).updated = true;
47-
trackables.at(trackIdGenerator).disappeared = 0;
48-
trackIdGenerator += 1;
49-
} else {
50-
int maxSimilarity = std::max_element(sim.begin(), sim.end(), [](std::pair<int, float> a, std::pair<int, float> b) {
51-
return std::get<1>(a) > std::get<1>(b);
52-
})->first;
53-
trackables.at(maxSimilarity) = to;
54-
trackables.at(maxSimilarity).updated = true;
55-
trackables.at(maxSimilarity).disappeared = 0;
56-
}
43+
if (sim.empty()) {
44+
trackables.insert({ trackIdGenerator, to });
45+
trackables.at(trackIdGenerator).updated = true;
46+
trackables.at(trackIdGenerator).disappeared = 0;
47+
trackIdGenerator += 1;
48+
} else {
49+
int maxSimilarity = std::max_element(sim.begin(), sim.end(), [](std::pair<int, float> a, std::pair<int, float> b) {
50+
return std::get<1>(a) > std::get<1>(b);
51+
})->first;
52+
trackables.at(maxSimilarity) = to;
53+
trackables.at(maxSimilarity).updated = true;
54+
trackables.at(maxSimilarity).disappeared = 0;
5755
}
56+
}
5857

59-
for (auto it = trackables.begin(); it != trackables.end(); ) {
60-
if (!it->second.updated) {
61-
it->second.disappeared += 1;
62-
if (it->second.disappeared > maxDisappeared) {
63-
it = trackables.erase(it);
64-
continue;
65-
}
58+
for (auto it = trackables.begin(); it != trackables.end(); ) {
59+
if (!it->second.updated) {
60+
it->second.disappeared += 1;
61+
if (it->second.disappeared > maxDisappeared) {
62+
it = trackables.erase(it);
63+
continue;
6664
}
67-
it->second.updated = false;
68-
++it;
6965
}
70-
} else {
71-
registerTrackables(tos);
66+
it->second.updated = false;
67+
++it;
7268
}
7369
}
7470

@@ -86,14 +82,6 @@ class PersonTrackers {
8682
return static_cast<float>(dot / (sqrt(denomA) * sqrt(denomB) + 1e-6));
8783
}
8884

89-
void registerTrackables(std::list<TrackableObject> &tos) {
90-
for (auto &to : tos) {
91-
to.disappeared = 0;
92-
trackables.insert({trackIdGenerator, to});
93-
trackIdGenerator += 1;
94-
}
95-
}
96-
9785
public:
9886
std::unordered_map<int, TrackableObject> trackables;
9987

demos/social_distance_demo/cpp/src/geodist.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,10 @@ std::tuple<bool, bool, double> socialDistance(std::tuple<int, int> &frameShape,
226226
std::tie(bdA, bdK) = getLineComponent(B, D);
227227
std::tie(acA, acK) = getLineComponent(A, C);
228228

229-
double bdinf = B.x < D.x ? -9999999999. : 9999999999.;
229+
double bdinf = std::lround(B.x) <= std::lround(D.x) ? -9999999999. : 9999999999.;
230230
Line2d BDinf = getLine(D, cv::Point2d(bdinf, getY(bdinf, bdA, bdK)));
231231

232-
double acinf = A.x < C.x ? -9999999999. : 9999999999.;
232+
double acinf = std::lround(A.x) <= std::lround(C.x) ? -9999999999. : 9999999999.;
233233
Line2d ACinf = getLine(C, cv::Point2d(acinf, getY(acinf, acA, acK)));
234234

235235
// Vanishing point

0 commit comments

Comments
 (0)