Skip to content

Commit bfd78f0

Browse files
mathieuouillonbaltzell
authored andcommitted
Refactor: Encapsulate model preparation and loading into a dedicated class
Moved the preparation of the model and loading logic into a specific class to encapsulate functionality and reduce the number of imports.
1 parent 5c6be95 commit bfd78f0

File tree

2 files changed

+63
-54
lines changed

2 files changed

+63
-54
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package org.jlab.rec.ahdc.AI;
2+
3+
import ai.djl.MalformedModelException;
4+
import ai.djl.ndarray.NDArray;
5+
import ai.djl.ndarray.NDList;
6+
import ai.djl.ndarray.NDManager;
7+
import ai.djl.ndarray.types.Shape;
8+
import ai.djl.repository.zoo.Criteria;
9+
import ai.djl.repository.zoo.ModelNotFoundException;
10+
import ai.djl.repository.zoo.ZooModel;
11+
import ai.djl.training.util.ProgressBar;
12+
import ai.djl.translate.Translator;
13+
import ai.djl.translate.TranslatorContext;
14+
import org.jlab.utils.CLASResources;
15+
16+
import java.io.IOException;
17+
import java.nio.file.Paths;
18+
19+
public class Model {
20+
private ZooModel<float[], Float> model;
21+
22+
public Model() {
23+
Translator<float[], Float> my_translator = new Translator<float[], Float>() {
24+
@Override
25+
public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
26+
return ndList.get(0).getFloat();
27+
}
28+
29+
@Override
30+
public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception {
31+
NDManager manager = NDManager.newBaseManager();
32+
NDArray samples = manager.zeros(new Shape(floats.length));
33+
samples.set(floats);
34+
return new NDList(samples);
35+
}
36+
};
37+
38+
String path = CLASResources.getResourcePath("etc/nnet/ALERT/model_AHDC/");
39+
Criteria<float[], Float> my_model = Criteria.builder().setTypes(float[].class, Float.class)
40+
.optModelPath(Paths.get("etc/nnet/ALERT/model_AHDC/"))
41+
.optEngine("PyTorch")
42+
.optTranslator(my_translator)
43+
.optProgress(new ProgressBar())
44+
.build();
45+
46+
47+
try {
48+
model = my_model.loadModel();
49+
} catch (IOException | ModelNotFoundException | MalformedModelException e) {
50+
throw new RuntimeException(e);
51+
}
52+
53+
}
54+
55+
public ZooModel<float[], Float> getModel() {
56+
return model;
57+
}
58+
}

reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java

Lines changed: 5 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,12 @@
11
package org.jlab.rec.service;
22

3-
import ai.djl.MalformedModelException;
4-
import ai.djl.ndarray.NDArray;
5-
import ai.djl.ndarray.NDList;
6-
import ai.djl.ndarray.NDManager;
7-
import ai.djl.ndarray.types.Shape;
8-
import ai.djl.repository.zoo.Criteria;
9-
import ai.djl.repository.zoo.ModelNotFoundException;
10-
import ai.djl.repository.zoo.ZooModel;
11-
import ai.djl.training.util.ProgressBar;
12-
import ai.djl.translate.TranslateException;
13-
import ai.djl.translate.Translator;
14-
import ai.djl.translate.TranslatorContext;
153
import org.jlab.clas.reco.ReconstructionEngine;
164
import org.jlab.clas.tracking.kalmanfilter.Material;
175
import org.jlab.io.base.DataBank;
186
import org.jlab.io.base.DataEvent;
197
import org.jlab.io.hipo.HipoDataSource;
208
import org.jlab.io.hipo.HipoDataSync;
21-
import org.jlab.jnp.hipo4.data.SchemaFactory;
22-
import org.jlab.rec.ahdc.AI.AIPrediction;
23-
import org.jlab.rec.ahdc.AI.PreClustering;
24-
import org.jlab.rec.ahdc.AI.PreclusterSuperlayer;
25-
import org.jlab.rec.ahdc.AI.TrackConstruction;
26-
import org.jlab.rec.ahdc.AI.TrackPrediction;
9+
import org.jlab.rec.ahdc.AI.*;
2710
import org.jlab.rec.ahdc.Banks.RecoBankWriter;
2811
import org.jlab.rec.ahdc.Cluster.Cluster;
2912
import org.jlab.rec.ahdc.Cluster.ClusterFinder;
@@ -38,11 +21,8 @@
3821
import org.jlab.rec.ahdc.PreCluster.PreCluster;
3922
import org.jlab.rec.ahdc.PreCluster.PreClusterFinder;
4023
import org.jlab.rec.ahdc.Track.Track;
41-
import org.jlab.utils.CLASResources;
4224

4325
import java.io.File;
44-
import java.io.IOException;
45-
import java.nio.file.Paths;
4626
import java.util.*;
4727

4828
public class AHDCEngine extends ReconstructionEngine {
@@ -51,7 +31,7 @@ public class AHDCEngine extends ReconstructionEngine {
5131
private boolean use_AI_for_trackfinding;
5232
private String findingMethod;
5333
private HashMap<String, Material> materialMap;
54-
private ZooModel<float[], Float> model;
34+
private Model model;
5535

5636
public AHDCEngine() {
5737
super("ALERT", "ouillon", "1.0.1");
@@ -67,36 +47,7 @@ public boolean init() {
6747
materialMap = MaterialMap.generateMaterials();
6848
}
6949

70-
Translator<float[], Float> my_translator = new Translator<float[], Float>() {
71-
@Override
72-
public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
73-
return ndList.get(0).getFloat();
74-
}
75-
76-
@Override
77-
public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception {
78-
NDManager manager = NDManager.newBaseManager();
79-
NDArray samples = manager.zeros(new Shape(floats.length));
80-
samples.set(floats);
81-
return new NDList(samples);
82-
}
83-
};
84-
85-
String path = CLASResources.getResourcePath("etc/nnet/ALERT/model_AHDC/");
86-
Criteria<float[], Float> my_model = Criteria.builder().setTypes(float[].class, Float.class)
87-
.optModelPath(Paths.get(path))
88-
.optEngine("PyTorch")
89-
.optTranslator(my_translator)
90-
.optProgress(new ProgressBar())
91-
.build();
92-
93-
94-
try {
95-
model = my_model.loadModel();
96-
} catch (IOException | ModelNotFoundException | MalformedModelException e) {
97-
throw new RuntimeException(e);
98-
}
99-
50+
model = new Model();
10051

10152
return true;
10253
}
@@ -182,8 +133,8 @@ public int compare(Hit a1, Hit a2) {
182133

183134
try {
184135
AIPrediction aiPrediction = new AIPrediction();
185-
predictions = aiPrediction.prediction(tracks, model);
186-
} catch (ModelNotFoundException | MalformedModelException | IOException | TranslateException e) {
136+
predictions = aiPrediction.prediction(tracks, model.getModel());
137+
} catch (Exception e) {
187138
throw new RuntimeException(e);
188139
}
189140

0 commit comments

Comments
 (0)