1111import ai .djl .training .util .ProgressBar ;
1212import ai .djl .translate .Translator ;
1313import ai .djl .translate .TranslatorContext ;
14+ import ai .djl .translate .Batchifier ;
1415import ai .djl .inference .Predictor ;
1516import ai .djl .repository .zoo .ModelNotFoundException ;
16- import ai .djl .translate .Batchifier ;
1717import ai .djl .translate .TranslateException ;
18+
1819import java .io .IOException ;
1920import java .util .concurrent .BlockingQueue ;
2021import java .util .concurrent .ArrayBlockingQueue ;
22+ import java .util .logging .Level ;
23+ import java .util .logging .Logger ;
2124
2225import org .jlab .clas .reco .ReconstructionEngine ;
2326import org .jlab .io .base .DataBank ;
2730public class DCDenoiseEngine extends ReconstructionEngine {
2831
2932 final static String [] BANK_NAMES = {"DC::tot" ,"DC::tdc" };
30- final static String CONF_THRESHOLD = "threshold" ;
33+ final static String CONF_MODEL_FILE = "modelFile" ;
34+ final static String CONF_THRESHOLD = "threshold" ;
3135 final static String CONF_THREADS = "threads" ;
36+
3237 final static int LAYERS = 36 ;
3338 final static int WIRES = 112 ;
39+ final static int SECTORS = 6 ;
3440
35- float threshold = 0.025f ;
36- Criteria <float [][],float [][]> criteria ;
37- ZooModel <float [][], float [][]> model ;
41+ String modelFile = "cnn_autoenc_sector1_2b_48f_4x6k.pt" ;
42+ float threshold = 0.03f ;
43+ Criteria <float [][][], float [][][]> criteria ;
44+ ZooModel <float [][][], float [][][]> model ;
3845 PredictorPool predictors ;
39-
46+
47+ // -------- Predictor Pool --------
4048 public static class PredictorPool {
41- final BlockingQueue <Predictor > pool ;
42- public PredictorPool (int size , ZooModel model ) {
49+ final BlockingQueue <Predictor < float [][][], float [][][]> > pool ;
50+ public PredictorPool (int size , ZooModel < float [][][], float [][][]> model ) {
4351 pool = new ArrayBlockingQueue <>(size );
44- for (int i =0 ; i <size ; i ++) pool . add ( model . newPredictor ());
45- }
46- public Predictor get () throws InterruptedException {
47- return pool . poll ();
48- }
49- public void put ( Predictor p ) {
50- if ( p != null ) pool . offer ( p );
52+ for (int i =0 ; i <size ; i ++) {
53+ try {
54+ pool . add ( model . newPredictor ());
55+ } catch ( Exception e ) {
56+ Logger . getLogger ( PredictorPool . class . getName ()). log ( Level . WARNING , "Failed to create predictor" , e );
57+ }
58+ }
5159 }
60+ public Predictor <float [][][], float [][][]> take () throws InterruptedException { return pool .take (); }
61+ public void put (Predictor <float [][][], float [][][]> p ) throws InterruptedException { if (p !=null ) pool .put (p ); }
62+ public void shutdownAll () { for (Predictor p : pool ) { try { p .close (); } catch (Exception ignored ) {} } }
5263 }
5364
5465 public DCDenoiseEngine () {
@@ -62,173 +73,138 @@ public boolean init() {
6273 System .setProperty ("ai.djl.pytorch.graph_optimizer" , "false" );
6374 if (getEngineConfigString (CONF_THRESHOLD ) != null )
6475 threshold = Float .parseFloat (getEngineConfigString (CONF_THRESHOLD ));
76+ if (getEngineConfigString (CONF_MODEL_FILE ) != null )
77+ modelFile = getEngineConfigString (CONF_MODEL_FILE );
78+
6579 try {
80+ String modelPath = ClasUtilsFile .getResourceDir ("CLAS12DIR" , "etc/data/nnet/dn/" + modelFile );
81+
6682 criteria = Criteria .builder ()
67- .setTypes (float [][].class , float [][].class )
68- .optModelPath (Paths .get (ClasUtilsFile . getResourceDir ( "CLAS12DIR" , "etc/data/nnet/dn/cnn_autoenc_sector1_nBlocks2.pt" ) ))
83+ .setTypes (float [][][] .class , float [] [][].class )
84+ .optModelPath (Paths .get (modelPath ))
6985 .optEngine ("PyTorch" )
70- .optTranslator (DCDenoiseEngine .getTranslator ())
86+ .optTranslator (DCDenoiseEngine .getBatchTranslator ())
7187 .optProgress (new ProgressBar ())
7288 .build ();
89+
7390 model = criteria .loadModel ();
91+
7492 int threads = Integer .parseInt (getEngineConfigString (CONF_THREADS ,"64" ));
7593 predictors = new PredictorPool (threads , model );
7694 return true ;
7795 } catch (NullPointerException | MalformedModelException | IOException | ModelNotFoundException ex ) {
78- System .getLogger (DCDenoiseEngine .class .getName ()).log (System . Logger . Level .ERROR , ( String ) null , ex );
96+ Logger .getLogger (DCDenoiseEngine .class .getName ()).log (Level .SEVERE , null , ex );
7997 return false ;
8098 }
8199 }
82100
83- public static void main (String args []){
84- DCDenoiseEngine dn = new DCDenoiseEngine ();
85- dn .init ();
86- for (int i =0 ; i <10000 ; i ++) {
87- dn .processFakeEvent ();
88- }
89- }
90-
91101 @ Override
92102 public boolean processDataEvent (DataEvent event ) {
103+ for (String bankName : BANK_NAMES ) {
104+ if (!event .hasBank (bankName )) continue ;
93105
94- //if (true) return processFakeEvent();
95-
96- for (int i =0 ; i <BANK_NAMES .length ; i ++){
97- if (event .hasBank (BANK_NAMES [i ])) {
98- DataBank bank = event .getBank (BANK_NAMES [i ]);
99- try {
100- // WARNING: Predictor is *not* thread safe.
101- Predictor <float [][], float [][]> predictor = predictors .get ();
102- for (int sector =0 ; sector <6 ; sector ++) {
103- float [][] input = DCDenoiseEngine .read (bank , sector +1 );
104- float [][] output = predictor .predict (input );
105- //System.out.println("IN:");show(input);
106- //System.out.println("OUT:");show(output);
107- update (bank , threshold , output , sector );
106+ DataBank bank = event .getBank (bankName );
107+ try {
108+ // Build batch for 6 sectors
109+ float [][][] batchInput = new float [SECTORS ][LAYERS ][WIRES ];
110+ boolean anySectorPresent = false ;
111+ int rows = bank .rows ();
112+ for (int r =0 ; r <rows ; r ++) {
113+ int sector = bank .getByte (0 ,r ); // 1..6
114+ if (sector < 1 || sector > SECTORS ) continue ;
115+ int layer = bank .getByte (1 ,r );
116+ int wire = bank .getShort (2 ,r );
117+ byte order = bank .getByte (3 ,r );
118+ if ((order ==0 )||(order ==10 )) {
119+ batchInput [sector -1 ][layer -1 ][wire -1 ]=1.0f ;
120+ anySectorPresent = true ;
108121 }
109- predictors .put (predictor );
110- event .removeBank (BANK_NAMES [i ]);
111- event .appendBank (bank );
112122 }
113- catch (TranslateException | InterruptedException e ) {
114- throw new RuntimeException (e );
115- }
116- break ;
117- }
118- }
119- return true ;
120- }
121123
122- boolean processFakeEvent () {
123- try {
124- Predictor <float [][], float [][]> predictor = model .newPredictor ();
125- float [][] input = getAlmostStraightSlightlyBendingTrack ();
126- float [][] output = predictor .predict (input );
127- //System.out.println("IN:");show(input);
128- //System.out.println("OUT:");show(output);
129- }
130- catch (TranslateException e ) {
131- throw new RuntimeException (e );
132- }
133- return true ;
134- }
135-
136- /**
137- * Reject sub-threshold hits by modifying the bank's order variable.
138- * WARNING: This is not a full implementation of OrderType enum and
139- * all its names, but for now a copy of the subset in C++ DC denoising, see:
140- * https://code.jlab.org/hallb/clas12/coatjava/denoising/-/blob/main/denoising/code/drift.cc?ref_type=heads#L162-198
141- */
142- static void update (DataBank b , float threshold , float [][] data , int sector ) {
143- //System.out.println("IN:");b.show();
144- for (int row =0 ; row <b .rows (); row ++) {
145- if (b .getByte (0 ,row )-1 != sector ) continue ;
146- if (data [b .getByte (1 ,row )-1 ][b .getShort (2 ,row )-1 ] < threshold ) {
147- if (b .getByte (3 ,row ) == 0 ) b .setByte (3 , row , (byte )(60 ));
148- if (b .getByte (3 ,row ) == 10 ) b .setByte (3 , row , (byte )(90 ));
149- }
150- }
151- //System.out.println("OUT:");b.show();
152- }
124+ if (!anySectorPresent ) continue ;
153125
154- /**
155- * Get one-sector data with weights set to 0/1.
156- */
157- static float [][] read (DataBank bank , int sector ) {
158- float [][] data = new float [LAYERS ][WIRES ];
159- for (int i =0 ; i <bank .rows (); ++i ) {
160- if (bank .getByte (0 ,i ) == sector ) {
161- byte o = bank .getByte (3 ,i );
162- if (0 ==o || 10 ==o )
163- // got a hit, set weight to one:
164- data [bank .getByte (1 ,i )-1 ][bank .getShort (2 ,i )-1 ] = 1.0f ;
165- }
166- }
167- return data ;
168- }
126+ Predictor <float [][][], float [][][]> predictor = predictors .take ();
127+ float [][][] batchOutput ;
128+ try {
129+ batchOutput = predictor .predict (batchInput );
130+ } finally {
131+ predictors .put (predictor );
132+ }
169133
170- /**
171- * Print all hits for one sector.
172- */
173- static void show (float [][] data ) {
174- System .out .println ("Shape: [" + data .length + "," + data [0 ].length + "]" );
175- for (int i = 0 ; i < LAYERS ; i ++) {
176- for (int j = 0 ; j < WIRES ; j ++)
177- System .out .printf ("%.3f " , data [i ][j ]);
178- System .out .println ();
179- }
180- }
134+ for (int sectorIdx =0 ; sectorIdx <SECTORS ; sectorIdx ++) {
135+ update (bank , threshold , batchOutput [sectorIdx ], sectorIdx );
136+ }
181137
182- /**
183- * @return a dummy sector/track
184- */
185- static float [][] getAlmostStraightSlightlyBendingTrack () {
186- float [][] data = new float [LAYERS ][WIRES ];
187- for (int y = 0 ; y < LAYERS ; y ++) {
188- int x = 50 + (y / 10 );
189- data [y ][x ] = 1.0f ;
138+ event .removeBank (bankName );
139+ event .appendBank (bank );
140+ } catch (TranslateException | InterruptedException e ) {
141+ throw new RuntimeException (e );
142+ }
143+ break ;
190144 }
191- return data ;
145+ return true ;
192146 }
193147
194- public static Translator <float [][],float [][]> getTranslator () {
195- return new Translator <float [][],float [][]>() {
148+ // -------- Translator for batch --------
149+ public static Translator <float [][][], float [][][]> getBatchTranslator () {
150+ return new Translator <float [][][], float [][][]>() {
196151 @ Override
197- public NDList processInput (TranslatorContext ctx , float [][] input ) throws Exception {
152+ public NDList processInput (TranslatorContext ctx , float [][][] input ) {
153+ int batch = input .length ;
154+ int height = input [0 ].length ;
155+ int width = input [0 ][0 ].length ;
156+ float [] flat = new float [batch *height *width ];
157+ int pos =0 ;
158+ for (int b =0 ; b <batch ; b ++)
159+ for (int h =0 ; h <height ; h ++) {
160+ System .arraycopy (input [b ][h ],0 ,flat ,pos ,width );
161+ pos +=width ;
162+ }
198163 NDManager manager = ctx .getNDManager ();
199- int height = input .length ;
200- int width = input [0 ].length ;
201- float [] flat = new float [height * width ];
202- for (int i = 0 ; i < height ; i ++) {
203- System .arraycopy (input [i ], 0 , flat , i * width , width );
204- }
205- NDArray x = manager .create (flat , new Shape (height , width ));
206- // Add batch and channel dims -> [1,1,36,112]
207- x = x .expandDims (0 ).expandDims (0 );
164+ NDArray x = manager .create (flat , new Shape (batch ,1 ,height ,width ));
208165 return new NDList (x );
209166 }
167+
210168 @ Override
211- public float [][] processOutput (TranslatorContext ctx , NDList list ) throws Exception {
169+ public float [][][] processOutput (TranslatorContext ctx , NDList list ) {
212170 NDArray result = list .get (0 );
213- // Remove batch and channel dims -> [36,112]
214- result = result .squeeze ();
215- // Convert to 1D float array
216- float [] flat = result .toFloatArray ();
217- // Reshape manually into 2D array
218171 long [] shape = result .getShape ().getShape ();
219- int height = (int ) shape [0 ];
220- int width = (int ) shape [1 ];
221- float [][] output2d = new float [height ][width ];
222- for (int i = 0 ; i < height ; i ++) {
223- System .arraycopy (flat , i * width , output2d [i ], 0 , width );
224- }
225- return output2d ;
172+ int batch = (int )shape [0 ];
173+ int height , width ;
174+ if (shape .length ==4 && shape [1 ]==1 ) {
175+ height =(int )shape [2 ]; width =(int )shape [3 ];
176+ result = result .squeeze (1 );
177+ } else if (shape .length ==3 ) {
178+ height =(int )shape [1 ]; width =(int )shape [2 ];
179+ } else throw new IllegalStateException ("Unexpected output shape: " +java .util .Arrays .toString (shape ));
180+ float [] flat = result .toFloatArray ();
181+ float [][][] out = new float [batch ][height ][width ];
182+ int pos =0 ;
183+ for (int b =0 ;b <batch ;b ++)
184+ for (int h =0 ;h <height ;h ++) {
185+ System .arraycopy (flat ,pos ,out [b ][h ],0 ,width );
186+ pos +=width ;
187+ }
188+ return out ;
226189 }
190+
227191 @ Override
228- public Batchifier getBatchifier () {
229- return null ; // no batching
230- }
192+ public Batchifier getBatchifier () { return null ; }
231193 };
232194 }
233195
196+ // -------- Update single sector in bank --------
197+ static void update (DataBank b , float threshold , float [][] data , int sectorIdx ) {
198+ for (int row =0 ; row <b .rows (); row ++) {
199+ if (b .getByte (0 ,row )-1 != sectorIdx ) continue ;
200+ int layer =b .getByte (1 ,row )-1 ;
201+ int wire =b .getShort (2 ,row )-1 ;
202+ if (layer <0 || layer >=data .length ) continue ;
203+ if (wire <0 || wire >=data [0 ].length ) continue ;
204+ if (data [layer ][wire ]<threshold ) {
205+ if (b .getByte (3 ,row )==0 ) b .setByte (3 ,row ,(byte )60 );
206+ if (b .getByte (3 ,row )==10 ) b .setByte (3 ,row ,(byte )90 );
207+ }
208+ }
209+ }
234210}
0 commit comments