Skip to content

Commit ac39aa5

Browse files
baltzelltongtongcao
andcommitted
Add DC denoising (#853)
* why not? * add denoising engine * relax * check model loading in init * cleanup * cleanup * cleanup * use real data * update bank * cleanup * cleanup * use index accessors * use more new index accessors * more prep * cleanup * cleanup * add real network * Revert "relax" This reverts commit 329ad8a. * cleanup, wupport both banks * use a pool * fix dependencies * cleanup * remove unused dependencies * improve printout * improve printout * bugfix * silence * rename * fix version number * fix package name * fix version numbers * update DCDenoiseEngine with optimized model * improve comments --------- Co-authored-by: tongtongcao <[email protected]>
1 parent 69e2484 commit ac39aa5

File tree

4 files changed

+265
-0
lines changed

4 files changed

+265
-0
lines changed
1.09 MB
Binary file not shown.

reconstruction/ai/pom.xml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
3+
<modelVersion>4.0.0</modelVersion>
4+
5+
<groupId>org.jlab.clas12.detector</groupId>
6+
<artifactId>clas12detector-ai</artifactId>
7+
<version>13.3.0-SNAPSHOT</version>
8+
<packaging>jar</packaging>
9+
10+
<parent>
11+
<groupId>org.jlab.clas12</groupId>
12+
<artifactId>reconstruction</artifactId>
13+
<version>13.3.0-SNAPSHOT</version>
14+
</parent>
15+
16+
<dependencies>
17+
18+
<dependency>
19+
<groupId>org.jlab.clas</groupId>
20+
<artifactId>clas-utils</artifactId>
21+
<version>13.3.0-SNAPSHOT</version>
22+
</dependency>
23+
24+
<dependency>
25+
<groupId>org.jlab.clas</groupId>
26+
<artifactId>clas-io</artifactId>
27+
<version>13.3.0-SNAPSHOT</version>
28+
</dependency>
29+
30+
<dependency>
31+
<groupId>org.jlab.clas</groupId>
32+
<artifactId>clas-reco</artifactId>
33+
<version>13.3.0-SNAPSHOT</version>
34+
</dependency>
35+
36+
<dependency>
37+
<groupId>ai.djl</groupId>
38+
<artifactId>api</artifactId>
39+
</dependency>
40+
41+
</dependencies>
42+
43+
</project>
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
package org.jlab.service.ai;
2+
3+
import ai.djl.MalformedModelException;
4+
import java.nio.file.Paths;
5+
import ai.djl.ndarray.NDArray;
6+
import ai.djl.ndarray.NDList;
7+
import ai.djl.ndarray.NDManager;
8+
import ai.djl.ndarray.types.Shape;
9+
import ai.djl.repository.zoo.Criteria;
10+
import ai.djl.repository.zoo.ZooModel;
11+
import ai.djl.training.util.ProgressBar;
12+
import ai.djl.translate.Translator;
13+
import ai.djl.translate.TranslatorContext;
14+
import ai.djl.inference.Predictor;
15+
import ai.djl.repository.zoo.ModelNotFoundException;
16+
import ai.djl.translate.Batchifier;
17+
import ai.djl.translate.TranslateException;
18+
import java.io.IOException;
19+
import java.util.concurrent.BlockingQueue;
20+
import java.util.concurrent.LinkedBlockingQueue;
21+
22+
import org.jlab.clas.reco.ReconstructionEngine;
23+
import org.jlab.io.base.DataBank;
24+
import org.jlab.io.base.DataEvent;
25+
import org.jlab.utils.system.ClasUtilsFile;
26+
27+
public class DCDenoiseEngine extends ReconstructionEngine {
28+
29+
final static String[] BANK_NAMES = {"DC::tot","DC::tdc"};
30+
final static String CONF_THRESHOLD = "threshold";
31+
final static int LAYERS = 36;
32+
final static int WIRES = 112;
33+
34+
float threshold = 0.025f;
35+
Criteria<float[][],float[][]> criteria;
36+
ZooModel<float[][], float[][]> model;
37+
PredictorPool predictors;
38+
39+
public static class PredictorPool {
40+
final BlockingQueue<Predictor> pool;
41+
public PredictorPool(int size, ZooModel model) {
42+
pool = new LinkedBlockingQueue<>(size);
43+
for (int i=0; i<size; i++) pool.offer(model.newPredictor());
44+
}
45+
public Predictor get() throws InterruptedException {
46+
return pool.take();
47+
}
48+
public void put(Predictor p) {
49+
if (p != null) pool.offer(p);
50+
}
51+
}
52+
53+
public DCDenoiseEngine() {
54+
super("DenoiseEngine","lleztlab","1.0");
55+
}
56+
57+
@Override
58+
public boolean init() {
59+
if (getEngineConfigString(CONF_THRESHOLD) != null)
60+
threshold = Float.parseFloat(getEngineConfigString(CONF_THRESHOLD));
61+
try {
62+
criteria = Criteria.builder()
63+
.setTypes(float[][].class, float[][].class)
64+
.optModelPath(Paths.get(ClasUtilsFile.getResourceDir("CLAS12DIR","etc/nnet/dn/cnn_autoenc_sector1_nBlocks2.pt")))
65+
.optEngine("PyTorch")
66+
.optTranslator(DCDenoiseEngine.getTranslator())
67+
.optProgress(new ProgressBar())
68+
.build();
69+
model = criteria.loadModel();
70+
predictors = new PredictorPool(64, model);
71+
return true;
72+
} catch (NullPointerException | MalformedModelException | IOException | ModelNotFoundException ex) {
73+
System.getLogger(DCDenoiseEngine.class.getName()).log(System.Logger.Level.ERROR, (String) null, ex);
74+
return false;
75+
}
76+
}
77+
78+
@Override
79+
public boolean processDataEvent(DataEvent event) {
80+
81+
//if (true) return processFakeEvent();
82+
83+
for (int i=0; i<BANK_NAMES.length; i++){
84+
if (event.hasBank(BANK_NAMES[i])) {
85+
DataBank bank = event.getBank(BANK_NAMES[i]);
86+
try {
87+
// WARNING: Predictor is *not* thread safe.
88+
Predictor<float[][], float[][]> predictor = predictors.get();
89+
for (int sector=0; sector<6; sector++) {
90+
float[][] input = DCDenoiseEngine.read(bank, sector+1);
91+
float[][] output = predictor.predict(input);
92+
//System.out.println("IN:");show(input);
93+
//System.out.println("OUT:");show(output);
94+
update(bank, threshold, output, sector);
95+
}
96+
predictors.put(predictor);
97+
event.removeBank(BANK_NAMES[i]);
98+
event.appendBank(bank);
99+
}
100+
catch (TranslateException | InterruptedException e) {
101+
throw new RuntimeException(e);
102+
}
103+
break;
104+
}
105+
}
106+
return true;
107+
}
108+
109+
boolean processFakeEvent() {
110+
try {
111+
Predictor<float[][], float[][]> predictor = model.newPredictor();
112+
float[][] input = getAlmostStraightSlightlyBendingTrack();
113+
float[][] output = predictor.predict(input);
114+
//System.out.println("IN:");show(input);
115+
//System.out.println("OUT:");show(output);
116+
}
117+
catch (TranslateException e) {
118+
throw new RuntimeException(e);
119+
}
120+
return true;
121+
}
122+
123+
/**
124+
* Reject sub-threshold hits by modifying the bank's order variable.
125+
* WARNING: This is not a full implementation of OrderType enum and
126+
* all its names, but for now a copy of the subset in C++ DC denoising, see:
127+
* https://code.jlab.org/hallb/clas12/coatjava/denoising/-/blob/main/denoising/code/drift.cc?ref_type=heads#L162-198
128+
*/
129+
static void update(DataBank b, float threshold, float[][] data, int sector) {
130+
//System.out.println("IN:");b.show();
131+
for (int row=0; row<b.rows(); row++) {
132+
if (b.getByte(0,row)-1 != sector) continue;
133+
if (data[b.getByte(1,row)-1][b.getShort(2,row)-1] < threshold) {
134+
if(b.getByte(3,row) == 0) b.setByte(3, row, (byte)(60));
135+
if(b.getByte(3,row) == 10) b.setByte(3, row, (byte)(90));
136+
}
137+
}
138+
//System.out.println("OUT:");b.show();
139+
}
140+
141+
/**
142+
* Get one-sector data with weights set to 0/1.
143+
*/
144+
static float[][] read(DataBank bank, int sector) {
145+
float[][] data = new float[LAYERS][WIRES];
146+
for (int i=0; i<bank.rows(); ++i) {
147+
if (bank.getByte(0,i) == sector) {
148+
byte o = bank.getByte(3,i);
149+
if (0==o || 10==o)
150+
// got a hit, set weight to one:
151+
data[bank.getByte(1,i)-1][bank.getShort(2,i)-1] = 1.0f;
152+
}
153+
}
154+
return data;
155+
}
156+
157+
/**
158+
* Print all hits for one sector.
159+
*/
160+
static void show(float[][] data) {
161+
System.out.println("Shape: [" + data.length + "," + data[0].length + "]");
162+
for (int i = 0; i < LAYERS; i++) {
163+
for (int j = 0; j < WIRES; j++)
164+
System.out.printf("%.3f ", data[i][j]);
165+
System.out.println();
166+
}
167+
}
168+
169+
/**
170+
* @return a dummy sector/track
171+
*/
172+
static float[][] getAlmostStraightSlightlyBendingTrack() {
173+
float[][] data = new float[LAYERS][WIRES];
174+
for (int y = 0; y < LAYERS; y++) {
175+
int x = 50 + (y / 10);
176+
data[y][x] = 1.0f;
177+
}
178+
return data;
179+
}
180+
181+
public static Translator<float[][],float[][]> getTranslator() {
182+
return new Translator<float[][],float[][]>() {
183+
@Override
184+
public NDList processInput(TranslatorContext ctx, float[][] input) throws Exception {
185+
NDManager manager = ctx.getNDManager();
186+
int height = input.length;
187+
int width = input[0].length;
188+
float[] flat = new float[height * width];
189+
for (int i = 0; i < height; i++) {
190+
System.arraycopy(input[i], 0, flat, i * width, width);
191+
}
192+
NDArray x = manager.create(flat, new Shape(height, width));
193+
// Add batch and channel dims -> [1,1,36,112]
194+
x = x.expandDims(0).expandDims(0);
195+
return new NDList(x);
196+
}
197+
@Override
198+
public float[][] processOutput(TranslatorContext ctx, NDList list) throws Exception {
199+
NDArray result = list.get(0);
200+
// Remove batch and channel dims -> [36,112]
201+
result = result.squeeze();
202+
// Convert to 1D float array
203+
float[] flat = result.toFloatArray();
204+
// Reshape manually into 2D array
205+
long[] shape = result.getShape().getShape();
206+
int height = (int) shape[0];
207+
int width = (int) shape[1];
208+
float[][] output2d = new float[height][width];
209+
for (int i = 0; i < height; i++) {
210+
System.arraycopy(flat, i * width, output2d[i], 0, width);
211+
}
212+
return output2d;
213+
}
214+
@Override
215+
public Batchifier getBatchifier() {
216+
return null; // no batching
217+
}
218+
};
219+
}
220+
221+
}

reconstruction/pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
</parent>
1515

1616
<modules>
17+
<module>ai</module>
1718
<module>dc</module>
1819
<module>tof</module>
1920
<module>cvt</module>

0 commit comments

Comments
 (0)