Skip to content

Commit 96a5ef3

Browse files
committed
some fixes and better interface
1 parent e69292e commit 96a5ef3

File tree

9 files changed

+98
-53
lines changed

9 files changed

+98
-53
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

workflows/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,6 @@ eyre.workspace = true
2121

2222
# ollama-rs is re-exported from ollama-workflows
2323
ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" }
24+
25+
[dev-dependencies]
26+
env_logger.workspace = true

workflows/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,15 @@ This crate handles the configurations of models to be used, and implements vario
55

66
- **OpenAI**: We check that the chosen models are enabled for the user's profile by fetching their models with their API key. We filter out the disabled models.
77
- **Ollama**: We provide a sample workflow to measure TPS and then pick models that are above some TPS threshold. While calculating TPS, there is also a timeout so that beyond that timeout the TPS is not even considered and the model becomes invalid.
8+
9+
## Environment Variables
10+
11+
DKN Workflows make use of several environment variables, respecting the providers.
12+
13+
- `OPENAI_API_KEY` is used for OpenAI requests
14+
- `OLLAMA_HOST` is used to connect to Ollama server
15+
- `OLLAMA_PORT` is used to connect to Ollama server
16+
- `OLLAMA_AUTO_PULL` indicates whether we should pull missing models automatically or not
17+
18+
SERPER_API_KEY=
19+
JINA_API_KEY=

workflows/src/config.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,29 @@ pub struct ModelConfig {
1515
}
1616

1717
impl ModelConfig {
18-
/// Parses Ollama-Workflows compatible models from a comma-separated values string.
19-
pub fn new_from_csv(input: Option<String>) -> Self {
20-
let models_str = split_comma_separated(input);
21-
22-
let models = models_str
18+
pub fn new(models: Vec<Model>) -> Self {
19+
let models_and_providers = models
2320
.into_iter()
24-
.filter_map(|s| match Model::try_from(s) {
25-
Ok(model) => Some((model.clone().into(), model)),
26-
Err(e) => {
27-
log::warn!("Error parsing model: {}", e);
28-
None
29-
}
30-
})
21+
.map(|model| (model.clone().into(), model))
3122
.collect::<Vec<_>>();
3223

3324
Self {
34-
models,
25+
models: models_and_providers,
3526
openai: OpenAIConfig::new(),
3627
ollama: OllamaConfig::new(),
3728
}
3829
}
30+
/// Parses Ollama-Workflows compatible models from a comma-separated values string.
31+
pub fn new_from_csv(input: &str) -> Self {
32+
let models_str = split_comma_separated(input);
33+
34+
let models = models_str
35+
.into_iter()
36+
.filter_map(|s| Model::try_from(s).ok())
37+
.collect();
38+
39+
Self::new(models)
40+
}
3941

4042
/// Returns the models that belong to a given providers from the config.
4143
pub fn get_models_for_provider(&self, provider: ModelProvider) -> Vec<Model> {
@@ -209,19 +211,18 @@ mod tests {
209211

210212
#[test]
211213
fn test_csv_parser() {
212-
let cfg =
213-
ModelConfig::new_from_csv(Some("idontexist,i dont either,i332287648762".to_string()));
214+
let cfg = ModelConfig::new_from_csv("idontexist,i dont either,i332287648762");
214215
assert_eq!(cfg.models.len(), 0);
215216

216-
let cfg = ModelConfig::new_from_csv(Some(
217-
"gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl".to_string(),
218-
));
217+
let cfg = ModelConfig::new_from_csv(
218+
"gemma2:9b-instruct-q8_0,phi3:14b-medium-4k-instruct-q4_1,balblablabl",
219+
);
219220
assert_eq!(cfg.models.len(), 2);
220221
}
221222

222223
#[test]
223224
fn test_model_matching() {
224-
let cfg = ModelConfig::new_from_csv(Some("gpt-4o,llama3.1:latest".to_string()));
225+
let cfg = ModelConfig::new_from_csv("gpt-4o,llama3.1:latest");
225226
assert_eq!(
226227
cfg.get_matching_model("openai".to_string()).unwrap().1,
227228
Model::GPT4o,
@@ -250,7 +251,7 @@ mod tests {
250251

251252
#[test]
252253
fn test_get_any_matching_model() {
253-
let cfg = ModelConfig::new_from_csv(Some("gpt-3.5-turbo,llama3.1:latest".to_string()));
254+
let cfg = ModelConfig::new_from_csv("gpt-3.5-turbo,llama3.1:latest");
254255
let result = cfg.get_any_matching_model(vec![
255256
"i-dont-exist".to_string(),
256257
"llama3.1:latest".to_string(),

workflows/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mod utils;
22
pub use utils::*;
33

44
mod providers;
5-
pub use providers::*;
5+
use providers::*;
66

77
mod config;
88
pub use config::ModelConfig;

workflows/src/providers/ollama.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ mod tests {
227227
use ollama_workflows::{Executor, Model, ProgramMemory, Workflow};
228228

229229
#[tokio::test]
230-
#[ignore = "run this manually"]
230+
#[ignore = "requires Ollama"]
231231
async fn test_ollama_prompt() {
232232
let model = Model::default().to_string();
233233
let ollama = Ollama::default();
@@ -246,7 +246,7 @@ mod tests {
246246
}
247247

248248
#[tokio::test]
249-
#[ignore = "run this manually"]
249+
#[ignore = "requires Ollama"]
250250
async fn test_ollama_workflow() {
251251
let workflow = r#"{
252252
"name": "Simple",

workflows/src/providers/openai.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use eyre::{eyre, Context, Result};
22
use ollama_workflows::Model;
3+
use reqwest::Client;
34
use serde::Deserialize;
45

56
const OPENAI_MODELS_API: &str = "https://api.openai.com/v1/models";
@@ -29,6 +30,7 @@ struct OpenAIModelsResponse {
2930

3031
#[derive(Debug, Clone, Default)]
3132
pub struct OpenAIConfig {
33+
/// API key, if available.
3234
pub(crate) api_key: Option<String>,
3335
}
3436

@@ -50,17 +52,17 @@ impl OpenAIConfig {
5052
};
5153

5254
// fetch models
53-
let client = reqwest::Client::new();
55+
let client = Client::new();
5456
let request = client
5557
.get(OPENAI_MODELS_API)
5658
.header("Authorization", format!("Bearer {}", api_key))
5759
.build()
58-
.wrap_err("Failed to build request")?;
60+
.wrap_err("failed to build request")?;
5961

6062
let response = client
6163
.execute(request)
6264
.await
63-
.wrap_err("Failed to send request")?;
65+
.wrap_err("failed to send request")?;
6466

6567
// parse response
6668
if response.status().is_client_error() {

workflows/src/utils.rs

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,35 @@
22
///
33
/// - Trims `"` from both ends at the start
44
/// - For each item, trims whitespace from both ends
5-
pub fn split_comma_separated(input: Option<String>) -> Vec<String> {
6-
match input {
7-
Some(s) => s
8-
.trim_matches('"')
9-
.split(',')
10-
.filter_map(|s| {
11-
let s = s.trim().to_string();
12-
if s.is_empty() {
13-
None
14-
} else {
15-
Some(s)
16-
}
17-
})
18-
.collect::<Vec<_>>(),
19-
None => vec![],
20-
}
5+
pub fn split_comma_separated(input: &str) -> Vec<String> {
6+
input
7+
.trim_matches('"')
8+
.split(',')
9+
.filter_map(|s| {
10+
let s = s.trim().to_string();
11+
if s.is_empty() {
12+
None
13+
} else {
14+
Some(s)
15+
}
16+
})
17+
.collect::<Vec<_>>()
2118
}
2219

2320
#[cfg(test)]
2421
mod tests {
2522
use super::*;
2623

2724
#[test]
28-
fn test_split_comma_separated() {
25+
fn test_example() {
2926
// should ignore whitespaces and `"` at both ends, and ignore empty items
30-
let input = Some("\"a, b , c ,, \"".to_string());
27+
let input = "\"a, b , c ,, \"";
3128
let expected = vec!["a".to_string(), "b".to_string(), "c".to_string()];
3229
assert_eq!(split_comma_separated(input), expected);
3330
}
31+
32+
#[test]
33+
fn test_empty() {
34+
assert!(split_comma_separated(Default::default()).is_empty());
35+
}
3436
}

workflows/tests/models_test.rs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,38 @@
1+
use std::env;
2+
13
use dkn_workflows::ModelConfig;
24
use eyre::Result;
5+
use ollama_workflows::Model;
6+
7+
#[tokio::test]
8+
#[ignore = "requires Ollama"]
9+
async fn test_ollama() -> Result<()> {
10+
env::set_var("RUST_LOG", "none,dkn_workflows=debug");
11+
let _ = env_logger::try_init();
12+
13+
let models = vec![Model::Phi3_5Mini];
14+
let mut model_config = ModelConfig::new(models);
15+
16+
model_config.check_services().await
17+
}
18+
19+
#[tokio::test]
20+
async fn test_openai() -> Result<()> {
21+
env::set_var("RUST_LOG", "debug");
22+
let _ = env_logger::try_init();
23+
24+
let models = vec![Model::GPT4Turbo];
25+
let mut model_config = ModelConfig::new(models);
26+
27+
model_config.check_services().await
28+
}
329

4-
// #[tokio::test]
5-
// async fn test_ollama() -> Result<()> {}
30+
#[tokio::test]
31+
async fn test_empty() -> Result<()> {
32+
let mut model_config = ModelConfig::new(vec![]);
633

7-
// #[tokio::test]
8-
// async fn test_openai() -> Result<()> {}
34+
let result = model_config.check_services().await;
35+
assert!(result.is_err());
936

10-
// #[tokio::test]
11-
// async fn test_empty() -> Result<()> {
12-
// let mut model_config = ModelConfig::default();
13-
// model_config.check_services().await
14-
// }
37+
Ok(())
38+
}

0 commit comments

Comments
 (0)