Skip to content

Commit 042c977

Browse files
authored
[api] Provide concrete implementation of TranslatorContext (#3622)
1 parent c6d78bc commit 042c977

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

api/src/main/java/ai/djl/inference/Predictor.java

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ protected NDList predictInternal(TranslatorContext ctx, NDList ndList)
157157
* @return a list of output objects defined by the user
158158
* @throws TranslateException if an error occurs during prediction
159159
*/
160-
@SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches", "unchecked"})
160+
@SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches"})
161161
public List<O> batchPredict(List<I> inputs) throws TranslateException {
162-
try (PredictorContext context = new PredictorContext()) {
162+
try (PredictorContext context = new PredictorContext(model, manager, metrics)) {
163163
if (!prepared) {
164164
translator.prepare(context);
165165
prepared = true;
@@ -220,7 +220,7 @@ public StreamOutput<O> streamingPredict(I input) throws TranslateException {
220220
StreamingTranslator<I, O> streamingTranslator = (StreamingTranslator<I, O>) translator;
221221

222222
try {
223-
PredictorContext context = new PredictorContext();
223+
PredictorContext context = new PredictorContext(model, manager, metrics);
224224
if (!prepared) {
225225
translator.prepare(context);
226226
prepared = true;
@@ -357,14 +357,21 @@ protected void finalize() throws Throwable {
357357
super.finalize();
358358
}
359359

360-
protected class PredictorContext implements TranslatorContext {
360+
/** An implementation of {@link TranslatorContext}. */
361+
public static final class PredictorContext implements TranslatorContext {
361362

363+
private Model model;
364+
private NDManager predictorManager;
365+
private Metrics metrics;
362366
private NDManager ctxManager;
363367
private Map<String, Object> attachments;
364368

365369
/** Constructs a new {@code PredictorContext} instance. */
366-
public PredictorContext() {
367-
ctxManager = manager.newSubManager();
370+
public PredictorContext(Model model, NDManager predictorManager, Metrics metrics) {
371+
this.model = model;
372+
this.predictorManager = predictorManager;
373+
this.metrics = metrics;
374+
ctxManager = predictorManager.newSubManager();
368375
ctxManager.setName("predictor ctx");
369376
attachments = new ConcurrentHashMap<>();
370377
}
@@ -384,13 +391,13 @@ public NDManager getNDManager() {
384391
/** {@inheritDoc} */
385392
@Override
386393
public NDManager getPredictorManager() {
387-
return manager;
394+
return predictorManager;
388395
}
389396

390397
/** {@inheritDoc} */
391398
@Override
392399
public Block getBlock() {
393-
return block;
400+
return model.getBlock();
394401
}
395402

396403
/** {@inheritDoc} */

0 commit comments

Comments
 (0)