@@ -78,7 +78,7 @@ impl WorkflowsWorker {
78
78
79
79
if let Some ( task) = task {
80
80
log:: info!( "Processing single workflow for task {}" , task. task_id) ;
81
- WorkflowsWorker :: execute ( ( task, self . publish_tx . clone ( ) ) ) . await
81
+ WorkflowsWorker :: execute ( ( task, & self . publish_tx ) ) . await
82
82
} else {
83
83
return self . shutdown ( ) ;
84
84
} ;
@@ -93,76 +93,85 @@ impl WorkflowsWorker {
93
93
///
94
94
/// Batch size must NOT be larger than `MAX_BATCH_SIZE`, otherwise will panic.
95
95
pub async fn run_batch ( & mut self , batch_size : usize ) {
96
- // TODO: need some better batch_size error handling here
96
+ assert ! (
97
+ batch_size <= Self :: MAX_BATCH_SIZE ,
98
+ "Batch size must not be larger than {}" ,
99
+ Self :: MAX_BATCH_SIZE
100
+ ) ;
101
+
97
102
loop {
98
- // get tasks in batch from the channel
99
- let mut task_buffer = Vec :: new ( ) ;
100
- let num_tasks = self
101
- . workflow_rx
102
- . recv_many ( & mut task_buffer, batch_size)
103
- . await ;
104
-
105
- if num_tasks == 0 {
106
- return self . shutdown ( ) ;
103
+ let mut tasks = Vec :: new ( ) ;
104
+
105
+ // get tasks in batch from the channel, we enter the loop if:
106
+ // (1) there are no tasks, or,
107
+ // (2) there are tasks less than the batch size and the channel is not empty
108
+ while tasks. len ( ) == 0 || ( tasks. len ( ) < batch_size && !self . workflow_rx . is_empty ( ) ) {
109
+ let limit = batch_size - tasks. len ( ) ;
110
+ match self . workflow_rx . recv_many ( & mut tasks, limit) . await {
111
+ // 0 tasks returned means that the channel is closed
112
+ 0 => return self . shutdown ( ) ,
113
+ _ => {
114
+ // wait a small amount of time to allow for more tasks to be sent into the channel
115
+ tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 256 ) ) . await ;
116
+ }
117
+ }
107
118
}
108
119
109
120
// process the batch
121
+ let num_tasks = tasks. len ( ) ;
122
+ debug_assert ! (
123
+ num_tasks <= batch_size,
124
+ "number of tasks cant be larger than batch size"
125
+ ) ;
126
+ debug_assert ! ( num_tasks != 0 , "number of tasks cant be zero" ) ;
110
127
log:: info!( "Processing {} workflows in batch" , num_tasks) ;
111
- let mut batch = task_buffer
112
- . into_iter ( )
113
- . map ( |b| ( b, self . publish_tx . clone ( ) ) ) ;
128
+ let mut batch = tasks. into_iter ( ) . map ( |b| ( b, & self . publish_tx ) ) ;
114
129
match num_tasks {
115
130
1 => {
116
- let r0 = WorkflowsWorker :: execute ( batch. next ( ) . unwrap ( ) ) . await ;
117
- vec ! [ r0]
131
+ WorkflowsWorker :: execute ( batch. next ( ) . unwrap ( ) ) . await ;
118
132
}
119
133
2 => {
120
- let ( r0 , r1 ) = tokio:: join!(
134
+ tokio:: join!(
121
135
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
122
136
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) )
123
137
) ;
124
- vec ! [ r0, r1]
125
138
}
126
139
3 => {
127
- let ( r0 , r1 , r2 ) = tokio:: join!(
140
+ tokio:: join!(
128
141
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
129
142
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
130
143
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) )
131
144
) ;
132
- vec ! [ r0, r1, r2]
133
145
}
134
146
4 => {
135
- let ( r0 , r1 , r2 , r3 ) = tokio:: join!(
147
+ tokio:: join!(
136
148
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
137
149
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
138
150
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
139
151
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) )
140
152
) ;
141
- vec ! [ r0, r1, r2, r3]
142
153
}
143
154
5 => {
144
- let ( r0 , r1 , r2 , r3 , r4 ) = tokio:: join!(
155
+ tokio:: join!(
145
156
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
146
157
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
147
158
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
148
159
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
149
160
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) )
150
161
) ;
151
- vec ! [ r0, r1, r2, r3, r4]
152
162
}
153
163
6 => {
154
- let ( r0 , r1 , r2 , r3 , r4 , r5 ) = tokio:: join!(
164
+ tokio:: join!(
155
165
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
156
166
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
157
167
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
158
168
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
159
169
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
160
170
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) )
161
171
) ;
162
- vec ! [ r0, r1, r2, r3, r4, r5]
163
172
}
164
173
7 => {
165
- let ( r0 , r1 , r2 , r3 , r4 , r5 , r6 ) = tokio:: join!(
174
+ tokio:: join!(
166
175
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
167
176
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
168
177
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
@@ -171,10 +180,9 @@ impl WorkflowsWorker {
171
180
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
172
181
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) )
173
182
) ;
174
- vec ! [ r0, r1, r2, r3, r4, r5, r6]
175
183
}
176
184
8 => {
177
- let ( r0 , r1 , r2 , r3 , r4 , r5 , r6 , r7 ) = tokio:: join!(
185
+ tokio:: join!(
178
186
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
179
187
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
180
188
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
@@ -184,7 +192,6 @@ impl WorkflowsWorker {
184
192
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) ) ,
185
193
WorkflowsWorker :: execute( batch. next( ) . unwrap( ) )
186
194
) ;
187
- vec ! [ r0, r1, r2, r3, r4, r5, r6, r7]
188
195
}
189
196
_ => {
190
197
unreachable ! (
@@ -199,23 +206,28 @@ impl WorkflowsWorker {
199
206
200
207
/// Executes a single task, and publishes the output.
201
208
pub async fn execute (
202
- ( input, publish_tx) : ( WorkflowsWorkerInput , mpsc:: Sender < WorkflowsWorkerOutput > ) ,
209
+ ( input, publish_tx) : ( WorkflowsWorkerInput , & mpsc:: Sender < WorkflowsWorkerOutput > ) ,
203
210
) {
211
+ let mut stats = input. stats ;
212
+
204
213
let mut memory = ProgramMemory :: new ( ) ;
205
214
215
+ // TODO: will be removed later
206
216
let started_at = std:: time:: Instant :: now ( ) ;
217
+ stats = stats. record_execution_started_at ( ) ;
207
218
let result = input
208
219
. executor
209
220
. execute ( input. entry . as_ref ( ) , & input. workflow , & mut memory)
210
221
. await ;
222
+ stats = stats. record_execution_ended_at ( ) ;
211
223
212
224
let output = WorkflowsWorkerOutput {
213
225
result,
214
226
public_key : input. public_key ,
215
227
task_id : input. task_id ,
216
228
model_name : input. model_name ,
217
229
batchable : input. batchable ,
218
- stats : input . stats . record_execution_time ( started_at) ,
230
+ stats : stats. record_execution_time ( started_at) ,
219
231
} ;
220
232
221
233
if let Err ( e) = publish_tx. send ( output) . await {
0 commit comments