@@ -120,6 +120,7 @@ class ExecuTorchLlmCallbackJni
120120class ExecuTorchLlmJni : public facebook ::jni::HybridClass<ExecuTorchLlmJni> {
121121 private:
122122 friend HybridBase;
123+ float temperature_;
123124 int model_type_category_;
124125 std::unique_ptr<llm::IRunner> runner_;
125126 std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
@@ -175,20 +176,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
175176 runner_ = std::make_unique<example::Runner>(
176177 model_path->toStdString ().c_str (),
177178 tokenizer_path->toStdString ().c_str (),
178- temperature,
179179 data_path->toStdString ().c_str ());
180180 } else {
181181 runner_ = std::make_unique<example::Runner>(
182182 model_path->toStdString ().c_str (),
183- tokenizer_path->toStdString ().c_str (),
184- temperature);
183+ tokenizer_path->toStdString ().c_str ());
185184 }
186185#if defined(EXECUTORCH_BUILD_MEDIATEK)
187186 } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
188187 runner_ = std::make_unique<MTKLlamaRunner>(
189188 model_path->toStdString ().c_str (),
190- tokenizer_path->toStdString ().c_str (),
191- temperature);
189+ tokenizer_path->toStdString ().c_str ());
192190 // Interpret the model type as LLM
193191 model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
194192#endif
@@ -228,6 +226,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
228226 executorch::extension::llm::GenerationConfig config{
229227 .echo = static_cast <bool >(echo),
230228 .seq_len = seq_len,
229+ .temperature = temperature_,
231230 };
232231 runner_->generate (
233232 prompt->toStdString (),
0 commit comments