@@ -114,6 +114,7 @@ class ExecuTorchLlmCallbackJni
114114class ExecuTorchLlmJni : public facebook ::jni::HybridClass<ExecuTorchLlmJni> {
115115 private:
116116 friend HybridBase;
117+ float temperature_;
117118 int model_type_category_;
118119 std::unique_ptr<llm::IRunner> runner_;
119120 std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
@@ -159,6 +160,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
159160#endif
160161
161162 model_type_category_ = model_type_category;
163+ temperature_ = temperature;
162164 if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
163165 multi_modal_runner_ = std::make_unique<example::LlavaRunner>(
164166 model_path->toStdString ().c_str (),
@@ -181,8 +183,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
181183 } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
182184 runner_ = std::make_unique<MTKLlamaRunner>(
183185 model_path->toStdString ().c_str (),
184- tokenizer_path->toStdString ().c_str (),
185- temperature);
186+ tokenizer_path->toStdString ().c_str ());
186187 // Interpret the model type as LLM
187188 model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
188189#endif
@@ -222,6 +223,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
222223 executorch::extension::llm::GenerationConfig config{
223224 .echo = static_cast <bool >(echo),
224225 .seq_len = seq_len,
226+ .temperature = temperature_,
225227 };
226228 runner_->generate (
227229 prompt->toStdString (),
0 commit comments