Skip to content

Commit 8ca76fe

Browse files
committed
fix vllm configuration and load balancing
1 parent c095ec3 commit 8ca76fe

File tree

6 files changed

+53
-47
lines changed

6 files changed

+53
-47
lines changed

applications/ColossalChat/coati/distributed/agent/agentic_producer.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
log_rollout_interval: int = 20,
5353
rollout_log_file: str = "./rollout_log.jsonl",
5454
enable_profiling: bool = False,
55+
load_balancer=None,
5556
n_behind: int = 0,
5657
):
5758
assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
@@ -84,6 +85,7 @@ def __init__(
8485
enable_profiling=enable_profiling,
8586
n_behind=n_behind,
8687
)
88+
self.load_balancer = load_balancer
8789
self.tool_workers = tool_workers
8890
self.agentic_config = model_config if not agentic_config else agentic_config
8991
self.agentic_config.update({"model": model_config["path"]})
@@ -183,32 +185,26 @@ def _parse_response(self, response: str) -> Dict[str, Any]:
183185
assistant_message["tool_calls"] = tool_calls
184186
return assistant_message
185187

186-
def _select_tool_worker(self) -> ray.actor.ActorHandle:
188+
def _select_tool_worker(self) -> int:
187189
"""
188190
Select a tool worker based on the current load.
189191
"""
190-
loads = ray.get([worker.get_load.remote() for worker in self.tool_workers])
191-
min_load = min(loads)
192-
candidates = [i for i, l in enumerate(loads) if l == min_load]
193-
selected_idx = random.choice(candidates) # random tie break
194-
ray.get(self.tool_workers[selected_idx].increase_load.remote())
195-
return self.tool_workers[selected_idx]
192+
selected_idx, current_loads = ray.get(self.load_balancer.get_next_worker.remote("tool", amount=1))
193+
return selected_idx
196194

197-
def _select_async_producer(self, request_id) -> ray.actor.ActorHandle:
195+
def _select_async_producer(self, request_id) -> int:
198196
"""
199197
Select an async producer based on the current load.
200198
"""
201199
# use the last used async producer if exists to reuse kv cache (as vllm use paged kv cache,
202200
# it will reuse most of the kv cache pages without recomputation)
203201
if request_id in self.async_llm_engine_map:
204-
return self.async_producers[self.async_llm_engine_map[request_id]]
202+
ray.get(self.load_balancer.increase_load.remote("async-llm", self.async_llm_engine_map[request_id], 1))
203+
return self.async_llm_engine_map[request_id]
205204
# otherwise select the least loaded async producer
206-
loads = ray.get([proc.get_producer_load.remote() for proc in self.async_producers])
207-
min_load = min(loads)
208-
candidates = [i for i, l in enumerate(loads) if l == min_load]
209-
selected_idx = random.choice(candidates) # random tie break
205+
selected_idx, current_loads = ray.get(self.load_balancer.get_next_worker.remote("async-llm", amount=1))
210206
self.async_llm_engine_map[request_id] = selected_idx
211-
return self.async_producers[selected_idx]
207+
return selected_idx
212208

213209
def _run_agentic_pipeline(self, messages):
214210
"""
@@ -234,7 +230,7 @@ def _run_agentic_pipeline(self, messages):
234230
)
235231
del self.async_llm_engine_map[request_id]
236232
return messages, response_input_ids, logprobs
237-
async_producer = self._select_async_producer(request_id=request_id)
233+
async_producer = self.async_producers[self._select_async_producer(request_id=request_id)]
238234
agentic_generate_config = copy.deepcopy(self.generate_config)
239235
agentic_generate_config["max_tokens"] = self.agentic_config.get("max_tokens", 2048)
240236
response = ray.get(
@@ -246,6 +242,7 @@ def _run_agentic_pipeline(self, messages):
246242
)
247243
)
248244
llm_call_count += 1
245+
ray.get(self.load_balancer.decrease_load.remote("async-llm", self.async_llm_engine_map[request_id], 1))
249246
self.consumer_global_step = response.pop("consumer_global_step")
250247
response_input_ids = response["input_ids"]
251248
logprobs = response["action_log_probs"]
@@ -261,12 +258,17 @@ def _run_agentic_pipeline(self, messages):
261258
return messages, response_input_ids, logprobs
262259
tool_call_count += len(assistant_message["tool_calls"])
263260
handlers = []
261+
tool_workers_called = []
264262
for tool_call in assistant_message["tool_calls"]:
265263
# select a tool worker to execute the tool call
266-
tool_worker = self._select_tool_worker()
264+
tool_worker_idx = self._select_tool_worker()
265+
tool_workers_called.append(tool_worker_idx)
266+
tool_worker = self.tool_workers[tool_worker_idx]
267267
handler = tool_worker.call.remote(tool_call["function"]["name"], tool_call["function"]["arguments"])
268268
handlers.append(handler)
269269
tool_results = ray.get(handlers)
270+
for idx in tool_workers_called:
271+
ray.get(self.load_balancer.decrease_load.remote("tool", idx, 1))
270272
for tool_call, tool_result in zip(assistant_message["tool_calls"], tool_results):
271273
tool_message = {"role": "tool", "content": str(tool_result)}
272274
messages.append(tool_message)

applications/ColossalChat/coati/distributed/agent/tool_worker.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,6 @@ def __init__(self, tools: List[BaseTool]):
1919
tools (List[BaseTool]): List of LangChain tools to register.
2020
"""
2121
self._tool_registry: Dict[str, BaseTool] = {tool.name: tool for tool in tools}
22-
self.pending = 0
23-
24-
@ray.method(concurrency_group="io")
25-
def get_load(self) -> int:
26-
"""Return the current load of the worker."""
27-
return self.pending
28-
29-
@ray.method(concurrency_group="io")
30-
def increase_load(self):
31-
"""Increase the load counter."""
32-
self.pending += 1
3322

3423
@ray.method(concurrency_group="io")
3524
def list_tools(self) -> List[str]:
@@ -64,7 +53,6 @@ def call(self, tool_name: str, input_data: Union[str, Dict[str, Any]], **kwargs)
6453
Any: The tool's output.
6554
"""
6655
if tool_name == "return_parsing_error":
67-
self.pending -= 1
6856
return "Error: Tool call parsing error. Please use the correct JSON format."
6957
if tool_name not in self._tool_registry:
7058
return f"Error: Tool {tool_name} not found. Available tools: {self.list_tools()}"
@@ -73,5 +61,4 @@ def call(self, tool_name: str, input_data: Union[str, Dict[str, Any]], **kwargs)
7361
ret = tool.run(input_data, **kwargs)
7462
except Exception as e:
7563
ret = f"Error: Tool {tool_name} execution failed with error: {str(e)}"
76-
self.pending -= 1
7764
return ret

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .consumer import SimpleConsumer
1111
from .grpo_consumer import GRPOConsumer
1212
from .producer import AsyncSimpleProducer, SimpleProducer
13+
from .utils import LoadBalancer
1314

1415
ALGO_MAP = {
1516
"Simple": SimpleConsumer,
@@ -86,7 +87,7 @@ def launch_distributed(
8687
num_samples = get_jsonl_size_fast(dataset_path)
8788
global_inference_batch_size = inference_batch_size * num_producers
8889
num_update_per_episode = num_samples // global_inference_batch_size
89-
num_recv_per_update = inference_batch_size // inference_microbatch_size if "async" not in inference_backend else 1
90+
num_recv_per_update = inference_batch_size // inference_microbatch_size if "async-agentic" not in inference_backend else 1
9091

9192
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
9293
wandb_group_name = str(uuid.uuid4())
@@ -124,6 +125,7 @@ def launch_distributed(
124125
enable_agentic = "agentic" in inference_backend
125126
if enable_agentic:
126127
inference_backend = inference_backend.replace("agentic-", "")
128+
inference_microbatch_size = inference_microbatch_size * num_generations
127129
for i in range(num_producers):
128130
node_id = gpu_to_node_id[0]
129131
producer_ip_address = gpu_to_ip_address[0]
@@ -141,11 +143,7 @@ def launch_distributed(
141143
model_config=inference_model_config,
142144
generate_config=generate_config,
143145
tokenizer_config=tokenizer_config,
144-
microbatch_size=(
145-
inference_microbatch_size * num_generations
146-
if "async-agentic" in inference_backend
147-
else inference_microbatch_size
148-
),
146+
microbatch_size=inference_microbatch_size,
149147
backend=inference_backend,
150148
num_generations=num_generations,
151149
consumer_plugin_config=plugin_config,
@@ -183,6 +181,7 @@ def launch_distributed(
183181
assert (
184182
agentic_config["agentic_producer"] in AGENTIC_PRODUCER_MAP
185183
), f"Only {list(AGENTIC_PRODUCER_MAP.keys())} are supported as agentic producer so far."
184+
load_balancer = LoadBalancer.remote({"tool": len(tool_workers), "async-llm": num_producers})
186185
agentic_producer_cls = AGENTIC_PRODUCER_MAP[agentic_config["agentic_producer"]]
187186
agentic_config.pop("agentic_producer")
188187
producer_procs = [
@@ -214,6 +213,7 @@ def launch_distributed(
214213
log_rollout_interval=log_rollout_interval,
215214
rollout_log_file=rollout_log_file,
216215
enable_profiling=enable_profiling,
216+
load_balancer=load_balancer,
217217
n_behind=n_behind,
218218
)
219219
for producer_idx in range(num_producers * inference_batch_size)

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -636,12 +636,6 @@ async def rollout(self, input_ids, attention_mask, **kwargs):
636636
"""
637637
raise NotImplementedError("rollout must be implemented in subclasses")
638638

639-
async def get_producer_load(self):
640-
"""
641-
Get the load of each producer.
642-
"""
643-
return len(self.model.running_requests)
644-
645639
async def async_sync_model(self, episode, step, num_processes: int = 1) -> None:
646640
"""
647641
Asyncronous version to sync model from consumer to producer.
@@ -853,7 +847,6 @@ class AsyncSimpleProducer(BaseAsyncProducer):
853847
Asyncronous version of the producer that uses vLLM for generation.
854848
This class is designed to handle multiple producer actors and distribute tasks among them.
855849
"""
856-
857850
@torch.no_grad()
858851
async def rollout(self, input_ids, attention_mask, **kwargs):
859852
# naive rollout strategy without load balancing

applications/ColossalChat/coati/distributed/utils.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import json
22
import os
33
from typing import Any, Dict, List
4-
4+
import asyncio
55
import torch
66
from filelock import FileLock
7-
7+
import random
88
from colossalai.shardformer.layer.loss import dist_log_prob
9-
9+
import ray
1010

1111
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
1212
batches = []
@@ -165,3 +165,25 @@ def safe_append_to_jsonl_file(file_path, data):
165165
for entry in data:
166166
json_line = json.dumps(entry, ensure_ascii=False)
167167
f.write(json_line + "\n")
168+
169+
@ray.remote
170+
class LoadBalancer:
171+
def __init__(self, worker_counts):
172+
self.load = {}
173+
for type in worker_counts:
174+
self.load[type] = {k: 0 for k in range(worker_counts[type])}
175+
176+
def get_next_worker(self, worker_type, amount=1):
177+
loads = [(k, v) for k, v in self.load[worker_type].items()]
178+
min_load = min(loads, key=lambda x: x[1])
179+
candidates = [k for k, v in loads if v == min_load[1]]
180+
chosen = random.choice(candidates)
181+
self.load[worker_type][chosen] += amount
182+
return chosen, self.load[worker_type]
183+
184+
def increase_load(self, worker_type, worker_id, amount=1):
185+
self.load[worker_type][worker_id] += amount
186+
187+
def decrease_load(self, worker_type, worker_id, amount=1):
188+
self.load[worker_type][worker_id] -= amount
189+

applications/ColossalChat/rl_example.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,10 @@
281281
# os.environ["VLLM_DP_SIZE"] = str(args.producer_data_parallel_size)
282282
inference_model_config.update(
283283
dict(
284-
gpu_memory_utilization=0.7,
285-
enforce_eager=True,
284+
gpu_memory_utilization=0.8,
285+
max_num_batched_tokens=4096,
286+
max_num_seqs=1024,
287+
enforce_eager=False,
286288
enable_chunked_prefill=True,
287289
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
288290
tensor_parallel_size=args.producer_tensor_parallel_size,

0 commit comments

Comments
 (0)