Skip to content

Commit 27b6c1d

Browse files
authored
apply AI model for HB tracking (#1057)
* apply AI model for HB tracking * change some member variables from public to private in org.jlab.clas.tracking.kalmanfilter.zReference.KFitter
1 parent 6785f53 commit 27b6c1d

File tree

6 files changed

+641
-2
lines changed

6 files changed

+641
-2
lines changed

common-tools/clas-tracking/src/main/java/org/jlab/clas/tracking/kalmanfilter/zReference/KFitter.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public class KFitter extends AKFitter {
6060
Matrix result = new Matrix();
6161
Matrix result_inv = new Matrix();
6262
Matrix adj = new Matrix();
63-
63+
6464
public KFitter(boolean filter, int iterations, int dir, Swim swim, double Z[], Libr mo) {
6565
super(filter, iterations, dir, swim, mo);
6666
this.Z = Z;
@@ -861,7 +861,7 @@ private void calcFinalChisq(int sector) {
861861

862862
// Since no vertex inforamtion, the starting point for path length is the final point at the last layer.
863863
// After vertex information is obtained, transition for the starting point from the final point to vertex will be taken.
864-
private void calcFinalChisq(int sector, boolean nofilter) {
864+
public void calcFinalChisq(int sector, boolean nofilter) {
865865
int k = svzLength - 1;
866866
this.chi2 = 0;
867867
double path = 0;
@@ -1157,6 +1157,14 @@ public StateVecs getStateVecs() {
11571157
public double getNDFDAF(){
11581158
return ndfDAF;
11591159
}
1160+
1161+
public void setSvzLength(int svzlength){
1162+
this.svzLength = svzlength;
1163+
}
1164+
1165+
public int getSvzLength(){
1166+
return svzLength;
1167+
}
11601168

11611169
public void printlnMeasVecs() {
11621170
for (int i = 0; i < mv.measurements.size(); i++) {
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package org.jlab.rec.ai.dcHBTrackState;
2+
3+
import ai.djl.MalformedModelException;
4+
import ai.djl.ModelException;
5+
import ai.djl.inference.Predictor;
6+
import ai.djl.ndarray.*;
7+
import ai.djl.ndarray.types.Shape;
8+
import ai.djl.repository.zoo.*;
9+
import ai.djl.training.util.ProgressBar;
10+
import ai.djl.translate.*;
11+
12+
import java.io.IOException;
13+
import java.nio.file.Paths;
14+
import java.util.concurrent.*;
15+
import java.util.logging.Level;
16+
import java.util.logging.Logger;
17+
18+
import org.jlab.clas.reco.ReconstructionEngine;
19+
import org.jlab.io.base.*;
20+
import org.jlab.utils.system.ClasUtilsFile;
21+
import org.jlab.service.ai.PredictorPool;
22+
23+
24+
public class HBTrackStateEstimator{
25+
// ---------------- Configuration ----------------
26+
private String modelFile;
27+
28+
ZooModel<float[][], float[]> model;
29+
PredictorPool predictors;
30+
31+
// ---------------- Statistics for normalization of inputs and outputs of training samples ----------------
32+
//// Note: Statistics of hits and track states depends on training samples, so need to be renewed when training samples change!!!
33+
// Statistics of hits: doca, xm, xr, yr, z
34+
private float[] HIT_MEAN;
35+
private float[] HIT_STD;
36+
37+
// Statistics of track state: x, y, tx, ty, Q at z = 229 cm in the tilted sector frame
38+
private float[] STATE_MEAN;
39+
private float[] STATE_STD;
40+
41+
public HBTrackStateEstimator(String modelFile){
42+
this.modelFile = modelFile;
43+
44+
if(modelFile.contains("inbending")){
45+
HIT_MEAN = new float[]{0.52949071f, -45.771999f, -45.744694f, 57.336819f, 373.046356f};
46+
HIT_STD = new float[]{0.40272677f, 47.928203f, 48.379021f, 32.645191f, 111.54994f};
47+
STATE_MEAN = new float[]{-33.564308f, 0.010787425f, -0.15567796f, 0.0017755219f, 0.317530721f};
48+
STATE_STD = new float[]{28.667490f, 17.761129f, 0.11940812f, 0.074460238f, 0.74185127f};
49+
}
50+
else if(modelFile.contains("outbending")){
51+
HIT_MEAN = new float[]{0.53385729f, -59.236504f, -59.200584f, 50.136387f, 372.057922f};
52+
HIT_STD = new float[]{0.40085429f, 51.385536f, 51.840462f, 31.498201f, 111.50029f};
53+
STATE_MEAN = new float[]{-39.446106f, 0.17583229f, -0.18047817f, 0.0014163271f, -0.082320645f};
54+
STATE_STD = new float[]{33.733425f, 17.226780f, 0.14071095f, 0.072449364f, 0.72273886f};
55+
}
56+
else{
57+
Logger.getLogger(getClass().getName()).log(Level.SEVERE, "Name of model file does not include inbending or outbending");
58+
}
59+
60+
System.setProperty("ai.djl.pytorch.num_interop_threads", "1");
61+
System.setProperty("ai.djl.pytorch.num_threads", "1");
62+
System.setProperty("ai.djl.pytorch.graph_optimizer", "false");
63+
try {
64+
String modelPath = ClasUtilsFile.getResourceDir("CLAS12DIR", "etc/data/nnet/hbTSE/" + modelFile);
65+
66+
Criteria<float[][], float[]> criteria = Criteria.builder()
67+
.setTypes(float[][].class, float[].class)
68+
.optModelPath(Paths.get(modelPath))
69+
.optEngine("PyTorch")
70+
.optTranslator(getTranslator())
71+
.optProgress(new ProgressBar())
72+
.build();
73+
74+
model = criteria.loadModel();
75+
76+
int threads = 64;
77+
predictors = new PredictorPool(threads, model);
78+
79+
80+
} catch (IOException | ModelException e) {
81+
Logger.getLogger(getClass().getName()).log(Level.SEVERE, null, e);
82+
}
83+
}
84+
85+
// ---------------- Translator ----------------
86+
private Translator<float[][], float[]> getTranslator() {
87+
return new Translator<float[][], float[]>() {
88+
89+
@Override
90+
public NDList processInput(TranslatorContext ctx, float[][] hits) {
91+
NDManager manager = ctx.getNDManager();
92+
int n = hits.length;
93+
94+
float[][] norm = new float[n][5];
95+
for (int i = 0; i < n; i++)
96+
for (int j = 0; j < 5; j++)
97+
norm[i][j] = (hits[i][j] - HIT_MEAN[j]) / HIT_STD[j];
98+
99+
NDArray x = manager.create(norm);
100+
x = x.reshape(1, n, 5);
101+
return new NDList(x);
102+
}
103+
104+
@Override
105+
public float[] processOutput(TranslatorContext ctx, NDList list) {
106+
NDArray out = list.get(0); // [1,5]
107+
float[] y = out.toFloatArray();
108+
109+
for (int i = 0; i < 5; i++)
110+
y[i] = y[i] * STATE_STD[i] + STATE_MEAN[i];
111+
112+
return y;
113+
}
114+
115+
@Override
116+
public Batchifier getBatchifier() {
117+
return null;
118+
}
119+
};
120+
}
121+
122+
123+
public float[] predict(float[][] hits) {
124+
if (hits == null) return null;
125+
126+
if (hits.length == 0) {
127+
throw new IllegalArgumentException("HBInitialStateEstimator: empty hits");
128+
}
129+
130+
for (int i = 0; i < hits.length; i++) {
131+
if (hits[i].length != 5) {
132+
throw new IllegalArgumentException(
133+
"Expect 5 features per hit, got " + hits[i].length
134+
);
135+
}
136+
}
137+
138+
try {
139+
Predictor<float[][], float[]> predictor = predictors.take();
140+
try {
141+
return predictor.predict(hits);
142+
} finally {
143+
predictors.put(predictor);
144+
}
145+
} catch (TranslateException | InterruptedException e) {
146+
throw new RuntimeException(e);
147+
}
148+
}
149+
}
150+

reconstruction/dc/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@
4141
<artifactId>clas-tracking</artifactId>
4242
<version>13.5.2-SNAPSHOT</version>
4343
</dependency>
44+
<dependency>
45+
<groupId>org.jlab.clas12.detector</groupId>
46+
<artifactId>clas12detector-ai</artifactId>
47+
<version>13.5.2-SNAPSHOT</version>
48+
</dependency>
4449
<dependency>
4550
<groupId>org.jlab.clas</groupId>
4651
<artifactId>clas-jcsg</artifactId>

reconstruction/dc/src/main/java/org/jlab/rec/dc/track/TrackCandListFinder.java

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import org.jlab.clas.tracking.utilities.MatrixOps.Libr;
3535
import org.jlab.clas.tracking.utilities.RungeKuttaDoca;
3636

37+
import org.jlab.rec.ai.dcHBTrackState.HBTrackStateEstimator;
38+
3739
/**
3840
* A class with a method implementing an algorithm that finds lists of track
3941
* candidates in the DC
@@ -204,6 +206,106 @@ private int calcInitTrkQ(double a, double TORSCALE) {
204206

205207
return q;
206208
}
209+
210+
public List<Track> getTrackCandsAI(CrossList crossList, DCGeant4Factory DcDetector, Swim dcSwim, HBTrackStateEstimator hbTSEstimator){
211+
List<Track> cands = new ArrayList();
212+
213+
for (List<Cross> aCrossList : crossList) {
214+
Track cand = new Track();
215+
216+
if (aCrossList.size() == 3 && this.PassNSuperlayerTracking(aCrossList, cand)) {
217+
cand.addAll(aCrossList);
218+
cand.setSector(aCrossList.get(0).get_Sector());
219+
int sector = cand.getSector();
220+
221+
List<Surface> measSurfaces = getMeasSurfaces(cand, DcDetector);
222+
223+
int numHits = 0;
224+
for(Surface surf : measSurfaces){
225+
numHits += surf.nMeas;
226+
}
227+
float[][] hits = new float[numHits][5];
228+
int indexHit = 0;
229+
for(int i = 0; i < measSurfaces.size(); i++){
230+
hits[indexHit][0] = (float)measSurfaces.get(i).doca[0];
231+
hits[indexHit][1] = (float)measSurfaces.get(i).wireLine[0].origin().x();
232+
hits[indexHit][2] = (float)measSurfaces.get(i).wireLine[0].end().x();
233+
hits[indexHit][3] = (float)measSurfaces.get(i).wireLine[0].end().y();
234+
hits[indexHit][4] = (float)measSurfaces.get(i).wireLine[0].end().z();
235+
236+
indexHit++;
237+
238+
if(measSurfaces.get(i).nMeas == 2){
239+
hits[indexHit][0] = (float)measSurfaces.get(i).doca[1];
240+
hits[indexHit][1] = (float)measSurfaces.get(i).wireLine[1].origin().x();
241+
hits[indexHit][2] = (float)measSurfaces.get(i).wireLine[1].end().x();
242+
hits[indexHit][3] = (float)measSurfaces.get(i).wireLine[1].end().y();
243+
hits[indexHit][4] = (float)measSurfaces.get(i).wireLine[1].end().z();
244+
245+
indexHit++;
246+
}
247+
}
248+
249+
float[] estSV = hbTSEstimator.predict(hits);
250+
StateVecs svs = new StateVecs();
251+
org.jlab.clas.tracking.kalmanfilter.AStateVecs.StateVec initSV = svs.new StateVec(0);
252+
initSV.x = estSV[0];
253+
initSV.y = estSV[1];
254+
initSV.z = 229.; // State vector at z = 229 for AI training samples
255+
initSV.tx = estSV[2];
256+
initSV.ty = estSV[3];
257+
initSV.Q = estSV[4];
258+
259+
RungeKuttaDoca rk = new RungeKuttaDoca();
260+
rk.SwimToZ(sector, initSV, dcSwim, measSurfaces.get(0).wireLine[0].end().z(), new float[3]);
261+
262+
KFitter kFZRef = new KFitter(true, 1, 1, dcSwim, Constants.getInstance().Z, Libr.JNP);
263+
Matrix initCMatrix = new Matrix();
264+
initSV.CM = new Matrix();
265+
kFZRef.init(measSurfaces, initSV);
266+
267+
org.jlab.clas.tracking.kalmanfilter.AStateVecs.StateVec finalSV = svs.new StateVec(initSV);
268+
rk.SwimToZ(sector, finalSV, dcSwim, measSurfaces.get(measSurfaces.size()-1).wireLine[0].end().z(), new float[3]);
269+
kFZRef.getStateVecs().transported(true).put(measSurfaces.size()-1, finalSV);
270+
271+
kFZRef.setSvzLength(measSurfaces.size());
272+
kFZRef.calcFinalChisq(sector, true);
273+
274+
StateVec stateVec = new StateVec(finalSV.x,
275+
finalSV.y, finalSV.tx, finalSV.ty);
276+
int q = (int) Math.signum(finalSV.Q);
277+
double p = 1. / Math.abs(finalSV.Q);
278+
stateVec.setZ(finalSV.z);
279+
280+
//set the track parameters
281+
cand.set_P(p);
282+
cand.set_Q(q);
283+
284+
// candidate parameters
285+
cand.set_FitChi2(kFZRef.chi2);
286+
cand.set_FitNDF(kFZRef.NDF);
287+
288+
cand.setFinalStateVec(stateVec);
289+
cand.set_Id(cands.size() + 1);
290+
this.setTrackPars(cand, null,
291+
null, stateVec,
292+
stateVec.getZ(),
293+
DcDetector, dcSwim);
294+
295+
if(cand.get_Vtx0() != null){
296+
Point3D VTCS = cand.get(cand.size()-1).getCoordsInTiltedSector(cand.get_Vtx0().x(), cand.get_Vtx0().y(), cand.get_Vtx0().z());
297+
double deltaPathToVtx = kFZRef.getDeltaPathToVtx(sector, VTCS.z());
298+
299+
List<org.jlab.rec.dc.trajectory.StateVec> kfStateVecsAlongTrajectory = setKFStateVecsAlongTrajectory(kFZRef, deltaPathToVtx);
300+
cand.setStateVecs(kfStateVecsAlongTrajectory);
301+
}
302+
303+
if (kFZRef.chi2 < Constants.MAXCHI2) cands.add(cand);
304+
}
305+
}
306+
307+
return cands;
308+
}
207309

208310
/**
209311
* @param crossList the input list of crosses

reconstruction/dc/src/main/java/org/jlab/service/dc/DCEngine.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.jlab.rec.dc.banks.Banks;
1313
import org.jlab.clas.tracking.kalmanfilter.zReference.KFitter;
1414
import org.jlab.clas.tracking.kalmanfilter.zReference.DAFilter;
15+
import org.jlab.rec.ai.dcHBTrackState.HBTrackStateEstimator;
1516

1617
public class DCEngine extends ReconstructionEngine {
1718

@@ -39,6 +40,9 @@ public class DCEngine extends ReconstructionEngine {
3940
private String dafChi2Cut = null;
4041
private String dafAnnealingFactorsTB = null;
4142

43+
protected String hbTSEModelFileInbending = "transformer_32d_4h_3l_inbending.pt"; // AI model file for HB track state estimator for inbending runs
44+
protected String hbTSEModelFileOutbending = "transformer_32d_4h_3l_outbending.pt"; // AI model file for HB track state estimator for outbending runs
45+
4246
public static final Logger LOGGER = Logger.getLogger(ReconstructionEngine.class.getName());
4347

4448

@@ -127,6 +131,14 @@ else if(this.getEngineConfigString("dcT2DFunc").equalsIgnoreCase("Polynomial"))
127131
dafAnnealingFactorsTB=this.getEngineConfigString("dafAnnealingFactorsTB");
128132
KFitter.setDafAnnealingFactorsTB(dafAnnealingFactorsTB);
129133
}
134+
135+
if (getEngineConfigString("hbTSEModelFileInbending") != null){
136+
hbTSEModelFileInbending = getEngineConfigString("hbTSEModelFileInbending");
137+
}
138+
139+
if (getEngineConfigString("hbTSEModelFileOutbending") != null){
140+
hbTSEModelFileOutbending = getEngineConfigString("hbTSEModelFileOutbending");
141+
}
130142

131143
// Set geometry shifts for alignment code
132144
if(this.getEngineConfigString("alignmentShifts")!=null) {

0 commit comments

Comments
 (0)