Skip to content

Commit c7e2ea5

Browse files
schmidt-sebastiancopybara-github
authored andcommitted
Make generateResponseAsync() return a ListenableFuture and add ProgressCallback to its arguments
PiperOrigin-RevId: 732288089
1 parent a0452b4 commit c7e2ea5

File tree

5 files changed

+96
-80
lines changed

5 files changed

+96
-80
lines changed

mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ typedef struct {
180180
// LlmResponseContext is the return type for
181181
// LlmInferenceEngine_Session_PredictSync.
182182
typedef struct {
183-
// An array of string. The size of the array depends on the number of
183+
// An array of strings. The size of the array depends on the number of
184184
// responses.
185185
char** response_array;
186186

mediapipe/tasks/java/com/google/mediapipe/tasks/genai/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ android_library(
5353
"//mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/jni/proto:llm_options_java_proto_lite",
5454
"//mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/jni/proto:llm_response_context_java_proto_lite",
5555
"//third_party:autovalue",
56+
"//third_party/java/android_libs/guava_jdk5:concurrent",
57+
"//third_party/java/android_libs/guava_jdk5:listenablefuture",
5658
"@com_google_protobuf//:protobuf_javalite",
5759
"@maven//:com_google_guava_guava",
5860
],

mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInference.java

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package com.google.mediapipe.tasks.genai.llminference;
22

3-
import static com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession.decodeResponse;
43

54
import android.content.Context;
65
import com.google.auto.value.AutoValue;
6+
import com.google.common.util.concurrent.ListenableFuture;
77
import com.google.mediapipe.tasks.genai.llminference.jni.proto.LlmOptionsProto.LlmModelSettings;
88
import com.google.mediapipe.tasks.genai.llminference.jni.proto.LlmOptionsProto.LlmModelSettings.LlmPreferredBackend;
99
import java.util.Collections;
@@ -77,41 +77,12 @@ public static LlmInference createFromOptions(Context context, LlmInferenceOption
7777
}
7878
}
7979

80-
return new LlmInference(context, STATS_TAG, modelSettings.build(), options.resultListener());
80+
return new LlmInference(context, STATS_TAG, modelSettings.build());
8181
}
8282

8383
/** Constructor to initialize an {@link LlmInference}. */
84-
private LlmInference(
85-
Context context,
86-
String taskName,
87-
LlmModelSettings modelSettings,
88-
Optional<ProgressListener<String>> resultListener) {
89-
Optional<ProgressListener<List<String>>> llmResultListener;
90-
if (resultListener.isPresent()) {
91-
llmResultListener =
92-
Optional.of(
93-
new ProgressListener<List<String>>() {
94-
private boolean receivedFirstToken = false;
95-
96-
@Override
97-
public void run(List<String> partialResult, boolean done) {
98-
String result =
99-
decodeResponse(
100-
partialResult, /* stripLeadingWhitespace= */ !receivedFirstToken);
101-
if (done) {
102-
receivedFirstToken = false; // Reset to initial state
103-
resultListener.get().run(result, done);
104-
} else if (!result.isEmpty()) {
105-
receivedFirstToken = true;
106-
resultListener.get().run(result, done);
107-
}
108-
}
109-
});
110-
} else {
111-
llmResultListener = Optional.empty();
112-
}
113-
114-
this.taskRunner = new LlmTaskRunner(context, taskName, modelSettings, llmResultListener);
84+
private LlmInference(Context context, String taskName, LlmModelSettings modelSettings) {
85+
this.taskRunner = new LlmTaskRunner(context, taskName, modelSettings);
11586
this.implicitSession = new AtomicReference<>();
11687
}
11788

@@ -136,23 +107,50 @@ public String generateResponse(String inputText) {
136107
}
137108

138109
/**
139-
* Generates a response based on the input text. This method cannot be called while other queries
140-
* are active.
110+
* Asynchronously generates a response based on the input text. This method cannot be called while
111+
* other queries are active.
141112
*
142-
* <p>This function creates a new session for each call. If you want to have a stateful inference,
143-
* use {@link LlmInferenceSession#generateResponseAsync()} instead.
113+
* <p>This function creates a new session for each call and returns the complete response as a
114+
* {@link ListenableFuture}. If you want to have a stateful inference, use {@link
115+
* LlmInferenceSession#generateResponseAsync()} instead.
144116
*
145117
* <p>Note: You cannot invoke simultaneous response generation calls on active sessions created
146118
* using the same {@link LlmInference}. You have to wait for the currently running response
147119
* generation call to complete before initiating another one.
148120
*
149121
* @param inputText a {@link String} for processing.
122+
* @return a {@link ListenableFuture} with the complete response once the inference is complete.
150123
* @throws IllegalStateException if the inference fails.
151124
*/
152-
public void generateResponseAsync(String inputText) {
125+
public ListenableFuture<String> generateResponseAsync(String inputText) {
153126
LlmInferenceSession session = resetImplicitSession();
154127
session.addQueryChunk(inputText);
155-
session.generateResponseAsync();
128+
return session.generateResponseAsync();
129+
}
130+
131+
/**
132+
* Asynchronously generates a response based on the input text and emits partial results. This
133+
* method cannot be called while other queries are active.
134+
*
135+
* <p>This function creates a new session for each call and returns the complete response as a
136+
* {@link ListenableFuture} and invokes the {@code progressListener} as the response is generated.
137+
* If you want to have a stateful inference, use {@link
138+
* LlmInferenceSession#generateResponseAsync()} instead.
139+
*
140+
* <p>Note: You cannot invoke simultaneous response generation calls on active sessions created
141+
* using the same {@link LlmInference}. You have to wait for the currently running response
142+
* generation call to complete before initiating another one.
143+
*
144+
* @param inputText a {@link String} for processing.
145+
* @param progressListener a {@link ProgressListener} to receive partial results.
146+
* @return a {@link ListenableFuture} with the complete response once the inference is complete.
147+
* @throws IllegalStateException if the inference fails.
148+
*/
149+
public ListenableFuture<String> generateResponseAsync(
150+
String inputText, ProgressListener<String> progressListener) {
151+
LlmInferenceSession session = resetImplicitSession();
152+
session.addQueryChunk(inputText);
153+
return session.generateResponseAsync(progressListener);
156154
}
157155

158156
/**
@@ -211,12 +209,6 @@ public abstract static class Builder {
211209
/** Sets the model path for the text generator task. */
212210
public abstract Builder setModelPath(String modelPath);
213211

214-
/** Sets the result listener to invoke with the async API. */
215-
public abstract Builder setResultListener(ProgressListener<String> listener);
216-
217-
/** Sets the error listener to invoke with the async API. */
218-
public abstract Builder setErrorListener(ErrorListener listener);
219-
220212
/** Configures the total number of tokens for input and output). */
221213
public abstract Builder setMaxTokens(int maxTokens);
222214

@@ -263,12 +255,6 @@ public final LlmInferenceOptions build() {
263255
/** The supported lora ranks for the base model. Used by GPU only. */
264256
public abstract List<Integer> supportedLoraRanks();
265257

266-
/** The result listener to use for the {@link LlmInference#generateAsync} API. */
267-
public abstract Optional<ProgressListener<String>> resultListener();
268-
269-
/** The error listener to use for the {@link LlmInference#generateAsync} API. */
270-
public abstract Optional<ErrorListener> errorListener();
271-
272258
/** The model options to for vision modality. */
273259
public abstract Optional<VisionModelOptions> visionModelOptions();
274260

mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInferenceSession.java

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package com.google.mediapipe.tasks.genai.llminference;
22

33
import com.google.auto.value.AutoValue;
4+
import com.google.common.util.concurrent.ListenableFuture;
5+
import com.google.common.util.concurrent.SettableFuture;
46
import com.google.mediapipe.framework.image.MPImage;
57
import com.google.mediapipe.tasks.genai.llminference.LlmTaskRunner.LlmSession;
68
import com.google.mediapipe.tasks.genai.llminference.jni.proto.LlmOptionsProto.LlmSessionConfig;
@@ -97,20 +99,57 @@ public String generateResponse() {
9799
}
98100

99101
/**
100-
* Generates a response based on the previously added query chunks asynchronously.
102+
* Asynchronously generates a response based on the input text. This method cannot be called while
103+
* other queries are active.
101104
*
102-
* <p>The {@code resultListener} callback of the {@link LlmInference} instance returns the partial
103-
* responses from the LLM. Use {@link #addQueryChunk(String)} to add at least one query chunk
104-
* before calling this function.
105+
* <p>The method returns the complete response as a {@link ListenableFuture}. Use {@link
106+
* #addQueryChunk(String)} to add at least one query chunk before calling this function.
107+
*
108+
* <p>Note: You cannot invoke simultaneous response generation calls on active sessions created
109+
* using the same {@link LlmInference}. You have to wait for the currently running response
110+
* generation call to complete before initiating another one.
111+
*
112+
* @return a {@link ListenableFuture} with the complete response once the inference is complete.
113+
* @throws IllegalStateException if the inference fails.
114+
*/
115+
public ListenableFuture<String> generateResponseAsync() {
116+
return generateResponseAsync((unused1, unused2) -> {});
117+
}
118+
119+
/**
120+
* Asynchronously generates a response based on the input text and emits partial results. This
121+
* method cannot be called while other queries are active.
122+
*
123+
* <p>The method returns the complete response as a {@link ListenableFuture} and invokes the
124+
* {@code progressListener} as the response is generated. Use {@link #addQueryChunk(String)} to
125+
* add at least one query chunk before calling this function.
105126
*
106127
* <p>Note: You cannot invoke simultaneous response generation calls on active sessions created
107128
* using the same {@link LlmInference}. You have to wait for the currently running response
108129
* generation call to complete before initiating another one.
109130
*
131+
* @param progressListener a {@link ProgressListener} to receive partial results.
132+
* @return a {@link ListenableFuture} with the complete response once the inference is complete.
110133
* @throws IllegalStateException if the inference fails.
111134
*/
112-
public void generateResponseAsync() {
113-
taskRunner.predictAsync(session);
135+
public ListenableFuture<String> generateResponseAsync(ProgressListener<String> progressListener) {
136+
SettableFuture<String> future = SettableFuture.create();
137+
StringBuilder response = new StringBuilder();
138+
taskRunner.predictAsync(
139+
session,
140+
(partialResult, done) -> {
141+
// Not using isEmpty() because it's not available on Android < 30.
142+
boolean stripLeadingWhitespace = response.length() == 0;
143+
String partialResultDecoded = decodeResponse(partialResult, stripLeadingWhitespace);
144+
response.append(partialResultDecoded);
145+
if (done) {
146+
progressListener.run(partialResultDecoded, done);
147+
future.set(response.toString());
148+
} else if (!partialResultDecoded.isEmpty()) {
149+
progressListener.run(partialResultDecoded, done);
150+
}
151+
});
152+
return future;
114153
}
115154

116155
/**
@@ -126,7 +165,7 @@ public int sizeInTokens(String text) {
126165
}
127166

128167
/** Decodes the response from the LLM engine and returns a human-readable string. */
129-
static String decodeResponse(List<String> responses, boolean stripLeadingWhitespace) {
168+
private static String decodeResponse(List<String> responses, boolean stripLeadingWhitespace) {
130169
if (responses.isEmpty()) {
131170
// Technically, this is an error. We should always get at least one response.
132171
return "";

mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmTaskRunner.java

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import com.google.protobuf.InvalidProtocolBufferException;
3030
import java.nio.ByteBuffer;
3131
import java.util.List;
32-
import java.util.Optional;
3332
import java.util.concurrent.atomic.AtomicBoolean;
3433

3534
/**
@@ -39,10 +38,11 @@
3938
*/
4039
public final class LlmTaskRunner implements AutoCloseable {
4140
private final long engineHandle;
42-
private final Optional<ProgressListener<List<String>>> resultListener;
4341
private final long callbackHandle;
4442
private final AtomicBoolean isProcessing;
4543

44+
private ProgressListener<List<String>> resultListener = (unused1, unused2) -> {};
45+
4646
/**
4747
* Describes how pixel bits encode color. A pixel may be an alpha mask, a grayscale, RGB, or ARGB.
4848
*
@@ -155,20 +155,9 @@ public static final class LlmSession {
155155
}
156156
}
157157

158-
public LlmTaskRunner(
159-
Context context,
160-
String taskName,
161-
LlmModelSettings modelSettings,
162-
Optional<ProgressListener<List<String>>> resultListener) {
158+
public LlmTaskRunner(Context context, String taskName, LlmModelSettings modelSettings) {
163159
this.engineHandle = nativeCreateEngine(modelSettings.toByteArray());
164-
165-
this.resultListener = resultListener;
166-
if (resultListener.isPresent()) {
167-
this.callbackHandle = nativeRegisterCallback(this);
168-
} else {
169-
this.callbackHandle = 0;
170-
}
171-
160+
this.callbackHandle = nativeRegisterCallback(this);
172161
this.isProcessing = new AtomicBoolean(false);
173162
}
174163

@@ -213,20 +202,18 @@ public List<String> predictSync(LlmSession session) {
213202
}
214203

215204
/** Invokes the LLM with the given session and calls the callback with the result. */
216-
public void predictAsync(LlmSession session) {
205+
public void predictAsync(LlmSession session, ProgressListener<List<String>> resultListener) {
217206
validateState();
218207

219-
if (callbackHandle == 0) {
220-
throw new IllegalStateException("No result listener provided.");
221-
}
222-
223208
try {
224209
isProcessing.set(true);
210+
this.resultListener = resultListener;
225211
nativePredictAsync(session.sessionHandle, callbackHandle);
226212
} catch (Throwable t) {
227213
// Only reset `isProcessing` if we fail to start the async inference. For successful
228214
// inferences, we reset `isProcessing` when we receive `done=true`.
229215
isProcessing.set(false);
216+
this.resultListener = (unused1, unused2) -> {};
230217
throw t;
231218
}
232219
}
@@ -265,10 +252,12 @@ private LlmResponseContext parseResponse(byte[] response) {
265252

266253
private void onAsyncResponse(byte[] responseBytes) {
267254
LlmResponseContext response = parseResponse(responseBytes);
255+
ProgressListener<List<String>> resultListener = this.resultListener;
268256
if (response.getDone()) {
269257
isProcessing.set(false);
258+
this.resultListener = (unused1, unused2) -> {};
270259
}
271-
resultListener.get().run(response.getResponsesList(), response.getDone());
260+
resultListener.run(response.getResponsesList(), response.getDone());
272261
}
273262

274263
@Override

0 commit comments

Comments
 (0)