11import abc
22import inspect
33import logging
4- from typing import Any , Dict , Optional , Type , Union , cast
4+ from time import sleep
5+ from typing import Generic , List , Optional , Type , TypeVar , cast
56
6- from pydvl .utils .config import ParallelConfig
7- from pydvl .utils .parallel .backend import RayParallelBackend , init_parallel_backend
7+ from ..config import ParallelConfig
8+ from ..status import Status
9+ from .backend import RayParallelBackend , init_parallel_backend
810
911__all__ = ["RayActorWrapper" , "Coordinator" , "Worker" ]
1012
@@ -65,50 +67,45 @@ def wrapper(
6567 setattr (self , name , remote_caller (name ))
6668
6769
68- class Coordinator (abc .ABC ):
70+ Result = TypeVar ("Result" ) # Avoids circular import with ValuationResult
71+
72+
73+ class Coordinator (Generic [Result ], abc .ABC ):
6974 """The coordinator has two main tasks: aggregating the results of the
7075 workers and terminating the process once a certain accuracy or total
7176 number of iterations is reached.
72-
73- :param progress: Whether to display a progress bar
7477 """
7578
76- def __init__ (self , * , progress : Optional [bool ] = True ):
77- self .progress = progress
78- # For each worker: values, stddev, num_iterations
79- self .workers_results : Dict [int , Dict [str , float ]] = dict ()
80- self ._total_iterations = 0
81- self ._is_done = False
79+ _status : Status
8280
83- def add_results (self , worker_id : int , results : Dict [str , Union [float , int ]]):
81+ def __init__ (self ):
82+ self .worker_results : List [Result ] = []
83+ self ._status = Status .Pending
84+
85+ def add_results (self , results : Result ):
8486 """Used by workers to report their results. Stores the results directly
85- into the `worker_status` dictionary.
87+ into :attr:`worker_results`
8688
87- :param worker_id: id of the worker
88- :param results: results of worker calculations
89+ :param results: results of worker's calculations
8990 """
90- self .workers_results [ worker_id ] = results
91+ self .worker_results . append ( results )
9192
9293 # this should be a @property, but with it ray.get messes up
9394 def is_done (self ) -> bool :
9495 """Used by workers to check whether to terminate their process.
9596
96- :return: `True` if workers must terminate, `False` otherwise.
97+ :return: `` True`` if workers must terminate, `` False` ` otherwise.
9798 """
98- return self ._is_done
99+ return bool ( self ._status )
99100
100101 @abc .abstractmethod
101- def get_results (self ) -> Any :
102+ def accumulate (self ) -> Result :
102103 """Aggregates the results of the different workers."""
103104 raise NotImplementedError ()
104105
105106 @abc .abstractmethod
106- def check_done (self ) -> bool :
107- """Checks whether the accuracy of the calculation or the total number
108- of iterations have crossed the set thresholds.
109-
110- If so, it sets the `is_done` label to `True`.
111- """
107+ def check_convergence (self ) -> bool :
108+ """Evaluates the convergence criteria on the aggregated results."""
112109 raise NotImplementedError ()
113110
114111
@@ -117,25 +114,22 @@ class Worker(abc.ABC):
117114
118115 def __init__ (
119116 self ,
120- coordinator : " Coordinator" ,
117+ coordinator : Coordinator ,
121118 worker_id : int ,
122119 * ,
123- progress : bool = False ,
124120 update_period : int = 30 ,
125121 ):
126122 """A worker
127123
128124 :param coordinator: worker results will be pushed to this coordinator
129125 :param worker_id: id used for reporting through maybe_progress
130- :param progress: set to True to report progress, else False
131126 :param update_period: interval in seconds between different updates
132127 to and from the coordinator
133128 """
134129 super ().__init__ ()
135130 self .worker_id = worker_id
136131 self .coordinator = coordinator
137132 self .update_period = update_period
138- self .progress = progress
139133
140134 def run (self , * args , ** kwargs ):
141135 """Runs the worker."""
0 commit comments