|
1 | 1 | use std::time::Duration;
|
2 | 2 |
|
3 |
| -use ollama_workflows::{ollama_rs::Ollama, Executor, Model, ProgramMemory, Workflow}; |
| 3 | +use ollama_workflows::{ |
| 4 | + ollama_rs::{generation::completion::request::GenerationRequest, Ollama}, |
| 5 | + Model, |
| 6 | +}; |
4 | 7 |
|
5 | 8 | const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1";
|
6 | 9 | const DEFAULT_OLLAMA_PORT: u16 = 11434;
|
7 | 10 |
|
8 | 11 | /// Some models such as small embedding models, are hardcoded into the node.
|
9 | 12 | const HARDCODED_MODELS: [&str; 1] = ["hellord/mxbai-embed-large-v1:f16"];
|
10 | 13 |
|
| 14 | +/// Prompt to be used to see Ollama performance. |
| 15 | +const TEST_PROMPT: &str = "Please write a poem about Kapadokya."; |
| 16 | + |
11 | 17 | /// Ollama-specific configurations.
|
12 | 18 | #[derive(Debug, Clone)]
|
13 | 19 | pub struct OllamaConfig {
|
@@ -66,12 +72,13 @@ impl OllamaConfig {
|
66 | 72 | pub async fn check(
|
67 | 73 | &self,
|
68 | 74 | external_models: Vec<Model>,
|
69 |
| - test_workflow_timeout: Duration, |
| 75 | + timeout: Duration, |
| 76 | + min_tps: f64, |
70 | 77 | ) -> Result<Vec<Model>, String> {
|
71 | 78 | log::info!(
|
72 | 79 | "Checking Ollama requirements (auto-pull {}, workflow timeout: {}s)",
|
73 | 80 | if self.auto_pull { "on" } else { "off" },
|
74 |
| - test_workflow_timeout.as_secs() |
| 81 | + timeout.as_secs() |
75 | 82 | );
|
76 | 83 |
|
77 | 84 | let ollama = Ollama::new(&self.host, self.port);
|
@@ -108,7 +115,7 @@ impl OllamaConfig {
|
108 | 115 | }
|
109 | 116 |
|
110 | 117 | if self
|
111 |
| - .test_workflow(model.clone(), test_workflow_timeout) |
| 118 | + .test_performance(&ollama, &model, timeout, min_tps) |
112 | 119 | .await
|
113 | 120 | {
|
114 | 121 | good_models.push(model);
|
@@ -149,71 +156,48 @@ impl OllamaConfig {
|
149 | 156 | ///
|
150 | 157 | /// This is to see if a given system can execute Ollama workflows for their chosen models,
|
151 | 158 | /// e.g. if they have enough RAM/CPU and such.
|
152 |
| - pub async fn test_workflow(&self, model: Model, timeout: Duration) -> bool { |
153 |
| - // this is the test workflow that we will run |
154 |
| - // TODO: when Workflow's have `Clone`, we can remove the repetitive parsing here |
155 |
| - let workflow = serde_json::from_value::<Workflow>(serde_json::json!({ |
156 |
| - "name": "Simple", |
157 |
| - "description": "This is a simple workflow", |
158 |
| - "config":{ |
159 |
| - "max_steps": 5, |
160 |
| - "max_time": 100, |
161 |
| - "max_tokens": 100, |
162 |
| - "tools": [] |
163 |
| - }, |
164 |
| - "tasks":[ |
165 |
| - { |
166 |
| - "id": "A", |
167 |
| - "name": "Random Poem", |
168 |
| - "description": "Writes a poem about Kapadokya.", |
169 |
| - "prompt": "Please write a poem about Kapadokya.", |
170 |
| - "inputs":[], |
171 |
| - "operator": "generation", |
172 |
| - "outputs":[ |
173 |
| - { |
174 |
| - "type": "write", |
175 |
| - "key": "poem", |
176 |
| - "value": "__result" |
177 |
| - } |
178 |
| - ] |
179 |
| - }, |
180 |
| - { |
181 |
| - "id": "__end", |
182 |
| - "name": "end", |
183 |
| - "description": "End of the task", |
184 |
| - "prompt": "End of the task", |
185 |
| - "inputs": [], |
186 |
| - "operator": "end", |
187 |
| - "outputs": [] |
188 |
| - } |
189 |
| - ], |
190 |
| - "steps":[ |
191 |
| - { |
192 |
| - "source":"A", |
193 |
| - "target":"end" |
194 |
| - } |
195 |
| - ], |
196 |
| - "return_value":{ |
197 |
| - "input":{ |
198 |
| - "type": "read", |
199 |
| - "key": "poem" |
200 |
| - } |
201 |
| - } |
202 |
| - })) |
203 |
| - .expect("Preset workflow should be parsed"); |
204 |
| - |
| 159 | + pub async fn test_performance( |
| 160 | + &self, |
| 161 | + ollama: &Ollama, |
| 162 | + model: &Model, |
| 163 | + timeout: Duration, |
| 164 | + min_tps: f64, |
| 165 | + ) -> bool { |
205 | 166 | log::info!("Testing model {}", model);
|
206 |
| - let executor = Executor::new_at(model.clone(), &self.host, self.port); |
207 |
| - let mut memory = ProgramMemory::new(); |
| 167 | + |
| 168 | + // first generate a dummy embedding to load the model into memory (warm-up) |
| 169 | + if let Err(err) = ollama |
| 170 | + .generate_embeddings(model.to_string(), "foobar".to_string(), Default::default()) |
| 171 | + .await |
| 172 | + { |
| 173 | + log::error!("Failed to generate embedding for model {}: {}", model, err); |
| 174 | + return false; |
| 175 | + }; |
| 176 | + |
| 177 | + // then, run a sample generation with timeout and measure tps |
208 | 178 | tokio::select! {
|
209 | 179 | _ = tokio::time::sleep(timeout) => {
|
210 | 180 | log::warn!("Ignoring model {}: Workflow timed out", model);
|
211 | 181 | },
|
212 |
| - result = executor.execute(None, workflow, &mut memory) => { |
| 182 | + result = ollama.generate(GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string())) => { |
213 | 183 | match result {
|
214 |
| - Ok(_) => { |
215 |
| - log::info!("Accepting model {}", model); |
216 |
| - return true; |
| 184 | + Ok(response) => { |
| 185 | + let tps = (response.eval_count.unwrap_or_default() as f64) |
| 186 | + / (response.eval_duration.unwrap_or(1) as f64) |
| 187 | + * 1_000_000_000f64; |
| 188 | + |
| 189 | + if tps >= min_tps { |
| 190 | + log::info!("Model {} passed the test with tps: {}", model, tps); |
| 191 | + return true; |
| 192 | + |
| 193 | + } |
| 194 | + |
| 195 | + log::warn!( |
| 196 | + "Ignoring model {}: tps too low ({:.3} < {:.3})", |
| 197 | + model, |
| 198 | + tps, |
| 199 | + min_tps |
| 200 | + ); |
217 | 201 | }
|
218 | 202 | Err(e) => {
|
219 | 203 | log::warn!("Ignoring model {}: Workflow failed with error {}", model, e);
|
|
0 commit comments