Skip to content

Commit e77c629

Browse files
tongtongcaobaltzell
andcommitted
Make denoising deterministic and speed it up (#968)
* update DCDenoiseEngine.java * speed up denoising * tiny changes * change default threshold * remove network file --------- Co-authored-by: Nathan Baltzell <[email protected]>
1 parent 33b5307 commit e77c629

File tree

1 file changed

+122
-146
lines changed

1 file changed

+122
-146
lines changed

reconstruction/ai/src/main/java/org/jlab/service/ai/DCDenoiseEngine.java

Lines changed: 122 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
import ai.djl.training.util.ProgressBar;
1212
import ai.djl.translate.Translator;
1313
import ai.djl.translate.TranslatorContext;
14+
import ai.djl.translate.Batchifier;
1415
import ai.djl.inference.Predictor;
1516
import ai.djl.repository.zoo.ModelNotFoundException;
16-
import ai.djl.translate.Batchifier;
1717
import ai.djl.translate.TranslateException;
18+
1819
import java.io.IOException;
1920
import java.util.concurrent.BlockingQueue;
2021
import java.util.concurrent.ArrayBlockingQueue;
22+
import java.util.logging.Level;
23+
import java.util.logging.Logger;
2124

2225
import org.jlab.clas.reco.ReconstructionEngine;
2326
import org.jlab.io.base.DataBank;
@@ -27,28 +30,36 @@
2730
public class DCDenoiseEngine extends ReconstructionEngine {
2831

2932
final static String[] BANK_NAMES = {"DC::tot","DC::tdc"};
30-
final static String CONF_THRESHOLD = "threshold";
33+
final static String CONF_MODEL_FILE = "modelFile";
34+
final static String CONF_THRESHOLD = "threshold";
3135
final static String CONF_THREADS = "threads";
36+
3237
final static int LAYERS = 36;
3338
final static int WIRES = 112;
39+
final static int SECTORS= 6;
3440

35-
float threshold = 0.025f;
36-
Criteria<float[][],float[][]> criteria;
37-
ZooModel<float[][], float[][]> model;
41+
String modelFile = "cnn_autoenc_sector1_2b_48f_4x6k.pt";
42+
float threshold = 0.03f;
43+
Criteria<float[][][], float[][][]> criteria;
44+
ZooModel<float[][][], float[][][]> model;
3845
PredictorPool predictors;
39-
46+
47+
// -------- Predictor Pool --------
4048
public static class PredictorPool {
41-
final BlockingQueue<Predictor> pool;
42-
public PredictorPool(int size, ZooModel model) {
49+
final BlockingQueue<Predictor<float[][][], float[][][]>> pool;
50+
public PredictorPool(int size, ZooModel<float[][][], float[][][]> model) {
4351
pool = new ArrayBlockingQueue<>(size);
44-
for (int i=0; i<size; i++) pool.add(model.newPredictor());
45-
}
46-
public Predictor get() throws InterruptedException {
47-
return pool.poll();
48-
}
49-
public void put(Predictor p) {
50-
if (p != null) pool.offer(p);
52+
for (int i=0; i<size; i++) {
53+
try {
54+
pool.add(model.newPredictor());
55+
} catch (Exception e) {
56+
Logger.getLogger(PredictorPool.class.getName()).log(Level.WARNING, "Failed to create predictor", e);
57+
}
58+
}
5159
}
60+
public Predictor<float[][][], float[][][]> take() throws InterruptedException { return pool.take(); }
61+
public void put(Predictor<float[][][], float[][][]> p) throws InterruptedException { if (p!=null) pool.put(p); }
62+
public void shutdownAll() { for (Predictor p: pool) { try { p.close(); } catch (Exception ignored) {} } }
5263
}
5364

5465
public DCDenoiseEngine() {
@@ -62,173 +73,138 @@ public boolean init() {
6273
System.setProperty("ai.djl.pytorch.graph_optimizer", "false");
6374
if (getEngineConfigString(CONF_THRESHOLD) != null)
6475
threshold = Float.parseFloat(getEngineConfigString(CONF_THRESHOLD));
76+
if (getEngineConfigString(CONF_MODEL_FILE) != null)
77+
modelFile = getEngineConfigString(CONF_MODEL_FILE);
78+
6579
try {
80+
String modelPath = ClasUtilsFile.getResourceDir("CLAS12DIR", "etc/data/nnet/dn/" + modelFile);
81+
6682
criteria = Criteria.builder()
67-
.setTypes(float[][].class, float[][].class)
68-
.optModelPath(Paths.get(ClasUtilsFile.getResourceDir("CLAS12DIR","etc/data/nnet/dn/cnn_autoenc_sector1_nBlocks2.pt")))
83+
.setTypes(float[][][].class, float[][][].class)
84+
.optModelPath(Paths.get(modelPath))
6985
.optEngine("PyTorch")
70-
.optTranslator(DCDenoiseEngine.getTranslator())
86+
.optTranslator(DCDenoiseEngine.getBatchTranslator())
7187
.optProgress(new ProgressBar())
7288
.build();
89+
7390
model = criteria.loadModel();
91+
7492
int threads = Integer.parseInt(getEngineConfigString(CONF_THREADS,"64"));
7593
predictors = new PredictorPool(threads, model);
7694
return true;
7795
} catch (NullPointerException | MalformedModelException | IOException | ModelNotFoundException ex) {
78-
System.getLogger(DCDenoiseEngine.class.getName()).log(System.Logger.Level.ERROR, (String) null, ex);
96+
Logger.getLogger(DCDenoiseEngine.class.getName()).log(Level.SEVERE, null, ex);
7997
return false;
8098
}
8199
}
82100

83-
public static void main(String args[]){
84-
DCDenoiseEngine dn = new DCDenoiseEngine();
85-
dn.init();
86-
for (int i=0; i<10000; i++) {
87-
dn.processFakeEvent();
88-
}
89-
}
90-
91101
@Override
92102
public boolean processDataEvent(DataEvent event) {
103+
for (String bankName : BANK_NAMES) {
104+
if (!event.hasBank(bankName)) continue;
93105

94-
//if (true) return processFakeEvent();
95-
96-
for (int i=0; i<BANK_NAMES.length; i++){
97-
if (event.hasBank(BANK_NAMES[i])) {
98-
DataBank bank = event.getBank(BANK_NAMES[i]);
99-
try {
100-
// WARNING: Predictor is *not* thread safe.
101-
Predictor<float[][], float[][]> predictor = predictors.get();
102-
for (int sector=0; sector<6; sector++) {
103-
float[][] input = DCDenoiseEngine.read(bank, sector+1);
104-
float[][] output = predictor.predict(input);
105-
//System.out.println("IN:");show(input);
106-
//System.out.println("OUT:");show(output);
107-
update(bank, threshold, output, sector);
106+
DataBank bank = event.getBank(bankName);
107+
try {
108+
// Build batch for 6 sectors
109+
float[][][] batchInput = new float[SECTORS][LAYERS][WIRES];
110+
boolean anySectorPresent = false;
111+
int rows = bank.rows();
112+
for (int r=0; r<rows; r++) {
113+
int sector = bank.getByte(0,r); // 1..6
114+
if (sector < 1 || sector > SECTORS) continue;
115+
int layer = bank.getByte(1,r);
116+
int wire = bank.getShort(2,r);
117+
byte order = bank.getByte(3,r);
118+
if ((order==0)||(order==10)) {
119+
batchInput[sector-1][layer-1][wire-1]=1.0f;
120+
anySectorPresent = true;
108121
}
109-
predictors.put(predictor);
110-
event.removeBank(BANK_NAMES[i]);
111-
event.appendBank(bank);
112122
}
113-
catch (TranslateException | InterruptedException e) {
114-
throw new RuntimeException(e);
115-
}
116-
break;
117-
}
118-
}
119-
return true;
120-
}
121123

122-
boolean processFakeEvent() {
123-
try {
124-
Predictor<float[][], float[][]> predictor = model.newPredictor();
125-
float[][] input = getAlmostStraightSlightlyBendingTrack();
126-
float[][] output = predictor.predict(input);
127-
//System.out.println("IN:");show(input);
128-
//System.out.println("OUT:");show(output);
129-
}
130-
catch (TranslateException e) {
131-
throw new RuntimeException(e);
132-
}
133-
return true;
134-
}
135-
136-
/**
137-
* Reject sub-threshold hits by modifying the bank's order variable.
138-
* WARNING: This is not a full implementation of OrderType enum and
139-
* all its names, but for now a copy of the subset in C++ DC denoising, see:
140-
* https://code.jlab.org/hallb/clas12/coatjava/denoising/-/blob/main/denoising/code/drift.cc?ref_type=heads#L162-198
141-
*/
142-
static void update(DataBank b, float threshold, float[][] data, int sector) {
143-
//System.out.println("IN:");b.show();
144-
for (int row=0; row<b.rows(); row++) {
145-
if (b.getByte(0,row)-1 != sector) continue;
146-
if (data[b.getByte(1,row)-1][b.getShort(2,row)-1] < threshold) {
147-
if(b.getByte(3,row) == 0) b.setByte(3, row, (byte)(60));
148-
if(b.getByte(3,row) == 10) b.setByte(3, row, (byte)(90));
149-
}
150-
}
151-
//System.out.println("OUT:");b.show();
152-
}
124+
if (!anySectorPresent) continue;
153125

154-
/**
155-
* Get one-sector data with weights set to 0/1.
156-
*/
157-
static float[][] read(DataBank bank, int sector) {
158-
float[][] data = new float[LAYERS][WIRES];
159-
for (int i=0; i<bank.rows(); ++i) {
160-
if (bank.getByte(0,i) == sector) {
161-
byte o = bank.getByte(3,i);
162-
if (0==o || 10==o)
163-
// got a hit, set weight to one:
164-
data[bank.getByte(1,i)-1][bank.getShort(2,i)-1] = 1.0f;
165-
}
166-
}
167-
return data;
168-
}
126+
Predictor<float[][][], float[][][]> predictor = predictors.take();
127+
float[][][] batchOutput;
128+
try {
129+
batchOutput = predictor.predict(batchInput);
130+
} finally {
131+
predictors.put(predictor);
132+
}
169133

170-
/**
171-
* Print all hits for one sector.
172-
*/
173-
static void show(float[][] data) {
174-
System.out.println("Shape: [" + data.length + "," + data[0].length + "]");
175-
for (int i = 0; i < LAYERS; i++) {
176-
for (int j = 0; j < WIRES; j++)
177-
System.out.printf("%.3f ", data[i][j]);
178-
System.out.println();
179-
}
180-
}
134+
for (int sectorIdx=0; sectorIdx<SECTORS; sectorIdx++) {
135+
update(bank, threshold, batchOutput[sectorIdx], sectorIdx);
136+
}
181137

182-
/**
183-
* @return a dummy sector/track
184-
*/
185-
static float[][] getAlmostStraightSlightlyBendingTrack() {
186-
float[][] data = new float[LAYERS][WIRES];
187-
for (int y = 0; y < LAYERS; y++) {
188-
int x = 50 + (y / 10);
189-
data[y][x] = 1.0f;
138+
event.removeBank(bankName);
139+
event.appendBank(bank);
140+
} catch (TranslateException | InterruptedException e) {
141+
throw new RuntimeException(e);
142+
}
143+
break;
190144
}
191-
return data;
145+
return true;
192146
}
193147

194-
public static Translator<float[][],float[][]> getTranslator() {
195-
return new Translator<float[][],float[][]>() {
148+
// -------- Translator for batch --------
149+
public static Translator<float[][][], float[][][]> getBatchTranslator() {
150+
return new Translator<float[][][], float[][][]>() {
196151
@Override
197-
public NDList processInput(TranslatorContext ctx, float[][] input) throws Exception {
152+
public NDList processInput(TranslatorContext ctx, float[][][] input) {
153+
int batch = input.length;
154+
int height = input[0].length;
155+
int width = input[0][0].length;
156+
float[] flat = new float[batch*height*width];
157+
int pos=0;
158+
for (int b=0; b<batch; b++)
159+
for (int h=0; h<height; h++) {
160+
System.arraycopy(input[b][h],0,flat,pos,width);
161+
pos+=width;
162+
}
198163
NDManager manager = ctx.getNDManager();
199-
int height = input.length;
200-
int width = input[0].length;
201-
float[] flat = new float[height * width];
202-
for (int i = 0; i < height; i++) {
203-
System.arraycopy(input[i], 0, flat, i * width, width);
204-
}
205-
NDArray x = manager.create(flat, new Shape(height, width));
206-
// Add batch and channel dims -> [1,1,36,112]
207-
x = x.expandDims(0).expandDims(0);
164+
NDArray x = manager.create(flat, new Shape(batch,1,height,width));
208165
return new NDList(x);
209166
}
167+
210168
@Override
211-
public float[][] processOutput(TranslatorContext ctx, NDList list) throws Exception {
169+
public float[][][] processOutput(TranslatorContext ctx, NDList list) {
212170
NDArray result = list.get(0);
213-
// Remove batch and channel dims -> [36,112]
214-
result = result.squeeze();
215-
// Convert to 1D float array
216-
float[] flat = result.toFloatArray();
217-
// Reshape manually into 2D array
218171
long[] shape = result.getShape().getShape();
219-
int height = (int) shape[0];
220-
int width = (int) shape[1];
221-
float[][] output2d = new float[height][width];
222-
for (int i = 0; i < height; i++) {
223-
System.arraycopy(flat, i * width, output2d[i], 0, width);
224-
}
225-
return output2d;
172+
int batch = (int)shape[0];
173+
int height, width;
174+
if (shape.length==4 && shape[1]==1) {
175+
height=(int)shape[2]; width=(int)shape[3];
176+
result = result.squeeze(1);
177+
} else if (shape.length==3) {
178+
height=(int)shape[1]; width=(int)shape[2];
179+
} else throw new IllegalStateException("Unexpected output shape: "+java.util.Arrays.toString(shape));
180+
float[] flat = result.toFloatArray();
181+
float[][][] out = new float[batch][height][width];
182+
int pos=0;
183+
for (int b=0;b<batch;b++)
184+
for (int h=0;h<height;h++) {
185+
System.arraycopy(flat,pos,out[b][h],0,width);
186+
pos+=width;
187+
}
188+
return out;
226189
}
190+
227191
@Override
228-
public Batchifier getBatchifier() {
229-
return null; // no batching
230-
}
192+
public Batchifier getBatchifier() { return null; }
231193
};
232194
}
233195

196+
// -------- Update single sector in bank --------
197+
static void update(DataBank b, float threshold, float[][] data, int sectorIdx) {
198+
for (int row=0; row<b.rows(); row++) {
199+
if (b.getByte(0,row)-1 != sectorIdx) continue;
200+
int layer=b.getByte(1,row)-1;
201+
int wire=b.getShort(2,row)-1;
202+
if (layer<0 || layer>=data.length) continue;
203+
if (wire<0 || wire>=data[0].length) continue;
204+
if (data[layer][wire]<threshold) {
205+
if(b.getByte(3,row)==0) b.setByte(3,row,(byte)60);
206+
if(b.getByte(3,row)==10) b.setByte(3,row,(byte)90);
207+
}
208+
}
209+
}
234210
}

0 commit comments

Comments
 (0)