diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 145b6cd48..7e9dd8bc7 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -6,21 +6,42 @@ import asyncio import logging +import tempfile import time from dataclasses import dataclass from typing import Callable +import safetensors.torch as safetensors import torch + +from absl import app, flags from datasets import load_dataset from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig from forge.actors.replay_buffer import ReplayBuffer from forge.controller.actor import ForgeActor from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from forge.data.rewards import MathReward, ThinkingReward +from forge.data.weights_handle import WeightsHandle, WeightsHandleType from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint + from transformers import AutoModelForCausalLM, AutoTokenizer +FLAGS = flags.FLAGS + +flags.DEFINE_integer("batch_size", 8, "") +flags.DEFINE_integer("update_period", 5, "") +flags.DEFINE_integer("group_size", 16, "") +flags.DEFINE_integer("max_policy_age", 10, "") + + +def clean_up_temp_dir(temp_dir: str) -> None: + """Clean up temporary directory.""" + import shutil + + shutil.rmtree(temp_dir, ignore_errors=True) + + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -77,6 +98,9 @@ def __init__(self, episode_id: int, prompt: str, target: str, policy_version: in def add_group(self, group: Group): self.groups.append(group) + def add_groups(self, groups: list[Group]): + self.groups.extend(groups) + class Trainer(ForgeActor): """GRPO Trainer implementation for policy optimization.""" @@ -189,24 +213,49 @@ async def train_step(self, batch: list[Episode]): return {"loss": avg_loss, "groups_processed": num_groups_processed} @endpoint - async def update_weights(self, policy_actor): - """Update policy model weights with trainer's current weights.""" + async def export_weights(self, step: int) -> WeightsHandle: + """Export weights to a temp file and return the handle.""" # Time how long it takes to update weights start_time = time.time() # Set model to eval mode for weight extraction self.model.eval() - # Extract current model state dict - model_state_dict = self.model.state_dict() + # Save weights to a memory-backed temporary directory using /dev/shm + import os - # Convert tensors to CPU for transfer (if they're on GPU) - cpu_state_dict = {} - for key, tensor in model_state_dict.items(): - cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor + shm_path = "/dev/shm" + if os.path.exists(shm_path): + # Use shared memory filesystem for memory-backed storage + temp_dir = tempfile.mkdtemp( + prefix=f"model_weights_step_{step:08d}_", dir=shm_path + ) + else: + # Fallback to system temp if /dev/shm not available + temp_dir = tempfile.mkdtemp(prefix=f"model_weights_step_{step:08d}_") - # Update the policy actor's model weights - await policy_actor.update_model_weights.choose(cpu_state_dict) + try: + # Save weights directly to SafeTensors file + weights_file = os.path.join(temp_dir, "model_weights.safetensors") + state_dict = {name: param for name, param in self.model.named_parameters()} + safetensors.save_file(state_dict, weights_file) + + # Create weights handle with the SafeTensors file path + param_names = list(state_dict.keys()) + weights_handle = WeightsHandle( + handle_type=WeightsHandleType.FILE, + version=step, + payload={ + "param_names": param_names, + "model_path": weights_file, + "model_name": self.model_name, + }, + ) + + except Exception as e: + # Clean up temporary directory if something goes wrong + clean_up_temp_dir(temp_dir) + raise e # Set model back to training mode self.model.train() @@ -214,6 +263,7 @@ async def update_weights(self, policy_actor): # Log the time taken end_time = time.time() self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds") + return weights_handle class RewardActor(ForgeActor): @@ -342,9 +392,9 @@ async def __next__(self) -> dict[str, str] | None: return None -async def main(): +async def _main(): """Main GRPO training loop with rollout and training processes.""" - group_size = 1 + group_size = 16 model = "Qwen/Qwen3-1.7B" # ---- Setup WandB Logger ---- # @@ -412,7 +462,6 @@ async def main(): reward_functions=[MathReward(), ThinkingReward()], ), ) - print("All services initialized successfully!") # ---- Core RL loops ---- # @@ -424,27 +473,32 @@ async def continuous_rollouts(): print("Dataloader is empty, exiting continuous rollout") return prompt, target = sample["question"], sample["answer"] - version = 0 # await policy.get_current_version.choose() + + response = await policy.generate.choose(prompt) + actions = response.completions + version = response.policy_version episode = Episode( episode_id=rollout_count, prompt=prompt, target=target, policy_version=version, ) - actions = await policy.generate.choose(prompt) - for action in actions: - ref_logprobs = await ref_model.forward.choose(action.token_ids) - reward = await reward_actor.evaluate_response.choose( - prompt=prompt, response=action.text, target=target + + async def _get_group(action): + ref_logprobs, reward = await asyncio.gather( + ref_model.forward.choose(action.token_ids), + reward_actor.evaluate_response.choose( + prompt=prompt, response=action.text, target=target + ), ) - episode.add_group( - Group( - response=action.text, - ref_logprobs=ref_logprobs, - reward=reward, - ) + return Group( + response=action.text, ref_logprobs=ref_logprobs, reward=reward ) + groups = await asyncio.gather(*[_get_group(action) for action in actions]) + + episode.add_groups(groups) + advantages = await compute_advantages.__call__.choose(episode.groups) for advantage, group in zip(advantages, episode.groups): group.advantage = advantage @@ -452,33 +506,47 @@ async def continuous_rollouts(): await replay_buffer.add.choose(episode) rollout_count += 1 + rewards = [] if rollout_count % 10 == 0: - avg_reward = sum(group.reward for group in episode.groups) / len( - episode.groups - ) + episode_avg_reward = sum( + group.reward for group in episode.groups + ) / len(episode.groups) + rewards.append(episode_avg_reward) + avg_reward = sum(rewards) / len(rewards) + rewards.clear() print( - f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" + f"Generated {rollout_count} rollouts, average reward of last 10 = {avg_reward}" ) logger.log("reward/rollout", avg_reward, rollout_count) async def continuous_training(): training_step = 0 + update_period = FLAGS.update_period + # using training_step as the policy version for now, open to suggestions while True: - batch = await replay_buffer.sample.choose(curr_policy_version=0) + batch = await replay_buffer.sample.choose(curr_policy_version=training_step) if batch is None: await asyncio.sleep(0.1) else: + # why is call_one not defined? + # training_result = await trainer.train_step.call_one(batch) training_result = await trainer.train_step.choose(batch) training_step += 1 - if training_step % 10 == 0: - print(f"Completed {training_step} training steps") - if training_result: - loss_value = training_result.get("loss", 0.0) - print(f"Latest loss: {loss_value}") - logger.log("loss/training_step", loss_value, training_step) - # await trainer.update_weights(policy) + print(f"Completed {training_step} training steps") + if training_result: + loss_value = training_result.get("loss", 0.0) + print(f"Latest loss: {loss_value}") + logger.log("loss/training_step", loss_value, training_step) + if training_step % update_period == 0: + print(f"Exporting policy weights @ {training_step=}") + weights_handle = await trainer.export_weights.choose(training_step) + print(f"Exported weights @ {training_step=}") + await policy.update_weights.call(weights_handle) + print(f"Updated policy weights to version @ {training_step=}") + clean_up_temp_dir(weights_handle.payload["model_path"]) print("Starting GRPO training loops...") + rollout_task = asyncio.create_task(continuous_rollouts()) training_task = asyncio.create_task(continuous_training()) @@ -501,5 +569,9 @@ async def continuous_training(): ) +def main(argv): + asyncio.run(_main()) + + if __name__ == "__main__": - asyncio.run(main()) + app.run(main) diff --git a/apps/vllm/main.py b/apps/vllm/main.py index 2d3c81ad9..7ed7fddf1 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -15,7 +15,13 @@ from argparse import Namespace from typing import List -from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig +from forge.actors.policy import ( + CompletionPolicyResponse, + Policy, + PolicyConfig, + SamplingOverrides, + WorkerConfig, +) from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from vllm.outputs import CompletionOutput @@ -89,7 +95,9 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: async with policy.session(): print("Requesting generation...") - responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) + responses: List[CompletionOutput] = await policy.generate.choose( + prompt=prompt + ).completions print("\nGeneration Results:") print("=" * 80) diff --git a/pyproject.toml b/pyproject.toml index ccb0cb12c..fdded605f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ # Miscellaneous "omegaconf", "wandb", + "aiorwlock", ] dynamic = ["version"] diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 4a51f7225..b9cb482d7 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -13,7 +13,18 @@ from typing import Dict, List import torch -from monarch.actor import current_rank, endpoint, ProcMesh + +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh +from forge.controller.synchronization import AsyncRWLock + +from forge.data.sharding import VLLMSharding +from forge.data.weights_handle import WeightsHandle, WeightsHandleType +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig +from monarch.actor import Actor, current_rank, endpoint, proc_mesh, ProcMesh + +from tenacity import retry, wait_exponential + from torchstore import MultiProcessStore from torchstore._state_dict_utils import DELIM @@ -37,13 +48,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh - -from forge.data.sharding import VLLMSharding -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig - - logger = logging.getLogger(__name__) @@ -93,6 +97,12 @@ class PolicyConfig: available_devices: str = None +@dataclass +class CompletionPolicyResponse: + policy_version: int + completions: list[CompletionOutput] + + @dataclass class Policy(PolicyInterface): config: PolicyConfig @@ -101,6 +111,9 @@ class Policy(PolicyInterface): lora_request: LoRARequest | None = None tokenization_kwargs: dict = field(default_factory=dict) policy_worker: "PolicyWorker" = None + # TODO: figure out if a finer grained versioing/locking scheme is needed/possible. + policy_version: int = 0 + model_lock: AsyncRWLock = field(default_factory=AsyncRWLock) def __post_init__(self): self._run_task: asyncio.Task | None = None @@ -215,8 +228,10 @@ def start_processing(self): if self._run_task is None or self._run_task.done(): self._run_task = asyncio.create_task(self.run()) - @endpoint - async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: + async def _generate_nolock( + self, prompt: str, priority: int = 0 + ) -> CompletionPolicyResponse: + """Returns the list of completions and policy version. Not safe to use without reader lock.""" self.request_id += 1 % sys.maxsize request_id = str(self.request_id) # implement from a counter @@ -273,7 +288,19 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_fut = asyncio.Future() self.requests[request_id] = (parent_req, request_fut) - return await request_fut + completions = await request_fut + return CompletionPolicyResponse( + policy_version=self.policy_version, completions=completions + ) + + @endpoint + @retry(wait=wait_exponential(multiplier=1, min=10, max=60)) + async def generate( + self, prompt: str, priority: int = 0 + ) -> CompletionPolicyResponse: + """Returns the list of completions and policy version.""" + async with self.model_lock.read_lock(): + return await self._generate_nolock(prompt, priority) # Abstracted to match vllm # https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 @@ -313,9 +340,16 @@ async def run(self): fut.set_result(request_output.outputs) @endpoint - async def update_weights(self): + async def update_weights(self, weights_handle: WeightsHandle): """Update the policy weights.""" - pass + logger.debug("Policy.update_weights() called") + async with self.model_lock.write_lock(): + assert self.policy_version < weights_handle.version, ( + "current policy version is not older than weights version to be updated\n" + + f"{self.policy_version=}, {weights_handle.version=}" + ) + await self.policy_worker.update.call(weights_handle) + self.policy_version = weights_handle.version @endpoint async def stop(self): @@ -365,7 +399,7 @@ def __post_init__(self): for key in cfg: value = getattr(self, key) if key != "data_parallel_size" else 1 if getattr(self.vllm_args, key) != value: - logger.warning( + self.logger.warning( f"{key} args don't match value in EngineArgs, overriding with {value}" ) setattr(self.vllm_args, key, value) @@ -408,8 +442,7 @@ async def _load_tensor_parallel_state_dict(self, current_state_dict: dict): updated_count += 1 - @endpoint - async def update(self): + async def _update_torchstore(self): """Update model weights by reading state dict from torchstore""" if self.torchstore is None: @@ -428,6 +461,37 @@ async def update(self): logger.debug("Successfully updated model weights from torchstore") + async def _update_from_file(self, weights_handle: WeightsHandle): + """Update model weights by reading state dict from file""" + import safetensors.torch as safetensors + + file_path = weights_handle.payload["model_path"] + + self.logger.debug(f"Loading model weights from SafeTensors file: {file_path}") + + state_dict = safetensors.load_file(file_path) + + # Convert to the format expected by load_weights: Iterable[tuple[str, Tensor]] + new_weights = [(name, tensor) for name, tensor in state_dict.items()] + + # Use the model's built-in load_weights method which handles tensor parallelism correctly and conversion + # from huggingface format to vllm format + model = self.worker.model_runner.model + model.load_weights(new_weights) + + @endpoint + async def update(self, weights_handle: WeightsHandle): + logger.debug("PolicyWorker.update() called") + match weights_handle.handle_type: + case WeightsHandleType.FILE: + await self._update_from_file(weights_handle) + case WeightsHandleType.TORCH_STORE: + await self._update_torchstore() + case _: + raise ValueError( + f"Unsupported weights handle type: {weights_handle.handle_type}" + ) + @endpoint async def setup_kv_cache(self): """Based on vllm/v1/engine/core.py:EngineCore._initialize_kv_caches diff --git a/src/forge/controller/synchronization.py b/src/forge/controller/synchronization.py new file mode 100644 index 000000000..605e1f566 --- /dev/null +++ b/src/forge/controller/synchronization.py @@ -0,0 +1,108 @@ +"""Asynchronous reader-writer lock implementation for concurrent access control.""" + +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncGenerator + + +class AsyncRWLock: + """ + An asynchronous reader-writer lock that allows multiple readers or a single writer. + + This implementation prioritizes writers - when a writer is waiting, new read requests + are rejected to prevent writer starvation. + + Usage: + lock = AsyncRWLock() + + async with lock.read_lock(): + # Multiple readers can access concurrently + pass + + async with lock.write_lock(): + # Only one writer can access at a time + pass + """ + + def __init__(self) -> None: + """Initialize the reader-writer lock.""" + self._readers = 0 + self._writer = False + self._writer_waiting = False + self._condition = asyncio.Condition() + + async def acquire_read(self) -> None: + """ + Acquire a read lock. + + Raises: + RuntimeError: If a writer is waiting (prevents writer starvation). + """ + async with self._condition: + if self._writer_waiting: + raise RuntimeError("Read lock request canceled: writer is waiting") + + while self._writer: + await self._condition.wait() + + self._readers += 1 + + async def release_read(self) -> None: + """Release a read lock.""" + async with self._condition: + self._readers -= 1 + if self._readers == 0: + self._condition.notify_all() + + async def acquire_write(self) -> None: + """ + Acquire a write lock. + + Waits until no other readers or writers are active. + """ + async with self._condition: + self._writer_waiting = True + + while self._writer or self._readers > 0: + await self._condition.wait() + + self._writer = True + self._writer_waiting = False + + async def release_write(self) -> None: + """Release a write lock.""" + async with self._condition: + self._writer = False + self._condition.notify_all() + + @asynccontextmanager + async def read_lock(self) -> AsyncGenerator[None, None]: + """ + Context manager for acquiring a read lock. + + Example: + async with lock.read_lock(): + # Multiple readers can access this section concurrently + data = shared_resource.read() + """ + await self.acquire_read() + try: + yield + finally: + await self.release_read() + + @asynccontextmanager + async def write_lock(self) -> AsyncGenerator[None, None]: + """ + Context manager for acquiring a write lock. + + Example: + async with lock.write_lock(): + # Only one writer can access this section at a time + shared_resource.write(data) + """ + await self.acquire_write() + try: + yield + finally: + await self.release_write() diff --git a/src/forge/data/weights_handle.py b/src/forge/data/weights_handle.py new file mode 100644 index 000000000..a7d9962bd --- /dev/null +++ b/src/forge/data/weights_handle.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class WeightsHandleType(Enum): + SIMPLE = 1 # passing state_dict directly (is this going to work?) + TORCH_STORE = 2 # using torchstore + FILE = 3 # using file system + + +@dataclass +class WeightsHandle: + """Handle for weights to be transferred between policy and trainer.""" + + handle_type: WeightsHandleType = field(default=WeightsHandleType.SIMPLE) + version: int = field(default=0) + payload: Any = field(default_factory=dict) diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 3dbbd560e..ed4d433db 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -7,12 +7,12 @@ from abc import ABC, abstractmethod from typing import Any, Mapping -from monarch.actor import endpoint - from forge.controller import ForgeActor from forge.types import Action, Message, Observation, Scalar, State +from monarch.actor import Actor, endpoint + class Transform(ABC): """Abstract base class for observation transforms. diff --git a/tests/unit_tests/test_synchronization.py b/tests/unit_tests/test_synchronization.py new file mode 100644 index 000000000..d2e396412 --- /dev/null +++ b/tests/unit_tests/test_synchronization.py @@ -0,0 +1,255 @@ +"""Tests for AsyncRWLock synchronization primitives.""" + +import asyncio + +import pytest +from forge.controller.synchronization import AsyncRWLock + + +class TestAsyncRWLock: + """Test suite for AsyncRWLock reader-writer lock implementation.""" + + @pytest.fixture + def lock(self): + """Create a fresh AsyncRWLock for each test.""" + return AsyncRWLock() + + @pytest.mark.asyncio + async def test_writer_blocks_readers(self, lock): + """Test that when writer has acquired the lock, no reader can read.""" + writer_acquired = asyncio.Event() + reader_attempted = asyncio.Event() + reader_acquired = asyncio.Event() + writer_released = asyncio.Event() + + async def writer_task(): + """Writer that holds the lock for a period.""" + async with lock.write_lock(): + writer_acquired.set() + # Wait for reader to attempt access + await reader_attempted.wait() + # Hold the lock briefly to ensure reader is blocked + await asyncio.sleep(0.1) + writer_released.set() + + async def reader_task(): + """Reader that attempts to acquire lock after writer.""" + # Wait for writer to acquire lock first + await writer_acquired.wait() + reader_attempted.set() + + # This should block until writer releases + async with lock.read_lock(): + reader_acquired.set() + + # Start both tasks + writer_task_handle = asyncio.create_task(writer_task()) + reader_task_handle = asyncio.create_task(reader_task()) + + # Wait for writer to acquire and reader to attempt + await writer_acquired.wait() + await reader_attempted.wait() + + # Reader should not have acquired yet (writer still holds lock) + assert not reader_acquired.is_set() + + # Wait for writer to finish + await writer_task_handle + assert writer_released.is_set() + + # Now reader should be able to acquire + await reader_task_handle + assert reader_acquired.is_set() + + @pytest.mark.asyncio + async def test_waiting_writer_blocks_new_readers(self, lock): + """Test that when writer is waiting for lock, new readers can't acquire lock.""" + reader1_acquired = asyncio.Event() + writer_waiting = asyncio.Event() + reader2_attempted = asyncio.Event() + reader2_failed = asyncio.Event() + + async def reader1_task(): + """First reader that holds the lock.""" + async with lock.read_lock(): + reader1_acquired.set() + # Hold lock to make writer wait + await asyncio.sleep(0.2) + + async def writer_task(): + """Writer that waits for reader1 to finish.""" + # Wait for reader1 to acquire lock + await reader1_acquired.wait() + writer_waiting.set() + + # This will block until reader1 releases + async with lock.write_lock(): + pass + + async def reader2_task(): + """Second reader that should be rejected due to waiting writer.""" + # Wait for writer to start waiting + await writer_waiting.wait() + reader2_attempted.set() + + # This should raise RuntimeError + try: + async with lock.read_lock(): + # Should not reach here + assert ( + False + ), "Reader2 should not acquire lock when writer is waiting" + except RuntimeError as e: + assert "writer is waiting" in str(e) + reader2_failed.set() + + # Start all tasks + reader1_task_handle = asyncio.create_task(reader1_task()) + writer_task_handle = asyncio.create_task(writer_task()) + reader2_task_handle = asyncio.create_task(reader2_task()) + + # Wait for sequence of events + await reader1_acquired.wait() + await writer_waiting.wait() + await reader2_attempted.wait() + await reader2_failed.wait() + + # Verify reader2 was properly rejected + assert reader2_failed.is_set() + + # Clean up + await reader1_task_handle + await writer_task_handle + await reader2_task_handle + + @pytest.mark.asyncio + async def test_writer_blocks_second_writer(self, lock): + """Test that when writer has acquired lock, second writer can't write.""" + writer1_acquired = asyncio.Event() + writer2_attempted = asyncio.Event() + writer2_acquired = asyncio.Event() + writer1_released = asyncio.Event() + + async def writer1_task(): + """First writer that holds the lock.""" + async with lock.write_lock(): + writer1_acquired.set() + # Wait for writer2 to attempt access + await writer2_attempted.wait() + # Hold the lock to ensure writer2 is blocked + await asyncio.sleep(0.1) + writer1_released.set() + + async def writer2_task(): + """Second writer that should wait for first writer.""" + # Wait for writer1 to acquire lock + await writer1_acquired.wait() + writer2_attempted.set() + + # This should block until writer1 releases + async with lock.write_lock(): + writer2_acquired.set() + + # Start both tasks + writer1_task_handle = asyncio.create_task(writer1_task()) + writer2_task_handle = asyncio.create_task(writer2_task()) + + # Wait for writer1 to acquire and writer2 to attempt + await writer1_acquired.wait() + await writer2_attempted.wait() + + # Writer2 should not have acquired yet (writer1 still holds lock) + assert not writer2_acquired.is_set() + + # Wait for writer1 to finish + await writer1_task_handle + assert writer1_released.is_set() + + # Now writer2 should be able to acquire + await writer2_task_handle + assert writer2_acquired.is_set() + + @pytest.mark.asyncio + async def test_multiple_readers_concurrent_access(self, lock): + """Test that multiple readers can access the resource concurrently.""" + num_readers = 3 + readers_acquired = [] + all_readers_active = asyncio.Event() + + async def reader_task(reader_id): + """Reader task that signals when it has acquired the lock.""" + async with lock.read_lock(): + readers_acquired.append(reader_id) + # Wait for all readers to be active + if len(readers_acquired) == num_readers: + all_readers_active.set() + await all_readers_active.wait() + # Hold lock briefly to ensure concurrent access + await asyncio.sleep(0.1) + + # Start multiple reader tasks + reader_tasks = [asyncio.create_task(reader_task(i)) for i in range(num_readers)] + + # Wait for all readers to complete + await asyncio.gather(*reader_tasks) + + # Verify all readers were able to acquire locks + assert len(readers_acquired) == num_readers + assert set(readers_acquired) == set(range(num_readers)) + + @pytest.mark.asyncio + async def test_lock_state_after_exception(self, lock): + """Test that lock state is properly cleaned up after exceptions.""" + reader_acquired = asyncio.Event() + exception_handled = asyncio.Event() + + async def reader_with_exception(): + """Reader that raises an exception while holding lock.""" + try: + async with lock.read_lock(): + reader_acquired.set() + raise ValueError("Test exception") + except ValueError: + exception_handled.set() + + async def subsequent_reader(): + """Reader that should be able to acquire lock after exception.""" + await exception_handled.wait() + async with lock.read_lock(): + # Should be able to acquire lock successfully + pass + + # Run tasks + await asyncio.gather( + asyncio.create_task(reader_with_exception()), + asyncio.create_task(subsequent_reader()), + ) + + # Verify exception was handled and lock is available + assert exception_handled.is_set() + + # Verify lock state is clean (no readers or writers) + assert lock._readers == 0 + assert lock._writer is False + assert lock._writer_waiting is False + + @pytest.mark.asyncio + async def test_context_manager_cleanup(self, lock): + """Test that context managers properly clean up lock state.""" + # Test read lock cleanup + async with lock.read_lock(): + assert lock._readers == 1 + assert lock._readers == 0 + + # Test write lock cleanup + async with lock.write_lock(): + assert lock._writer is True + assert lock._writer is False + + # Test nested operations don't interfere + async with lock.read_lock(): + assert lock._readers == 1 + async with lock.read_lock(): + assert lock._readers == 2 + assert lock._readers == 1 + assert lock._readers == 0