11package org .jlab .rec .service ;
22
3- import ai .djl .MalformedModelException ;
4- import ai .djl .ndarray .NDArray ;
5- import ai .djl .ndarray .NDList ;
6- import ai .djl .ndarray .NDManager ;
7- import ai .djl .ndarray .types .Shape ;
8- import ai .djl .repository .zoo .Criteria ;
9- import ai .djl .repository .zoo .ModelNotFoundException ;
10- import ai .djl .repository .zoo .ZooModel ;
11- import ai .djl .training .util .ProgressBar ;
12- import ai .djl .translate .TranslateException ;
13- import ai .djl .translate .Translator ;
14- import ai .djl .translate .TranslatorContext ;
153import org .jlab .clas .reco .ReconstructionEngine ;
164import org .jlab .clas .tracking .kalmanfilter .Material ;
175import org .jlab .io .base .DataBank ;
186import org .jlab .io .base .DataEvent ;
197import org .jlab .io .hipo .HipoDataSource ;
208import org .jlab .io .hipo .HipoDataSync ;
21- import org .jlab .jnp .hipo4 .data .SchemaFactory ;
22- import org .jlab .rec .ahdc .AI .AIPrediction ;
23- import org .jlab .rec .ahdc .AI .PreClustering ;
24- import org .jlab .rec .ahdc .AI .PreclusterSuperlayer ;
25- import org .jlab .rec .ahdc .AI .TrackConstruction ;
26- import org .jlab .rec .ahdc .AI .TrackPrediction ;
9+ import org .jlab .rec .ahdc .AI .*;
2710import org .jlab .rec .ahdc .Banks .RecoBankWriter ;
2811import org .jlab .rec .ahdc .Cluster .Cluster ;
2912import org .jlab .rec .ahdc .Cluster .ClusterFinder ;
3821import org .jlab .rec .ahdc .PreCluster .PreCluster ;
3922import org .jlab .rec .ahdc .PreCluster .PreClusterFinder ;
4023import org .jlab .rec .ahdc .Track .Track ;
41- import org .jlab .utils .CLASResources ;
4224
4325import java .io .File ;
44- import java .io .IOException ;
45- import java .nio .file .Paths ;
4626import java .util .*;
4727
4828public class AHDCEngine extends ReconstructionEngine {
@@ -51,7 +31,7 @@ public class AHDCEngine extends ReconstructionEngine {
5131 private boolean use_AI_for_trackfinding ;
5232 private String findingMethod ;
5333 private HashMap <String , Material > materialMap ;
54- private ZooModel < float [], Float > model ;
34+ private Model model ;
5535
5636 public AHDCEngine () {
5737 super ("ALERT" , "ouillon" , "1.0.1" );
@@ -67,36 +47,7 @@ public boolean init() {
6747 materialMap = MaterialMap .generateMaterials ();
6848 }
6949
70- Translator <float [], Float > my_translator = new Translator <float [], Float >() {
71- @ Override
72- public Float processOutput (TranslatorContext translatorContext , NDList ndList ) throws Exception {
73- return ndList .get (0 ).getFloat ();
74- }
75-
76- @ Override
77- public NDList processInput (TranslatorContext translatorContext , float [] floats ) throws Exception {
78- NDManager manager = NDManager .newBaseManager ();
79- NDArray samples = manager .zeros (new Shape (floats .length ));
80- samples .set (floats );
81- return new NDList (samples );
82- }
83- };
84-
85- String path = CLASResources .getResourcePath ("etc/nnet/ALERT/model_AHDC/" );
86- Criteria <float [], Float > my_model = Criteria .builder ().setTypes (float [].class , Float .class )
87- .optModelPath (Paths .get (path ))
88- .optEngine ("PyTorch" )
89- .optTranslator (my_translator )
90- .optProgress (new ProgressBar ())
91- .build ();
92-
93-
94- try {
95- model = my_model .loadModel ();
96- } catch (IOException | ModelNotFoundException | MalformedModelException e ) {
97- throw new RuntimeException (e );
98- }
99-
50+ model = new Model ();
10051
10152 return true ;
10253 }
@@ -182,8 +133,8 @@ public int compare(Hit a1, Hit a2) {
182133
183134 try {
184135 AIPrediction aiPrediction = new AIPrediction ();
185- predictions = aiPrediction .prediction (tracks , model );
186- } catch (ModelNotFoundException | MalformedModelException | IOException | TranslateException e ) {
136+ predictions = aiPrediction .prediction (tracks , model . getModel () );
137+ } catch (Exception e ) {
187138 throw new RuntimeException (e );
188139 }
189140
0 commit comments