@@ -38,6 +38,8 @@ use pyo3::types::PyType;
38
38
use tokio:: sync:: Mutex ;
39
39
use tokio:: sync:: mpsc;
40
40
41
+ type OnStopCallback = Box < dyn FnOnce ( ) -> Box < dyn std:: future:: Future < Output = ( ) > + Send > + Send > ;
42
+
41
43
use crate :: actor_mesh:: PythonActorMesh ;
42
44
use crate :: actor_mesh:: PythonActorMeshImpl ;
43
45
use crate :: alloc:: PyAlloc ;
@@ -55,6 +57,7 @@ pub struct TrackedProcMesh {
55
57
inner : SharedCellRef < ProcMesh > ,
56
58
cell : SharedCell < ProcMesh > ,
57
59
children : SharedCellPool ,
60
+ onstop_callbacks : Arc < Mutex < Vec < OnStopCallback > > > ,
58
61
}
59
62
60
63
impl Debug for TrackedProcMesh {
@@ -77,6 +80,7 @@ impl From<ProcMesh> for TrackedProcMesh {
77
80
inner,
78
81
cell,
79
82
children : SharedCellPool :: new ( ) ,
83
+ onstop_callbacks : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
80
84
}
81
85
}
82
86
}
@@ -107,8 +111,25 @@ impl TrackedProcMesh {
107
111
self . inner . client_proc ( )
108
112
}
109
113
110
- pub fn into_inner ( self ) -> ( SharedCell < ProcMesh > , SharedCellPool ) {
111
- ( self . cell , self . children )
114
+ pub fn into_inner (
115
+ self ,
116
+ ) -> (
117
+ SharedCell < ProcMesh > ,
118
+ SharedCellPool ,
119
+ Arc < Mutex < Vec < OnStopCallback > > > ,
120
+ ) {
121
+ ( self . cell , self . children , self . onstop_callbacks )
122
+ }
123
+
124
+ /// Register a callback to be called when this TrackedProcMesh is stopped
125
+ pub async fn register_onstop_callback < F , Fut > ( & self , callback : F ) -> Result < ( ) , anyhow:: Error >
126
+ where
127
+ F : FnOnce ( ) -> Fut + Send + ' static ,
128
+ Fut : std:: future:: Future < Output = ( ) > + Send + ' static ,
129
+ {
130
+ let mut callbacks = self . onstop_callbacks . lock ( ) . await ;
131
+ callbacks. push ( Box :: new ( || Box :: new ( callback ( ) ) ) ) ;
132
+ Ok ( ( ) )
112
133
}
113
134
}
114
135
@@ -230,7 +251,17 @@ impl PyProcMesh {
230
251
let tracked_proc_mesh = inner. take ( ) . await . map_err ( |e| {
231
252
PyRuntimeError :: new_err ( format ! ( "`ProcMesh` has already been stopped: {}" , e) )
232
253
} ) ?;
233
- let ( proc_mesh, children) = tracked_proc_mesh. into_inner ( ) ;
254
+ let ( proc_mesh, children, drop_callbacks) = tracked_proc_mesh. into_inner ( ) ;
255
+
256
+ // Call all registered drop callbacks before stopping
257
+ let mut callbacks = drop_callbacks. lock ( ) . await ;
258
+ let callbacks_to_call = callbacks. drain ( ..) . collect :: < Vec < _ > > ( ) ;
259
+ drop ( callbacks) ; // Release the lock
260
+
261
+ for callback in callbacks_to_call {
262
+ let future = callback ( ) ;
263
+ std:: pin:: Pin :: from ( future) . await ;
264
+ }
234
265
235
266
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
236
267
children. discard_all ( ) . await ?;
@@ -486,3 +517,196 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
486
517
hyperactor_mod. add_class :: < PyProcEvent > ( ) ?;
487
518
Ok ( ( ) )
488
519
}
520
+
521
+ #[ cfg( test) ]
522
+ mod tests {
523
+ use std:: sync:: Arc ;
524
+ use std:: sync:: atomic:: AtomicBool ;
525
+ use std:: sync:: atomic:: AtomicU32 ;
526
+ use std:: sync:: atomic:: Ordering ;
527
+
528
+ use anyhow:: Result ;
529
+ use hyperactor_mesh:: alloc:: AllocSpec ;
530
+ use hyperactor_mesh:: alloc:: Allocator ;
531
+ use hyperactor_mesh:: alloc:: local:: LocalAllocator ;
532
+ use hyperactor_mesh:: proc_mesh:: ProcMesh ;
533
+ use ndslice:: extent;
534
+ use tokio:: sync:: Mutex ;
535
+
536
+ use super :: * ;
537
+
538
+ #[ tokio:: test]
539
+ async fn test_register_onstop_callback_single ( ) -> Result < ( ) > {
540
+ // Create a TrackedProcMesh
541
+ let alloc = LocalAllocator
542
+ . allocate ( AllocSpec {
543
+ extent : extent ! { replica = 1 } ,
544
+ constraints : Default :: default ( ) ,
545
+ } )
546
+ . await ?;
547
+
548
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
549
+
550
+ // Extract events before wrapping in TrackedProcMesh
551
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
552
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
553
+
554
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
555
+
556
+ // Create a flag to track if callback was executed
557
+ let callback_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
558
+ let callback_executed_clone = callback_executed. clone ( ) ;
559
+
560
+ // Register a callback
561
+ tracked_proc_mesh
562
+ . register_onstop_callback ( move || {
563
+ let flag = callback_executed_clone. clone ( ) ;
564
+ async move {
565
+ flag. store ( true , Ordering :: SeqCst ) ;
566
+ }
567
+ } )
568
+ . await ?;
569
+
570
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
571
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
572
+
573
+ // Call stop_mesh (this should trigger the callback)
574
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
575
+
576
+ // Verify the callback was executed
577
+ assert ! (
578
+ callback_executed. load( Ordering :: SeqCst ) ,
579
+ "Callback should have been executed"
580
+ ) ;
581
+
582
+ Ok ( ( ) )
583
+ }
584
+
585
+ #[ tokio:: test]
586
+ async fn test_register_onstop_callback_multiple ( ) -> Result < ( ) > {
587
+ // Create a TrackedProcMesh
588
+ let alloc = LocalAllocator
589
+ . allocate ( AllocSpec {
590
+ extent : extent ! { replica = 1 } ,
591
+ constraints : Default :: default ( ) ,
592
+ } )
593
+ . await ?;
594
+
595
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
596
+
597
+ // Extract events before wrapping in TrackedProcMesh
598
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
599
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
600
+
601
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
602
+
603
+ // Create counters to track callback executions
604
+ let callback_count = Arc :: new ( AtomicU32 :: new ( 0 ) ) ;
605
+ let execution_order = Arc :: new ( Mutex :: new ( Vec :: < u32 > :: new ( ) ) ) ;
606
+
607
+ // Register multiple callbacks
608
+ for i in 1 ..=3 {
609
+ let count = callback_count. clone ( ) ;
610
+ let order = execution_order. clone ( ) ;
611
+ tracked_proc_mesh
612
+ . register_onstop_callback ( move || {
613
+ let count_clone = count. clone ( ) ;
614
+ let order_clone = order. clone ( ) ;
615
+ async move {
616
+ count_clone. fetch_add ( 1 , Ordering :: SeqCst ) ;
617
+ let mut order_vec = order_clone. lock ( ) . await ;
618
+ order_vec. push ( i) ;
619
+ }
620
+ } )
621
+ . await ?;
622
+ }
623
+
624
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
625
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
626
+
627
+ // Call stop_mesh (this should trigger all callbacks)
628
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
629
+
630
+ // Verify all callbacks were executed
631
+ assert_eq ! (
632
+ callback_count. load( Ordering :: SeqCst ) ,
633
+ 3 ,
634
+ "All 3 callbacks should have been executed"
635
+ ) ;
636
+
637
+ // Verify execution order (callbacks should be executed in registration order)
638
+ let order_vec = execution_order. lock ( ) . await ;
639
+ assert_eq ! (
640
+ * order_vec,
641
+ vec![ 1 , 2 , 3 ] ,
642
+ "Callbacks should be executed in registration order"
643
+ ) ;
644
+
645
+ Ok ( ( ) )
646
+ }
647
+
648
+ #[ tokio:: test]
649
+ async fn test_register_onstop_callback_error_handling ( ) -> Result < ( ) > {
650
+ // Create a TrackedProcMesh
651
+ let alloc = LocalAllocator
652
+ . allocate ( AllocSpec {
653
+ extent : extent ! { replica = 1 } ,
654
+ constraints : Default :: default ( ) ,
655
+ } )
656
+ . await ?;
657
+
658
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
659
+
660
+ // Extract events before wrapping in TrackedProcMesh
661
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
662
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
663
+
664
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
665
+
666
+ // Create flags to track callback executions
667
+ let callback1_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
668
+ let callback2_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
669
+
670
+ let callback1_executed_clone = callback1_executed. clone ( ) ;
671
+ let callback2_executed_clone = callback2_executed. clone ( ) ;
672
+
673
+ // Register a callback that panics
674
+ tracked_proc_mesh
675
+ . register_onstop_callback ( move || {
676
+ let flag = callback1_executed_clone. clone ( ) ;
677
+ async move {
678
+ flag. store ( true , Ordering :: SeqCst ) ;
679
+ // This callback completes successfully
680
+ }
681
+ } )
682
+ . await ?;
683
+
684
+ // Register another callback that should still execute even if the first one had issues
685
+ tracked_proc_mesh
686
+ . register_onstop_callback ( move || {
687
+ let flag = callback2_executed_clone. clone ( ) ;
688
+ async move {
689
+ flag. store ( true , Ordering :: SeqCst ) ;
690
+ }
691
+ } )
692
+ . await ?;
693
+
694
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
695
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
696
+
697
+ // Call stop_mesh (this should trigger both callbacks)
698
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
699
+
700
+ // Verify both callbacks were executed
701
+ assert ! (
702
+ callback1_executed. load( Ordering :: SeqCst ) ,
703
+ "First callback should have been executed"
704
+ ) ;
705
+ assert ! (
706
+ callback2_executed. load( Ordering :: SeqCst ) ,
707
+ "Second callback should have been executed"
708
+ ) ;
709
+
710
+ Ok ( ( ) )
711
+ }
712
+ }
0 commit comments