Skip to content

Commit f4fdfd7

Browse files
committed
use dummy gen instead of embedding for warm-ups for Ollama
1 parent ba0ebb7 commit f4fdfd7

File tree

4 files changed

+26
-37
lines changed

4 files changed

+26
-37
lines changed

Cargo.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default-members = ["compute"]
77

88
[workspace.package]
99
edition = "2021"
10-
version = "0.5.0"
10+
version = "0.5.1"
1111
license = "Apache-2.0"
1212
readme = "README.md"
1313

workflows-v2/src/providers/ollama.rs

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
use eyre::{eyre, Context, Result};
2-
use ollama_rs::generation::{
3-
completion::request::GenerationRequest,
4-
embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest},
5-
};
2+
use ollama_rs::generation::completion::request::GenerationRequest;
63
use rig::completion::{Chat, PromptError};
74
use rig::providers::ollama;
85
use std::time::Duration;
@@ -165,18 +162,19 @@ impl OllamaClient {
165162
pub async fn test_performance(&self, model: &Model) -> bool {
166163
log::info!("Testing model {}", model);
167164

168-
// first generate a dummy embedding to load the model into memory (warm-up)
169-
let request = GenerateEmbeddingsRequest::new(
170-
model.to_string(),
171-
EmbeddingsInput::Single("embedme".into()),
172-
);
173-
if let Err(err) = self.ollama_rs_client.generate_embeddings(request).await {
174-
log::error!("Failed to generate embedding for model {}: {}", model, err);
175-
return false;
176-
};
177-
178165
let generation_request = GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string());
179166

167+
// run a dummy generation for warm-up
168+
log::debug!("Warming up Ollama for model {}", model);
169+
if let Err(e) = self
170+
.ollama_rs_client
171+
.generate(generation_request.clone())
172+
.await
173+
{
174+
log::warn!("Ignoring model {}: Workflow failed with error {}", model, e);
175+
return false;
176+
}
177+
180178
// then, run a sample generation with timeout and measure tps
181179
tokio::select! {
182180
_ = tokio::time::sleep(PERFORMANCE_TIMEOUT) => {

workflows/src/providers/ollama.rs

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
use eyre::{eyre, Context, Result};
22
use ollama_workflows::{
3-
ollama_rs::{
4-
generation::{
5-
completion::request::GenerationRequest,
6-
embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest},
7-
},
8-
Ollama,
9-
},
3+
ollama_rs::{generation::completion::request::GenerationRequest, Ollama},
104
Model,
115
};
126
use std::env;
@@ -187,18 +181,15 @@ impl OllamaConfig {
187181
pub async fn test_performance(&self, ollama: &Ollama, model: &Model) -> bool {
188182
log::info!("Testing model {}", model);
189183

190-
// first generate a dummy embedding to load the model into memory (warm-up)
191-
let request = GenerateEmbeddingsRequest::new(
192-
model.to_string(),
193-
EmbeddingsInput::Single("embedme".into()),
194-
);
195-
if let Err(err) = ollama.generate_embeddings(request).await {
196-
log::error!("Failed to generate embedding for model {}: {}", model, err);
197-
return false;
198-
};
199-
200184
let generation_request = GenerationRequest::new(model.to_string(), TEST_PROMPT.to_string());
201185

186+
// run a dummy generation for warm-up
187+
log::debug!("Warming up Ollama for model {}", model);
188+
if let Err(e) = ollama.generate(generation_request.clone()).await {
189+
log::warn!("Ignoring model {}: Workflow failed with error {}", model, e);
190+
return false;
191+
}
192+
202193
// then, run a sample generation with timeout and measure tps
203194
tokio::select! {
204195
_ = tokio::time::sleep(self.timeout) => {

0 commit comments

Comments
 (0)