1
- use crate :: ModelProvider ;
1
+ use crate :: { Model , ModelProvider , TaskBody } ;
2
2
use rig:: completion:: PromptError ;
3
- use std:: collections:: HashSet ;
3
+ use std:: collections:: { HashMap , HashSet } ;
4
4
5
5
mod ollama;
6
6
use ollama:: OllamaClient ;
@@ -35,7 +35,7 @@ impl DriaExecutor {
35
35
}
36
36
37
37
/// Executes the given task using the appropriate provider.
38
- pub async fn execute ( & self , task : crate :: TaskBody ) -> Result < String , PromptError > {
38
+ pub async fn execute ( & self , task : TaskBody ) -> Result < String , PromptError > {
39
39
match self {
40
40
DriaExecutor :: Ollama ( provider) => provider. execute ( task) . await ,
41
41
DriaExecutor :: OpenAI ( provider) => provider. execute ( task) . await ,
@@ -47,7 +47,10 @@ impl DriaExecutor {
47
47
/// Checks if the requested models exist and are available in the provider's account.
48
48
///
49
49
/// For Ollama in particular, it also checks if the models are performant enough.
50
- pub async fn check ( & self , models : & mut HashSet < crate :: Model > ) -> eyre:: Result < ( ) > {
50
+ pub async fn check (
51
+ & self ,
52
+ models : & mut HashSet < Model > ,
53
+ ) -> eyre:: Result < HashMap < Model , ModelPerformanceMetric > > {
51
54
match self {
52
55
DriaExecutor :: Ollama ( provider) => provider. check ( models) . await ,
53
56
DriaExecutor :: OpenAI ( provider) => provider. check ( models) . await ,
@@ -56,3 +59,9 @@ impl DriaExecutor {
56
59
}
57
60
}
58
61
}
62
+
63
+ #[ derive( Debug , Clone , serde:: Serialize , serde:: Deserialize ) ]
64
+ pub enum ModelPerformanceMetric {
65
+ Latency ( f64 ) , // in seconds
66
+ TPS ( f64 ) , // (eval) tokens per second
67
+ }
0 commit comments