@@ -157,9 +157,9 @@ protected NDList predictInternal(TranslatorContext ctx, NDList ndList)
157157 * @return a list of output objects defined by the user
158158 * @throws TranslateException if an error occurs during prediction
159159 */
160- @ SuppressWarnings ({"PMD.AvoidRethrowingException" , "PMD.IdenticalCatchBranches" , "unchecked" })
160+ @ SuppressWarnings ({"PMD.AvoidRethrowingException" , "PMD.IdenticalCatchBranches" })
161161 public List <O > batchPredict (List <I > inputs ) throws TranslateException {
162- try (PredictorContext context = new PredictorContext ()) {
162+ try (PredictorContext context = new PredictorContext (model , manager , metrics )) {
163163 if (!prepared ) {
164164 translator .prepare (context );
165165 prepared = true ;
@@ -220,7 +220,7 @@ public StreamOutput<O> streamingPredict(I input) throws TranslateException {
220220 StreamingTranslator <I , O > streamingTranslator = (StreamingTranslator <I , O >) translator ;
221221
222222 try {
223- PredictorContext context = new PredictorContext ();
223+ PredictorContext context = new PredictorContext (model , manager , metrics );
224224 if (!prepared ) {
225225 translator .prepare (context );
226226 prepared = true ;
@@ -357,14 +357,21 @@ protected void finalize() throws Throwable {
357357 super .finalize ();
358358 }
359359
360- protected class PredictorContext implements TranslatorContext {
360+ /** An implementation of {@link TranslatorContext}. */
361+ public static final class PredictorContext implements TranslatorContext {
361362
363+ private Model model ;
364+ private NDManager predictorManager ;
365+ private Metrics metrics ;
362366 private NDManager ctxManager ;
363367 private Map <String , Object > attachments ;
364368
365369 /** Constructs a new {@code PredictorContext} instance. */
366- public PredictorContext () {
367- ctxManager = manager .newSubManager ();
370+ public PredictorContext (Model model , NDManager predictorManager , Metrics metrics ) {
371+ this .model = model ;
372+ this .predictorManager = predictorManager ;
373+ this .metrics = metrics ;
374+ ctxManager = predictorManager .newSubManager ();
368375 ctxManager .setName ("predictor ctx" );
369376 attachments = new ConcurrentHashMap <>();
370377 }
@@ -384,13 +391,13 @@ public NDManager getNDManager() {
384391 /** {@inheritDoc} */
385392 @ Override
386393 public NDManager getPredictorManager () {
387- return manager ;
394+ return predictorManager ;
388395 }
389396
390397 /** {@inheritDoc} */
391398 @ Override
392399 public Block getBlock () {
393- return block ;
400+ return model . getBlock () ;
394401 }
395402
396403 /** {@inheritDoc} */
0 commit comments