Skip to content

Commit 4f80afc

Browse files
mathieuouillonc-dilks
authored andcommitted
Alert track matching ai (#969)
* Start an implementation for the track matching for ALERT * Fix the path for the ALERT AI models * ALERT: Improve the AI model for track finding * ALERT: Change name of the model class for the track finding * ALERT: Rename some variables and class (Mode to ModeTrackFinding) * ALERT: change superpreclusters to InterCluster * Move the track matching to the ALERT Engine * ALERT: change the structure of the AHDCEngine to write the interclusters and link them to their associated track. - Introduced new intercluster data structure in alert.json to store track ID and coordinates. - Add to InterCluster class the track ID. - Rename TrackConstruction to TrackCandidatesGenerator. - Updated Track class to manage interclusters and clusters and set track IDs. - Modified AHDCEngine to have similar interface for AI and conventional track finding. - Use the AI precluster algorithm for everthings (CV and AI) * Fix the track matching AI and the output bank * Add the model and fix the constant loader in the ATOF * Modify HipoDataSync initialization in AHDCEngine * Update model path for TrackMatchingAI to use the new path for AI models * Remove the AI networks * Update model path for ALERT model in ModelTrackFinding * Update model path for TrackMatchingAI * fix(ci): use `--lfs` with `--unittests` --------- Co-authored-by: Christopher Dilks <[email protected]>
1 parent c57f817 commit 4f80afc

File tree

21 files changed

+624
-375
lines changed

21 files changed

+624
-375
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ jobs:
132132
with:
133133
cvmfs_repositories: 'oasis.opensciencegrid.org'
134134
- name: unit tests
135-
run: ./build-coatjava.sh --cvmfs --unittests --no-progress -T${{ env.nthreads }}
135+
run: ./build-coatjava.sh --lfs --unittests --no-progress -T${{ env.nthreads }}
136136
- name: collect jacoco report
137137
if: ${{ matrix.JAVA_VERSION == env.JAVA_VERSION }}
138138
run: validation/jacoco-aggregate.sh

.gitlab-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ download:
9494
dependencies: [build]
9595
script:
9696
- tar -xzf coatjava.tar.gz
97-
- ./build-coatjava.sh -T$JL_RUNNER_AVAIL_CPU --unittests --quiet --no-progress
97+
- ./build-coatjava.sh -T$JL_RUNNER_AVAIL_CPU --lfs --unittests --quiet --no-progress
9898
- ./validation/jacoco-aggregate.sh
9999
artifacts:
100100
when: always

etc/bankdefs/hipo4/alert.json

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,18 @@
5555
"info": "path length inside atof wedge in mm"
5656
}
5757
]
58-
},{
58+
},
59+
{
60+
"name": "ALERT::ai:projections",
61+
"group": 23000,
62+
"item": 32,
63+
"info": "Track Projections to ATOF given by AI",
64+
"entries": [
65+
{"name": "trackid", "type": "I", "info": "track id"},
66+
{"name": "matched_atof_hit_id", "type": "I", "info": "id of the matched ATOF hit, -1 if no hit was matched"}
67+
]
68+
},
69+
{
5970
"name": "ATOF::hits",
6071
"group": 22500,
6172
"item": 21,
@@ -414,5 +425,16 @@
414425
{"name": "y5", "type": "F", "info": "Y5 position of the 5th superprecluster (mm)"},
415426
{"name": "pred", "type": "F", "info": "Prediction of the model: 0 mean bad track; 1 mean good track"}
416427
]
428+
},
429+
{
430+
"name": "AHDC::interclusters",
431+
"group": 23000,
432+
"item": 27,
433+
"info": "InterClusters info",
434+
"entries": [
435+
{"name": "trackid", "type": "I", "info": "track id"},
436+
{"name": "x", "type": "F", "info": "x info (mm)"},
437+
{"name": "y", "type": "F", "info": "y info (mm)"}
438+
]
417439
}
418440
]

reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/AIPrediction.java

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,34 @@
22

33
import java.util.ArrayList;
44

5-
import ai.djl.MalformedModelException;
6-
import ai.djl.inference.Predictor;
7-
import ai.djl.repository.zoo.ModelNotFoundException;
8-
import ai.djl.repository.zoo.ZooModel;
9-
import ai.djl.translate.TranslateException;
10-
11-
import java.io.IOException;
12-
135
public class AIPrediction {
146

157

16-
public AIPrediction() throws ModelNotFoundException, MalformedModelException, IOException {
17-
}
8+
public AIPrediction() {}
189

19-
public ArrayList<TrackPrediction> prediction(ArrayList<ArrayList<PreclusterSuperlayer>> tracks, ZooModel<float[], Float> model) throws TranslateException {
10+
public ArrayList<TrackPrediction> prediction(ArrayList<ArrayList<InterCluster>> tracks, ModelTrackFinding modelTrackFinding) throws Exception {
2011
ArrayList<TrackPrediction> result = new ArrayList<>();
21-
for (ArrayList<PreclusterSuperlayer> track : tracks) {
22-
float[] a = new float[]{(float) track.get(0).getX(), (float) track.get(0).getY(),
23-
(float) track.get(1).getX(), (float) track.get(1).getY(),
24-
(float) track.get(2).getX(), (float) track.get(2).getY(),
25-
(float) track.get(3).getX(), (float) track.get(3).getY(),
26-
(float) track.get(4).getX(), (float) track.get(4).getY(),
27-
};
28-
29-
Predictor<float[], Float> my_predictor = model.newPredictor();
30-
result.add(new TrackPrediction(my_predictor.predict(a), track));
12+
13+
if (tracks.isEmpty()) return result;
14+
15+
float[][] batchInput = new float[tracks.size()][10];
16+
for (int i = 0; i < tracks.size(); i++) {
17+
ArrayList<InterCluster> track = tracks.get(i);
18+
batchInput[i][0] = (float) track.get(0).getX();
19+
batchInput[i][1] = (float) track.get(0).getY();
20+
batchInput[i][2] = (float) track.get(1).getX();
21+
batchInput[i][3] = (float) track.get(1).getY();
22+
batchInput[i][4] = (float) track.get(2).getX();
23+
batchInput[i][5] = (float) track.get(2).getY();
24+
batchInput[i][6] = (float) track.get(3).getX();
25+
batchInput[i][7] = (float) track.get(3).getY();
26+
batchInput[i][8] = (float) track.get(4).getX();
27+
batchInput[i][9] = (float) track.get(4).getY();
28+
}
29+
30+
float[] predictions = modelTrackFinding.batchPredict(batchInput);
31+
for (int i = 0; i < tracks.size(); i++) {
32+
result.add(new TrackPrediction(predictions[i], tracks.get(i)));
3133
}
3234

3335
return result;

reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreclusterSuperlayer.java renamed to reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/InterCluster.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
package org.jlab.rec.ahdc.AI;
22

3-
import org.jlab.rec.ahdc.Hit.Hit;
43
import org.jlab.rec.ahdc.PreCluster.PreCluster;
54

65
import java.util.ArrayList;
76

8-
public class PreclusterSuperlayer {
7+
public class InterCluster {
8+
private int trackId = -1;
99
private final double x;
1010
private final double y;
1111
private ArrayList<PreCluster> preclusters = new ArrayList<>();
1212

13-
14-
; public PreclusterSuperlayer(ArrayList<PreCluster> preclusters_) {
13+
public InterCluster(ArrayList<PreCluster> preclusters_) {
1514
this.preclusters = preclusters_;
1615
double x_ = 0;
1716
double y_ = 0;
@@ -43,6 +42,13 @@ public int getSuperlayer() {
4342
return this.preclusters.get(0).get_Super_layer();
4443
}
4544

45+
public int getTrackId() {
46+
return trackId;
47+
}
48+
49+
public void setTrackId(int trackId) {
50+
this.trackId = trackId;
51+
}
4652

4753
public String toString() {
4854
return "PreCluster{" + "X: " + this.x + " Y: " + this.y + " phi: " + Math.atan2(this.y, this.x) + "}\n";

reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/Model.java renamed to reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/ModelTrackFinding.java

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,28 @@
2121
*
2222
* \todo fix class name
2323
*/
24-
public class Model {
25-
private ZooModel<float[], Float> model;
24+
public class ModelTrackFinding {
25+
private final ZooModel<float[], Float> model;
2626

27-
public Model() {
28-
Translator<float[], Float> my_translator = new Translator<float[], Float>() {
27+
public ModelTrackFinding() {
28+
Translator<float[], Float> my_translator = new Translator<>() {
2929
@Override
3030
public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
3131
return ndList.get(0).getFloat();
3232
}
3333

3434
@Override
3535
public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception {
36-
NDManager manager = NDManager.newBaseManager();
37-
NDArray samples = manager.zeros(new Shape(floats.length));
38-
samples.set(floats);
36+
NDManager manager = translatorContext.getNDManager();
37+
NDArray samples = manager.create(floats);
3938
return new NDList(samples);
4039
}
4140
};
4241
System.setProperty("ai.djl.pytorch.num_interop_threads", "1");
4342
System.setProperty("ai.djl.pytorch.num_threads", "1");
4443
System.setProperty("ai.djl.pytorch.graph_optimizer", "false");
4544

46-
String path = CLASResources.getResourcePath("etc/data/nnet/ALERT/model_AHDC/");
45+
String path = CLASResources.getResourcePath("etc/data/nnet/rg-l/model_AHDC/");
4746
Criteria<float[], Float> my_model = Criteria.builder().setTypes(float[].class, Float.class)
4847
.optModelPath(Paths.get(path))
4948
.optEngine("PyTorch")
@@ -63,4 +62,41 @@ public NDList processInput(TranslatorContext translatorContext, float[] floats)
6362
public ZooModel<float[], Float> getModel() {
6463
return model;
6564
}
65+
66+
/**
67+
* Batch prediction for improved performance.
68+
* Predicts all tracks at once instead of one at a time.
69+
* This is significantly faster due to reduced overhead and better GPU utilization.
70+
*
71+
* @param inputs Array of input features for each track
72+
* @return Array of predictions for each track
73+
*/
74+
public float[] batchPredict(float[][] inputs) throws Exception {
75+
if (inputs == null || inputs.length == 0) {
76+
return new float[0];
77+
}
78+
79+
try (NDManager manager = NDManager.newBaseManager()) {
80+
int batchSize = inputs.length;
81+
NDArray batchInput = manager.create(inputs);
82+
NDList inputList = new NDList(batchInput);
83+
ai.djl.inference.Predictor<NDList, NDList> rawPredictor = model.newPredictor(new ai.djl.translate.NoopTranslator());
84+
NDList output = rawPredictor.predict(inputList);
85+
86+
NDArray outputArray = output.get(0);
87+
float[] results = new float[batchSize];
88+
89+
if (outputArray.getShape().dimension() == 2) {
90+
for (int i = 0; i < batchSize; i++) {
91+
results[i] = outputArray.get(i, 0).getFloat();
92+
}
93+
} else {
94+
for (int i = 0; i < batchSize; i++) {
95+
results[i] = outputArray.get(i).getFloat();
96+
}
97+
}
98+
99+
return results;
100+
}
101+
}
66102
}

reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java

Lines changed: 10 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -3,119 +3,32 @@
33
import org.jlab.rec.ahdc.Hit.Hit;
44
import org.jlab.rec.ahdc.PreCluster.PreCluster;
55

6-
import java.util.ArrayList;
7-
import java.util.Arrays;
8-
import java.util.Comparator;
9-
import java.util.List;
6+
import java.util.*;
107

118
public class PreClustering {
9+
static final double DISTANCE_MAX = 8.0;
1210

13-
private ArrayList<Hit> fill(List<Hit> hits, int super_layer, int layer) {
14-
15-
ArrayList<Hit> result = new ArrayList<>();
16-
for (Hit hit : hits) {
17-
if (hit.getSuperLayerId() == super_layer && hit.getLayerId() == layer) result.add(hit);
18-
}
19-
return result;
20-
}
21-
22-
public ArrayList<PreCluster> find_preclusters_for_AI(List<Hit> AHDC_hits) {
23-
ArrayList<PreCluster> preclusters = new ArrayList<>();
24-
25-
ArrayList<Hit> s1l1 = fill(AHDC_hits, 1, 1);
26-
ArrayList<Hit> s2l1 = fill(AHDC_hits, 2, 1);
27-
ArrayList<Hit> s2l2 = fill(AHDC_hits, 2, 2);
28-
ArrayList<Hit> s3l1 = fill(AHDC_hits, 3, 1);
29-
ArrayList<Hit> s3l2 = fill(AHDC_hits, 3, 2);
30-
ArrayList<Hit> s4l1 = fill(AHDC_hits, 4, 1);
31-
ArrayList<Hit> s4l2 = fill(AHDC_hits, 4, 2);
32-
ArrayList<Hit> s5l1 = fill(AHDC_hits, 5, 1);
33-
34-
// Sort hits of each layers by phi:
35-
Comparator<Hit> comparator = new Comparator<>() {
36-
@Override
37-
public int compare(Hit a1, Hit a2) {
38-
return Double.compare(a1.getPhi(), a2.getPhi());
39-
}
40-
};
41-
42-
s1l1.sort(comparator);
43-
s2l1.sort(comparator);
44-
s2l2.sort(comparator);
45-
s3l1.sort(comparator);
46-
s3l2.sort(comparator);
47-
s4l1.sort(comparator);
48-
s4l2.sort(comparator);
49-
s5l1.sort(comparator);
50-
51-
ArrayList<ArrayList<Hit>> all_super_layer = new ArrayList<>(Arrays.asList(s1l1, s2l1, s2l2, s3l1, s3l2, s4l1, s4l2, s5l1));
52-
53-
for (ArrayList<Hit> p : all_super_layer) {
54-
for (Hit hit : p) {
55-
hit.setUse(false);
56-
}
57-
}
58-
59-
for (ArrayList<Hit> p : all_super_layer) {
60-
for (Hit hit : p) {
61-
if (hit.is_NoUsed()) {
62-
ArrayList<Hit> temp = new ArrayList<>();
63-
temp.add(hit);
64-
hit.setUse(true);
65-
int expected_wire_plus = hit.getWireId() + 1;
66-
int expected_wire_minus = hit.getWireId() - 1;
67-
if (hit.getWireId() == 1)
68-
expected_wire_minus = hit.getNbOfWires();
69-
if (hit.getWireId() == hit.getNbOfWires() )
70-
expected_wire_plus = 1;
71-
72-
73-
boolean has_next = true;
74-
while (has_next) {
75-
has_next = false;
76-
for (Hit hit1 : p) {
77-
if (hit1.is_NoUsed() && (hit1.getWireId() == expected_wire_minus || hit1.getWireId() == expected_wire_plus)) {
78-
temp.add(hit1);
79-
hit1.setUse(true);
80-
has_next = true;
81-
break;
82-
}
83-
}
84-
}
85-
if (!temp.isEmpty()) preclusters.add(new PreCluster(temp));
86-
}
87-
}
88-
}
89-
return preclusters;
90-
}
91-
92-
public ArrayList<PreclusterSuperlayer> merge_preclusters(ArrayList<PreCluster> preclusters) {
93-
double distance_max = 8.0;
94-
95-
ArrayList<PreclusterSuperlayer> superpreclusters = new ArrayList<>();
11+
public ArrayList<InterCluster> mergePreclusters(ArrayList<PreCluster> preclusters) {
12+
ArrayList<InterCluster> interclusters = new ArrayList<>();
9613
for (PreCluster precluster : preclusters) {
9714
if (!precluster.is_Used()) {
9815
ArrayList<PreCluster> tmp = new ArrayList<>();
9916
tmp.add(precluster);
10017
precluster.set_Used(true);
10118
for (PreCluster other : preclusters) {
102-
if (precluster.get_hits_list().get(precluster.get_hits_list().size() - 1).getSuperLayerId() == other.get_hits_list().get(other.get_hits_list().size() - 1).getSuperLayerId() && precluster.get_hits_list().get(precluster.get_hits_list().size() - 1).getLayerId() != other.get_hits_list().get(other.get_hits_list().size() - 1).getLayerId() && !other.is_Used()) {
103-
double dx = precluster.get_X() - other.get_X();
104-
double dy = precluster.get_Y() - other.get_Y();
105-
double distance = Math.sqrt(dx * dx + dy * dy);
106-
107-
if (distance < distance_max) {
19+
if (precluster.get_hits_list().getLast().getSuperLayerId() == other.get_hits_list().getLast().getSuperLayerId()
20+
&& precluster.get_hits_list().getLast().getLayerId() != other.get_hits_list().getLast().getLayerId()
21+
&& !other.is_Used()) {
22+
if (Math.hypot(precluster.get_X() - other.get_X(), precluster.get_Y() - other.get_Y()) < DISTANCE_MAX) {
10823
other.set_Used(true);
10924
tmp.add(other);
11025
}
11126
}
11227
}
113-
114-
if (!tmp.isEmpty()) superpreclusters.add(new PreclusterSuperlayer(tmp));
28+
if (!tmp.isEmpty()) interclusters.add(new InterCluster(tmp));
11529
}
11630
}
117-
118-
return superpreclusters;
31+
return interclusters;
11932
}
12033

12134

0 commit comments

Comments
 (0)