1
1
use std:: time:: Duration ;
2
2
3
3
use ollama_workflows:: {
4
- ollama_rs:: { generation:: completion:: request:: GenerationRequest , Ollama } ,
4
+ ollama_rs:: {
5
+ generation:: { completion:: request:: GenerationRequest , options:: GenerationOptions } ,
6
+ Ollama ,
7
+ } ,
5
8
Model ,
6
9
} ;
7
10
@@ -174,12 +177,26 @@ impl OllamaConfig {
174
177
return false ;
175
178
} ;
176
179
180
+ let mut generation_request =
181
+ GenerationRequest :: new ( model. to_string ( ) , TEST_PROMPT . to_string ( ) ) ;
182
+
183
+ // FIXME: temporary workaround, can take num threads from outside
184
+ if let Ok ( num_thread) = std:: env:: var ( "OLLAMA_NUM_THREAD" ) {
185
+ generation_request = generation_request. options (
186
+ GenerationOptions :: default ( ) . num_thread (
187
+ num_thread
188
+ . parse ( )
189
+ . expect ( "num threads should be a positive integer" ) ,
190
+ ) ,
191
+ ) ;
192
+ }
193
+
177
194
// then, run a sample generation with timeout and measure tps
178
195
tokio:: select! {
179
196
_ = tokio:: time:: sleep( timeout) => {
180
197
log:: warn!( "Ignoring model {}: Workflow timed out" , model) ;
181
198
} ,
182
- result = ollama. generate( GenerationRequest :: new ( model . to_string ( ) , TEST_PROMPT . to_string ( ) ) ) => {
199
+ result = ollama. generate( generation_request ) => {
183
200
match result {
184
201
Ok ( response) => {
185
202
let tps = ( response. eval_count. unwrap_or_default( ) as f64 )
@@ -189,7 +206,6 @@ impl OllamaConfig {
189
206
if tps >= min_tps {
190
207
log:: info!( "Model {} passed the test with tps: {}" , model, tps) ;
191
208
return true ;
192
-
193
209
}
194
210
195
211
log:: warn!(
0 commit comments