Skip to content

Commit 1f0f038

Browse files
authored
[llm] Route generate_from_pos() JNI API to native llm runner API (#11952)
Connect `generate_from_pos()` API to C++ llm runner's `generate_from_pos()`.
1 parent 1a1886a commit 1f0f038

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,16 +297,27 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
297297
jlong start_pos,
298298
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
299299
jboolean echo) {
300-
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
301-
return static_cast<jint>(Error::NotSupported);
300+
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
301+
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
302+
prompt->toStdString(),
303+
seq_len,
304+
start_pos,
305+
[callback](const std::string& result) { callback->onResult(result); },
306+
[callback](const llm::Stats& stats) { callback->onStats(stats); },
307+
echo));
308+
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
309+
executorch::extension::llm::GenerationConfig config{
310+
.echo = static_cast<bool>(echo),
311+
.seq_len = seq_len,
312+
.temperature = temperature_,
313+
};
314+
runner_->generate_from_pos(
315+
prompt->toStdString(),
316+
start_pos,
317+
config,
318+
[callback](std::string result) { callback->onResult(result); },
319+
[callback](const llm::Stats& stats) { callback->onStats(stats); });
302320
}
303-
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
304-
prompt->toStdString(),
305-
seq_len,
306-
start_pos,
307-
[callback](const std::string& result) { callback->onResult(result); },
308-
[callback](const llm::Stats& stats) { callback->onStats(stats); },
309-
echo));
310321
}
311322

312323
void stop() {

0 commit comments

Comments
 (0)