Skip to content

Commit b4773cf

Browse files
feat: MLSolver hashe test and compilation in the CI (acts-project#1959)
This MR adds the Ml Ambiguity solver to the CI. For now, only hashes tests are available, but once the performance evaluation for the solver has been refactored to use root, we should also add it to the test.
1 parent 8efbafb commit b4773cf

File tree

7 files changed

+42
-12
lines changed

7 files changed

+42
-12
lines changed

.github/workflows/builds.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ jobs:
166166
-DACTS_BUILD_EXAMPLES_EDM4HEP=ON
167167
-DACTS_FORCE_ASSERTIONS=ON
168168
-DACTS_BUILD_ANALYSIS_APPS=ON
169+
-DACTS_BUILD_PLUGIN_ONNX=ON
169170
170171
- name: Build
171172
run: cmake --build build

Examples/Algorithms/TrackFinding/include/ActsExamples/TrackFinding/AmbiguityResolutionMLAlgorithm.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#pragma once
1010

1111
#include "Acts/Plugins/Onnx/OnnxRuntimeBase.hpp"
12-
#include "ActsExamples/Framework/BareAlgorithm.hpp"
12+
#include "ActsExamples/Framework/IAlgorithm.hpp"
1313

1414
#include <string>
1515
#include <vector>
@@ -22,7 +22,7 @@ namespace ActsExamples {
2222
/// 1) Cluster together nearby tracks using shared hits
2323
/// 2) For each track use a neural network to compute a score
2424
/// 3) In each cluster keep the track with the highest score
25-
class AmbiguityResolutionMLAlgorithm final : public BareAlgorithm {
25+
class AmbiguityResolutionMLAlgorithm final : public IAlgorithm {
2626
public:
2727
struct Config {
2828
/// Input trajectories collection.

Examples/Algorithms/TrackFinding/src/AmbiguityResolutionMLAlgorithm.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
ActsExamples::AmbiguityResolutionMLAlgorithm::AmbiguityResolutionMLAlgorithm(
2929
ActsExamples::AmbiguityResolutionMLAlgorithm::Config cfg,
3030
Acts::Logging::Level lvl)
31-
: ActsExamples::BareAlgorithm("AmbiguityResolutionMLAlgorithm", lvl),
31+
: ActsExamples::IAlgorithm("AmbiguityResolutionMLAlgorithm", lvl),
3232
m_cfg(std::move(cfg)),
3333
m_env(ORT_LOGGING_LEVEL_WARNING, "MLClassifier"),
3434
m_duplicateClassifier(m_env, m_cfg.inputDuplicateNN.c_str()) {
@@ -71,7 +71,8 @@ std::unordered_map<int, std::vector<int>> clusterTracks(
7171
}
7272
// None of the hits have been matched to a track create a new cluster
7373
if (matchedTrack == hitToTrack.end()) {
74-
cluster.emplace(track->second.first, {track->second.first});
74+
cluster.emplace(track->second.first,
75+
std::vector<int>(1, track->second.first));
7576
for (const auto& hit : hits) {
7677
// Add the hits of the new cluster to the hitToTrack
7778
hitToTrack.emplace(hit, track->second.first);

Examples/Python/tests/root_file_hashes.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,4 @@ test_root_clusters_writer[configKwConstructor]__clusters.root: 97f04fdd2c0eef4d3
8585
test_root_clusters_writer[kwargsConstructor]__clusters.root: 97f04fdd2c0eef4d37dc8732dd25ab49a90bb51925b2638d94826becf5059fae
8686
test_exatrkx[onnx]__performance_track_finding.root: 89bf19c4e0414c78982e0b36c4125cd1d25335bca821941665d325152c05aeb1
8787
test_exatrkx[torch]__performance_track_finding.root: a06ff95c0baa4e07834d9c388af5f7da00f3c84ad9cf88aaf76385d50bf68195
88+
test_ML_Ambiguity_Solver__performance_ambiML.root: 080e183e758b8593a9c233e2d1b4d213f28fdcb18d82acefdac7c9a5a5763bfc

Examples/Python/tests/test_examples.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,6 @@ def test_seeding_orthogonal(tmp_path, trk_geo, field, assert_root_hash):
324324

325325

326326
def test_itk_seeding(tmp_path, trk_geo, field, assert_root_hash):
327-
328327
field = acts.ConstantBField(acts.Vector3(0, 0, 2 * acts.UnitConstants.T))
329328

330329
csv = tmp_path / "csv"
@@ -480,7 +479,6 @@ def test_propagation(tmp_path, trk_geo, field, seq, assert_root_hash):
480479
@pytest.mark.skipif(not geant4Enabled, reason="Geant4 not set up")
481480
@pytest.mark.skipif(not dd4hepEnabled, reason="DD4hep not set up")
482481
def test_material_recording(tmp_path, material_recording, assert_root_hash):
483-
484482
root_files = [
485483
(
486484
"geant4_material_tracks.root",
@@ -503,7 +501,6 @@ def test_material_recording(tmp_path, material_recording, assert_root_hash):
503501
@pytest.mark.skipif(not dd4hepEnabled, reason="DD4hep not set up")
504502
@pytest.mark.skipif(not geant4Enabled, reason="Geant4 not set up")
505503
def test_event_recording(tmp_path):
506-
507504
script = (
508505
Path(__file__).parent.parent.parent.parent
509506
/ "Examples"
@@ -1117,7 +1114,6 @@ def test_vertex_fitting(tmp_path):
11171114
def test_vertex_fitting_reading(
11181115
tmp_path, ptcl_gun, rng, finder, inputTracks, entries, assert_root_hash
11191116
):
1120-
11211117
ptcl_file = tmp_path / "particles.root"
11221118

11231119
detector, trackingGeometry, decorators = GenericDetector.create()
@@ -1240,6 +1236,40 @@ def test_full_chain_odd_example_pythia_geant4(tmp_path):
12401236
)
12411237

12421238

1239+
@pytest.mark.skipif(not dd4hepEnabled, reason="DD4hep not set up")
1240+
@pytest.mark.slow
1241+
def test_ML_Ambiguity_Solver(tmp_path, assert_root_hash):
1242+
root_file = "performance_ambiML.root"
1243+
output_dir = "odd_output"
1244+
assert not (tmp_path / root_file).exists()
1245+
# This test literally only ensures that the full chain example can run without erroring out
1246+
getOpenDataDetector(
1247+
getOpenDataDetectorDirectory()
1248+
) # just to make sure it can build
1249+
1250+
script = (
1251+
Path(__file__).parent.parent.parent.parent
1252+
/ "Examples"
1253+
/ "Scripts"
1254+
/ "Python"
1255+
/ "full_chain_odd.py"
1256+
)
1257+
assert script.exists()
1258+
env = os.environ.copy()
1259+
env["ACTS_LOG_FAILURE_THRESHOLD"] = "WARNING"
1260+
subprocess.check_call(
1261+
[sys.executable, str(script), "-n5", "--MLSolver"],
1262+
cwd=tmp_path,
1263+
env=env,
1264+
stderr=subprocess.STDOUT,
1265+
)
1266+
1267+
rfp = tmp_path / output_dir / root_file
1268+
assert rfp.exists()
1269+
1270+
assert_root_hash(root_file, rfp)
1271+
1272+
12431273
def test_bfield_writing(tmp_path, seq, assert_root_hash):
12441274
from bfield_writing import runBFieldWriting
12451275

cmake/ActsConfig.cmake.in

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,6 @@ endif()
8484

8585

8686
# dependencies which cannot be searched in CONFIG mode
87-
if(PluginOnnx IN_LIST Acts_COMPONENTS)
88-
find_package(OnnxRuntime REQUIRED)
89-
endif()
9087
if(PluginSycl IN_LIST Acts_COMPONENTS)
9188
find_package(SYCL REQUIRED)
9289
endif()

cmake/FindOnnxRuntime.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ endif()
1919
find_path(
2020
OnnxRuntime_INCLUDE_DIR
2121
NAMES core/session/onnxruntime_cxx_api.h #core/session/providers/cuda_provider_factory.h
22-
PATHS ${onxxruntime_DIR}
22+
PATHS ${onnxruntime_DIR}
2323
PATH_SUFFIXES include include/onnxruntime
2424
DOC "The ONNXRuntime include directory")
2525

0 commit comments

Comments
 (0)