1- import os
21import asyncio
2+ import os
33from typing import Any , Dict , List , Optional , Tuple
44
55try :
88 e .msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator"""
99 raise
1010
11- from ray .util .placement_group import (PlacementGroup ,
12- PlacementGroupSchedulingStrategy ,
11+ from ray .util .placement_group import (PlacementGroupSchedulingStrategy ,
1312 get_current_placement_group ,
1413 placement_group )
1514
@@ -79,15 +78,15 @@ def __init__(self,
7978 self .master_address = ray .util .get_node_ip_address ()
8079 self .master_port = get_free_port ()
8180
82- self .worker_kwargs = dict (** worker_kwargs ,
83- postproc_worker_config = postproc_worker_config ,
84- is_llm_executor = is_llm_executor )
85- if not has_event_loop ():
86- self .init_workers_sync ()
81+ self .worker_kwargs = dict (
82+ ** worker_kwargs ,
83+ postproc_worker_config = postproc_worker_config ,
84+ is_llm_executor = is_llm_executor )
8785
8886 self .init_rpc_executor ()
8987 worker_kwargs ['rpc_addr' ] = self .rpc_addr
90- self .create_workers (RayGPUWorker , worker_kwargs )
88+ if not has_event_loop ():
89+ self .init_workers_sync ()
9190 self .setup_engine_remote ()
9291 self .setup_mainloop (tasks = [self ._fetch_responses_loop_async ],
9392 thread_name = "ray_executor_main_loop" )
@@ -99,9 +98,13 @@ def __init__(self,
9998 raise e
10099
101100 def create_workers (self , worker_cls , worker_kwargs ):
101+ llm_args = worker_kwargs .get ("llm_args" )
102+
102103 # When set to be a fraction, it allows Ray to schedule
103104 # multiple actors on a single GPU for colocate use cases.
104- num_gpus = float (os .getenv ("TRTLLM_RAY_PER_WORKER_GPUS" , "1.0" ))
105+ num_gpus = (llm_args .per_worker_gpu_share if llm_args
106+ and llm_args .per_worker_gpu_share is not None else float (
107+ os .getenv ("TRTLLM_RAY_PER_WORKER_GPUS" , "1.0" )))
105108 logger .debug (f"{ num_gpus = } for each worker." )
106109
107110 runtime_env = ray .runtime_env .RuntimeEnv ()
@@ -112,42 +115,40 @@ def create_workers(self, worker_cls, worker_kwargs):
112115 "MASTER_PORT" : str (self .master_port )
113116 })
114117
115- self . placement_group , self .bundle_indices = self ._get_placement_group (
116- tp_size = self .tp_size )
118+ placement_groups , self .bundle_indices = self ._get_placement_group (
119+ tp_size = self .tp_size , worker_kwargs = worker_kwargs )
117120
118- self .workers = [
119- RayWorkerWrapper .options (
121+ if isinstance (placement_groups , list ):
122+ self .placement_group = None
123+ else :
124+ self .placement_group = placement_groups
125+
126+ self .workers = []
127+ for rank in range (self .world_size ):
128+ pg = placement_groups [rank ] if isinstance (
129+ placement_groups , list ) else placement_groups
130+ worker = RayWorkerWrapper .options (
120131 num_gpus = num_gpus ,
121- runtime_env = runtime_env , # per-actor env
132+ runtime_env = runtime_env ,
122133 scheduling_strategy = PlacementGroupSchedulingStrategy (
123- placement_group = self . placement_group ,
134+ placement_group = pg ,
124135 placement_group_bundle_index = self .bundle_indices [rank ],
125136 )).remote (worker_cls , worker_kwargs , self .world_size , rank )
126- for rank in range (self .world_size )
127- ]
137+ self .workers .append (worker )
128138
129139 def init_workers_sync (self ):
130140 self .create_workers (RayGPUWorker , self .worker_kwargs )
131141 try :
132- ray .get ([ worker . __ray_ready__ . remote () for worker in self . workers ] )
142+ ray .get (self . _get_worker_ready_futures () )
133143 except ray .exceptions .ActorDiedError as e :
134- if "The actor died because of an error raised in its creation task" in str (
135- e ):
136- raise RuntimeError (
137- "RayGPUWorker died during initialization" ) from e
138- raise
144+ raise RuntimeError ("RayGPUWorker died during initialization" ) from e
139145
140146 async def init_workers_async (self ):
141147 self .create_workers (RayGPUWorker , self .worker_kwargs )
142148 try :
143- await asyncio .gather (* [ worker . __ray_ready__ . remote () for worker in self . workers ] )
149+ await asyncio .gather (* self . _get_worker_ready_futures () )
144150 except ray .exceptions .ActorDiedError as e :
145- if "The actor died because of an error raised in its creation task" in str (
146- e ):
147- raise RuntimeError (
148- "RayGPUWorker died during initialization" ) from e
149- raise
150-
151+ raise RuntimeError ("RayGPUWorker died during initialization" ) from e
151152
152153 @unwrap_ray_errors ()
153154 def call_all_ray_workers (self , func : str , leader_only : bool ,
@@ -187,6 +188,20 @@ def collective_rpc(self,
187188 ** kwargs ))
188189 return refs if non_block else ray .get (refs )
189190
191+ @unwrap_ray_errors ()
192+ async def collective_rpc_async (
193+ self ,
194+ method : str ,
195+ args : tuple = (),
196+ kwargs : Optional [dict ] = None ,
197+ unique_reply_rank : Optional [int ] = None ) -> list [Any ]:
198+ refs = self .collective_rpc (method ,
199+ args ,
200+ kwargs ,
201+ non_block = True ,
202+ unique_reply_rank = unique_reply_rank )
203+ return await asyncio .gather (* refs )
204+
190205 def submit (self , request : "GenerationRequest" ) -> "GenerationResult" :
191206 """
192207 Low-level API to the executor. Return a "future" GenerationResult
@@ -281,15 +296,51 @@ def shutdown(self):
281296 logger .debug ("Shutting down Ray cluster" )
282297 ray .shutdown ()
283298
284- def _get_placement_group (self ,
285- tp_size : int ) -> Tuple [PlacementGroup , List [int ]]:
299+ def _get_worker_ready_futures (self ):
300+ return [worker .__ray_ready__ .remote () for worker in self .workers ]
301+
302+ def _get_placement_group (
303+ self ,
304+ tp_size : int ,
305+ worker_kwargs : Dict = None ) -> Tuple [Any , List [int ]]:
286306 """
287307 Either use the existing placement group from driver script (e.g., in the case of RL FW integration),
288308 or create a default PACK placement group where each bundle has tp_size GPUs.
289309 - When tp_size ≤ GPUs per node, keep one TP group per node.
290310 - When tp_size > GPUs per node, allow a TP group span nodes.
291311 - rank 0 must be put on the driver node
312+
313+ Returns:
314+ Tuple of (placement_group(s), bundle_indices)
315+ - placement_group(s) can be a single PlacementGroup or a List[PlacementGroup]
316+ - bundle_indices is always a List[int]
292317 """
318+ llm_args = worker_kwargs .get ("llm_args" ) if worker_kwargs else None
319+
320+ if llm_args and hasattr (
321+ llm_args ,
322+ 'placement_groups' ) and llm_args .placement_groups is not None :
323+ total_workers = sum (
324+ len (indices ) for indices in llm_args .placement_bundle_indices )
325+ if total_workers != self .world_size :
326+ raise ValueError (
327+ f"Total bundle indices ({ total_workers } ) must equal world_size ({ self .world_size } )"
328+ )
329+
330+ logger .info (
331+ f"Creating { self .world_size } workers with external placement groups"
332+ )
333+
334+ flat_pgs = []
335+ flat_indices = []
336+ for pg , indices in zip (llm_args .placement_groups ,
337+ llm_args .placement_bundle_indices ):
338+ for idx in indices :
339+ flat_pgs .append (pg )
340+ flat_indices .append (idx )
341+
342+ return flat_pgs , flat_indices
343+
293344 bundle_indices = os .getenv ("TRTLLM_RAY_BUNDLE_INDICES" , None )
294345
295346 if bundle_indices :
0 commit comments