@@ -52,6 +52,7 @@ def __init__(
5252 log_rollout_interval : int = 20 ,
5353 rollout_log_file : str = "./rollout_log.jsonl" ,
5454 enable_profiling : bool = False ,
55+ load_balancer = None ,
5556 n_behind : int = 0 ,
5657 ):
5758 assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
@@ -84,6 +85,7 @@ def __init__(
8485 enable_profiling = enable_profiling ,
8586 n_behind = n_behind ,
8687 )
88+ self .load_balancer = load_balancer
8789 self .tool_workers = tool_workers
8890 self .agentic_config = model_config if not agentic_config else agentic_config
8991 self .agentic_config .update ({"model" : model_config ["path" ]})
@@ -183,32 +185,26 @@ def _parse_response(self, response: str) -> Dict[str, Any]:
183185 assistant_message ["tool_calls" ] = tool_calls
184186 return assistant_message
185187
186- def _select_tool_worker (self ) -> ray . actor . ActorHandle :
188+ def _select_tool_worker (self ) -> int :
187189 """
188190 Select a tool worker based on the current load.
189191 """
190- loads = ray .get ([worker .get_load .remote () for worker in self .tool_workers ])
191- min_load = min (loads )
192- candidates = [i for i , l in enumerate (loads ) if l == min_load ]
193- selected_idx = random .choice (candidates ) # random tie break
194- ray .get (self .tool_workers [selected_idx ].increase_load .remote ())
195- return self .tool_workers [selected_idx ]
192+ selected_idx , current_loads = ray .get (self .load_balancer .get_next_worker .remote ("tool" , amount = 1 ))
193+ return selected_idx
196194
197- def _select_async_producer (self , request_id ) -> ray . actor . ActorHandle :
195+ def _select_async_producer (self , request_id ) -> int :
198196 """
199197 Select an async producer based on the current load.
200198 """
201199 # use the last used async producer if exists to reuse kv cache (as vllm use paged kv cache,
202200 # it will reuse most of the kv cache pages without recomputation)
203201 if request_id in self .async_llm_engine_map :
204- return self .async_producers [self .async_llm_engine_map [request_id ]]
202+ ray .get (self .load_balancer .increase_load .remote ("async-llm" , self .async_llm_engine_map [request_id ], 1 ))
203+ return self .async_llm_engine_map [request_id ]
205204 # otherwise select the least loaded async producer
206- loads = ray .get ([proc .get_producer_load .remote () for proc in self .async_producers ])
207- min_load = min (loads )
208- candidates = [i for i , l in enumerate (loads ) if l == min_load ]
209- selected_idx = random .choice (candidates ) # random tie break
205+ selected_idx , current_loads = ray .get (self .load_balancer .get_next_worker .remote ("async-llm" , amount = 1 ))
210206 self .async_llm_engine_map [request_id ] = selected_idx
211- return self . async_producers [ selected_idx ]
207+ return selected_idx
212208
213209 def _run_agentic_pipeline (self , messages ):
214210 """
@@ -234,7 +230,7 @@ def _run_agentic_pipeline(self, messages):
234230 )
235231 del self .async_llm_engine_map [request_id ]
236232 return messages , response_input_ids , logprobs
237- async_producer = self ._select_async_producer (request_id = request_id )
233+ async_producer = self .async_producers [ self . _select_async_producer (request_id = request_id )]
238234 agentic_generate_config = copy .deepcopy (self .generate_config )
239235 agentic_generate_config ["max_tokens" ] = self .agentic_config .get ("max_tokens" , 2048 )
240236 response = ray .get (
@@ -246,6 +242,7 @@ def _run_agentic_pipeline(self, messages):
246242 )
247243 )
248244 llm_call_count += 1
245+ ray .get (self .load_balancer .decrease_load .remote ("async-llm" , self .async_llm_engine_map [request_id ], 1 ))
249246 self .consumer_global_step = response .pop ("consumer_global_step" )
250247 response_input_ids = response ["input_ids" ]
251248 logprobs = response ["action_log_probs" ]
@@ -261,12 +258,17 @@ def _run_agentic_pipeline(self, messages):
261258 return messages , response_input_ids , logprobs
262259 tool_call_count += len (assistant_message ["tool_calls" ])
263260 handlers = []
261+ tool_workers_called = []
264262 for tool_call in assistant_message ["tool_calls" ]:
265263 # select a tool worker to execute the tool call
266- tool_worker = self ._select_tool_worker ()
264+ tool_worker_idx = self ._select_tool_worker ()
265+ tool_workers_called .append (tool_worker_idx )
266+ tool_worker = self .tool_workers [tool_worker_idx ]
267267 handler = tool_worker .call .remote (tool_call ["function" ]["name" ], tool_call ["function" ]["arguments" ])
268268 handlers .append (handler )
269269 tool_results = ray .get (handlers )
270+ for idx in tool_workers_called :
271+ ray .get (self .load_balancer .decrease_load .remote ("tool" , idx , 1 ))
270272 for tool_call , tool_result in zip (assistant_message ["tool_calls" ], tool_results ):
271273 tool_message = {"role" : "tool" , "content" : str (tool_result )}
272274 messages .append (tool_message )
0 commit comments