Skip to content

Commit e69292e

Browse files
committed
better model parsing added as subcrate
1 parent 1f2f15d commit e69292e

File tree

11 files changed

+191
-150
lines changed

11 files changed

+191
-150
lines changed

.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ DKN_ADMIN_PUBLIC_KEY=0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110b
1010
DKN_MODELS=
1111

1212
## DRIA (optional) ##
13-
# P2P address, you don't need to change this unless you really want this port.
13+
# P2P address, you don't need to change this unless this port is already in use.
1414
DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001
1515
# Comma-separated static relay nodes
1616
DKN_RELAY_NODES=

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ serde = { version = "1.0", features = ["derive"] }
1717
serde_json = "1.0"
1818
async-trait = "0.1.81"
1919
reqwest = "0.12.5"
20-
20+
rand = "0.8.5"
2121
env_logger = "0.11.3"
2222
log = "0.4.21"
2323
eyre = "0.6.12"

workflows/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ serde.workspace = true
1515
serde_json.workspace = true
1616
async-trait.workspace = true
1717
reqwest.workspace = true
18-
18+
rand.workspace = true
1919
log.workspace = true
2020
eyre.workspace = true
2121

22+
# ollama-rs is re-exported from ollama-workflows
2223
ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" }

workflows/src/models.rs renamed to workflows/src/config.rs

Lines changed: 92 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,21 @@
1-
use crate::{utils::split_comma_separated, OllamaConfig, OpenAIConfig};
1+
use crate::{split_comma_separated, OllamaConfig, OpenAIConfig};
22
use eyre::{eyre, Result};
33
use ollama_workflows::{Model, ModelProvider};
44
use rand::seq::IteratorRandom; // provides Vec<_>.choose
55

66
#[derive(Debug, Clone)]
77
pub struct ModelConfig {
8+
/// List of models with their providers.
89
pub models: Vec<(ModelProvider, Model)>,
10+
/// Even if Ollama is not used, we store the host & port here.
11+
/// If Ollama is used, this config will be respected during its instantiations.
912
pub ollama: OllamaConfig,
13+
/// OpenAI API key & its service check implementation.
1014
pub openai: OpenAIConfig,
1115
}
1216

13-
impl std::fmt::Display for ModelConfig {
14-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15-
let models_str = self
16-
.models
17-
.iter()
18-
.map(|(provider, model)| format!("{:?}:{}", provider, model))
19-
.collect::<Vec<_>>()
20-
.join(",");
21-
write!(f, "{}", models_str)
22-
}
23-
}
24-
2517
impl ModelConfig {
26-
/// Creates a new config with the given list of models.
27-
pub fn new(models: Vec<Model>) -> Self {
28-
// map models to (provider, model) pairs
29-
let models_providers = models
30-
.into_iter()
31-
.map(|m| (m.clone().into(), m))
32-
.collect::<Vec<_>>();
33-
34-
let mut providers = Vec::new();
35-
36-
// get ollama models & config
37-
let ollama_models = models_providers
38-
.iter()
39-
.filter_map(|(p, m)| {
40-
if *p == ModelProvider::Ollama {
41-
Some(m.clone())
42-
} else {
43-
None
44-
}
45-
})
46-
.collect::<Vec<_>>();
47-
let ollama_config = if !ollama_models.is_empty() {
48-
providers.push(ModelProvider::Ollama);
49-
Some(OllamaConfig::new(ollama_models))
50-
} else {
51-
None
52-
};
53-
54-
// get openai models & config
55-
let openai_models = models_providers
56-
.iter()
57-
.filter_map(|(p, m)| {
58-
if *p == ModelProvider::OpenAI {
59-
Some(m.clone())
60-
} else {
61-
None
62-
}
63-
})
64-
.collect::<Vec<_>>();
65-
let openai_config = if !openai_models.is_empty() {
66-
providers.push(ModelProvider::OpenAI);
67-
Some(OpenAIConfig::new(openai_models))
68-
} else {
69-
None
70-
};
71-
72-
Self {
73-
models_providers,
74-
providers,
75-
ollama_config,
76-
openai_config,
77-
}
78-
}
79-
8018
/// Parses Ollama-Workflows compatible models from a comma-separated values string.
81-
///
82-
/// ## Example
83-
///
84-
/// ```
85-
/// let config = ModelConfig::new_from_csv("gpt-4-turbo,gpt-4o-mini");
86-
/// ```
8719
pub fn new_from_csv(input: Option<String>) -> Self {
8820
let models_str = split_comma_separated(input);
8921

@@ -98,7 +30,11 @@ impl ModelConfig {
9830
})
9931
.collect::<Vec<_>>();
10032

101-
Self { models }
33+
Self {
34+
models,
35+
openai: OpenAIConfig::new(),
36+
ollama: OllamaConfig::new(),
37+
}
10238
}
10339

10440
/// Returns the models that belong to a given providers from the config.
@@ -117,12 +53,27 @@ impl ModelConfig {
11753

11854
/// Given a raw model name or provider (as a string), returns the first matching model & provider.
11955
///
120-
/// If this is a model and is supported by this node, it is returned directly.
121-
/// If this is a provider, the first matching model in the node config is returned.
56+
/// - If input is `*` or `all`, a random model is returned.
57+
/// - if input is `!` the first model is returned.
58+
/// - If input is a model and is supported by this node, it is returned directly.
59+
/// - If input is a provider, the first matching model in the node config is returned.
12260
///
12361
/// If there are no matching models with this logic, an error is returned.
12462
pub fn get_matching_model(&self, model_or_provider: String) -> Result<(ModelProvider, Model)> {
125-
if let Ok(provider) = ModelProvider::try_from(model_or_provider.clone()) {
63+
if model_or_provider == "*" {
64+
// return a random model
65+
self.models
66+
.iter()
67+
.choose(&mut rand::thread_rng())
68+
.ok_or_else(|| eyre!("No models to randomly pick for '*'."))
69+
.cloned()
70+
} else if model_or_provider == "!" {
71+
// return the first model
72+
self.models
73+
.first()
74+
.ok_or_else(|| eyre!("No models to choose first for '!'."))
75+
.cloned()
76+
} else if let Ok(provider) = ModelProvider::try_from(model_or_provider.clone()) {
12677
// this is a valid provider, return the first matching model in the config
12778
self.models
12879
.iter()
@@ -186,6 +137,70 @@ impl ModelConfig {
186137
unique
187138
})
188139
}
140+
141+
/// Check if the required compute services are running.
142+
/// This has several steps:
143+
///
144+
/// - If Ollama models are used, hardcoded models are checked locally, and for
145+
/// external models, the workflow is tested with a simple task with timeout.
146+
/// - If OpenAI models are used, the API key is checked and the models are tested
147+
///
148+
/// If both type of models are used, both services are checked.
149+
/// In the end, bad models are filtered out and we simply check if we are left if any valid models at all.
150+
/// If not, an error is returned.
151+
pub async fn check_services(&mut self) -> Result<()> {
152+
log::info!("Checking configured services.");
153+
154+
// TODO: can refactor (provider, model) logic here
155+
let unique_providers = self.get_providers();
156+
157+
let mut good_models = Vec::new();
158+
159+
// if Ollama is a provider, check that it is running & Ollama models are pulled (or pull them)
160+
if unique_providers.contains(&ModelProvider::Ollama) {
161+
let ollama_models = self.get_models_for_provider(ModelProvider::Ollama);
162+
163+
// ensure that the models are pulled / pull them if not
164+
let good_ollama_models = self.ollama.check(ollama_models).await?;
165+
good_models.extend(
166+
good_ollama_models
167+
.into_iter()
168+
.map(|m| (ModelProvider::Ollama, m)),
169+
);
170+
}
171+
172+
// if OpenAI is a provider, check that the API key is set
173+
if unique_providers.contains(&ModelProvider::OpenAI) {
174+
let openai_models = self.get_models_for_provider(ModelProvider::OpenAI);
175+
176+
let good_openai_models = self.openai.check(openai_models).await?;
177+
good_models.extend(
178+
good_openai_models
179+
.into_iter()
180+
.map(|m| (ModelProvider::OpenAI, m)),
181+
);
182+
}
183+
184+
// update good models
185+
if good_models.is_empty() {
186+
Err(eyre!("No good models found, please check logs for errors."))
187+
} else {
188+
self.models = good_models;
189+
Ok(())
190+
}
191+
}
192+
}
193+
194+
impl std::fmt::Display for ModelConfig {
195+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196+
let models_str = self
197+
.models
198+
.iter()
199+
.map(|(provider, model)| format!("{:?}:{}", provider, model))
200+
.collect::<Vec<_>>()
201+
.join(",");
202+
write!(f, "{}", models_str)
203+
}
189204
}
190205

191206
#[cfg(test)]

workflows/src/lib.rs

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,8 @@
1-
use async_trait::async_trait;
2-
use eyre::Result;
1+
mod utils;
2+
pub use utils::*;
33

4-
mod models;
5-
pub use models::ModelConfig;
4+
mod providers;
5+
pub use providers::*;
66

7-
/// Ollama configurations & service checks
8-
mod ollama;
9-
pub(crate) use ollama::OllamaConfig;
10-
11-
/// OpenAI configurations & service checks
12-
mod openai;
13-
pub(crate) use openai::OpenAIConfig;
14-
15-
/// Extension trait for model providers to check if they are ready, and describe themselves.
16-
#[async_trait]
17-
pub trait ProvidersExt {
18-
const PROVIDER_NAME: &str;
19-
20-
/// Ensures that the required provider is online & ready.
21-
async fn check_service(&self) -> Result<()>;
22-
}
7+
mod config;
8+
pub use config::ModelConfig;

workflows/src/providers/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mod ollama;
2+
pub use ollama::OllamaConfig;
3+
4+
mod openai;
5+
pub use openai::OpenAIConfig;

0 commit comments

Comments
 (0)