Skip to content

Commit 2d68b88

Browse files
Android audio input API (#15169)
### Summary Expose all llm runner API to java Mark some API as experimental, not deprecated ### Test plan CI cc @cbilgin Co-authored-by: Hansong Zhang <[email protected]>
1 parent 487d584 commit 2d68b88

File tree

2 files changed

+103
-7
lines changed

2 files changed

+103
-7
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ public int generate(
167167
}
168168

169169
/**
170-
* Prefill an LLaVA Module with the given images input.
170+
* Prefill a multimodal Module with the given images input.
171171
*
172172
* @param image Input image as a byte array
173173
* @param width Input image width
@@ -177,7 +177,7 @@ public int generate(
177177
* exposed to user.
178178
* @throws RuntimeException if the prefill failed
179179
*/
180-
@Deprecated
180+
@Experimental
181181
public long prefillImages(int[] image, int width, int height, int channels) {
182182
int nativeResult = appendImagesInput(image, width, height, channels);
183183
if (nativeResult != 0) {
@@ -189,7 +189,7 @@ public long prefillImages(int[] image, int width, int height, int channels) {
189189
private native int appendImagesInput(int[] image, int width, int height, int channels);
190190

191191
/**
192-
* Prefill an LLaVA Module with the given images input.
192+
* Prefill a multimodal Module with the given images input.
193193
*
194194
* @param image Input normalized image as a float array
195195
* @param width Input image width
@@ -199,7 +199,7 @@ public long prefillImages(int[] image, int width, int height, int channels) {
199199
* exposed to user.
200200
* @throws RuntimeException if the prefill failed
201201
*/
202-
@Deprecated
202+
@Experimental
203203
public long prefillImages(float[] image, int width, int height, int channels) {
204204
int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
205205
if (nativeResult != 0) {
@@ -212,14 +212,59 @@ private native int appendNormalizedImagesInput(
212212
float[] image, int width, int height, int channels);
213213

214214
/**
215-
* Prefill an LLaVA Module with the given text input.
215+
* Prefill a multimodal Module with the given audio input.
216216
*
217-
* @param prompt The text prompt to LLaVA.
217+
* @param audio Input preprocessed audio as a byte array
218+
* @param batch_size Input batch size
219+
* @param n_bins Input number of bins
220+
* @param n_frames Input number of frames
218221
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
219222
* exposed to user.
220223
* @throws RuntimeException if the prefill failed
221224
*/
222-
@Deprecated
225+
@Experimental
226+
public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) {
227+
int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames);
228+
if (nativeResult != 0) {
229+
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
230+
}
231+
return 0;
232+
}
233+
234+
private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
235+
236+
/**
237+
* Prefill a multimodal Module with the given raw audio input.
238+
*
239+
* @param audio Input raw audio as a byte array
240+
* @param batch_size Input batch size
241+
* @param n_channels Input number of channels
242+
* @param n_samples Input number of samples
243+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
244+
* exposed to user.
245+
* @throws RuntimeException if the prefill failed
246+
*/
247+
@Experimental
248+
public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) {
249+
int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples);
250+
if (nativeResult != 0) {
251+
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
252+
}
253+
return 0;
254+
}
255+
256+
private native int appendRawAudioInput(
257+
byte[] audio, int batch_size, int n_channels, int n_samples);
258+
259+
/**
260+
* Prefill a multimodal Module with the given text input.
261+
*
262+
* @param prompt The text prompt to prefill.
263+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
264+
* exposed to user.
265+
* @throws RuntimeException if the prefill failed
266+
*/
267+
@Experimental
223268
public long prefillPrompt(String prompt) {
224269
int nativeResult = appendTextInput(prompt);
225270
if (nativeResult != 0) {

extension/android/jni/jni_layer_llama.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,53 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
302302
return 0;
303303
}
304304

305+
// Returns status_code
306+
jint append_audio_input(
307+
facebook::jni::alias_ref<jbyteArray> data,
308+
jint batch_size,
309+
jint n_bins,
310+
jint n_frames) {
311+
if (data == nullptr) {
312+
return static_cast<jint>(Error::EndOfMethod);
313+
}
314+
auto data_size = data->size();
315+
if (data_size != 0) {
316+
std::vector<jbyte> data_jbyte(data_size);
317+
std::vector<uint8_t> data_u8(data_size);
318+
data->getRegion(0, data_size, data_jbyte.data());
319+
for (int i = 0; i < data_size; i++) {
320+
data_u8[i] = data_jbyte[i];
321+
}
322+
llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames};
323+
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
324+
}
325+
return 0;
326+
}
327+
328+
// Returns status_code
329+
jint append_raw_audio_input(
330+
facebook::jni::alias_ref<jbyteArray> data,
331+
jint batch_size,
332+
jint n_channels,
333+
jint n_samples) {
334+
if (data == nullptr) {
335+
return static_cast<jint>(Error::EndOfMethod);
336+
}
337+
auto data_size = data->size();
338+
if (data_size != 0) {
339+
std::vector<jbyte> data_jbyte(data_size);
340+
std::vector<uint8_t> data_u8(data_size);
341+
data->getRegion(0, data_size, data_jbyte.data());
342+
for (int i = 0; i < data_size; i++) {
343+
data_u8[i] = data_jbyte[i];
344+
}
345+
llm::RawAudio audio{
346+
std::move(data_u8), batch_size, n_channels, n_samples};
347+
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
348+
}
349+
return 0;
350+
}
351+
305352
void stop() {
306353
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
307354
multi_modal_runner_->stop();
@@ -339,6 +386,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
339386
makeNativeMethod(
340387
"appendNormalizedImagesInput",
341388
ExecuTorchLlmJni::append_normalized_images_input),
389+
makeNativeMethod(
390+
"appendAudioInput", ExecuTorchLlmJni::append_audio_input),
391+
makeNativeMethod(
392+
"appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input),
342393
makeNativeMethod(
343394
"appendTextInput", ExecuTorchLlmJni::append_text_input),
344395
makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),

0 commit comments

Comments
 (0)