1
+ use dkn_utils:: payloads:: SpecModelPerformance ;
1
2
use eyre:: { Context , Result } ;
2
3
use ollama_rs:: generation:: completion:: request:: GenerationRequest ;
3
4
use rig:: completion:: { Chat , PromptError } ;
4
5
use rig:: providers:: ollama;
6
+ use std:: collections:: HashMap ;
5
7
use std:: time:: Duration ;
6
8
use std:: { collections:: HashSet , env} ;
7
9
@@ -78,7 +80,10 @@ impl OllamaClient {
78
80
}
79
81
80
82
/// Check if requested models exist in Ollama & test them using a dummy prompt.
81
- pub async fn check ( & self , models : & mut HashSet < Model > ) -> Result < ( ) > {
83
+ pub async fn check (
84
+ & self ,
85
+ models : & mut HashSet < Model > ,
86
+ ) -> Result < HashMap < Model , SpecModelPerformance > > {
82
87
log:: info!(
83
88
"Checking Ollama requirements (auto-pull {}, timeout: {}s, min tps: {})" ,
84
89
if self . auto_pull { "on" } else { "off" } ,
@@ -101,6 +106,7 @@ impl OllamaClient {
101
106
// check external models & pull them if available
102
107
// iterate over models and remove bad ones
103
108
let mut models_to_remove = Vec :: new ( ) ;
109
+ let mut model_performances = HashMap :: new ( ) ;
104
110
for model in models. iter ( ) {
105
111
// pull the model if it is not in the local models
106
112
if !local_models. contains ( & model. to_string ( ) ) {
@@ -117,8 +123,13 @@ impl OllamaClient {
117
123
}
118
124
119
125
// test its performance
120
- if !self . test_performance ( model) . await {
126
+ let perf = self . measure_tps_with_warmup ( model) . await ;
127
+ if let SpecModelPerformance :: PassedWithTPS ( _) = perf {
128
+ model_performances. insert ( * model, perf) ;
129
+ } else {
130
+ // if its anything but PassedWithTPS, remove the model
121
131
models_to_remove. push ( * model) ;
132
+ model_performances. insert ( * model, perf) ;
122
133
}
123
134
}
124
135
@@ -133,7 +144,7 @@ impl OllamaClient {
133
144
log:: info!( "Ollama checks are finished, using models: {:#?}" , models) ;
134
145
}
135
146
136
- Ok ( ( ) )
147
+ Ok ( model_performances )
137
148
}
138
149
139
150
/// Pulls a model from Ollama.
@@ -154,7 +165,7 @@ impl OllamaClient {
154
165
///
155
166
/// This is to see if a given system can execute tasks for their chosen models,
156
167
/// e.g. if they have enough RAM/CPU and such.
157
- pub async fn test_performance ( & self , model : & Model ) -> bool {
168
+ pub async fn measure_tps_with_warmup ( & self , model : & Model ) -> SpecModelPerformance {
158
169
const TEST_PROMPT : & str = "Please write a poem about Kapadokya." ;
159
170
const WARMUP_PROMPT : & str = "Write a short poem about hedgehogs and squirrels." ;
160
171
@@ -171,44 +182,46 @@ impl OllamaClient {
171
182
. await
172
183
{
173
184
log:: warn!( "Ignoring model {model}: {err}" ) ;
174
- return false ;
185
+ return SpecModelPerformance :: ExecutionFailed ;
175
186
}
176
187
177
188
// then, run a sample generation with timeout and measure tps
178
- tokio:: select! {
179
- _ = tokio:: time:: sleep( PERFORMANCE_TIMEOUT ) => {
180
- log:: warn!( "Ignoring model {model}: Timed out" ) ;
181
- } ,
182
- result = self . ollama_rs_client. generate( GenerationRequest :: new(
189
+ let Ok ( result) = tokio:: time:: timeout (
190
+ PERFORMANCE_TIMEOUT ,
191
+ self . ollama_rs_client . generate ( GenerationRequest :: new (
183
192
model. to_string ( ) ,
184
193
TEST_PROMPT . to_string ( ) ,
185
- ) ) => {
186
- match result {
187
- Ok ( response) => {
188
- let tps = ( response. eval_count. unwrap_or_default( ) as f64 )
189
- / ( response. eval_duration. unwrap_or( 1 ) as f64 )
190
- * 1_000_000_000f64 ;
191
-
192
- if tps >= PERFORMANCE_MIN_TPS {
193
- log:: info!( "Model {} passed the test with tps: {}" , model, tps) ;
194
- return true ;
195
- }
196
-
197
- log:: warn!(
198
- "Ignoring model {}: tps too low ({:.3} < {:.3})" ,
199
- model,
200
- tps,
201
- PERFORMANCE_MIN_TPS
202
- ) ;
203
- }
204
- Err ( e) => {
205
- log:: warn!( "Ignoring model {}: Task failed with error {}" , model, e) ;
206
- }
207
- }
208
- }
194
+ ) ) ,
195
+ )
196
+ . await
197
+ else {
198
+ log:: warn!( "Ignoring model {model}: Timed out" ) ;
199
+ return SpecModelPerformance :: Timeout ;
209
200
} ;
210
201
211
- false
202
+ // check the result
203
+ match result {
204
+ Ok ( response) => {
205
+ let tps = ( response. eval_count . unwrap_or_default ( ) as f64 )
206
+ / ( response. eval_duration . unwrap_or ( 1 ) as f64 )
207
+ * 1_000_000_000f64 ;
208
+
209
+ if tps >= PERFORMANCE_MIN_TPS {
210
+ log:: info!( "Model {model} passed the test with tps: {tps}" ) ;
211
+ SpecModelPerformance :: PassedWithTPS ( tps)
212
+ } else {
213
+ log:: warn!(
214
+ "Ignoring model {model}: tps too low ({tps:.3} < {:.3})" ,
215
+ PERFORMANCE_MIN_TPS
216
+ ) ;
217
+ SpecModelPerformance :: FailedWithTPS ( tps)
218
+ }
219
+ }
220
+ Err ( err) => {
221
+ log:: warn!( "Ignoring model {model} due to: {err}" ) ;
222
+ SpecModelPerformance :: ExecutionFailed
223
+ }
224
+ }
212
225
}
213
226
}
214
227
0 commit comments