Skip to content

Commit 25f0d32

Browse files
committed
use Model instead of (ModelProvider, Model)
1 parent c713798 commit 25f0d32

File tree

8 files changed

+621
-494
lines changed

8 files changed

+621
-494
lines changed

Cargo.lock

Lines changed: 565 additions & 372 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ docker compose --profile=ollama-rocm up
9090
Note that we are very dependent on Ollama packages, and it is important to check their versions if relevant:
9191

9292
```sh
93-
@cat Cargo.lock | grep "https://github.com/andthattoo/ollama-workflows"
94-
@cat Cargo.lock | grep "https://github.com/andthattoo/ollama-rs"
93+
cat Cargo.lock | grep "https://github.com/andthattoo/ollama-workflows"
94+
cat Cargo.lock | grep "https://github.com/andthattoo/ollama-rs"
9595
```
9696

9797
### Testing

compute/src/main.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,7 @@ async fn main() -> Result<()> {
8888
}?;
8989
log::warn!(
9090
"Using models: {}",
91-
config
92-
.workflows
93-
.models
94-
.iter()
95-
.map(|(p, m)| format!("{}/{}", p, m))
96-
.collect::<Vec<_>>()
97-
.join(", ")
91+
config.workflows.get_model_names().join(", ")
9892
);
9993

10094
// check network-specific configurations

compute/src/node/diagnostic.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,7 @@ impl DriaComputeNode {
4747
// print models
4848
diagnostics.push(format!(
4949
"Models: {}",
50-
self.config
51-
.workflows
52-
.models
53-
.iter()
54-
.map(|(p, m)| format!("{}/{}", p, m))
55-
.collect::<Vec<String>>()
56-
.join(", ")
50+
self.config.workflows.get_model_names().join(", ")
5751
));
5852

5953
// if we have not received pings for a while, we are considered offline

compute/src/reqres/heartbeat.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,7 @@ impl HeartbeatRequester {
5858
let heartbeat_request = HeartbeatRequest {
5959
heartbeat_id: uuid,
6060
deadline,
61-
models: node
62-
.config
63-
.workflows
64-
.models
65-
.iter()
66-
.map(|m| m.1.to_string())
67-
.collect(),
61+
models: node.config.workflows.get_model_names(),
6862
pending_tasks: node.get_pending_task_count(),
6963
};
7064

compute/src/reqres/task.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@ impl TaskResponder {
6565
let task_public_key = PublicKey::parse_slice(&task_public_key_bytes, None)?;
6666

6767
// read model / provider from the task
68-
let (model_provider, model) = node
68+
let model = node
6969
.config
7070
.workflows
7171
.get_any_matching_model(task.input.model)?;
7272
let model_name = model.to_string(); // get model name, we will pass it in payload
7373
log::info!("Using model {} for task {}", model_name, task.task_id);
7474

7575
// prepare workflow executor
76-
let (executor, batchable) = if model_provider == ModelProvider::Ollama {
76+
let (executor, batchable) = if model.provider() == ModelProvider::Ollama {
7777
(
7878
Executor::new_at(
7979
model,

workflows/src/config.rs

Lines changed: 44 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ use crate::{
44
Model, ModelProvider,
55
};
66
use dkn_utils::split_csv_line;
7-
use eyre::{eyre, Result};
7+
use eyre::{eyre, OptionExt, Result};
88
use rand::seq::IteratorRandom; // provides Vec<_>.choose
99

1010
#[derive(Debug, Clone)]
1111
pub struct DriaWorkflowsConfig {
12-
/// List of models with their providers.
13-
pub models: Vec<(ModelProvider, Model)>,
12+
/// List of models.
13+
///
14+
/// You can do `model.provider()` to get its provider.
15+
pub models: Vec<Model>,
1416
/// Ollama configurations, in case Ollama is used.
1517
/// Otherwise, can be ignored.
1618
pub ollama: OllamaConfig,
@@ -40,13 +42,8 @@ impl Default for DriaWorkflowsConfig {
4042
impl DriaWorkflowsConfig {
4143
/// Creates a new config with the given models.
4244
pub fn new(models: Vec<Model>) -> Self {
43-
let models_and_providers = models
44-
.into_iter()
45-
.map(|model| (model.clone().into(), model))
46-
.collect::<Vec<_>>();
47-
4845
Self {
49-
models: models_and_providers,
46+
models,
5047
ollama: OllamaConfig::new(),
5148
openai: OpenAIConfig::new(),
5249
openrouter: OpenRouterConfig::new(),
@@ -84,24 +81,23 @@ impl DriaWorkflowsConfig {
8481
pub fn get_models_for_provider(&self, provider: ModelProvider) -> Vec<Model> {
8582
self.models
8683
.iter()
87-
.filter_map(|(p, m)| {
88-
if *p == provider {
89-
Some(m.clone())
90-
} else {
91-
None
92-
}
93-
})
84+
.filter(|m| m.provider() == provider)
85+
.cloned()
9486
.collect()
9587
}
9688

9789
/// Returns `true` if the configuration contains models that can be processed in parallel, e.g. API calls.
9890
pub fn has_batchable_models(&self) -> bool {
99-
self.models.iter().any(|(p, _)| *p != ModelProvider::Ollama)
91+
self.models
92+
.iter()
93+
.any(|m| m.provider() != ModelProvider::Ollama)
10094
}
10195

10296
/// Returns `true` if the configuration contains a model that cant be run in parallel, e.g. a Ollama model.
10397
pub fn has_non_batchable_models(&self) -> bool {
104-
self.models.iter().any(|(p, _)| *p == ModelProvider::Ollama)
98+
self.models
99+
.iter()
100+
.any(|m| m.provider() == ModelProvider::Ollama)
105101
}
106102

107103
/// Given a raw model name or provider (as a string), returns the first matching model & provider.
@@ -112,51 +108,46 @@ impl DriaWorkflowsConfig {
112108
/// - If input is a provider, the first matching model in the node config is returned.
113109
///
114110
/// If there are no matching models with this logic, an error is returned.
115-
pub fn get_matching_model(&self, model_or_provider: String) -> Result<(ModelProvider, Model)> {
111+
pub fn get_matching_model(&self, model_or_provider: String) -> Result<Model> {
116112
if model_or_provider == "*" {
117113
// return a random model
118114
self.models
119115
.iter()
120116
.choose(&mut rand::thread_rng())
121-
.ok_or_else(|| eyre!("No models to randomly pick for '*'."))
117+
.ok_or_eyre("could not find models to randomly pick for '*'")
122118
.cloned()
123119
} else if model_or_provider == "!" {
124120
// return the first model
125121
self.models
126122
.first()
127-
.ok_or_else(|| eyre!("No models to choose first for '!'."))
123+
.ok_or_eyre("could not find models to choose first for '!'")
128124
.cloned()
129125
} else if let Ok(provider) = ModelProvider::try_from(model_or_provider.clone()) {
130126
// this is a valid provider, return the first matching model in the config
131127
self.models
132128
.iter()
133-
.find(|(p, _)| *p == provider)
134-
.ok_or(eyre!(
135-
"Provider {} is not supported by this node.",
136-
provider
129+
.find(|&m| m.provider() == provider)
130+
.ok_or_eyre(format!(
131+
"Provider {provider} is not supported by this node."
137132
))
138133
.cloned()
139134
} else if let Ok(model) = Model::try_from(model_or_provider.clone()) {
140135
// this is a valid model, return it if it is supported by the node
141136
self.models
142137
.iter()
143-
.find(|(_, m)| *m == model)
144-
.ok_or(eyre!("Model {} is not supported by this node.", model))
138+
.find(|&m| *m == model)
139+
.ok_or_eyre(format!("Model {model} is not supported by this node."))
145140
.cloned()
146141
} else {
147142
// this is neither a valid provider or model for this node
148143
Err(eyre!(
149-
"Given string '{}' is neither a model nor provider.",
150-
model_or_provider
144+
"Given string '{model_or_provider}' is neither a model nor provider.",
151145
))
152146
}
153147
}
154148

155149
/// From a list of model or provider names, return a random matching model & provider.
156-
pub fn get_any_matching_model(
157-
&self,
158-
list_model_or_provider: Vec<String>,
159-
) -> Result<(ModelProvider, Model)> {
150+
pub fn get_any_matching_model(&self, list_model_or_provider: Vec<String>) -> Result<Model> {
160151
// filter models w.r.t supported ones
161152
let matching_models = list_model_or_provider
162153
.into_iter()
@@ -182,23 +173,21 @@ impl DriaWorkflowsConfig {
182173
/// Returns the list of unique providers in the config.
183174
#[inline]
184175
pub fn get_providers(&self) -> Vec<ModelProvider> {
185-
self.models
186-
.iter()
187-
.fold(Vec::new(), |mut unique, (provider, _)| {
188-
if !unique.contains(provider) {
189-
unique.push(provider.clone());
190-
}
191-
unique
192-
})
176+
self.models.iter().fold(Vec::new(), |mut unique, m| {
177+
let provider = m.provider();
178+
179+
if !unique.contains(&provider) {
180+
unique.push(provider);
181+
}
182+
183+
unique
184+
})
193185
}
194186

195-
/// Returns the list of all models in the config.
196-
#[inline]
187+
/// Returns the names of all models in the config.
188+
#[inline(always)]
197189
pub fn get_model_names(&self) -> Vec<String> {
198-
self.models
199-
.iter()
200-
.map(|(_, model)| model.to_string())
201-
.collect()
190+
self.models.iter().map(|m| m.to_string()).collect()
202191
}
203192

204193
/// Check if the required compute services are running.
@@ -226,49 +215,25 @@ impl DriaWorkflowsConfig {
226215
// if Ollama is a provider, check that it is running & Ollama models are pulled (or pull them)
227216
if unique_providers.contains(&ModelProvider::Ollama) {
228217
let provider_models = self.get_models_for_provider(ModelProvider::Ollama);
229-
good_models.extend(
230-
self.ollama
231-
.check(provider_models)
232-
.await?
233-
.into_iter()
234-
.map(|m| (ModelProvider::Ollama, m)),
235-
);
218+
good_models.extend(self.ollama.check(provider_models).await?);
236219
}
237220

238221
// if OpenAI is a provider, check that the API key is set & models are available
239222
if unique_providers.contains(&ModelProvider::OpenAI) {
240223
let provider_models = self.get_models_for_provider(ModelProvider::OpenAI);
241-
good_models.extend(
242-
self.openai
243-
.check(provider_models)
244-
.await?
245-
.into_iter()
246-
.map(|m| (ModelProvider::OpenAI, m)),
247-
);
224+
good_models.extend(self.openai.check(provider_models).await?);
248225
}
249226

250227
// if Gemini is a provider, check that the API key is set & models are available
251228
if unique_providers.contains(&ModelProvider::Gemini) {
252229
let provider_models = self.get_models_for_provider(ModelProvider::Gemini);
253-
good_models.extend(
254-
self.gemini
255-
.check(provider_models)
256-
.await?
257-
.into_iter()
258-
.map(|m| (ModelProvider::Gemini, m)),
259-
);
230+
good_models.extend(self.gemini.check(provider_models).await?);
260231
}
261232

262233
// if OpenRouter is a provider, check that the API key is set
263234
if unique_providers.contains(&ModelProvider::OpenRouter) {
264235
let provider_models = self.get_models_for_provider(ModelProvider::OpenRouter);
265-
good_models.extend(
266-
self.openrouter
267-
.check(provider_models)
268-
.await?
269-
.into_iter()
270-
.map(|m| (ModelProvider::OpenRouter, m)),
271-
);
236+
good_models.extend(self.openrouter.check(provider_models).await?);
272237
}
273238

274239
// update good models
@@ -286,7 +251,7 @@ impl std::fmt::Display for DriaWorkflowsConfig {
286251
let models_str = self
287252
.models
288253
.iter()
289-
.map(|(provider, model)| format!("{:?}:{}", provider, model))
254+
.map(|model| format!("{}:{}", model.provider(), model))
290255
.collect::<Vec<_>>()
291256
.join(",");
292257
write!(f, "{}", models_str)
@@ -312,15 +277,14 @@ mod tests {
312277
fn test_model_matching() {
313278
let cfg = DriaWorkflowsConfig::new_from_csv("gpt-4o,llama3.1:latest");
314279
assert_eq!(
315-
cfg.get_matching_model("openai".to_string()).unwrap().1,
280+
cfg.get_matching_model("openai".to_string()).unwrap(),
316281
Model::GPT4o,
317282
"Should find existing model"
318283
);
319284

320285
assert_eq!(
321286
cfg.get_matching_model("llama3.1:latest".to_string())
322-
.unwrap()
323-
.1,
287+
.unwrap(),
324288
Model::Llama3_1_8B,
325289
"Should find existing model"
326290
);
@@ -347,7 +311,7 @@ mod tests {
347311
"ollama".to_string(),
348312
]);
349313
assert_eq!(
350-
result.unwrap().1,
314+
result.unwrap(),
351315
Model::Llama3_1_8B,
352316
"Should find existing model"
353317
);

workflows/tests/models_test.rs

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use dkn_workflows::{DriaWorkflowsConfig, Model, ModelProvider};
1+
use dkn_workflows::{DriaWorkflowsConfig, Model};
22
use eyre::Result;
33

44
#[inline(always)]
@@ -24,10 +24,7 @@ async fn test_ollama_check() -> Result<()> {
2424
let mut model_config = DriaWorkflowsConfig::new(models);
2525
model_config.check_services().await?;
2626

27-
assert_eq!(
28-
model_config.models[0],
29-
(ModelProvider::Ollama, Model::Phi3_5Mini)
30-
);
27+
assert_eq!(model_config.models[0], Model::Phi3_5Mini);
3128

3229
Ok(())
3330
}
@@ -41,10 +38,7 @@ async fn test_openai_check() -> Result<()> {
4138
let mut model_config = DriaWorkflowsConfig::new(models);
4239
model_config.check_services().await?;
4340

44-
assert_eq!(
45-
model_config.models[0],
46-
(ModelProvider::OpenAI, Model::GPT4Turbo)
47-
);
41+
assert_eq!(model_config.models[0], Model::GPT4Turbo);
4842
Ok(())
4943
}
5044

@@ -57,10 +51,7 @@ async fn test_gemini_check() -> Result<()> {
5751
let mut model_config = DriaWorkflowsConfig::new(models);
5852
model_config.check_services().await?;
5953

60-
assert_eq!(
61-
model_config.models[0],
62-
(ModelProvider::Gemini, Model::Gemini15Flash)
63-
);
54+
assert_eq!(model_config.models[0], Model::Gemini15Flash);
6455
Ok(())
6556
}
6657

@@ -73,10 +64,7 @@ async fn test_openrouter_check() -> Result<()> {
7364
let mut model_config = DriaWorkflowsConfig::new(models);
7465
model_config.check_services().await?;
7566

76-
assert_eq!(
77-
model_config.models[0],
78-
(ModelProvider::OpenRouter, Model::ORDeepSeek2_5)
79-
);
67+
assert_eq!(model_config.models[0], Model::ORDeepSeek2_5);
8068
Ok(())
8169
}
8270

0 commit comments

Comments
 (0)