diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 4a98ccb7c82..1f78a87f4ee 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -120,6 +120,7 @@ class ExecuTorchLlmCallbackJni class ExecuTorchLlmJni : public facebook::jni::HybridClass { private: friend HybridBase; + float temperature_; int model_type_category_; std::unique_ptr runner_; std::unique_ptr multi_modal_runner_; @@ -175,20 +176,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { runner_ = std::make_unique( model_path->toStdString().c_str(), tokenizer_path->toStdString().c_str(), - temperature, data_path->toStdString().c_str()); } else { runner_ = std::make_unique( model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - temperature); + tokenizer_path->toStdString().c_str()); } #if defined(EXECUTORCH_BUILD_MEDIATEK) } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { runner_ = std::make_unique( model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - temperature); + tokenizer_path->toStdString().c_str()); // Interpret the model type as LLM model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif @@ -228,6 +226,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { executorch::extension::llm::GenerationConfig config{ .echo = static_cast(echo), .seq_len = seq_len, + .temperature = temperature_, }; runner_->generate( prompt->toStdString(),