File tree Expand file tree Collapse file tree 1 file changed +20
-9
lines changed
Expand file tree Collapse file tree 1 file changed +20
-9
lines changed Original file line number Diff line number Diff line change @@ -297,16 +297,27 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
297297 jlong start_pos,
298298 facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
299299 jboolean echo) {
300- if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
301- return static_cast <jint>(Error::NotSupported);
300+ if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
301+ return static_cast <jint>(multi_modal_runner_->generate_from_pos (
302+ prompt->toStdString (),
303+ seq_len,
304+ start_pos,
305+ [callback](const std::string& result) { callback->onResult (result); },
306+ [callback](const llm::Stats& stats) { callback->onStats (stats); },
307+ echo));
308+ } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
309+ executorch::extension::llm::GenerationConfig config{
310+ .echo = static_cast <bool >(echo),
311+ .seq_len = seq_len,
312+ .temperature = temperature_,
313+ };
314+ runner_->generate_from_pos (
315+ prompt->toStdString (),
316+ start_pos,
317+ config,
318+ [callback](std::string result) { callback->onResult (result); },
319+ [callback](const llm::Stats& stats) { callback->onStats (stats); });
302320 }
303- return static_cast <jint>(multi_modal_runner_->generate_from_pos (
304- prompt->toStdString (),
305- seq_len,
306- start_pos,
307- [callback](const std::string& result) { callback->onResult (result); },
308- [callback](const llm::Stats& stats) { callback->onStats (stats); },
309- echo));
310321 }
311322
312323 void stop () {
You can’t perform that action at this time.
0 commit comments