Skip to content

Commit d020306

Browse files
committed
strange things
1 parent 05378f5 commit d020306

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ public int generate(
167167
}
168168

169169
/**
170-
* Prefill an LLaVA Module with the given images input.
170+
* Prefill an multimodal Module with the given images input.
171171
*
172172
* @param image Input image as a byte array
173173
* @param width Input image width
@@ -189,9 +189,9 @@ public long prefillImages(int[] image, int width, int height, int channels) {
189189
private native int appendImagesInput(int[] image, int width, int height, int channels);
190190

191191
/**
192-
* Prefill an LLaVA Module with the given text input.
192+
* Prefill an multimodal Module with the given text input.
193193
*
194-
* @param prompt The text prompt to LLaVA.
194+
* @param prompt The text prompt to multimodal model.
195195
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
196196
* exposed to user.
197197
* @throws RuntimeException if the prefill failed
@@ -208,6 +208,35 @@ public long prefillPrompt(String prompt) {
208208
// returns status
209209
private native int appendTextInput(String prompt);
210210

211+
/**
212+
* Prefill a multimodal Module with the given text input.
213+
*
214+
* @param prompt The text prompt to multimodal model.
215+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
216+
* exposed to user.
217+
* @throws RuntimeException if the prefill failed
218+
*/
219+
public int prefillAudio(String filePath) {
220+
java.io.File file = new java.io.File(filePath);
221+
try (java.io.FileInputStream fis = new java.io.FileInputStream(file)) {
222+
byte[] fileBytes = new byte[(int) file.length()];
223+
int bytesRead = fis.read(fileBytes);
224+
if (bytesRead != fileBytes.length) {
225+
throw new RuntimeException("Could not completely read file " + file.getName());
226+
}
227+
int nFloats = fileBytes.length / 4;
228+
int batchSize = nFloats / (128 * 3000);
229+
return appendAudioInput(fileBytes, batchSize, 128, 3000);
230+
} catch (java.io.IOException e) {
231+
throw new RuntimeException("Failed to read file: " + e);
232+
}
233+
}
234+
235+
// For Audio (option B), not RawAudio
236+
// Use batch_size = ceil(n_floats / (n_bins * n_frames)), n_bins = 128, n_frames = 3000
237+
// returns status
238+
private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
239+
211240
/**
212241
* Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.
213242
*

extension/android/jni/jni_layer_llama.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
286286
}
287287

288288
jint prefill_audio_input(
289-
facebook::jni::alias_ref<jintArray> audio,
289+
facebook::jni::alias_ref<jbyteArray> audio,
290290
jint batch_size,
291-
jint n_channels,
292-
jint n_samples) {
291+
jint n_bins,
292+
jint n_frames) {
293293
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
294294
return static_cast<jint>(Error::InvalidArgument);
295295
}
@@ -299,12 +299,12 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
299299
auto audio_size = audio->size();
300300
std::vector<uint8_t> audio_data(audio_size);
301301
if (audio_size != 0) {
302-
std::vector<jint> audio_data_jint(audio_size);
303-
audio->getRegion(0, audio_size, audio_data_jint.data());
302+
std::vector<jbyte> audio_data_jbyte(audio_size);
303+
audio->getRegion(0, audio_size, audio_data_jbyte.data());
304304
for (int i = 0; i < audio_size; i++) {
305-
audio_data[i] = audio_data_jint[i];
305+
audio_data[i] = audio_data_jbyte[i];
306306
}
307-
llm::RawAudio audio_input{audio_data, batch_size, n_channels, n_samples};
307+
llm::Audio audio_input{std::move(audio_data), batch_size, n_bins, n_frames};
308308
multi_modal_runner_->prefill(
309309
{llm::MultimodalInput{std::move(audio_input)}});
310310
}

0 commit comments

Comments
 (0)