Skip to content

Commit 1f27b54

Browse files
authored
Add locking to request queue while weight updates (#264)
* Add basic locking * remove file * Renamed/Wrap Lock, Update Tests, Added reset_prefix_cache * Update comment * Inline * test_policy_update broken in trunk, fixing here * Added Doc suggestions and extending lock for adding request
1 parent 3186797 commit 1f27b54

File tree

2 files changed

+94
-60
lines changed

2 files changed

+94
-60
lines changed

src/forge/actors/policy.py

Lines changed: 89 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.distributed.checkpoint as dcp
2020
import torchstore as ts
21+
2122
from monarch.actor import current_rank, endpoint, ProcMesh
2223
from torchstore.state_dict_utils import DELIM
2324
from vllm.config import VllmConfig
@@ -243,6 +244,15 @@ async def setup(self):
243244
self.request_id = 0
244245
self.policy_version = 0
245246
self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
247+
248+
# TODO: Investigate whether this can be combined with `policy.running`
249+
# Whether this policy is accepting requests.
250+
self.accepting_requests = True
251+
# Guard for accepting_requests
252+
self.request_lock = asyncio.Condition()
253+
# Guard for updating requests
254+
self.update_lock = asyncio.Condition()
255+
246256
self.vllm_config: VllmConfig = self.engine_config.create_vllm_config()
247257

248258
# Setup sampling params
@@ -332,33 +342,39 @@ async def generate(self, prompt: str, priority: int = 0) -> list[Completion]:
332342
)
333343
t.step("process_inputs")
334344

335-
# Explicitly keeping the redundant logic to make it easier to pick up
336-
# vllm changes
337-
# TODO: Clean up before release
338-
if (num_samples := self.sampling_params.n) == 1:
339-
self.output_processor.add_request(request, prompt_str, None, 0)
340-
request, _ = self.preprocess_add_request(request)
341-
request_fut = asyncio.Future()
342-
self.requests[request_id] = (None, request_fut)
343-
344-
self.scheduler.add_request(request)
345-
else:
346-
parent_req = ParentRequest(request_id, self.sampling_params)
347-
for idx in range(num_samples):
348-
# Note: `get_child_info` mutates ParentRequest to track the
349-
# generated child request
350-
child_request_id, params = parent_req.get_child_info(idx)
351-
child_request = request if idx == num_samples - 1 else copy(request)
352-
child_request.request_id = child_request_id
353-
child_request.sampling_params = params
354-
self.output_processor.add_request(
355-
child_request, prompt_str, parent_req, idx
356-
)
357-
child_request, _ = self.preprocess_add_request(child_request)
345+
# Wait until we're accepting requests (releases lock while waiting)
346+
# If accepting_requests is True, continue immediately (holding the lock)
347+
# If False, release lock, wait for notification, re-acquire and recheck
348+
async with self.request_lock:
349+
await self.request_lock.wait_for(lambda: self.accepting_requests)
350+
351+
# Explicitly keeping the redundant logic to make it easier to pick up
352+
# vllm changes
353+
# TODO: Clean up before release
354+
if (num_samples := self.sampling_params.n) == 1:
355+
self.output_processor.add_request(request, prompt_str, None, 0)
356+
request, _ = self.preprocess_add_request(request)
357+
request_fut = asyncio.Future()
358+
self.requests[request_id] = (None, request_fut)
359+
360+
self.scheduler.add_request(request)
361+
else:
362+
parent_req = ParentRequest(request_id, self.sampling_params)
363+
for idx in range(num_samples):
364+
# Note: `get_child_info` mutates ParentRequest to track the
365+
# generated child request
366+
child_request_id, params = parent_req.get_child_info(idx)
367+
child_request = request if idx == num_samples - 1 else copy(request)
368+
child_request.request_id = child_request_id
369+
child_request.sampling_params = params
370+
self.output_processor.add_request(
371+
child_request, prompt_str, parent_req, idx
372+
)
373+
child_request, _ = self.preprocess_add_request(child_request)
358374

359-
self.scheduler.add_request(child_request)
360-
request_fut = asyncio.Future()
361-
self.requests[request_id] = (parent_req, request_fut)
375+
self.scheduler.add_request(child_request)
376+
request_fut = asyncio.Future()
377+
self.requests[request_id] = (parent_req, request_fut)
362378

363379
completions = await request_fut
364380
t.step("generate")
@@ -428,34 +444,57 @@ async def run(self):
428444
_, fut = self.requests.pop(request_output.request_id)
429445
fut.set_result(completions)
430446

447+
# Notify waiters if queue is drained
448+
async with self.request_lock:
449+
if len(self.requests) == 0:
450+
self.request_lock.notify_all()
451+
431452
@endpoint
432453
async def update_weights(self, policy_version: int):
433-
# TODO: If generating long sequences, this might be long and will block policy weight updates
434-
curr_requests = [fut for _, fut in self.requests.values()]
435-
if curr_requests:
436-
# Record pending requests metrics
437-
record_metric(
438-
"policy_perf/update_weights/avg_pending_requests",
439-
len(curr_requests),
440-
Reduce.MEAN,
441-
)
442-
record_metric(
443-
"policy_perf/update_weights/max_pending_requests",
444-
len(curr_requests),
445-
Reduce.MAX,
446-
)
447-
logger.debug(f"Waiting for {len(curr_requests)} pending requests")
448-
await asyncio.gather(*curr_requests)
454+
# Serialize updates (only one update at a time)
455+
async with self.update_lock:
456+
# Grab the lock to stop accepting requests and wait on pending requests
457+
async with self.request_lock:
458+
self.accepting_requests = False
459+
460+
curr_requests = [fut for _, fut in self.requests.values()]
461+
if curr_requests:
462+
# Record pending requests metrics
463+
record_metric(
464+
"policy_perf/update_weights/avg_pending_requests",
465+
len(curr_requests),
466+
Reduce.MEAN,
467+
)
468+
record_metric(
469+
"policy_perf/update_weights/max_pending_requests",
470+
len(curr_requests),
471+
Reduce.MAX,
472+
)
473+
logger.debug(f"Waiting for {len(curr_requests)} pending requests")
449474

450-
# Record weight update metrics
451-
record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM)
475+
# Wait until all pending requests have been processed
476+
# TODO: If generating long sequences, this might be long and will block
477+
# policy weight updates
478+
await self.request_lock.wait_for(lambda: len(self.requests) == 0)
479+
480+
# Record weight update metrics
481+
record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM)
482+
483+
logger.debug(f"Starting weight update on {self.__class__.__name__}")
484+
if self.use_vllm_builtin_load:
485+
await self.policy_worker.update.call(version=policy_version)
486+
else:
487+
await self.policy_worker.update_DEPRECATED.call(version=policy_version)
488+
self.policy_version = policy_version
489+
490+
# After updating the weights, we need to reset the KV cache
491+
self.scheduler.kv_cache_manager.reset_prefix_cache()
492+
493+
# Resume accepting requests and wake up any waiting generate() calls
494+
async with self.request_lock:
495+
self.accepting_requests = True
496+
self.request_lock.notify_all()
452497

453-
logger.debug(f"Starting weight update on {self.__class__.__name__}")
454-
if self.use_vllm_builtin_load:
455-
await self.policy_worker.update.call(version=policy_version)
456-
else:
457-
await self.policy_worker.update_DEPRECATED.call(version=policy_version)
458-
self.policy_version = policy_version
459498
logger.info(f"Weight update completed (now v{self.policy_version})")
460499

461500
@endpoint

tests/integration_tests/test_policy_update.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ async def test_sanity_check(self, request):
165165
use_dcp_override = request.config.getoption("--use_dcp")
166166
cfg = self._load_config(config_path=config_path)
167167

168-
trainer_proc_size = cfg.services.trainer.procs
168+
trainer_proc_size = cfg.actors.trainer.procs
169169
policy_tp_size = cfg.policy.engine_config.tensor_parallel_size
170170

171171
if policy_tp_size != cfg.services.policy.procs:
@@ -188,9 +188,6 @@ async def test_sanity_check(self, request):
188188
services_policy_cfg = cfg.services.policy
189189
services_policy_cfg.num_replicas = 1
190190

191-
services_trainer_cfg = cfg.services.trainer
192-
services_trainer_cfg.num_replicas = 1
193-
194191
trainer_cfg = cfg.trainer
195192
trainer_cfg.checkpoint = {
196193
"enable": True,
@@ -207,20 +204,18 @@ async def test_sanity_check(self, request):
207204
policy, rl_trainer = await asyncio.gather(
208205
*[
209206
Policy.options(**services_policy_cfg).as_service(**cfg.policy),
210-
MockRLTrainer.options(**services_trainer_cfg).as_service(
211-
**trainer_cfg
212-
),
207+
MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
213208
]
214209
)
215210

216211
# Main logic begins here
217212
v0 = uuid.uuid4().int
218213
v1 = v0 + 1
219214

220-
await rl_trainer.push_weights.fanout(policy_version=v0)
215+
await rl_trainer.push_weights.call(policy_version=v0)
221216
# Setting everything to zero
222-
await rl_trainer.zero_out_model_states.fanout()
223-
await rl_trainer.push_weights.fanout(policy_version=v1)
217+
await rl_trainer.zero_out_model_states.call()
218+
await rl_trainer.push_weights.call(policy_version=v1)
224219
await policy._test_save_model_params.fanout()
225220

226221
# Sanity check that before update all the tests pass

0 commit comments

Comments
 (0)