1717"""
1818
1919import logging
20- import queue
20+ import sys
2121import threading
2222from contextlib import contextmanager , nullcontext
2323from dataclasses import dataclass
6363from torch .futures import Future
6464from torch .utils ._pytree import tree_any
6565
66+ from torchft .multiprocessing import _MonitoredQueue
67+
6668if TYPE_CHECKING :
6769 from torchft .manager import Manager
6870
7779T = TypeVar ("T" )
7880
7981
80- def _get (q : mp .Queue , timeout : Union [float , timedelta ]) -> object :
81- """
82- Gets an item from a queue with a timeout. If the timeout is exceeded then
83- a TimeoutError is raised.
84-
85- If an exception is returned from the queue then it is raised.
86-
87- Args:
88- q: queue to get from
89- timeout: timeout in seconds
90- """
91- if isinstance (timeout , timedelta ):
92- timeout = timeout .total_seconds ()
93- try :
94- v = q .get (timeout = timeout )
95- except queue .Empty as e :
96- raise TimeoutError (f"queue.get() timed out after { timeout } seconds" ) from e
97- if isinstance (v , Exception ):
98- raise v
99- return v
100-
101-
10282def create_store_client (store_addr : str ) -> Store :
10383 """
10484 Creates a PrefixStore(TCPStore(...)) client from an address in the format:
@@ -573,8 +553,8 @@ class _BabyWork(Work):
573553 def __init__ (
574554 self ,
575555 pg : "ProcessGroupBaby" ,
576- tx : mp . Queue ,
577- rx : mp . Queue ,
556+ tx : _MonitoredQueue ,
557+ rx : _MonitoredQueue ,
578558 op_id : int ,
579559 timeout : float ,
580560 ) -> None :
@@ -592,7 +572,7 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
592572 self ._tx .put (("wait" , self ._op_id ), timeout = self ._timeout )
593573 op_id , event = cast (
594574 Tuple [int , Optional [torch .cuda .Event ]],
595- _get ( self ._rx , timeout or self ._timeout ),
575+ self ._rx . get ( timeout or self ._timeout ),
596576 )
597577 assert op_id == self ._op_id
598578 if event is not None :
@@ -649,9 +629,9 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
649629 self ._world_size = - 1
650630
651631 self ._p : Optional [mp .Process ] = None
652- self ._tx : Optional [mp . Queue ] = None
653- self ._rx : Optional [mp . Queue ] = None
654- self ._future_queue : Optional [mp . Queue ] = None
632+ self ._tx : Optional [_MonitoredQueue ] = None
633+ self ._rx : Optional [_MonitoredQueue ] = None
634+ self ._future_queue : Optional [_MonitoredQueue ] = None
655635 self ._future_thread : Optional [threading .Thread ] = None
656636 self ._futures : Dict [int , Future [object ]] = {}
657637 self ._futures_lock = threading .Lock ()
@@ -661,60 +641,80 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
661641
662642 self ._timeout : float = timeout
663643
664- def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
665- if self ._p is not None :
666- self ._p .kill ()
644+ def shutdown (self ) -> None :
645+ """
646+ Shutdown the process group. This will kill the underlying process and
647+ close all queues.
667648
668- self ._world_size = world_size
649+ This is a no-op if the process group is already shutdown.
650+
651+ ProcessGroup can be reconfigured after shutdown.
652+ """
669653
670654 if self ._tx is not None :
671655 self ._tx .close ()
672656 if self ._rx is not None :
673657 self ._rx .close ()
674- if self ._future_queue is not None :
658+
659+ future_queue = self ._future_queue
660+ if future_queue is not None :
675661 # wait for the future thread to exit and then close the queue
676- self ._future_queue .put (_QUEUE_CLOSE )
677- assert self ._future_thread is not None
678- self ._future_thread .join (timeout = 10.0 )
679- # pyre-ignore[16]: optional value is checked above
680- if self ._future_thread .is_alive ():
662+ future_queue .put (_QUEUE_CLOSE , timeout = timedelta (seconds = 10.0 ))
663+
664+ future_thread = self ._future_thread
665+ assert future_thread is not None
666+ future_thread .join (timeout = 10.0 )
667+ if future_thread .is_alive ():
681668 raise RuntimeError ("future thread did not exit" )
682- # pyre-ignore[16]: optional value is checked above
683- self ._future_queue .close ()
669+
670+ future_queue .close ()
671+
672+ # Kill after closing queues to avoid log spam.
673+ if self ._p is not None :
674+ self ._p .kill ()
675+
676+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
677+ self ._world_size = world_size
678+
679+ self .shutdown ()
684680
685681 ctx = mp .get_context ("spawn" )
686- self ._tx = ctx .Queue ()
687- self ._rx = rx = ctx .Queue ()
682+ tx = ctx .Queue ()
683+ rx = ctx .Queue ()
684+ future_queue = ctx .Queue ()
685+
686+ self ._p = p = ctx .Process (
687+ target = self ._worker ,
688+ args = (
689+ store_addr ,
690+ rank ,
691+ world_size ,
692+ tx ,
693+ rx ,
694+ future_queue ,
695+ ),
696+ daemon = True ,
697+ )
698+ p .start ()
699+
700+ self ._tx = tx = _MonitoredQueue (p , tx )
701+ self ._rx = rx = _MonitoredQueue (p , rx )
702+ self ._future_queue = future_queue = _MonitoredQueue (p , future_queue )
688703
689704 # futures need thread to fire callbacks
690- self ._future_queue = ctx .Queue ()
691705 # this lock needs to be held when manipulating _futures
692706 self ._futures_lock = threading .Lock ()
693707 self ._futures = {}
694708 self ._future_thread = threading .Thread (
695709 target = self ._future_handler ,
696- args = (self . _future_queue ,),
710+ args = (future_queue ,),
697711 daemon = True ,
698712 )
699713 self ._future_thread .start ()
700714
701- self ._p = ctx .Process (
702- target = self ._worker ,
703- args = (
704- store_addr ,
705- rank ,
706- world_size ,
707- self ._tx ,
708- self ._rx ,
709- self ._future_queue ,
710- ),
711- daemon = True ,
712- )
713- self ._p .start ()
714-
715715 # fetch the status of the PG init
716- # if an exception was returned _get will throw
717- assert _get ( rx , self ._timeout ) is None
716+ # if an exception was returned get will throw
717+ assert rx . get ( self ._timeout ) is None
718718
719719 @classmethod
720720 def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
@@ -739,7 +739,7 @@ def _worker(
739739 try :
740740 pg = cls ._create_pg (store , rank , world_size )
741741 except Exception as e :
742- logger . exception (f"got exception in worker: { e } " )
742+ print (f"got exception in worker: { e } " , file = sys . stderr )
743743 tx .put (e )
744744 return
745745 tx .put (None )
@@ -829,17 +829,21 @@ def callback(fut: Future[object]) -> None:
829829 raise ValueError (f"unknown cmd: { cmd } " )
830830
831831 except Exception as e :
832- logger . exception ( "worker errored" )
832+ print ( f "worker errored: { e } " , file = sys . stderr )
833833 tx .put (e )
834834 raise
835835
836- def _future_handler (self , future_queue : mp . Queue ) -> None :
836+ def _future_handler (self , future_queue : _MonitoredQueue ) -> None :
837837 try :
838838 while True :
839- cmd = future_queue .get ()
839+ try :
840+ # timeout doesn't really matter here
841+ cmd = future_queue .get (timeout = timedelta (seconds = 10.0 ))
842+ except TimeoutError :
843+ continue
840844 if cmd == _QUEUE_CLOSE :
841845 break
842- op_id , mode , data = cmd
846+ op_id , mode , data = cast ( Tuple [ int , str , object ], cmd )
843847 with self ._futures_lock :
844848 fut = self ._futures [op_id ]
845849 del self ._futures [op_id ]
@@ -862,7 +866,7 @@ def _get_future(self, op_id: int) -> Future[object]:
862866 self ._tx .put (("future" , op_id ), timeout = self ._timeout )
863867
864868 assert self ._rx is not None
865- assert _get ( self ._rx , self ._timeout ) == op_id
869+ assert self ._rx . get ( self ._timeout ) == op_id
866870 # TODO: return correct tensor instead of None
867871 return fut
868872
@@ -899,7 +903,7 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
899903 timeout = self ._timeout ,
900904 )
901905
902- op_id = _get ( rx , self ._timeout )
906+ op_id = rx . get ( self ._timeout )
903907 assert isinstance (op_id , int ), f"invalid return { op_id } "
904908
905909 return _BabyWork (pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout )
@@ -968,7 +972,7 @@ def num_active_work(self) -> int:
968972 self ._tx .put (("num_active_work" ,), timeout = self ._timeout )
969973
970974 assert self ._rx is not None
971- return cast (int , _get ( self ._rx , self ._timeout ))
975+ return cast (int , self ._rx . get ( self ._timeout ))
972976
973977
974978@dataclass
0 commit comments