Skip to content

Commit 88b1dd8

Browse files
authored
Merge pull request #25 from mathieuouillon/AI_for_ALERT
Add AI for ALERT trackfinding
2 parents e2acd07 + 3521313 commit 88b1dd8

File tree

13 files changed

+550
-17
lines changed

13 files changed

+550
-17
lines changed

etc/bankdefs/hipo4/alert.json

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,5 +159,24 @@
159159
"info": "pz in MeV"
160160
}
161161
]
162-
}
162+
},
163+
{
164+
"name": "AHDC_AI::Prediction",
165+
"group": 23000,
166+
"item": 30,
167+
"info": "Prediction given by AI",
168+
"entries": [
169+
{"name": "X1", "type": "F", "info": "X1 position of the 1th superprecluster (mm)"},
170+
{"name": "Y1", "type": "F", "info": "Y1 position of the 1th superprecluster (mm)"},
171+
{"name": "X2", "type": "F", "info": "X2 position of the 2nd superprecluster (mm)"},
172+
{"name": "Y2", "type": "F", "info": "Y2 position of the 2nd superprecluster (mm)"},
173+
{"name": "X3", "type": "F", "info": "X3 position of the 3rd superprecluster (mm)"},
174+
{"name": "Y3", "type": "F", "info": "Y3 position of the 3rd superprecluster (mm)"},
175+
{"name": "X4", "type": "F", "info": "X4 position of the 4th superprecluster (mm)"},
176+
{"name": "Y4", "type": "F", "info": "Y4 position of the 4th superprecluster (mm)"},
177+
{"name": "X5", "type": "F", "info": "X5 position of the 5th superprecluster (mm)"},
178+
{"name": "Y5", "type": "F", "info": "Y5 position of the 5th superprecluster (mm)"},
179+
{"name": "Pred", "type": "F", "info": "Prediction of the model: 0 mean bad track; 1 mean good track"}
180+
]
181+
}
163182
]

etc/bankdefs/util/bankSplit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def create(dirname, banklist):
6565
raster = ["RASTER::position"]
6666
rich = ["RICH::tdc","RICH::Ring","RICH::Particle"]
6767
rtpc = ["RTPC::hits","RTPC::tracks","RTPC::KFtracks"]
68-
alert = ["AHDC::Track", "AHDC::MC", "AHDC::Hits", "AHDC::PreClusters", "AHDC::Clusters", "AHDC::KFTrack"]
68+
alert = ["AHDC::Track", "AHDC::MC", "AHDC::Hits", "AHDC::PreClusters", "AHDC::Clusters", "AHDC::KFTrack", "AHDC_AI::Prediction"]
6969
dets = band + raster + rich + rtpc + alert
7070

7171
# additions for the calibration schema:

parent/pom.xml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@
4646
<version>0.8.12</version>
4747
</dependency>
4848

49+
<dependency>
50+
<groupId>ai.djl</groupId>
51+
<artifactId>model-zoo</artifactId>
52+
<version>0.30.0</version>
53+
<scope>compile</scope>
54+
</dependency>
55+
<dependency>
56+
<groupId>ai.djl.pytorch</groupId>
57+
<artifactId>pytorch-model-zoo</artifactId>
58+
<version>0.30.0</version>
59+
</dependency>
60+
4961
</dependencies>
5062

5163
<build>

reconstruction/alert/pom.xml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@
4040
<version>11.0.5-SNAPSHOT</version>
4141
<scope>compile</scope>
4242
</dependency>
43+
<dependency>
44+
<groupId>ai.djl</groupId>
45+
<artifactId>model-zoo</artifactId>
46+
<version>0.30.0</version>
47+
<scope>compile</scope>
48+
</dependency>
49+
<dependency>
50+
<groupId>ai.djl.pytorch</groupId>
51+
<artifactId>pytorch-model-zoo</artifactId>
52+
<version>0.30.0</version>
53+
</dependency>
4354
</dependencies>
4455

4556
</project>
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package org.jlab.rec.ahdc.AI;
2+
3+
import java.util.ArrayList;
4+
5+
import ai.djl.MalformedModelException;
6+
import ai.djl.inference.Predictor;
7+
import ai.djl.repository.zoo.ModelNotFoundException;
8+
import ai.djl.repository.zoo.ZooModel;
9+
import ai.djl.translate.TranslateException;
10+
11+
import java.io.IOException;
12+
13+
public class AIPrediction {
14+
15+
16+
public AIPrediction() throws ModelNotFoundException, MalformedModelException, IOException {
17+
}
18+
19+
public ArrayList<TrackPrediction> prediction(ArrayList<ArrayList<PreclusterSuperlayer>> tracks, ZooModel<float[], Float> model) throws TranslateException {
20+
ArrayList<TrackPrediction> result = new ArrayList<>();
21+
for (ArrayList<PreclusterSuperlayer> track : tracks) {
22+
float[] a = new float[]{(float) track.get(0).getX(), (float) track.get(0).getY(),
23+
(float) track.get(1).getX(), (float) track.get(1).getY(),
24+
(float) track.get(2).getX(), (float) track.get(2).getY(),
25+
(float) track.get(3).getX(), (float) track.get(3).getY(),
26+
(float) track.get(4).getX(), (float) track.get(4).getY(),
27+
};
28+
29+
Predictor<float[], Float> my_predictor = model.newPredictor();
30+
result.add(new TrackPrediction(my_predictor.predict(a), track));
31+
}
32+
33+
return result;
34+
}
35+
36+
37+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package org.jlab.rec.ahdc.AI;
2+
3+
import org.jlab.rec.ahdc.Hit.Hit;
4+
import org.jlab.rec.ahdc.PreCluster.PreCluster;
5+
6+
import java.util.ArrayList;
7+
import java.util.Arrays;
8+
import java.util.Comparator;
9+
import java.util.List;
10+
11+
public class PreClustering {
12+
13+
private ArrayList<Hit> fill(List<Hit> hits, int super_layer, int layer) {
14+
15+
ArrayList<Hit> result = new ArrayList<>();
16+
for (Hit hit : hits) {
17+
if (hit.getSuperLayerId() == super_layer && hit.getLayerId() == layer) result.add(hit);
18+
}
19+
return result;
20+
}
21+
22+
public ArrayList<PreCluster> find_preclusters_for_AI(List<Hit> AHDC_hits) {
23+
ArrayList<PreCluster> preclusters = new ArrayList<>();
24+
25+
ArrayList<Hit> s1l1 = fill(AHDC_hits, 1, 1);
26+
ArrayList<Hit> s2l1 = fill(AHDC_hits, 2, 1);
27+
ArrayList<Hit> s2l2 = fill(AHDC_hits, 2, 2);
28+
ArrayList<Hit> s3l1 = fill(AHDC_hits, 3, 1);
29+
ArrayList<Hit> s3l2 = fill(AHDC_hits, 3, 2);
30+
ArrayList<Hit> s4l1 = fill(AHDC_hits, 4, 1);
31+
ArrayList<Hit> s4l2 = fill(AHDC_hits, 4, 2);
32+
ArrayList<Hit> s5l1 = fill(AHDC_hits, 5, 1);
33+
34+
// Sort hits of each layers by phi:
35+
s1l1.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
36+
s2l1.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
37+
s2l2.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
38+
s3l1.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
39+
s3l2.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
40+
s4l1.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
41+
s4l2.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
42+
s5l1.sort(new Comparator<Hit>() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}});
43+
44+
ArrayList<ArrayList<Hit>> all_super_layer = new ArrayList<>(Arrays.asList(s1l1, s2l1, s2l2, s3l1, s3l2, s4l1, s4l2, s5l1));
45+
46+
for (ArrayList<Hit> p : all_super_layer) {
47+
for (Hit hit : p) {
48+
hit.setUse(false);
49+
}
50+
}
51+
52+
for (ArrayList<Hit> p : all_super_layer) {
53+
for (Hit hit : p) {
54+
if (hit.is_NoUsed()) {
55+
ArrayList<Hit> temp = new ArrayList<>();
56+
temp.add(hit);
57+
hit.setUse(true);
58+
59+
boolean has_next = true;
60+
while (has_next) {
61+
has_next = false;
62+
for (Hit hit1 : p) {
63+
if (hit1.is_NoUsed() && (hit1.getWireId() == temp.get(temp.size() - 1).getWireId() + 1 || hit1.getWireId() == temp.get(temp.size() - 1).getWireId() - 1)) {
64+
temp.add(hit1);
65+
hit1.setUse(true);
66+
has_next = true;
67+
break;
68+
}
69+
}
70+
}
71+
if (!temp.isEmpty()) preclusters.add(new PreCluster(temp));
72+
}
73+
}
74+
}
75+
return preclusters;
76+
}
77+
78+
public ArrayList<PreclusterSuperlayer> merge_preclusters(ArrayList<PreCluster> preclusters) {
79+
double distance_max = 8.0;
80+
81+
ArrayList<PreclusterSuperlayer> superpreclusters = new ArrayList<>();
82+
for (PreCluster precluster : preclusters) {
83+
if (!precluster.is_Used()) {
84+
ArrayList<PreCluster> tmp = new ArrayList<>();
85+
tmp.add(precluster);
86+
precluster.set_Used(true);
87+
for (PreCluster other : preclusters) {
88+
if (precluster.get_hits_list().get(precluster.get_hits_list().size() - 1).getSuperLayerId() == other.get_hits_list().get(other.get_hits_list().size() - 1).getSuperLayerId() && precluster.get_hits_list().get(precluster.get_hits_list().size() - 1).getLayerId() != other.get_hits_list().get(other.get_hits_list().size() - 1).getLayerId() && !other.is_Used()) {
89+
double dx = precluster.get_X() - other.get_X();
90+
double dy = precluster.get_Y() - other.get_Y();
91+
double distance = Math.sqrt(dx * dx + dy * dy);
92+
93+
if (distance < distance_max) {
94+
other.set_Used(true);
95+
tmp.add(other);
96+
}
97+
}
98+
}
99+
100+
if (!tmp.isEmpty()) superpreclusters.add(new PreclusterSuperlayer(tmp));
101+
}
102+
}
103+
104+
return superpreclusters;
105+
}
106+
107+
108+
109+
110+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package org.jlab.rec.ahdc.AI;
2+
3+
import org.jlab.rec.ahdc.Hit.Hit;
4+
import org.jlab.rec.ahdc.PreCluster.PreCluster;
5+
6+
import java.util.ArrayList;
7+
8+
public class PreclusterSuperlayer {
9+
private final double x;
10+
private final double y;
11+
private ArrayList<PreCluster> preclusters = new ArrayList<>();
12+
13+
14+
; public PreclusterSuperlayer(ArrayList<PreCluster> preclusters_) {
15+
this.preclusters = preclusters_;
16+
double x_ = 0;
17+
double y_ = 0;
18+
19+
for (PreCluster p : this.preclusters) {
20+
x_ += p.get_X();
21+
y_ += p.get_Y();
22+
}
23+
this.x = x_ / this.preclusters.size();
24+
this.y = y_ / this.preclusters.size();
25+
26+
27+
28+
}
29+
30+
public ArrayList<PreCluster> getPreclusters() {
31+
return preclusters;
32+
}
33+
34+
public double getX() {
35+
return x;
36+
}
37+
38+
public double getY() {
39+
return y;
40+
}
41+
42+
43+
public String toString() {
44+
return "PreCluster{" + "X: " + this.x + " Y: " + this.y + " phi: " + Math.atan2(this.y, this.x) + "}\n";
45+
}
46+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package org.jlab.rec.ahdc.AI;
2+
3+
import org.jlab.rec.ahdc.Hit.Hit;
4+
5+
import java.io.File;
6+
import java.io.FileWriter;
7+
import java.io.IOException;
8+
import java.util.*;
9+
10+
public class TrackConstruction {
11+
public TrackConstruction() {}
12+
13+
private double mod(double x, double y) {
14+
15+
if (0. == y) return x;
16+
17+
double m = x - y * Math.floor(x / y);
18+
// handle boundary cases resulted from floating-point cut off:
19+
if (y > 0) { // modulo range: [0..y)
20+
if (m >= y) return 0; // Mod(-1e-16 , 360. ): m= 360.
21+
if (m < 0) {
22+
if (y + m == y) return 0; // just in case...
23+
else return y + m; // Mod(106.81415022205296 , _TWO_PI ): m= -1.421e-14
24+
}
25+
} else { // modulo range: (y..0]
26+
if (m <= y) return 0; // Mod(1e-16 , -360. ): m= -360.
27+
if (m > 0) {
28+
if (y + m == y) return 0; // just in case...
29+
else return y + m; // Mod(-106.81415022205296, -_TWO_PI): m= 1.421e-14
30+
}
31+
}
32+
33+
return m;
34+
}
35+
36+
private double warp_zero_two_pi(double angle) { return mod(angle, 2. * Math.PI); }
37+
38+
private boolean angle_in_range(double angle, double lower, double upper) { return warp_zero_two_pi(angle - lower) <= warp_zero_two_pi(upper - lower); }
39+
40+
41+
public ArrayList<ArrayList<PreclusterSuperlayer>> get_all_possible_track(ArrayList<PreclusterSuperlayer> preclusterSuperlayers) {
42+
43+
// Get seeds to start the track finding algorithm
44+
ArrayList<PreclusterSuperlayer> seeds = new ArrayList<>();
45+
for (PreclusterSuperlayer precluster : preclusterSuperlayers) {
46+
if (precluster.getPreclusters().get(0).get_hits_list().get(0).getSuperLayerId() == 1) seeds.add(precluster);
47+
}
48+
seeds.sort(new Comparator<PreclusterSuperlayer>() {
49+
@Override
50+
public int compare(PreclusterSuperlayer a1, PreclusterSuperlayer a2) {
51+
return Double.compare(Math.atan2(a1.getY(), a1.getX()), Math.atan2(a2.getY(), a2.getX()));
52+
}
53+
});
54+
// System.out.println("seeds: " + seeds);
55+
56+
// Get all possible tracks ----------------------------------------------------------------
57+
double max_angle = Math.toRadians(60);
58+
59+
ArrayList<ArrayList<PreclusterSuperlayer>> all_combinations = new ArrayList<>();
60+
for (PreclusterSuperlayer seed : seeds) {
61+
double phi_seed = warp_zero_two_pi(Math.atan2(seed.getY(), seed.getX()));
62+
63+
ArrayList<PreclusterSuperlayer> track = new ArrayList<>();
64+
for (PreclusterSuperlayer p : preclusterSuperlayers) {
65+
double phi_p = warp_zero_two_pi(Math.atan2(p.getY(), p.getX()));
66+
if (angle_in_range(phi_p, phi_seed - max_angle, phi_seed + max_angle)) track.add(p);
67+
}
68+
// System.out.println("track: " + track.size());
69+
70+
ArrayList<ArrayList<PreclusterSuperlayer>> combinations = new ArrayList<>(List.of(new ArrayList<>(List.of(seed))));
71+
// System.out.println("combinations: " + combinations);
72+
73+
for (int i = 1; i < 5; ++i) {
74+
ArrayList<ArrayList<PreclusterSuperlayer>> new_combinations = new ArrayList<>();
75+
for (ArrayList<PreclusterSuperlayer> combination : combinations) {
76+
77+
for (PreclusterSuperlayer precluster : track) {
78+
if (precluster.getPreclusters().get(0).get_hits_list().get(0).getSuperLayerId() == seed.getPreclusters().get(0).get_hits_list().get(0).getSuperLayerId() + i) {
79+
// System.out.printf("Good Precluster x: %.2f, y: %.2f, r: %.2f%n", precluster.getX(), precluster.getY(), Math.hypot(precluster.getX(), precluster.getY()));
80+
// System.out.println("combination: " + combination);
81+
82+
ArrayList<PreclusterSuperlayer> new_combination = new ArrayList<>(combination);
83+
new_combination.add(precluster);
84+
// System.out.println("new_combination: " + new_combination);
85+
new_combinations.add(new_combination);
86+
}
87+
}
88+
for (ArrayList<PreclusterSuperlayer> c : new_combinations) {
89+
// System.out.println("c.size: " + c.size() + ", c: " + c);
90+
}
91+
92+
}
93+
combinations = new_combinations;
94+
if (combinations.size() > 10000) break;
95+
}
96+
for (ArrayList<PreclusterSuperlayer> combination : combinations) {
97+
if (combination.size() == 5) {
98+
all_combinations.add(combination);
99+
}
100+
}
101+
}
102+
103+
return all_combinations;
104+
}
105+
106+
}

0 commit comments

Comments
 (0)