@@ -123,7 +123,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
123
123
std::unique_ptr<llm::IRunner> runner_;
124
124
std::unique_ptr<executorch::extension::llm::MultimodalRunner>
125
125
multi_modal_runner_;
126
- std::vector<llm::MultimodalInput> prefill_inputs_;
127
126
128
127
public:
129
128
constexpr static auto kJavaDescriptor =
@@ -213,8 +212,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
213
212
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
214
213
jboolean echo) {
215
214
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
216
- std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
217
- prefill_inputs_.clear ();
215
+ std::vector<llm::MultimodalInput> inputs;
218
216
if (!prompt->toStdString ().empty ()) {
219
217
inputs.emplace_back (llm::MultimodalInput{prompt->toStdString ()});
220
218
}
@@ -247,9 +245,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
247
245
// Contract is valid within an AAR (JNI + corresponding Java code)
248
246
jint prefill_text_input (facebook::jni::alias_ref<jstring> prompt) {
249
247
if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
250
- prefill_inputs_.emplace_back (llm::MultimodalInput{prompt->toStdString ()});
248
+ runner_->prefill (prompt->toStdString (), {});
249
+ return 0 ;
251
250
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
252
- multi_modal_runner_->prefill (llm::MultimodalInput{prompt->toStdString ()});
251
+ multi_modal_runner_->prefill (
252
+ {llm::MultimodalInput{prompt->toStdString ()}});
253
253
return 0 ;
254
254
}
255
255
}
@@ -260,10 +260,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
260
260
jint height,
261
261
jint channels) {
262
262
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
263
- return Error::InvalidArgument;
263
+ return static_cast <jint>( Error::InvalidArgument) ;
264
264
}
265
265
if (image == nullptr ) {
266
- return Error::InvalidArgument;
266
+ return static_cast <jint>( Error::InvalidArgument) ;
267
267
}
268
268
std::vector<llm::Image> images;
269
269
if (image == nullptr ) {
@@ -278,7 +278,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
278
278
image_data[i] = image_data_jint[i];
279
279
}
280
280
llm::Image image_runner{std::move (image_data), width, height, channels};
281
- multi_modal_runner_->prefill (llm::MultimodalInput{std::move (image_runner)});
281
+ multi_modal_runner_->prefill (
282
+ {llm::MultimodalInput{std::move (image_runner)}});
282
283
}
283
284
284
285
return 0 ;
@@ -290,10 +291,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
290
291
jint n_channels,
291
292
jint n_samples) {
292
293
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
293
- return Error::InvalidArgument;
294
+ return static_cast <jint>( Error::InvalidArgument) ;
294
295
}
295
296
if (audio == nullptr ) {
296
- return Error::InvalidArgument;
297
+ return static_cast <jint>( Error::InvalidArgument) ;
297
298
}
298
299
auto audio_size = audio->size ();
299
300
std::vector<uint8_t > audio_data (audio_size);
@@ -304,7 +305,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
304
305
audio_data[i] = audio_data_jint[i];
305
306
}
306
307
llm::RawAudio audio_input{audio_data, batch_size, n_channels, n_samples};
307
- multi_modal_runner_->prefill (llm::MultimodalInput{std::move (audio_input)});
308
+ multi_modal_runner_->prefill (
309
+ {llm::MultimodalInput{std::move (audio_input)}});
308
310
}
309
311
return 0 ;
310
312
}
0 commit comments