11package com .google .mediapipe .tasks .genai .llminference ;
22
3- import static com .google .mediapipe .tasks .genai .llminference .LlmInferenceSession .decodeResponse ;
43
54import android .content .Context ;
65import com .google .auto .value .AutoValue ;
6+ import com .google .common .util .concurrent .ListenableFuture ;
77import com .google .mediapipe .tasks .genai .llminference .jni .proto .LlmOptionsProto .LlmModelSettings ;
88import com .google .mediapipe .tasks .genai .llminference .jni .proto .LlmOptionsProto .LlmModelSettings .LlmPreferredBackend ;
99import java .util .Collections ;
@@ -77,41 +77,12 @@ public static LlmInference createFromOptions(Context context, LlmInferenceOption
7777 }
7878 }
7979
80- return new LlmInference (context , STATS_TAG , modelSettings .build (), options . resultListener () );
80+ return new LlmInference (context , STATS_TAG , modelSettings .build ());
8181 }
8282
8383 /** Constructor to initialize an {@link LlmInference}. */
84- private LlmInference (
85- Context context ,
86- String taskName ,
87- LlmModelSettings modelSettings ,
88- Optional <ProgressListener <String >> resultListener ) {
89- Optional <ProgressListener <List <String >>> llmResultListener ;
90- if (resultListener .isPresent ()) {
91- llmResultListener =
92- Optional .of (
93- new ProgressListener <List <String >>() {
94- private boolean receivedFirstToken = false ;
95-
96- @ Override
97- public void run (List <String > partialResult , boolean done ) {
98- String result =
99- decodeResponse (
100- partialResult , /* stripLeadingWhitespace= */ !receivedFirstToken );
101- if (done ) {
102- receivedFirstToken = false ; // Reset to initial state
103- resultListener .get ().run (result , done );
104- } else if (!result .isEmpty ()) {
105- receivedFirstToken = true ;
106- resultListener .get ().run (result , done );
107- }
108- }
109- });
110- } else {
111- llmResultListener = Optional .empty ();
112- }
113-
114- this .taskRunner = new LlmTaskRunner (context , taskName , modelSettings , llmResultListener );
84+ private LlmInference (Context context , String taskName , LlmModelSettings modelSettings ) {
85+ this .taskRunner = new LlmTaskRunner (context , taskName , modelSettings );
11586 this .implicitSession = new AtomicReference <>();
11687 }
11788
@@ -136,23 +107,50 @@ public String generateResponse(String inputText) {
136107 }
137108
138109 /**
139- * Generates a response based on the input text. This method cannot be called while other queries
140- * are active.
110+ * Asynchronously generates a response based on the input text. This method cannot be called while
111+ * other queries are active.
141112 *
142- * <p>This function creates a new session for each call. If you want to have a stateful inference,
143- * use {@link LlmInferenceSession#generateResponseAsync()} instead.
113+ * <p>This function creates a new session for each call and returns the complete response as a
114+ * {@link ListenableFuture}. If you want to have a stateful inference, use {@link
115+ * LlmInferenceSession#generateResponseAsync()} instead.
144116 *
145117 * <p>Note: You cannot invoke simultaneous response generation calls on active sessions created
146118 * using the same {@link LlmInference}. You have to wait for the currently running response
147119 * generation call to complete before initiating another one.
148120 *
149121 * @param inputText a {@link String} for processing.
122+ * @return a {@link ListenableFuture} with the complete response once the inference is complete.
150123 * @throws IllegalStateException if the inference fails.
151124 */
152- public void generateResponseAsync (String inputText ) {
125+ public ListenableFuture < String > generateResponseAsync (String inputText ) {
153126 LlmInferenceSession session = resetImplicitSession ();
154127 session .addQueryChunk (inputText );
155- session .generateResponseAsync ();
128+ return session .generateResponseAsync ();
129+ }
130+
131+ /**
132+ * Asynchronously generates a response based on the input text and emits partial results. This
133+ * method cannot be called while other queries are active.
134+ *
135+ * <p>This function creates a new session for each call and returns the complete response as a
136+ * {@link ListenableFuture} and invokes the {@code progressListener} as the response is generated.
137+ * If you want to have a stateful inference, use {@link
138+ * LlmInferenceSession#generateResponseAsync()} instead.
139+ *
140+ * <p>Note: You cannot invoke simultaneous response generation calls on active sessions created
141+ * using the same {@link LlmInference}. You have to wait for the currently running response
142+ * generation call to complete before initiating another one.
143+ *
144+ * @param inputText a {@link String} for processing.
145+ * @param progressListener a {@link ProgressListener} to receive partial results.
146+ * @return a {@link ListenableFuture} with the complete response once the inference is complete.
147+ * @throws IllegalStateException if the inference fails.
148+ */
149+ public ListenableFuture <String > generateResponseAsync (
150+ String inputText , ProgressListener <String > progressListener ) {
151+ LlmInferenceSession session = resetImplicitSession ();
152+ session .addQueryChunk (inputText );
153+ return session .generateResponseAsync (progressListener );
156154 }
157155
158156 /**
@@ -211,12 +209,6 @@ public abstract static class Builder {
211209 /** Sets the model path for the text generator task. */
212210 public abstract Builder setModelPath (String modelPath );
213211
214- /** Sets the result listener to invoke with the async API. */
215- public abstract Builder setResultListener (ProgressListener <String > listener );
216-
217- /** Sets the error listener to invoke with the async API. */
218- public abstract Builder setErrorListener (ErrorListener listener );
219-
220212 /** Configures the total number of tokens for input and output). */
221213 public abstract Builder setMaxTokens (int maxTokens );
222214
@@ -263,12 +255,6 @@ public final LlmInferenceOptions build() {
263255 /** The supported lora ranks for the base model. Used by GPU only. */
264256 public abstract List <Integer > supportedLoraRanks ();
265257
266- /** The result listener to use for the {@link LlmInference#generateAsync} API. */
267- public abstract Optional <ProgressListener <String >> resultListener ();
268-
269- /** The error listener to use for the {@link LlmInference#generateAsync} API. */
270- public abstract Optional <ErrorListener > errorListener ();
271-
272258 /** The model options to for vision modality. */
273259 public abstract Optional <VisionModelOptions > visionModelOptions ();
274260
0 commit comments