@@ -406,3 +406,103 @@ fn test_keccak256_cuda_tracegen() {
406
406
. simple_test ( )
407
407
. unwrap ( ) ;
408
408
}
409
+
410
+ #[ cfg( feature = "cuda" ) ]
411
+ #[ test]
412
+ fn test_keccak256_cuda_tracegen_multi ( ) {
413
+ let num_threads: usize = std:: env:: var ( "NUM_THREADS" )
414
+ . ok ( )
415
+ . and_then ( |s| s. parse ( ) . ok ( ) )
416
+ . unwrap_or ( 2 ) ;
417
+
418
+ let num_tasks: usize = std:: env:: var ( "NUM_TASKS" )
419
+ . ok ( )
420
+ . and_then ( |s| s. parse ( ) . ok ( ) )
421
+ . unwrap_or ( num_threads * 4 ) ;
422
+
423
+ let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
424
+ . max_blocking_threads ( num_threads)
425
+ . enable_all ( )
426
+ . build ( )
427
+ . unwrap ( ) ;
428
+
429
+ runtime. block_on ( async {
430
+ let tasks_per_thread = num_tasks. div_ceil ( num_threads) ;
431
+ let mut worker_handles = Vec :: new ( ) ;
432
+
433
+ for worker_idx in 0 ..num_threads {
434
+ let start_task = worker_idx * tasks_per_thread;
435
+ let end_task = std:: cmp:: min ( start_task + tasks_per_thread, num_tasks) ;
436
+
437
+ let worker_handle = tokio:: task:: spawn ( async move {
438
+ for task_id in start_task..end_task {
439
+ tokio:: task:: spawn_blocking ( move || {
440
+ println ! ( "[worker {}, task {}] Starting test" , worker_idx, task_id) ;
441
+
442
+ let mut rng = create_seeded_rng ( ) ;
443
+ let mut tester = GpuChipTestBuilder :: default ( )
444
+ . with_bitwise_op_lookup ( default_bitwise_lookup_bus ( ) ) ;
445
+
446
+ let mut harness = create_cuda_harness ( & tester) ;
447
+
448
+ let num_ops: usize = 10 ;
449
+ for _ in 0 ..num_ops {
450
+ set_and_execute (
451
+ & mut tester,
452
+ & mut harness. executor ,
453
+ & mut harness. dense_arena ,
454
+ & mut rng,
455
+ KECCAK256 ,
456
+ None ,
457
+ None ,
458
+ None ,
459
+ ) ;
460
+ }
461
+
462
+ for len in [ 0 , 135 , 136 , 137 , 2000 ] {
463
+ set_and_execute (
464
+ & mut tester,
465
+ & mut harness. executor ,
466
+ & mut harness. dense_arena ,
467
+ & mut rng,
468
+ KECCAK256 ,
469
+ None ,
470
+ Some ( len) ,
471
+ None ,
472
+ ) ;
473
+ }
474
+
475
+ harness
476
+ . dense_arena
477
+ . get_record_seeker :: < KeccakVmRecordMut , _ > ( )
478
+ . transfer_to_matrix_arena ( & mut harness. matrix_arena ) ;
479
+
480
+ tester
481
+ . build ( )
482
+ . load_gpu_harness ( harness)
483
+ . finalize ( )
484
+ . simple_test ( )
485
+ . unwrap ( ) ;
486
+
487
+ println ! (
488
+ "[worker {}, task {}] Test completed ✅" ,
489
+ worker_idx, task_id
490
+ ) ;
491
+ } )
492
+ . await
493
+ . expect ( "task failed" ) ;
494
+ }
495
+ } ) ;
496
+ worker_handles. push ( worker_handle) ;
497
+ }
498
+
499
+ for handle in worker_handles {
500
+ handle. await . expect ( "worker failed" ) ;
501
+ }
502
+
503
+ println ! (
504
+ "\n All {} tasks completed on {} workers." ,
505
+ num_tasks, num_threads
506
+ ) ;
507
+ } ) ;
508
+ }
0 commit comments