Skip to content

Commit 10ae46f

Browse files
committed
add thread env arg to ollama check
1 parent 2206702 commit 10ae46f

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

src/config/ollama.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use std::time::Duration;
22

33
use ollama_workflows::{
4-
ollama_rs::{generation::completion::request::GenerationRequest, Ollama},
4+
ollama_rs::{
5+
generation::{completion::request::GenerationRequest, options::GenerationOptions},
6+
Ollama,
7+
},
58
Model,
69
};
710

@@ -174,12 +177,26 @@ impl OllamaConfig {
174177
return false;
175178
};
176179

180+
let mut generation_request =
181+
GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string());
182+
183+
// FIXME: temporary workaround, can take num threads from outside
184+
if let Ok(num_thread) = std::env::var("OLLAMA_NUM_THREAD") {
185+
generation_request = generation_request.options(
186+
GenerationOptions::default().num_thread(
187+
num_thread
188+
.parse()
189+
.expect("num threads should be a positive integer"),
190+
),
191+
);
192+
}
193+
177194
// then, run a sample generation with timeout and measure tps
178195
tokio::select! {
179196
_ = tokio::time::sleep(timeout) => {
180197
log::warn!("Ignoring model {}: Workflow timed out", model);
181198
},
182-
result = ollama.generate(GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string())) => {
199+
result = ollama.generate(generation_request) => {
183200
match result {
184201
Ok(response) => {
185202
let tps = (response.eval_count.unwrap_or_default() as f64)
@@ -189,7 +206,6 @@ impl OllamaConfig {
189206
if tps >= min_tps {
190207
log::info!("Model {} passed the test with tps: {}", model, tps);
191208
return true;
192-
193209
}
194210

195211
log::warn!(

0 commit comments

Comments
 (0)