@@ -106,6 +106,7 @@ impl WorkflowsWorker {
106
106
// (1) there are no tasks, or,
107
107
// (2) there are tasks less than the batch size and the channel is not empty
108
108
while tasks. is_empty ( ) || ( tasks. len ( ) < batch_size && !self . workflow_rx . is_empty ( ) ) {
109
+ log:: info!( "Waiting for more workflows to process ({})" , tasks. len( ) ) ;
109
110
let limit = batch_size - tasks. len ( ) ;
110
111
match self . workflow_rx . recv_many ( & mut tasks, limit) . await {
111
112
// 0 tasks returned means that the channel is closed
@@ -235,3 +236,107 @@ impl WorkflowsWorker {
235
236
}
236
237
}
237
238
}
239
+
240
+ #[ cfg( test) ]
241
+ mod tests {
242
+ use super :: * ;
243
+ use crate :: payloads:: TaskStats ;
244
+
245
+ use dkn_workflows:: { Executor , Model } ;
246
+ use libsecp256k1:: { PublicKey , SecretKey } ;
247
+ use tokio:: sync:: mpsc;
248
+
249
+ // cargo test --package dkn-compute --lib --all-features -- workers::workflow::tests::test_workflows_worker --exact --show-output --nocapture --ignored
250
+ #[ tokio:: test]
251
+ #[ ignore = "run manually" ]
252
+ async fn test_workflows_worker ( ) {
253
+ let _ = env_logger:: builder ( )
254
+ . filter_level ( log:: LevelFilter :: Off )
255
+ . filter_module ( "dkn_compute" , log:: LevelFilter :: Debug )
256
+ . is_test ( true )
257
+ . try_init ( ) ;
258
+
259
+ let ( publish_tx, mut publish_rx) = mpsc:: channel ( 1024 ) ;
260
+ let ( mut worker, workflow_tx) = WorkflowsWorker :: new ( publish_tx) ;
261
+
262
+ // create batch workflow worker
263
+ let worker_handle = tokio:: spawn ( async move {
264
+ worker. run_batch ( 4 ) . await ;
265
+ } ) ;
266
+
267
+ let num_tasks = 4 ;
268
+ let model = Model :: O1Preview ;
269
+ let workflow = serde_json:: json!( {
270
+ "config" : {
271
+ "max_steps" : 10 ,
272
+ "max_time" : 250 ,
273
+ "tools" : [ "" ]
274
+ } ,
275
+ "tasks" : [
276
+ {
277
+ "id" : "A" ,
278
+ "name" : "" ,
279
+ "description" : "" ,
280
+ "operator" : "generation" ,
281
+ "messages" : [ { "role" : "user" , "content" : "Write a 4 paragraph poem about Julius Caesar." } ] ,
282
+ "inputs" : [ ] ,
283
+ "outputs" : [ { "type" : "write" , "key" : "result" , "value" : "__result" } ]
284
+ } ,
285
+ {
286
+ "id" : "__end" ,
287
+ "name" : "end" ,
288
+ "description" : "End of the task" ,
289
+ "operator" : "end" ,
290
+ "messages" : [ { "role" : "user" , "content" : "End of the task" } ] ,
291
+ "inputs" : [ ] ,
292
+ "outputs" : [ ]
293
+ }
294
+ ] ,
295
+ "steps" : [ { "source" : "A" , "target" : "__end" } ] ,
296
+ "return_value" : { "input" : { "type" : "read" , "key" : "result" }
297
+ }
298
+ } ) ;
299
+
300
+ for i in 0 ..num_tasks {
301
+ log:: info!( "Sending task {}" , i + 1 ) ;
302
+
303
+ let workflow = serde_json:: from_value ( workflow. clone ( ) ) . unwrap ( ) ;
304
+
305
+ let executor = Executor :: new ( model. clone ( ) ) ;
306
+ let input = WorkflowsWorkerInput {
307
+ entry : None ,
308
+ executor,
309
+ workflow,
310
+ public_key : PublicKey :: from_secret_key ( & SecretKey :: default ( ) ) ,
311
+ task_id : "task_id" . to_string ( ) ,
312
+ model_name : model. to_string ( ) ,
313
+ stats : TaskStats :: default ( ) ,
314
+ batchable : true ,
315
+ } ;
316
+
317
+ // send workflow to worker
318
+ workflow_tx. send ( input) . await . unwrap ( ) ;
319
+ }
320
+
321
+ // now wait for all results
322
+ let mut results = Vec :: new ( ) ;
323
+ for i in 0 ..num_tasks {
324
+ log:: info!( "Waiting for result {}" , i + 1 ) ;
325
+ let result = publish_rx. recv ( ) . await . unwrap ( ) ;
326
+ log:: info!(
327
+ "Got result {} (exeuction time: {})" ,
328
+ i + 1 ,
329
+ ( result. stats. execution_time as f64 ) / 1_000_000_000f64
330
+ ) ;
331
+ if result. result . is_err ( ) {
332
+ println ! ( "Error: {:?}" , result. result) ;
333
+ }
334
+ results. push ( result) ;
335
+ }
336
+
337
+ log:: info!( "Got all results, closing channel." ) ;
338
+ publish_rx. close ( ) ;
339
+ workflow_tx. worker_handle . await . unwrap ( ) ;
340
+ log:: info!( "Done." ) ;
341
+ }
342
+ }
0 commit comments