1+ import functools
12import os
3+ from abc import ABCMeta , abstractmethod
24from dataclasses import asdict
3- from typing import Any , Iterable , List , Optional , Tuple , TypeVar , Union
5+ from typing import Any , Dict , Iterable , List , Optional , Tuple , Type , TypeVar , Union
46
57import ray
68from ray import ObjectRef
1012
1113__all__ = [
1214 "init_parallel_backend" ,
13- "available_cpus" ,
1415]
1516
1617T = TypeVar ("T" )
1718
18- _PARALLEL_BACKEND : Optional [ "RayParallelBackend " ] = None
19+ _PARALLEL_BACKENDS : Dict [ str , "Type[BaseParallelBackend] " ] = {}
1920
2021
21- class RayParallelBackend :
22- """Class used to wrap ray to make it transparent to algorithms. It shouldn't
22+ class NoPublicConstructor (ABCMeta ):
23+ """Metaclass that ensures a private constructor
24+
25+ If a class uses this metaclass like this:
26+
27+ class SomeClass(metaclass=NoPublicConstructor):
28+ pass
29+
30+ If you try to instantiate your class (`SomeClass()`),
31+ a `TypeError` will be thrown.
32+
33+ Taken almost verbatim from:
34+ https://stackoverflow.com/a/64682734
35+ """
36+
37+ def __call__ (cls , * args , ** kwargs ):
38+ raise TypeError (
39+ f"{ cls .__module__ } .{ cls .__qualname__ } cannot be initialized directly. "
40+ "Use init_parallel_backend() instead."
41+ )
42+
43+ def _create (cls , * args : Any , ** kwargs : Any ):
44+ return super ().__call__ (* args , ** kwargs )
45+
46+
47+ class BaseParallelBackend (metaclass = NoPublicConstructor ):
48+ """Abstract base class for all parallel backends"""
49+
50+ config : Dict [str , Any ] = {}
51+
52+ def __init_subclass__ (cls , * , backend_name : str , ** kwargs ):
53+ global _PARALLEL_BACKENDS
54+ _PARALLEL_BACKENDS [backend_name ] = cls
55+ super ().__init_subclass__ (** kwargs )
56+
57+ @abstractmethod
58+ def get (self , v : Any , * args , ** kwargs ):
59+ ...
60+
61+ @abstractmethod
62+ def put (self , v : Any , * args , ** kwargs ) -> Any :
63+ ...
64+
65+ @abstractmethod
66+ def wrap (self , * args , ** kwargs ) -> Any :
67+ ...
68+
69+ @abstractmethod
70+ def wait (self , v : Any , * args , ** kwargs ) -> Any :
71+ ...
72+
73+ @abstractmethod
74+ def _effective_n_jobs (self , n_jobs : int ) -> int :
75+ ...
76+
77+ def effective_n_jobs (self , n_jobs : int = - 1 ) -> int :
78+ if n_jobs == 0 :
79+ raise ValueError ("n_jobs == 0 in Parallel has no meaning" )
80+ n_jobs = self ._effective_n_jobs (n_jobs )
81+ return n_jobs
82+
83+ def __repr__ (self ) -> str :
84+ return f"<{ self .__class__ .__name__ } : { self .config } >"
85+
86+
87+ class SequentialParallelBackend (BaseParallelBackend , backend_name = "sequential" ):
88+ """Class used to run jobs sequentially and locally. It shouldn't
2389 be initialized directly. You should instead call `init_parallel_backend`.
2490
25- :param config: instance of :class:`~pydvl.utils.config.ParallelConfig` with
26- cluster address, number of cpus, etc.
91+ :param config: instance of :class:`~pydvl.utils.config.ParallelConfig` with number of cpus
92+ """
2793
28- :Example:
94+ def __init__ (self , config : ParallelConfig ):
95+ config_dict = asdict (config )
96+ config_dict .pop ("backend" )
97+ config_dict .pop ("address" )
98+ config_dict ["num_cpus" ] = config_dict .pop ("n_local_workers" )
99+ self .config = config_dict
29100
30- >>> from pydvl.utils.parallel.backend import RayParallelBackend
31- >>> from pydvl.utils.config import ParallelConfig
32- >>> config = ParallelConfig(backend="ray")
33- >>> parallel_backend = RayParallelBackend(config)
34- >>> parallel_backend
35- <RayParallelBackend: {'address': None, 'num_cpus': None}>
101+ def get (self , v : Any , * args , ** kwargs ):
102+ return v
103+
104+ def put (self , v : Any , * args , ** kwargs ) -> Any :
105+ return v
106+
107+ def wrap (self , * args , ** kwargs ) -> Any :
108+ assert len (args ) == 1
109+ return functools .partial (args [0 ], ** kwargs )
36110
111+ def wait (self , v : Any , * args , ** kwargs ) -> Tuple [list , list ]:
112+ return v , []
113+
114+ def _effective_n_jobs (self , n_jobs : int ) -> int :
115+ if n_jobs < 0 :
116+ if self .config ["num_cpus" ]:
117+ eff_n_jobs : int = self .config ["num_cpus" ]
118+ else :
119+ eff_n_jobs = available_cpus ()
120+ else :
121+ eff_n_jobs = n_jobs
122+ return eff_n_jobs
123+
124+
125+ class RayParallelBackend (BaseParallelBackend , backend_name = "ray" ):
126+ """Class used to wrap ray to make it transparent to algorithms. It shouldn't
127+ be initialized directly. You should instead call `init_parallel_backend`.
128+
129+ :param config: instance of :class:`~pydvl.utils.config.ParallelConfig` with
130+ cluster address, number of cpus, etc.
37131 """
38132
39133 def __init__ (self , config : ParallelConfig ):
40134 config_dict = asdict (config )
41135 config_dict .pop ("backend" )
42- config_dict ["num_cpus" ] = config_dict .pop ("num_workers " )
136+ config_dict ["num_cpus" ] = config_dict .pop ("n_local_workers " )
43137 self .config = config_dict
138+ if self .config ["address" ] is None :
139+ self .config ["ignore_reinit_error" ] = True
44140 ray .init (** self .config )
45141
46142 def get (
47143 self ,
48144 v : Union [ObjectRef , Iterable [ObjectRef ], T ],
49- * ,
50- timeout : Optional [ float ] = None ,
145+ * args ,
146+ ** kwargs ,
51147 ) -> Union [T , Any ]:
148+ timeout : Optional [float ] = kwargs .get ("timeout" , None )
52149 if isinstance (v , ObjectRef ):
53150 return ray .get (v , timeout = timeout )
54151 elif isinstance (v , Iterable ):
55152 return [self .get (x , timeout = timeout ) for x in v ]
56153 else :
57154 return v
58155
59- def put (self , x : Any , ** kwargs ) -> ObjectRef :
60- return ray .put (x , ** kwargs ) # type: ignore
156+ def put (self , v : T , * args , ** kwargs ) -> Union ["ObjectRef[T]" , T ]:
157+ try :
158+ return ray .put (v , ** kwargs ) # type: ignore
159+ except TypeError :
160+ return v # type: ignore
61161
62162 def wrap (self , * args , ** kwargs ) -> RemoteFunction :
63163 return ray .remote (* args , ** kwargs ) # type: ignore
64164
65165 def wait (
66166 self ,
67- object_refs : List ["ray.ObjectRef" ],
68- * ,
69- num_returns : int = 1 ,
70- timeout : Optional [float ] = None ,
167+ v : List ["ObjectRef" ],
168+ * args ,
169+ ** kwargs ,
71170 ) -> Tuple [List [ObjectRef ], List [ObjectRef ]]:
171+ num_returns : int = kwargs .get ("num_returns" , 1 )
172+ timeout : Optional [float ] = kwargs .get ("timeout" , None )
72173 return ray .wait ( # type: ignore
73- object_refs ,
174+ v ,
74175 num_returns = num_returns ,
75176 timeout = timeout ,
76177 )
77178
78- def effective_n_jobs (self , n_jobs : Optional [int ]) -> int :
79- if n_jobs == 0 :
80- raise ValueError ("n_jobs == 0 in Parallel has no meaning" )
81- elif n_jobs is None or n_jobs < 0 :
179+ def _effective_n_jobs (self , n_jobs : int ) -> int :
180+ if n_jobs < 0 :
82181 ray_cpus = int (ray ._private .state .cluster_resources ()["CPU" ]) # type: ignore
83182 eff_n_jobs = ray_cpus
84183 else :
85184 eff_n_jobs = n_jobs
86185 return eff_n_jobs
87186
88- def __repr__ (self ) -> str :
89- return f"<RayParallelBackend: { self .config } >"
90-
91187
92- def init_parallel_backend (config : ParallelConfig ) -> "RayParallelBackend" :
188+ def init_parallel_backend (
189+ config : ParallelConfig ,
190+ ) -> BaseParallelBackend :
93191 """Initializes the parallel backend and returns an instance of it.
94192
95193 :param config: instance of :class:`~pydvl.utils.config.ParallelConfig` with cluster address, number of cpus, etc.
@@ -101,16 +199,15 @@ def init_parallel_backend(config: ParallelConfig) -> "RayParallelBackend":
101199 >>> config = ParallelConfig(backend="ray")
102200 >>> parallel_backend = init_parallel_backend(config)
103201 >>> parallel_backend
104- <RayParallelBackend: {'address': None, 'num_cpus': None}>
202+ <RayParallelBackend: {'address': None, 'num_cpus': None, 'ignore_reinit_error': True }>
105203
106204 """
107- global _PARALLEL_BACKEND
108- if _PARALLEL_BACKEND is None :
109- if config .backend == "ray" :
110- _PARALLEL_BACKEND = RayParallelBackend (config )
111- else :
112- raise NotImplementedError (f"Unexpected parallel type { config .backend } " )
113- return _PARALLEL_BACKEND
205+ try :
206+ parallel_backend_cls = _PARALLEL_BACKENDS [config .backend ]
207+ except KeyError :
208+ raise NotImplementedError (f"Unexpected parallel backend { config .backend } " )
209+ parallel_backend = parallel_backend_cls ._create (config )
210+ return parallel_backend # type: ignore
114211
115212
116213def available_cpus () -> int :
0 commit comments