@@ -38,6 +38,8 @@ use pyo3::types::PyType;
38
38
use tokio:: sync:: Mutex ;
39
39
use tokio:: sync:: mpsc;
40
40
41
+ type OnStopCallback = Box < dyn FnOnce ( ) -> Box < dyn std:: future:: Future < Output = ( ) > + Send > + Send > ;
42
+
41
43
use crate :: actor_mesh:: PythonActorMesh ;
42
44
use crate :: actor_mesh:: PythonActorMeshImpl ;
43
45
use crate :: alloc:: PyAlloc ;
@@ -55,6 +57,7 @@ pub struct TrackedProcMesh {
55
57
inner : SharedCellRef < ProcMesh > ,
56
58
cell : SharedCell < ProcMesh > ,
57
59
children : SharedCellPool ,
60
+ onstop_callbacks : Arc < Mutex < Vec < OnStopCallback > > > ,
58
61
}
59
62
60
63
impl Debug for TrackedProcMesh {
@@ -77,6 +80,7 @@ impl From<ProcMesh> for TrackedProcMesh {
77
80
inner,
78
81
cell,
79
82
children : SharedCellPool :: new ( ) ,
83
+ onstop_callbacks : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
80
84
}
81
85
}
82
86
}
@@ -107,8 +111,25 @@ impl TrackedProcMesh {
107
111
self . inner . client_proc ( )
108
112
}
109
113
110
- pub fn into_inner ( self ) -> ( SharedCell < ProcMesh > , SharedCellPool ) {
111
- ( self . cell , self . children )
114
+ pub fn into_inner (
115
+ self ,
116
+ ) -> (
117
+ SharedCell < ProcMesh > ,
118
+ SharedCellPool ,
119
+ Arc < Mutex < Vec < OnStopCallback > > > ,
120
+ ) {
121
+ ( self . cell , self . children , self . onstop_callbacks )
122
+ }
123
+
124
+ /// Register a callback to be called when this TrackedProcMesh is stopped
125
+ pub async fn register_onstop_callback < F , Fut > ( & self , callback : F ) -> Result < ( ) , anyhow:: Error >
126
+ where
127
+ F : FnOnce ( ) -> Fut + Send + ' static ,
128
+ Fut : std:: future:: Future < Output = ( ) > + Send + ' static ,
129
+ {
130
+ let mut callbacks = self . onstop_callbacks . lock ( ) . await ;
131
+ callbacks. push ( Box :: new ( || Box :: new ( callback ( ) ) ) ) ;
132
+ Ok ( ( ) )
112
133
}
113
134
}
114
135
@@ -230,7 +251,17 @@ impl PyProcMesh {
230
251
let tracked_proc_mesh = inner. take ( ) . await . map_err ( |e| {
231
252
PyRuntimeError :: new_err ( format ! ( "`ProcMesh` has already been stopped: {}" , e) )
232
253
} ) ?;
233
- let ( proc_mesh, children) = tracked_proc_mesh. into_inner ( ) ;
254
+ let ( proc_mesh, children, drop_callbacks) = tracked_proc_mesh. into_inner ( ) ;
255
+
256
+ // Call all registered drop callbacks before stopping
257
+ let mut callbacks = drop_callbacks. lock ( ) . await ;
258
+ let callbacks_to_call = callbacks. drain ( ..) . collect :: < Vec < _ > > ( ) ;
259
+ drop ( callbacks) ; // Release the lock
260
+
261
+ for callback in callbacks_to_call {
262
+ let future = callback ( ) ;
263
+ std:: pin:: Pin :: from ( future) . await ;
264
+ }
234
265
235
266
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
236
267
// Discarding actor meshes that have been individually stopped will result in an expected error
@@ -488,3 +519,196 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
488
519
hyperactor_mod. add_class :: < PyProcEvent > ( ) ?;
489
520
Ok ( ( ) )
490
521
}
522
+
523
+ #[ cfg( test) ]
524
+ mod tests {
525
+ use std:: sync:: Arc ;
526
+ use std:: sync:: atomic:: AtomicBool ;
527
+ use std:: sync:: atomic:: AtomicU32 ;
528
+ use std:: sync:: atomic:: Ordering ;
529
+
530
+ use anyhow:: Result ;
531
+ use hyperactor_mesh:: alloc:: AllocSpec ;
532
+ use hyperactor_mesh:: alloc:: Allocator ;
533
+ use hyperactor_mesh:: alloc:: local:: LocalAllocator ;
534
+ use hyperactor_mesh:: proc_mesh:: ProcMesh ;
535
+ use ndslice:: extent;
536
+ use tokio:: sync:: Mutex ;
537
+
538
+ use super :: * ;
539
+
540
+ #[ tokio:: test]
541
+ async fn test_register_onstop_callback_single ( ) -> Result < ( ) > {
542
+ // Create a TrackedProcMesh
543
+ let alloc = LocalAllocator
544
+ . allocate ( AllocSpec {
545
+ extent : extent ! { replica = 1 } ,
546
+ constraints : Default :: default ( ) ,
547
+ } )
548
+ . await ?;
549
+
550
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
551
+
552
+ // Extract events before wrapping in TrackedProcMesh
553
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
554
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
555
+
556
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
557
+
558
+ // Create a flag to track if callback was executed
559
+ let callback_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
560
+ let callback_executed_clone = callback_executed. clone ( ) ;
561
+
562
+ // Register a callback
563
+ tracked_proc_mesh
564
+ . register_onstop_callback ( move || {
565
+ let flag = callback_executed_clone. clone ( ) ;
566
+ async move {
567
+ flag. store ( true , Ordering :: SeqCst ) ;
568
+ }
569
+ } )
570
+ . await ?;
571
+
572
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
573
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
574
+
575
+ // Call stop_mesh (this should trigger the callback)
576
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
577
+
578
+ // Verify the callback was executed
579
+ assert ! (
580
+ callback_executed. load( Ordering :: SeqCst ) ,
581
+ "Callback should have been executed"
582
+ ) ;
583
+
584
+ Ok ( ( ) )
585
+ }
586
+
587
+ #[ tokio:: test]
588
+ async fn test_register_onstop_callback_multiple ( ) -> Result < ( ) > {
589
+ // Create a TrackedProcMesh
590
+ let alloc = LocalAllocator
591
+ . allocate ( AllocSpec {
592
+ extent : extent ! { replica = 1 } ,
593
+ constraints : Default :: default ( ) ,
594
+ } )
595
+ . await ?;
596
+
597
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
598
+
599
+ // Extract events before wrapping in TrackedProcMesh
600
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
601
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
602
+
603
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
604
+
605
+ // Create counters to track callback executions
606
+ let callback_count = Arc :: new ( AtomicU32 :: new ( 0 ) ) ;
607
+ let execution_order = Arc :: new ( Mutex :: new ( Vec :: < u32 > :: new ( ) ) ) ;
608
+
609
+ // Register multiple callbacks
610
+ for i in 1 ..=3 {
611
+ let count = callback_count. clone ( ) ;
612
+ let order = execution_order. clone ( ) ;
613
+ tracked_proc_mesh
614
+ . register_onstop_callback ( move || {
615
+ let count_clone = count. clone ( ) ;
616
+ let order_clone = order. clone ( ) ;
617
+ async move {
618
+ count_clone. fetch_add ( 1 , Ordering :: SeqCst ) ;
619
+ let mut order_vec = order_clone. lock ( ) . await ;
620
+ order_vec. push ( i) ;
621
+ }
622
+ } )
623
+ . await ?;
624
+ }
625
+
626
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
627
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
628
+
629
+ // Call stop_mesh (this should trigger all callbacks)
630
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
631
+
632
+ // Verify all callbacks were executed
633
+ assert_eq ! (
634
+ callback_count. load( Ordering :: SeqCst ) ,
635
+ 3 ,
636
+ "All 3 callbacks should have been executed"
637
+ ) ;
638
+
639
+ // Verify execution order (callbacks should be executed in registration order)
640
+ let order_vec = execution_order. lock ( ) . await ;
641
+ assert_eq ! (
642
+ * order_vec,
643
+ vec![ 1 , 2 , 3 ] ,
644
+ "Callbacks should be executed in registration order"
645
+ ) ;
646
+
647
+ Ok ( ( ) )
648
+ }
649
+
650
+ #[ tokio:: test]
651
+ async fn test_register_onstop_callback_error_handling ( ) -> Result < ( ) > {
652
+ // Create a TrackedProcMesh
653
+ let alloc = LocalAllocator
654
+ . allocate ( AllocSpec {
655
+ extent : extent ! { replica = 1 } ,
656
+ constraints : Default :: default ( ) ,
657
+ } )
658
+ . await ?;
659
+
660
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
661
+
662
+ // Extract events before wrapping in TrackedProcMesh
663
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
664
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
665
+
666
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
667
+
668
+ // Create flags to track callback executions
669
+ let callback1_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
670
+ let callback2_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
671
+
672
+ let callback1_executed_clone = callback1_executed. clone ( ) ;
673
+ let callback2_executed_clone = callback2_executed. clone ( ) ;
674
+
675
+ // Register a callback that panics
676
+ tracked_proc_mesh
677
+ . register_onstop_callback ( move || {
678
+ let flag = callback1_executed_clone. clone ( ) ;
679
+ async move {
680
+ flag. store ( true , Ordering :: SeqCst ) ;
681
+ // This callback completes successfully
682
+ }
683
+ } )
684
+ . await ?;
685
+
686
+ // Register another callback that should still execute even if the first one had issues
687
+ tracked_proc_mesh
688
+ . register_onstop_callback ( move || {
689
+ let flag = callback2_executed_clone. clone ( ) ;
690
+ async move {
691
+ flag. store ( true , Ordering :: SeqCst ) ;
692
+ }
693
+ } )
694
+ . await ?;
695
+
696
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
697
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
698
+
699
+ // Call stop_mesh (this should trigger both callbacks)
700
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
701
+
702
+ // Verify both callbacks were executed
703
+ assert ! (
704
+ callback1_executed. load( Ordering :: SeqCst ) ,
705
+ "First callback should have been executed"
706
+ ) ;
707
+ assert ! (
708
+ callback2_executed. load( Ordering :: SeqCst ) ,
709
+ "Second callback should have been executed"
710
+ ) ;
711
+
712
+ Ok ( ( ) )
713
+ }
714
+ }
0 commit comments