forked from acts-project/acts
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTruthGraphBuilder.cpp
More file actions
197 lines (158 loc) · 6.45 KB
/
TruthGraphBuilder.cpp
File metadata and controls
197 lines (158 loc) · 6.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
#include "ActsExamples/TrackFindingGnn/TruthGraphBuilder.hpp"
#include "Acts/Definitions/Units.hpp"
#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include <algorithm>
using namespace Acts;
using namespace Acts::UnitLiterals;
namespace ActsExamples {
TruthGraphBuilder::TruthGraphBuilder(Config config,
std::unique_ptr<const Acts::Logger> logger)
: IAlgorithm("TruthGraphBuilder", std::move(logger)),
m_cfg(std::move(config)) {
m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
m_inputParticles.initialize(m_cfg.inputParticles);
m_outputGraph.initialize(m_cfg.outputGraph);
m_inputMeasParticlesMap.maybeInitialize(m_cfg.inputMeasurementParticlesMap);
m_inputSimhits.maybeInitialize(m_cfg.inputSimHits);
m_inputMeasSimhitMap.maybeInitialize(m_cfg.inputMeasurementSimHitsMap);
bool a = m_inputMeasParticlesMap.isInitialized();
bool b =
m_inputSimhits.isInitialized() && m_inputMeasSimhitMap.isInitialized();
// Logical XOR operation
if (!a != !b) {
throw std::invalid_argument("Missing inputs, cannot build truth graph");
}
}
std::vector<std::int64_t> TruthGraphBuilder::buildFromMeasurements(
const SpacePointContainer& spacePoints,
const SimParticleContainer& particles,
const MeasurementParticlesMap& measPartMap) const {
if (m_cfg.targetMinPT < 500_MeV) {
ACTS_WARNING(
"truth graph building based on distance from origin, this breaks down "
"for low pT particles. Consider using a higher target pT value");
}
// Associate tracks to graph, collect momentum
std::unordered_map<SimBarcode, std::vector<std::size_t>> tracks;
for (auto i = 0ul; i < spacePoints.size(); ++i) {
const auto measId =
spacePoints[i].sourceLinks()[0].template get<IndexSourceLink>().index();
auto [a, b] = measPartMap.equal_range(measId);
for (auto it = a; it != b; ++it) {
tracks[it->second].push_back(i);
}
}
// Collect edges for truth graph and target graph
std::vector<std::int64_t> graph;
std::size_t notFoundParticles = 0;
std::size_t moduleDuplicatesRemoved = 0;
for (auto& [pid, track] : tracks) {
auto found = particles.find(pid);
if (found == particles.end()) {
ACTS_VERBOSE("Did not find " << pid << ", skip track");
notFoundParticles++;
continue;
}
if (found->transverseMomentum() < m_cfg.targetMinPT ||
track.size() < m_cfg.targetMinSize) {
continue;
}
const Vector3 vtx = found->fourPosition().segment<3>(0);
auto radiusForOrdering = [&](std::size_t i) {
const auto& sp = spacePoints[i];
return std::hypot(sp.x() - vtx[0], sp.y() - vtx[1], sp.z() - vtx[2]);
};
// Sort by radius (this breaks down if the particle has to low momentum)
std::ranges::sort(track, {},
[&](const auto& t) { return radiusForOrdering(t); });
if (m_cfg.uniqueModules) {
auto newEnd = std::unique(
track.begin(), track.end(), [&](const auto& a, const auto& b) {
auto gidA = spacePoints[a]
.sourceLinks()[0]
.template get<IndexSourceLink>()
.geometryId();
auto gidB = spacePoints[b]
.sourceLinks()[0]
.template get<IndexSourceLink>()
.geometryId();
return gidA == gidB;
});
moduleDuplicatesRemoved += std::distance(newEnd, track.end());
track.erase(newEnd, track.end());
}
for (auto i = 0ul; i < track.size() - 1; ++i) {
graph.push_back(track[i]);
graph.push_back(track[i + 1]);
}
}
ACTS_DEBUG("Did not find particles for " << notFoundParticles << " tracks");
if (moduleDuplicatesRemoved > 0) {
ACTS_DEBUG(
"Removed " << moduleDuplicatesRemoved
<< " hit to ensure a unique hit per track and module");
}
return graph;
}
struct HitInfo {
std::size_t spacePointIndex;
std::int32_t hitIndex;
};
std::vector<std::int64_t> TruthGraphBuilder::buildFromSimhits(
const SpacePointContainer& spacePoints,
const MeasurementSimHitsMap& measHitMap, const SimHitContainer& simhits,
const SimParticleContainer& particles) const {
// Associate tracks to graph, collect momentum
std::unordered_map<SimBarcode, std::vector<HitInfo>> tracks;
for (auto i = 0ul; i < spacePoints.size(); ++i) {
const auto measId =
spacePoints[i].sourceLinks()[0].template get<IndexSourceLink>().index();
auto [a, b] = measHitMap.equal_range(measId);
for (auto it = a; it != b; ++it) {
const auto& hit = *simhits.nth(it->second);
tracks[hit.particleId()].push_back({i, hit.index()});
}
}
// Collect edges for truth graph and target graph
std::vector<std::int64_t> truthGraph;
for (auto& [pid, track] : tracks) {
// Sort by hit index, so the edges are connected correctly
std::ranges::sort(track, {}, [](const auto& t) { return t.hitIndex; });
auto found = particles.find(pid);
if (found == particles.end()) {
ACTS_WARNING("Did not find " << pid << ", skip track");
continue;
}
for (auto i = 0ul; i < track.size() - 1; ++i) {
if (found->transverseMomentum() > m_cfg.targetMinPT &&
track.size() >= m_cfg.targetMinSize) {
truthGraph.push_back(track[i].spacePointIndex);
truthGraph.push_back(track[i + 1].spacePointIndex);
}
}
}
return truthGraph;
}
ProcessCode TruthGraphBuilder::execute(const AlgorithmContext& ctx) const {
// Read input data
const auto& spacePoints = m_inputSpacePoints(ctx);
const auto& particles = m_inputParticles(ctx);
auto edges = (m_inputMeasParticlesMap.isInitialized())
? buildFromMeasurements(spacePoints, particles,
m_inputMeasParticlesMap(ctx))
: buildFromSimhits(spacePoints, m_inputMeasSimhitMap(ctx),
m_inputSimhits(ctx), particles);
ACTS_DEBUG("Truth track edges: " << edges.size() / 2);
Graph g;
g.edges = std::move(edges);
m_outputGraph(ctx, std::move(g));
return ProcessCode::SUCCESS;
}
} // namespace ActsExamples