Skip to content

Commit 8a0b403

Browse files
committed
audio float API
1 parent e3e8e60 commit 8a0b403

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,28 @@ public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames)
252252

253253
private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
254254

255+
/**
256+
* Prefill a multimodal Module with the given audio input.
257+
*
258+
* @param audio Input preprocessed audio as a float array
259+
* @param batch_size Input batch size
260+
* @param n_bins Input number of bins
261+
* @param n_frames Input number of frames
262+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
263+
* exposed to user.
264+
* @throws RuntimeException if the prefill failed
265+
*/
266+
@Experimental
267+
public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) {
268+
int nativeResult = appendAudioInputFloat(audio, batch_size, n_bins, n_frames);
269+
if (nativeResult != 0) {
270+
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
271+
}
272+
return 0;
273+
}
274+
275+
private native int appendAudioInputFloat(float[] audio, int batch_size, int n_bins, int n_frames);
276+
255277
/**
256278
* Prefill a multimodal Module with the given raw audio input.
257279
*

extension/android/jni/jni_layer_llama.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,29 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
339339
return 0;
340340
}
341341

342+
// Returns status_code
343+
jint append_audio_input_float(
344+
facebook::jni::alias_ref<jfloatArray> data,
345+
jint batch_size,
346+
jint n_bins,
347+
jint n_frames) {
348+
if (data == nullptr) {
349+
return static_cast<jint>(Error::EndOfMethod);
350+
}
351+
auto data_size = data->size();
352+
if (data_size != 0) {
353+
std::vector<jfloat> data_jfloat(data_size);
354+
std::vector<float> data_f(data_size);
355+
data->getRegion(0, data_size, data_jfloat.data());
356+
for (int i = 0; i < data_size; i++) {
357+
data_f[i] = data_jfloat[i];
358+
}
359+
llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames};
360+
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
361+
}
362+
return 0;
363+
}
364+
342365
// Returns status_code
343366
jint append_raw_audio_input(
344367
facebook::jni::alias_ref<jbyteArray> data,
@@ -402,6 +425,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
402425
ExecuTorchLlmJni::append_normalized_images_input),
403426
makeNativeMethod(
404427
"appendAudioInput", ExecuTorchLlmJni::append_audio_input),
428+
makeNativeMethod(
429+
"appendAudioInputFloat",
430+
ExecuTorchLlmJni::append_audio_input_float),
405431
makeNativeMethod(
406432
"appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input),
407433
makeNativeMethod(

0 commit comments

Comments
 (0)