|
18 | 18 | import torch |
19 | 19 | import torch.distributed.checkpoint as dcp |
20 | 20 | import torchstore as ts |
| 21 | + |
21 | 22 | from monarch.actor import current_rank, endpoint, ProcMesh |
22 | 23 | from torchstore.state_dict_utils import DELIM |
23 | 24 | from vllm.config import VllmConfig |
@@ -243,6 +244,15 @@ async def setup(self): |
243 | 244 | self.request_id = 0 |
244 | 245 | self.policy_version = 0 |
245 | 246 | self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} |
| 247 | + |
| 248 | + # TODO: Investigate whether this can be combined with `policy.running` |
| 249 | + # Whether this policy is accepting requests. |
| 250 | + self.accepting_requests = True |
| 251 | + # Guard for accepting_requests |
| 252 | + self.request_lock = asyncio.Condition() |
| 253 | + # Guard for updating requests |
| 254 | + self.update_lock = asyncio.Condition() |
| 255 | + |
246 | 256 | self.vllm_config: VllmConfig = self.engine_config.create_vllm_config() |
247 | 257 |
|
248 | 258 | # Setup sampling params |
@@ -332,33 +342,39 @@ async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: |
332 | 342 | ) |
333 | 343 | t.step("process_inputs") |
334 | 344 |
|
335 | | - # Explicitly keeping the redundant logic to make it easier to pick up |
336 | | - # vllm changes |
337 | | - # TODO: Clean up before release |
338 | | - if (num_samples := self.sampling_params.n) == 1: |
339 | | - self.output_processor.add_request(request, prompt_str, None, 0) |
340 | | - request, _ = self.preprocess_add_request(request) |
341 | | - request_fut = asyncio.Future() |
342 | | - self.requests[request_id] = (None, request_fut) |
343 | | - |
344 | | - self.scheduler.add_request(request) |
345 | | - else: |
346 | | - parent_req = ParentRequest(request_id, self.sampling_params) |
347 | | - for idx in range(num_samples): |
348 | | - # Note: `get_child_info` mutates ParentRequest to track the |
349 | | - # generated child request |
350 | | - child_request_id, params = parent_req.get_child_info(idx) |
351 | | - child_request = request if idx == num_samples - 1 else copy(request) |
352 | | - child_request.request_id = child_request_id |
353 | | - child_request.sampling_params = params |
354 | | - self.output_processor.add_request( |
355 | | - child_request, prompt_str, parent_req, idx |
356 | | - ) |
357 | | - child_request, _ = self.preprocess_add_request(child_request) |
| 345 | + # Wait until we're accepting requests (releases lock while waiting) |
| 346 | + # If accepting_requests is True, continue immediately (holding the lock) |
| 347 | + # If False, release lock, wait for notification, re-acquire and recheck |
| 348 | + async with self.request_lock: |
| 349 | + await self.request_lock.wait_for(lambda: self.accepting_requests) |
| 350 | + |
| 351 | + # Explicitly keeping the redundant logic to make it easier to pick up |
| 352 | + # vllm changes |
| 353 | + # TODO: Clean up before release |
| 354 | + if (num_samples := self.sampling_params.n) == 1: |
| 355 | + self.output_processor.add_request(request, prompt_str, None, 0) |
| 356 | + request, _ = self.preprocess_add_request(request) |
| 357 | + request_fut = asyncio.Future() |
| 358 | + self.requests[request_id] = (None, request_fut) |
| 359 | + |
| 360 | + self.scheduler.add_request(request) |
| 361 | + else: |
| 362 | + parent_req = ParentRequest(request_id, self.sampling_params) |
| 363 | + for idx in range(num_samples): |
| 364 | + # Note: `get_child_info` mutates ParentRequest to track the |
| 365 | + # generated child request |
| 366 | + child_request_id, params = parent_req.get_child_info(idx) |
| 367 | + child_request = request if idx == num_samples - 1 else copy(request) |
| 368 | + child_request.request_id = child_request_id |
| 369 | + child_request.sampling_params = params |
| 370 | + self.output_processor.add_request( |
| 371 | + child_request, prompt_str, parent_req, idx |
| 372 | + ) |
| 373 | + child_request, _ = self.preprocess_add_request(child_request) |
358 | 374 |
|
359 | | - self.scheduler.add_request(child_request) |
360 | | - request_fut = asyncio.Future() |
361 | | - self.requests[request_id] = (parent_req, request_fut) |
| 375 | + self.scheduler.add_request(child_request) |
| 376 | + request_fut = asyncio.Future() |
| 377 | + self.requests[request_id] = (parent_req, request_fut) |
362 | 378 |
|
363 | 379 | completions = await request_fut |
364 | 380 | t.step("generate") |
@@ -428,34 +444,57 @@ async def run(self): |
428 | 444 | _, fut = self.requests.pop(request_output.request_id) |
429 | 445 | fut.set_result(completions) |
430 | 446 |
|
| 447 | + # Notify waiters if queue is drained |
| 448 | + async with self.request_lock: |
| 449 | + if len(self.requests) == 0: |
| 450 | + self.request_lock.notify_all() |
| 451 | + |
431 | 452 | @endpoint |
432 | 453 | async def update_weights(self, policy_version: int): |
433 | | - # TODO: If generating long sequences, this might be long and will block policy weight updates |
434 | | - curr_requests = [fut for _, fut in self.requests.values()] |
435 | | - if curr_requests: |
436 | | - # Record pending requests metrics |
437 | | - record_metric( |
438 | | - "policy_perf/update_weights/avg_pending_requests", |
439 | | - len(curr_requests), |
440 | | - Reduce.MEAN, |
441 | | - ) |
442 | | - record_metric( |
443 | | - "policy_perf/update_weights/max_pending_requests", |
444 | | - len(curr_requests), |
445 | | - Reduce.MAX, |
446 | | - ) |
447 | | - logger.debug(f"Waiting for {len(curr_requests)} pending requests") |
448 | | - await asyncio.gather(*curr_requests) |
| 454 | + # Serialize updates (only one update at a time) |
| 455 | + async with self.update_lock: |
| 456 | + # Grab the lock to stop accepting requests and wait on pending requests |
| 457 | + async with self.request_lock: |
| 458 | + self.accepting_requests = False |
| 459 | + |
| 460 | + curr_requests = [fut for _, fut in self.requests.values()] |
| 461 | + if curr_requests: |
| 462 | + # Record pending requests metrics |
| 463 | + record_metric( |
| 464 | + "policy_perf/update_weights/avg_pending_requests", |
| 465 | + len(curr_requests), |
| 466 | + Reduce.MEAN, |
| 467 | + ) |
| 468 | + record_metric( |
| 469 | + "policy_perf/update_weights/max_pending_requests", |
| 470 | + len(curr_requests), |
| 471 | + Reduce.MAX, |
| 472 | + ) |
| 473 | + logger.debug(f"Waiting for {len(curr_requests)} pending requests") |
449 | 474 |
|
450 | | - # Record weight update metrics |
451 | | - record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM) |
| 475 | + # Wait until all pending requests have been processed |
| 476 | + # TODO: If generating long sequences, this might be long and will block |
| 477 | + # policy weight updates |
| 478 | + await self.request_lock.wait_for(lambda: len(self.requests) == 0) |
| 479 | + |
| 480 | + # Record weight update metrics |
| 481 | + record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM) |
| 482 | + |
| 483 | + logger.debug(f"Starting weight update on {self.__class__.__name__}") |
| 484 | + if self.use_vllm_builtin_load: |
| 485 | + await self.policy_worker.update.call(version=policy_version) |
| 486 | + else: |
| 487 | + await self.policy_worker.update_DEPRECATED.call(version=policy_version) |
| 488 | + self.policy_version = policy_version |
| 489 | + |
| 490 | + # After updating the weights, we need to reset the KV cache |
| 491 | + self.scheduler.kv_cache_manager.reset_prefix_cache() |
| 492 | + |
| 493 | + # Resume accepting requests and wake up any waiting generate() calls |
| 494 | + async with self.request_lock: |
| 495 | + self.accepting_requests = True |
| 496 | + self.request_lock.notify_all() |
452 | 497 |
|
453 | | - logger.debug(f"Starting weight update on {self.__class__.__name__}") |
454 | | - if self.use_vllm_builtin_load: |
455 | | - await self.policy_worker.update.call(version=policy_version) |
456 | | - else: |
457 | | - await self.policy_worker.update_DEPRECATED.call(version=policy_version) |
458 | | - self.policy_version = policy_version |
459 | 498 | logger.info(f"Weight update completed (now v{self.policy_version})") |
460 | 499 |
|
461 | 500 | @endpoint |
|
0 commit comments