Skip to content

Commit faf41ba

Browse files
authored
Switch to generic predictor pool (#1040)
* add generic predictor pool * cleanup
1 parent 499a107 commit faf41ba

File tree

3 files changed

+39
-42
lines changed

3 files changed

+39
-42
lines changed

reconstruction/ai/src/main/java/org/jlab/service/ai/DCClsComboEngine.java

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import ai.djl.ndarray.NDArray;
1313
import ai.djl.ndarray.NDList;
1414
import ai.djl.ndarray.NDManager;
15-
import ai.djl.ndarray.types.Shape;
1615
import ai.djl.repository.zoo.Criteria;
1716
import ai.djl.repository.zoo.ModelNotFoundException;
1817
import ai.djl.repository.zoo.ZooModel;
@@ -22,8 +21,6 @@
2221
import ai.djl.translate.TranslatorContext;
2322
import java.io.IOException;
2423
import java.nio.file.Paths;
25-
import java.util.concurrent.ArrayBlockingQueue;
26-
import java.util.concurrent.BlockingQueue;
2724
import java.util.logging.Level;
2825
import java.util.logging.Logger;
2926
import org.jlab.clas.reco.ReconstructionEngine;
@@ -57,27 +54,8 @@ public class DCClsComboEngine extends ReconstructionEngine {
5754

5855
final static int SUPERLAYERS = 6;
5956

60-
// -------- Predictor Pool --------
61-
public static class PredictorPool {
62-
final BlockingQueue<Predictor<float[][], float[]>> pool;
63-
public PredictorPool(int size, ZooModel<float[][], float[]> model) {
64-
pool = new ArrayBlockingQueue<>(size);
65-
for (int i=0; i<size; i++) {
66-
try {
67-
pool.add(model.newPredictor());
68-
} catch (Exception e) {
69-
Logger.getLogger(PredictorPool.class.getName()).log(Level.WARNING, "Failed to create predictor", e);
70-
}
71-
}
72-
}
73-
public Predictor<float[][], float[]> take() throws InterruptedException { return pool.take(); }
74-
public void put(Predictor<float[][], float[]> p) throws InterruptedException { if (p!=null) pool.put(p); }
75-
public void shutdownAll() { for (Predictor p: pool) { try { p.close(); } catch (Exception ignored) {} } }
76-
}
77-
7857
public DCClsComboEngine() {
7958
super("DCClsComboEngine","tongtong","1.0");
80-
8159
}
8260

8361
@Override

reconstruction/ai/src/main/java/org/jlab/service/ai/DCDenoiseEngine.java

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import ai.djl.translate.TranslateException;
1818

1919
import java.io.IOException;
20-
import java.util.concurrent.BlockingQueue;
21-
import java.util.concurrent.ArrayBlockingQueue;
2220
import java.util.logging.Level;
2321
import java.util.logging.Logger;
2422

@@ -44,24 +42,6 @@ public class DCDenoiseEngine extends ReconstructionEngine {
4442
ZooModel<float[][][], float[][][]> model;
4543
PredictorPool predictors;
4644

47-
// -------- Predictor Pool --------
48-
public static class PredictorPool {
49-
final BlockingQueue<Predictor<float[][][], float[][][]>> pool;
50-
public PredictorPool(int size, ZooModel<float[][][], float[][][]> model) {
51-
pool = new ArrayBlockingQueue<>(size);
52-
for (int i=0; i<size; i++) {
53-
try {
54-
pool.add(model.newPredictor());
55-
} catch (Exception e) {
56-
Logger.getLogger(PredictorPool.class.getName()).log(Level.WARNING, "Failed to create predictor", e);
57-
}
58-
}
59-
}
60-
public Predictor<float[][][], float[][][]> take() throws InterruptedException { return pool.take(); }
61-
public void put(Predictor<float[][][], float[][][]> p) throws InterruptedException { if (p!=null) pool.put(p); }
62-
public void shutdownAll() { for (Predictor p: pool) { try { p.close(); } catch (Exception ignored) {} } }
63-
}
64-
6545
public DCDenoiseEngine() {
6646
super("DenoiseEngine","lleztlab","1.0");
6747
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package org.jlab.service.ai;
2+
3+
import ai.djl.inference.Predictor;
4+
import ai.djl.repository.zoo.ZooModel;
5+
import java.util.concurrent.ArrayBlockingQueue;
6+
import java.util.concurrent.BlockingQueue;
7+
import java.util.logging.Level;
8+
import java.util.logging.Logger;
9+
10+
public class PredictorPool <T,U> {
11+
12+
final BlockingQueue<Predictor<T,U>> pool;
13+
14+
public PredictorPool(int size, ZooModel<T,U> model) {
15+
pool = new ArrayBlockingQueue<>(size);
16+
for (int i=0; i<size; i++) {
17+
try {
18+
pool.add(model.newPredictor());
19+
} catch (Exception e) {
20+
Logger.getLogger(PredictorPool.class.getName()).log(Level.WARNING, "Failed to create predictor", e);
21+
}
22+
}
23+
}
24+
25+
public Predictor<T,U> take() throws InterruptedException {
26+
return pool.take();
27+
}
28+
29+
public void put(Predictor<T,U> p) throws InterruptedException {
30+
if (p!=null) pool.put(p);
31+
}
32+
33+
public void shutdownAll() {
34+
for (Predictor p: pool) {
35+
try { p.close(); }
36+
catch (Exception ignored) {}
37+
}
38+
}
39+
}

0 commit comments

Comments
 (0)