1
- use crate :: { utils :: split_comma_separated, OllamaConfig , OpenAIConfig } ;
1
+ use crate :: { split_comma_separated, OllamaConfig , OpenAIConfig } ;
2
2
use eyre:: { eyre, Result } ;
3
3
use ollama_workflows:: { Model , ModelProvider } ;
4
4
use rand:: seq:: IteratorRandom ; // provides Vec<_>.choose
5
5
6
6
#[ derive( Debug , Clone ) ]
7
7
pub struct ModelConfig {
8
+ /// List of models with their providers.
8
9
pub models : Vec < ( ModelProvider , Model ) > ,
10
+ /// Even if Ollama is not used, we store the host & port here.
11
+ /// If Ollama is used, this config will be respected during its instantiations.
9
12
pub ollama : OllamaConfig ,
13
+ /// OpenAI API key & its service check implementation.
10
14
pub openai : OpenAIConfig ,
11
15
}
12
16
13
- impl std:: fmt:: Display for ModelConfig {
14
- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
15
- let models_str = self
16
- . models
17
- . iter ( )
18
- . map ( |( provider, model) | format ! ( "{:?}:{}" , provider, model) )
19
- . collect :: < Vec < _ > > ( )
20
- . join ( "," ) ;
21
- write ! ( f, "{}" , models_str)
22
- }
23
- }
24
-
25
17
impl ModelConfig {
26
- /// Creates a new config with the given list of models.
27
- pub fn new ( models : Vec < Model > ) -> Self {
28
- // map models to (provider, model) pairs
29
- let models_providers = models
30
- . into_iter ( )
31
- . map ( |m| ( m. clone ( ) . into ( ) , m) )
32
- . collect :: < Vec < _ > > ( ) ;
33
-
34
- let mut providers = Vec :: new ( ) ;
35
-
36
- // get ollama models & config
37
- let ollama_models = models_providers
38
- . iter ( )
39
- . filter_map ( |( p, m) | {
40
- if * p == ModelProvider :: Ollama {
41
- Some ( m. clone ( ) )
42
- } else {
43
- None
44
- }
45
- } )
46
- . collect :: < Vec < _ > > ( ) ;
47
- let ollama_config = if !ollama_models. is_empty ( ) {
48
- providers. push ( ModelProvider :: Ollama ) ;
49
- Some ( OllamaConfig :: new ( ollama_models) )
50
- } else {
51
- None
52
- } ;
53
-
54
- // get openai models & config
55
- let openai_models = models_providers
56
- . iter ( )
57
- . filter_map ( |( p, m) | {
58
- if * p == ModelProvider :: OpenAI {
59
- Some ( m. clone ( ) )
60
- } else {
61
- None
62
- }
63
- } )
64
- . collect :: < Vec < _ > > ( ) ;
65
- let openai_config = if !openai_models. is_empty ( ) {
66
- providers. push ( ModelProvider :: OpenAI ) ;
67
- Some ( OpenAIConfig :: new ( openai_models) )
68
- } else {
69
- None
70
- } ;
71
-
72
- Self {
73
- models_providers,
74
- providers,
75
- ollama_config,
76
- openai_config,
77
- }
78
- }
79
-
80
18
/// Parses Ollama-Workflows compatible models from a comma-separated values string.
81
- ///
82
- /// ## Example
83
- ///
84
- /// ```
85
- /// let config = ModelConfig::new_from_csv("gpt-4-turbo,gpt-4o-mini");
86
- /// ```
87
19
pub fn new_from_csv ( input : Option < String > ) -> Self {
88
20
let models_str = split_comma_separated ( input) ;
89
21
@@ -98,7 +30,11 @@ impl ModelConfig {
98
30
} )
99
31
. collect :: < Vec < _ > > ( ) ;
100
32
101
- Self { models }
33
+ Self {
34
+ models,
35
+ openai : OpenAIConfig :: new ( ) ,
36
+ ollama : OllamaConfig :: new ( ) ,
37
+ }
102
38
}
103
39
104
40
/// Returns the models that belong to a given providers from the config.
@@ -117,12 +53,27 @@ impl ModelConfig {
117
53
118
54
/// Given a raw model name or provider (as a string), returns the first matching model & provider.
119
55
///
120
- /// If this is a model and is supported by this node, it is returned directly.
121
- /// If this is a provider, the first matching model in the node config is returned.
56
+ /// - If input is `*` or `all`, a random model is returned.
57
+ /// - if input is `!` the first model is returned.
58
+ /// - If input is a model and is supported by this node, it is returned directly.
59
+ /// - If input is a provider, the first matching model in the node config is returned.
122
60
///
123
61
/// If there are no matching models with this logic, an error is returned.
124
62
pub fn get_matching_model ( & self , model_or_provider : String ) -> Result < ( ModelProvider , Model ) > {
125
- if let Ok ( provider) = ModelProvider :: try_from ( model_or_provider. clone ( ) ) {
63
+ if model_or_provider == "*" {
64
+ // return a random model
65
+ self . models
66
+ . iter ( )
67
+ . choose ( & mut rand:: thread_rng ( ) )
68
+ . ok_or_else ( || eyre ! ( "No models to randomly pick for '*'." ) )
69
+ . cloned ( )
70
+ } else if model_or_provider == "!" {
71
+ // return the first model
72
+ self . models
73
+ . first ( )
74
+ . ok_or_else ( || eyre ! ( "No models to choose first for '!'." ) )
75
+ . cloned ( )
76
+ } else if let Ok ( provider) = ModelProvider :: try_from ( model_or_provider. clone ( ) ) {
126
77
// this is a valid provider, return the first matching model in the config
127
78
self . models
128
79
. iter ( )
@@ -186,6 +137,70 @@ impl ModelConfig {
186
137
unique
187
138
} )
188
139
}
140
+
141
+ /// Check if the required compute services are running.
142
+ /// This has several steps:
143
+ ///
144
+ /// - If Ollama models are used, hardcoded models are checked locally, and for
145
+ /// external models, the workflow is tested with a simple task with timeout.
146
+ /// - If OpenAI models are used, the API key is checked and the models are tested
147
+ ///
148
+ /// If both type of models are used, both services are checked.
149
+ /// In the end, bad models are filtered out and we simply check if we are left if any valid models at all.
150
+ /// If not, an error is returned.
151
+ pub async fn check_services ( & mut self ) -> Result < ( ) > {
152
+ log:: info!( "Checking configured services." ) ;
153
+
154
+ // TODO: can refactor (provider, model) logic here
155
+ let unique_providers = self . get_providers ( ) ;
156
+
157
+ let mut good_models = Vec :: new ( ) ;
158
+
159
+ // if Ollama is a provider, check that it is running & Ollama models are pulled (or pull them)
160
+ if unique_providers. contains ( & ModelProvider :: Ollama ) {
161
+ let ollama_models = self . get_models_for_provider ( ModelProvider :: Ollama ) ;
162
+
163
+ // ensure that the models are pulled / pull them if not
164
+ let good_ollama_models = self . ollama . check ( ollama_models) . await ?;
165
+ good_models. extend (
166
+ good_ollama_models
167
+ . into_iter ( )
168
+ . map ( |m| ( ModelProvider :: Ollama , m) ) ,
169
+ ) ;
170
+ }
171
+
172
+ // if OpenAI is a provider, check that the API key is set
173
+ if unique_providers. contains ( & ModelProvider :: OpenAI ) {
174
+ let openai_models = self . get_models_for_provider ( ModelProvider :: OpenAI ) ;
175
+
176
+ let good_openai_models = self . openai . check ( openai_models) . await ?;
177
+ good_models. extend (
178
+ good_openai_models
179
+ . into_iter ( )
180
+ . map ( |m| ( ModelProvider :: OpenAI , m) ) ,
181
+ ) ;
182
+ }
183
+
184
+ // update good models
185
+ if good_models. is_empty ( ) {
186
+ Err ( eyre ! ( "No good models found, please check logs for errors." ) )
187
+ } else {
188
+ self . models = good_models;
189
+ Ok ( ( ) )
190
+ }
191
+ }
192
+ }
193
+
194
+ impl std:: fmt:: Display for ModelConfig {
195
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
196
+ let models_str = self
197
+ . models
198
+ . iter ( )
199
+ . map ( |( provider, model) | format ! ( "{:?}:{}" , provider, model) )
200
+ . collect :: < Vec < _ > > ( )
201
+ . join ( "," ) ;
202
+ write ! ( f, "{}" , models_str)
203
+ }
189
204
}
190
205
191
206
#[ cfg( test) ]
0 commit comments