Skip to content

Commit 263894a

Browse files
committed
added tps check
1 parent 71c1078 commit 263894a

File tree

5 files changed

+65
-73
lines changed

5 files changed

+65
-73
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ profile-mem:
3737
version:
3838
@cargo pkgid | cut -d@ -f2
3939

40+
.PHONY: ollama-cpu # | Run Ollama CPU container
41+
ollama-cpu:
42+
docker run -p=11434:11434 -v=${HOME}/.ollama:/root/.ollama ollama/ollama
43+
4044
###############################################################################
4145
.PHONY: test # | Run tests
4246
test:

docs/NODE_GUIDE.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,6 @@ For the models that you choose (see list of models just below [here](#1-choose-m
192192
ollama pull llama3.1:latest
193193
```
194194

195-
> [!TIP]
196-
197195
#### Optional Services
198196

199197
Based on presence of API keys, [Ollama Workflows](https://github.com/andthattoo/ollama-workflows/) may use more superior services instead of free alternatives, e.g. [Serper](https://serper.dev/) instead of [DuckDuckGo](https://duckduckgo.com/) or [Jina](https://jina.ai/) without rate-limit instead of with rate-limit. Add these within your `.env` as:
@@ -213,7 +211,7 @@ Based on the resources of your machine, you must decide which models that you wi
213211

214212
#### Ollama Models
215213

216-
- `adrienbrault/nous-hermes2theta-llama3-8b:q8_0`
214+
- `finalend/hermes-3-llama-3.1:8b-q8_0`
217215
- `phi3:14b-medium-4k-instruct-q4_1`
218216
- `phi3:14b-medium-128k-instruct-q4_1`
219217
- `phi3.5:3.8b`

src/config/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ use openai::OpenAIConfig;
1111

1212
use std::{env, time::Duration};
1313

14+
/// Timeout duration for checking model performance during a generation.
15+
const CHECK_TIMEOUT_DURATION: Duration = Duration::from_secs(80);
16+
17+
/// Minimum tokens per second (TPS) for checking model performance during a generation.
18+
const CHECK_TPS: f64 = 5.0;
19+
1420
#[derive(Debug, Clone)]
1521
pub struct DriaComputeNodeConfig {
1622
/// Wallet secret/private key.
@@ -139,7 +145,7 @@ impl DriaComputeNodeConfig {
139145
// ensure that the models are pulled / pull them if not
140146
let good_ollama_models = self
141147
.ollama_config
142-
.check(ollama_models, Duration::from_secs(30))
148+
.check(ollama_models, CHECK_TIMEOUT_DURATION, CHECK_TPS)
143149
.await?;
144150
good_models.extend(
145151
good_ollama_models

src/config/models.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,25 +136,25 @@ mod tests {
136136
assert_eq!(cfg.models.len(), 0);
137137

138138
let cfg = ModelConfig::new_from_csv(Some(
139-
"phi3:3.8b,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(),
139+
"gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(),
140140
));
141141
assert_eq!(cfg.models.len(), 2);
142142
}
143143

144144
#[test]
145145
fn test_model_matching() {
146-
let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,phi3:3.8b".to_string()));
146+
let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string()));
147147
assert_eq!(
148148
cfg.get_matching_model("openai".to_string()).unwrap().1,
149149
Model::GPT3_5Turbo,
150150
"Should find existing model"
151151
);
152152

153153
assert_eq!(
154-
cfg.get_matching_model(Model::default().to_string())
154+
cfg.get_matching_model("llama3.1:latest".to_string())
155155
.unwrap()
156156
.1,
157-
Model::default(),
157+
Model::Llama3_1_8B,
158158
"Should find existing model"
159159
);
160160

@@ -172,7 +172,7 @@ mod tests {
172172

173173
#[test]
174174
fn test_get_any_matching_model() {
175-
let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,phi3:3.8b".to_string()));
175+
let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string()));
176176
let result = cfg.get_any_matching_model(vec![
177177
"i-dont-exist".to_string(),
178178
"llama3.1:latest".to_string(),
@@ -181,7 +181,7 @@ mod tests {
181181
]);
182182
assert_eq!(
183183
result.unwrap().1,
184-
Model::default(),
184+
Model::Llama3_1_8B,
185185
"Should find existing model"
186186
);
187187
}

src/config/ollama.rs

Lines changed: 47 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
use std::time::Duration;
22

3-
use ollama_workflows::{ollama_rs::Ollama, Executor, Model, ProgramMemory, Workflow};
3+
use ollama_workflows::{
4+
ollama_rs::{generation::completion::request::GenerationRequest, Ollama},
5+
Model,
6+
};
47

58
const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1";
69
const DEFAULT_OLLAMA_PORT: u16 = 11434;
710

811
/// Some models such as small embedding models, are hardcoded into the node.
912
const HARDCODED_MODELS: [&str; 1] = ["hellord/mxbai-embed-large-v1:f16"];
1013

14+
/// Prompt to be used to see Ollama performance.
15+
const TEST_PROMPT: &str = "Please write a poem about Kapadokya.";
16+
1117
/// Ollama-specific configurations.
1218
#[derive(Debug, Clone)]
1319
pub struct OllamaConfig {
@@ -66,12 +72,13 @@ impl OllamaConfig {
6672
pub async fn check(
6773
&self,
6874
external_models: Vec<Model>,
69-
test_workflow_timeout: Duration,
75+
timeout: Duration,
76+
min_tps: f64,
7077
) -> Result<Vec<Model>, String> {
7178
log::info!(
7279
"Checking Ollama requirements (auto-pull {}, workflow timeout: {}s)",
7380
if self.auto_pull { "on" } else { "off" },
74-
test_workflow_timeout.as_secs()
81+
timeout.as_secs()
7582
);
7683

7784
let ollama = Ollama::new(&self.host, self.port);
@@ -108,7 +115,7 @@ impl OllamaConfig {
108115
}
109116

110117
if self
111-
.test_workflow(model.clone(), test_workflow_timeout)
118+
.test_performance(&ollama, &model, timeout, min_tps)
112119
.await
113120
{
114121
good_models.push(model);
@@ -149,71 +156,48 @@ impl OllamaConfig {
149156
///
150157
/// This is to see if a given system can execute Ollama workflows for their chosen models,
151158
/// e.g. if they have enough RAM/CPU and such.
152-
pub async fn test_workflow(&self, model: Model, timeout: Duration) -> bool {
153-
// this is the test workflow that we will run
154-
// TODO: when Workflow's have `Clone`, we can remove the repetitive parsing here
155-
let workflow = serde_json::from_value::<Workflow>(serde_json::json!({
156-
"name": "Simple",
157-
"description": "This is a simple workflow",
158-
"config":{
159-
"max_steps": 5,
160-
"max_time": 100,
161-
"max_tokens": 100,
162-
"tools": []
163-
},
164-
"tasks":[
165-
{
166-
"id": "A",
167-
"name": "Random Poem",
168-
"description": "Writes a poem about Kapadokya.",
169-
"prompt": "Please write a poem about Kapadokya.",
170-
"inputs":[],
171-
"operator": "generation",
172-
"outputs":[
173-
{
174-
"type": "write",
175-
"key": "poem",
176-
"value": "__result"
177-
}
178-
]
179-
},
180-
{
181-
"id": "__end",
182-
"name": "end",
183-
"description": "End of the task",
184-
"prompt": "End of the task",
185-
"inputs": [],
186-
"operator": "end",
187-
"outputs": []
188-
}
189-
],
190-
"steps":[
191-
{
192-
"source":"A",
193-
"target":"end"
194-
}
195-
],
196-
"return_value":{
197-
"input":{
198-
"type": "read",
199-
"key": "poem"
200-
}
201-
}
202-
}))
203-
.expect("Preset workflow should be parsed");
204-
159+
pub async fn test_performance(
160+
&self,
161+
ollama: &Ollama,
162+
model: &Model,
163+
timeout: Duration,
164+
min_tps: f64,
165+
) -> bool {
205166
log::info!("Testing model {}", model);
206-
let executor = Executor::new_at(model.clone(), &self.host, self.port);
207-
let mut memory = ProgramMemory::new();
167+
168+
// first generate a dummy embedding to load the model into memory (warm-up)
169+
if let Err(err) = ollama
170+
.generate_embeddings(model.to_string(), "foobar".to_string(), Default::default())
171+
.await
172+
{
173+
log::error!("Failed to generate embedding for model {}: {}", model, err);
174+
return false;
175+
};
176+
177+
// then, run a sample generation with timeout and measure tps
208178
tokio::select! {
209179
_ = tokio::time::sleep(timeout) => {
210180
log::warn!("Ignoring model {}: Workflow timed out", model);
211181
},
212-
result = executor.execute(None, workflow, &mut memory) => {
182+
result = ollama.generate(GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string())) => {
213183
match result {
214-
Ok(_) => {
215-
log::info!("Accepting model {}", model);
216-
return true;
184+
Ok(response) => {
185+
let tps = (response.eval_count.unwrap_or_default() as f64)
186+
/ (response.eval_duration.unwrap_or(1) as f64)
187+
* 1_000_000_000f64;
188+
189+
if tps >= min_tps {
190+
log::info!("Model {} passed the test with tps: {}", model, tps);
191+
return true;
192+
193+
}
194+
195+
log::warn!(
196+
"Ignoring model {}: tps too low ({:.3} < {:.3})",
197+
model,
198+
tps,
199+
min_tps
200+
);
217201
}
218202
Err(e) => {
219203
log::warn!("Ignoring model {}: Workflow failed with error {}", model, e);

0 commit comments

Comments
 (0)