@@ -4,13 +4,15 @@ use crate::{
4
4
Model , ModelProvider ,
5
5
} ;
6
6
use dkn_utils:: split_csv_line;
7
- use eyre:: { eyre, Result } ;
7
+ use eyre:: { eyre, OptionExt , Result } ;
8
8
use rand:: seq:: IteratorRandom ; // provides Vec<_>.choose
9
9
10
10
#[ derive( Debug , Clone ) ]
11
11
pub struct DriaWorkflowsConfig {
12
- /// List of models with their providers.
13
- pub models : Vec < ( ModelProvider , Model ) > ,
12
+ /// List of models.
13
+ ///
14
+ /// You can do `model.provider()` to get its provider.
15
+ pub models : Vec < Model > ,
14
16
/// Ollama configurations, in case Ollama is used.
15
17
/// Otherwise, can be ignored.
16
18
pub ollama : OllamaConfig ,
@@ -40,13 +42,8 @@ impl Default for DriaWorkflowsConfig {
40
42
impl DriaWorkflowsConfig {
41
43
/// Creates a new config with the given models.
42
44
pub fn new ( models : Vec < Model > ) -> Self {
43
- let models_and_providers = models
44
- . into_iter ( )
45
- . map ( |model| ( model. clone ( ) . into ( ) , model) )
46
- . collect :: < Vec < _ > > ( ) ;
47
-
48
45
Self {
49
- models : models_and_providers ,
46
+ models,
50
47
ollama : OllamaConfig :: new ( ) ,
51
48
openai : OpenAIConfig :: new ( ) ,
52
49
openrouter : OpenRouterConfig :: new ( ) ,
@@ -84,24 +81,23 @@ impl DriaWorkflowsConfig {
84
81
pub fn get_models_for_provider ( & self , provider : ModelProvider ) -> Vec < Model > {
85
82
self . models
86
83
. iter ( )
87
- . filter_map ( |( p, m) | {
88
- if * p == provider {
89
- Some ( m. clone ( ) )
90
- } else {
91
- None
92
- }
93
- } )
84
+ . filter ( |m| m. provider ( ) == provider)
85
+ . cloned ( )
94
86
. collect ( )
95
87
}
96
88
97
89
/// Returns `true` if the configuration contains models that can be processed in parallel, e.g. API calls.
98
90
pub fn has_batchable_models ( & self ) -> bool {
99
- self . models . iter ( ) . any ( |( p, _) | * p != ModelProvider :: Ollama )
91
+ self . models
92
+ . iter ( )
93
+ . any ( |m| m. provider ( ) != ModelProvider :: Ollama )
100
94
}
101
95
102
96
/// Returns `true` if the configuration contains a model that cant be run in parallel, e.g. a Ollama model.
103
97
pub fn has_non_batchable_models ( & self ) -> bool {
104
- self . models . iter ( ) . any ( |( p, _) | * p == ModelProvider :: Ollama )
98
+ self . models
99
+ . iter ( )
100
+ . any ( |m| m. provider ( ) == ModelProvider :: Ollama )
105
101
}
106
102
107
103
/// Given a raw model name or provider (as a string), returns the first matching model & provider.
@@ -112,51 +108,46 @@ impl DriaWorkflowsConfig {
112
108
/// - If input is a provider, the first matching model in the node config is returned.
113
109
///
114
110
/// If there are no matching models with this logic, an error is returned.
115
- pub fn get_matching_model ( & self , model_or_provider : String ) -> Result < ( ModelProvider , Model ) > {
111
+ pub fn get_matching_model ( & self , model_or_provider : String ) -> Result < Model > {
116
112
if model_or_provider == "*" {
117
113
// return a random model
118
114
self . models
119
115
. iter ( )
120
116
. choose ( & mut rand:: thread_rng ( ) )
121
- . ok_or_else ( || eyre ! ( "No models to randomly pick for '*'." ) )
117
+ . ok_or_eyre ( "could not find models to randomly pick for '*'" )
122
118
. cloned ( )
123
119
} else if model_or_provider == "!" {
124
120
// return the first model
125
121
self . models
126
122
. first ( )
127
- . ok_or_else ( || eyre ! ( "No models to choose first for '!'." ) )
123
+ . ok_or_eyre ( "could not find models to choose first for '!'" )
128
124
. cloned ( )
129
125
} else if let Ok ( provider) = ModelProvider :: try_from ( model_or_provider. clone ( ) ) {
130
126
// this is a valid provider, return the first matching model in the config
131
127
self . models
132
128
. iter ( )
133
- . find ( |( p, _) | * p == provider)
134
- . ok_or ( eyre ! (
135
- "Provider {} is not supported by this node." ,
136
- provider
129
+ . find ( |& m| m. provider ( ) == provider)
130
+ . ok_or_eyre ( format ! (
131
+ "Provider {provider} is not supported by this node."
137
132
) )
138
133
. cloned ( )
139
134
} else if let Ok ( model) = Model :: try_from ( model_or_provider. clone ( ) ) {
140
135
// this is a valid model, return it if it is supported by the node
141
136
self . models
142
137
. iter ( )
143
- . find ( |( _ , m ) | * m == model)
144
- . ok_or ( eyre ! ( "Model {} is not supported by this node." , model ) )
138
+ . find ( |& m | * m == model)
139
+ . ok_or_eyre ( format ! ( "Model {model } is not supported by this node." ) )
145
140
. cloned ( )
146
141
} else {
147
142
// this is neither a valid provider or model for this node
148
143
Err ( eyre ! (
149
- "Given string '{}' is neither a model nor provider." ,
150
- model_or_provider
144
+ "Given string '{model_or_provider}' is neither a model nor provider." ,
151
145
) )
152
146
}
153
147
}
154
148
155
149
/// From a list of model or provider names, return a random matching model & provider.
156
- pub fn get_any_matching_model (
157
- & self ,
158
- list_model_or_provider : Vec < String > ,
159
- ) -> Result < ( ModelProvider , Model ) > {
150
+ pub fn get_any_matching_model ( & self , list_model_or_provider : Vec < String > ) -> Result < Model > {
160
151
// filter models w.r.t supported ones
161
152
let matching_models = list_model_or_provider
162
153
. into_iter ( )
@@ -182,23 +173,21 @@ impl DriaWorkflowsConfig {
182
173
/// Returns the list of unique providers in the config.
183
174
#[ inline]
184
175
pub fn get_providers ( & self ) -> Vec < ModelProvider > {
185
- self . models
186
- . iter ( )
187
- . fold ( Vec :: new ( ) , |mut unique, ( provider, _) | {
188
- if !unique. contains ( provider) {
189
- unique. push ( provider. clone ( ) ) ;
190
- }
191
- unique
192
- } )
176
+ self . models . iter ( ) . fold ( Vec :: new ( ) , |mut unique, m| {
177
+ let provider = m. provider ( ) ;
178
+
179
+ if !unique. contains ( & provider) {
180
+ unique. push ( provider) ;
181
+ }
182
+
183
+ unique
184
+ } )
193
185
}
194
186
195
- /// Returns the list of all models in the config.
196
- #[ inline]
187
+ /// Returns the names of all models in the config.
188
+ #[ inline( always ) ]
197
189
pub fn get_model_names ( & self ) -> Vec < String > {
198
- self . models
199
- . iter ( )
200
- . map ( |( _, model) | model. to_string ( ) )
201
- . collect ( )
190
+ self . models . iter ( ) . map ( |m| m. to_string ( ) ) . collect ( )
202
191
}
203
192
204
193
/// Check if the required compute services are running.
@@ -226,49 +215,25 @@ impl DriaWorkflowsConfig {
226
215
// if Ollama is a provider, check that it is running & Ollama models are pulled (or pull them)
227
216
if unique_providers. contains ( & ModelProvider :: Ollama ) {
228
217
let provider_models = self . get_models_for_provider ( ModelProvider :: Ollama ) ;
229
- good_models. extend (
230
- self . ollama
231
- . check ( provider_models)
232
- . await ?
233
- . into_iter ( )
234
- . map ( |m| ( ModelProvider :: Ollama , m) ) ,
235
- ) ;
218
+ good_models. extend ( self . ollama . check ( provider_models) . await ?) ;
236
219
}
237
220
238
221
// if OpenAI is a provider, check that the API key is set & models are available
239
222
if unique_providers. contains ( & ModelProvider :: OpenAI ) {
240
223
let provider_models = self . get_models_for_provider ( ModelProvider :: OpenAI ) ;
241
- good_models. extend (
242
- self . openai
243
- . check ( provider_models)
244
- . await ?
245
- . into_iter ( )
246
- . map ( |m| ( ModelProvider :: OpenAI , m) ) ,
247
- ) ;
224
+ good_models. extend ( self . openai . check ( provider_models) . await ?) ;
248
225
}
249
226
250
227
// if Gemini is a provider, check that the API key is set & models are available
251
228
if unique_providers. contains ( & ModelProvider :: Gemini ) {
252
229
let provider_models = self . get_models_for_provider ( ModelProvider :: Gemini ) ;
253
- good_models. extend (
254
- self . gemini
255
- . check ( provider_models)
256
- . await ?
257
- . into_iter ( )
258
- . map ( |m| ( ModelProvider :: Gemini , m) ) ,
259
- ) ;
230
+ good_models. extend ( self . gemini . check ( provider_models) . await ?) ;
260
231
}
261
232
262
233
// if OpenRouter is a provider, check that the API key is set
263
234
if unique_providers. contains ( & ModelProvider :: OpenRouter ) {
264
235
let provider_models = self . get_models_for_provider ( ModelProvider :: OpenRouter ) ;
265
- good_models. extend (
266
- self . openrouter
267
- . check ( provider_models)
268
- . await ?
269
- . into_iter ( )
270
- . map ( |m| ( ModelProvider :: OpenRouter , m) ) ,
271
- ) ;
236
+ good_models. extend ( self . openrouter . check ( provider_models) . await ?) ;
272
237
}
273
238
274
239
// update good models
@@ -286,7 +251,7 @@ impl std::fmt::Display for DriaWorkflowsConfig {
286
251
let models_str = self
287
252
. models
288
253
. iter ( )
289
- . map ( |( provider , model) | format ! ( "{:? }:{}" , provider, model) )
254
+ . map ( |model| format ! ( "{}:{}" , model . provider( ) , model) )
290
255
. collect :: < Vec < _ > > ( )
291
256
. join ( "," ) ;
292
257
write ! ( f, "{}" , models_str)
@@ -312,15 +277,14 @@ mod tests {
312
277
fn test_model_matching ( ) {
313
278
let cfg = DriaWorkflowsConfig :: new_from_csv ( "gpt-4o,llama3.1:latest" ) ;
314
279
assert_eq ! (
315
- cfg. get_matching_model( "openai" . to_string( ) ) . unwrap( ) . 1 ,
280
+ cfg. get_matching_model( "openai" . to_string( ) ) . unwrap( ) ,
316
281
Model :: GPT4o ,
317
282
"Should find existing model"
318
283
) ;
319
284
320
285
assert_eq ! (
321
286
cfg. get_matching_model( "llama3.1:latest" . to_string( ) )
322
- . unwrap( )
323
- . 1 ,
287
+ . unwrap( ) ,
324
288
Model :: Llama3_1_8B ,
325
289
"Should find existing model"
326
290
) ;
@@ -347,7 +311,7 @@ mod tests {
347
311
"ollama" . to_string( ) ,
348
312
] ) ;
349
313
assert_eq ! (
350
- result. unwrap( ) . 1 ,
314
+ result. unwrap( ) ,
351
315
Model :: Llama3_1_8B ,
352
316
"Should find existing model"
353
317
) ;
0 commit comments