@@ -9,6 +9,10 @@ use crate::p2p::P2PMessage;
9
9
use crate :: utils:: get_current_time_nanos;
10
10
use crate :: utils:: payload:: { TaskRequest , TaskRequestPayload } ;
11
11
12
+ use super :: ComputeHandler ;
13
+
14
+ pub struct WorkflowHandler ;
15
+
12
16
#[ derive( Debug , Deserialize ) ]
13
17
struct WorkflowPayload {
14
18
/// Workflow object to be parsed.
@@ -23,18 +27,9 @@ struct WorkflowPayload {
23
27
}
24
28
25
29
#[ async_trait]
26
- pub trait HandlesWorkflow {
27
- async fn handle_workflow (
28
- & mut self ,
29
- message : P2PMessage ,
30
- result_topic : & str ,
31
- ) -> NodeResult < MessageAcceptance > ;
32
- }
33
-
34
- #[ async_trait]
35
- impl HandlesWorkflow for DriaComputeNode {
36
- async fn handle_workflow (
37
- & mut self ,
30
+ impl ComputeHandler for WorkflowHandler {
31
+ async fn handle_compute (
32
+ node : & mut DriaComputeNode ,
38
33
message : P2PMessage ,
39
34
result_topic : & str ,
40
35
) -> NodeResult < MessageAcceptance > {
@@ -55,7 +50,7 @@ impl HandlesWorkflow for DriaComputeNode {
55
50
}
56
51
57
52
// check task inclusion via the bloom filter
58
- if !task. filter . contains ( & self . config . address ) ? {
53
+ if !task. filter . contains ( & node . config . address ) ? {
59
54
log:: info!(
60
55
"Task {} does not include this node within the filter." ,
61
56
task. task_id
@@ -75,7 +70,7 @@ impl HandlesWorkflow for DriaComputeNode {
75
70
} ;
76
71
77
72
// read model / provider from the task
78
- let ( model_provider, model) = self
73
+ let ( model_provider, model) = node
79
74
. config
80
75
. model_config
81
76
. get_any_matching_model ( task. input . model ) ?;
@@ -85,8 +80,8 @@ impl HandlesWorkflow for DriaComputeNode {
85
80
let executor = if model_provider == ModelProvider :: Ollama {
86
81
Executor :: new_at (
87
82
model,
88
- & self . config . ollama_config . host ,
89
- self . config . ollama_config . port ,
83
+ & node . config . ollama_config . host ,
84
+ node . config . ollama_config . port ,
90
85
)
91
86
} else {
92
87
Executor :: new ( model)
@@ -98,7 +93,7 @@ impl HandlesWorkflow for DriaComputeNode {
98
93
. map ( |prompt| Entry :: try_value_or_str ( & prompt) ) ;
99
94
let result: Option < String > ;
100
95
tokio:: select! {
101
- _ = self . cancellation. cancelled( ) => {
96
+ _ = node . cancellation. cancelled( ) => {
102
97
log:: info!( "Received cancellation, quitting all tasks." ) ;
103
98
return Ok ( MessageAcceptance :: Accept )
104
99
} ,
@@ -113,7 +108,7 @@ impl HandlesWorkflow for DriaComputeNode {
113
108
let result = result. ok_or :: < String > ( format ! ( "No result for task {}" , task. task_id) ) ?;
114
109
115
110
// publish the result
116
- self . send_result ( result_topic, & task. public_key , & task. task_id , result) ?;
111
+ node . send_result ( result_topic, & task. public_key , & task. task_id , result) ?;
117
112
118
113
// accept message, someone else may be included in the filter
119
114
Ok ( MessageAcceptance :: Accept )
0 commit comments