diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7f31c26c9..b568705e4 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -286,6 +286,7 @@ async def continuous_rollouts(): return prompt, target = sample["request"], sample["target"] responses = await policy.generate.choose(prompt) + # TODO: this shall be part of the responses metadata instead of a separate call version = await policy.get_version.choose() group = Group.new_group( group_id=rollout_count, @@ -343,11 +344,8 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 - policy_version = 0 while True: - batch = await replay_buffer.sample.choose( - curr_policy_version=policy_version - ) + batch = await replay_buffer.sample.choose(curr_policy_version=training_step) if batch is None: await asyncio.sleep(0.1) else: @@ -355,9 +353,8 @@ async def continuous_training(): loss = await trainer.train_step.choose(inputs, targets) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - await trainer.push_weights.call(policy_version) - policy_version += 1 - await policy.update_weights.call() + await trainer.push_weights.call(training_step) + await policy.update_weights.call(training_step) print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index 6b1d8d763..d4780f2e6 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -464,21 +464,18 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 - policy_version = 0 while True: - batch = await replay_buffer.sample.choose( - curr_policy_version=policy_version - ) + batch = await replay_buffer.sample.choose(curr_policy_version=training_step) if batch is None: await asyncio.sleep(0.1) else: loss = await trainer.train_step.choose(batch[0]) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - print(f"loss/training_step: {loss} at {training_step}") - await trainer.push_weights.call(policy_version) - policy_version += 1 - await policy.update_weights.call() + print(f"loss/training_step: {loss} at training step {training_step}") + await trainer.push_weights.call(training_step) + await policy.update_weights.call(training_step) + # NOTE: hard-coded to be on-policy for faster convergence await replay_buffer.clear.call() print("Starting training loop.") diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 070c00798..7e2f181c7 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -123,12 +123,12 @@ class Policy(PolicyInterface): lora_request: LoRARequest | None = None tokenization_kwargs: dict = field(default_factory=dict) policy_worker: "PolicyWorker" = None + policy_version: int | None = None def __post_init__(self): self._run_task: asyncio.Task | None = None self._policy_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None - self.weights_version: int = 0 self.running = False if isinstance(self.engine_config, Mapping): self.engine_config = EngineConfig.from_dict(self.engine_config) @@ -212,6 +212,7 @@ async def setup(self): await self.policy_worker.setup.call() self.request_id = 0 + self.policy_version = 0 self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} self.vllm_config: VllmConfig = self.engine_config.create_vllm_config() @@ -364,7 +365,7 @@ async def run(self): fut.set_result(request_output) @endpoint - async def update_weights(self): + async def update_weights(self, policy_version: int): # TODO: If generating long sequences, this might be long and will block policy weight updates curr_requests = [fut for _, fut in self.requests.values()] if curr_requests: @@ -372,9 +373,9 @@ async def update_weights(self): await asyncio.gather(*curr_requests) self.logger.debug(f"Starting weight update on {self.__class__.__name__}") - await self.policy_worker.update.call(version=self.weights_version) - self.weights_version += 1 - self.logger.info(f"Weight update completed (now v{self.weights_version})") + await self.policy_worker.update.call(version=policy_version) + self.policy_version = policy_version + self.logger.info(f"Weight update completed (now v{self.policy_version})") @endpoint async def _get_model_params(self) -> dict[str, torch.Tensor]: @@ -388,7 +389,7 @@ async def _get_model_params(self) -> dict[str, torch.Tensor]: @endpoint async def get_version(self) -> int: """Get the current policy version.""" - return self.weights_version + return self.policy_version @endpoint async def stop(self): diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index ca7d487bc..fd60ce35c 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import random from dataclasses import dataclass from typing import Any, Callable @@ -12,6 +13,9 @@ from forge.controller import ForgeActor +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + @dataclass class ReplayBuffer(ForgeActor): @@ -23,6 +27,9 @@ class ReplayBuffer(ForgeActor): seed: int | None = None collate: Callable = lambda batch: batch + def __post_init__(self): + super().__init__() + @endpoint async def setup(self) -> None: self.buffer: list = [] @@ -87,11 +94,18 @@ async def evict(self, curr_policy_version: int) -> None: self._evict(curr_policy_version) def _evict(self, curr_policy_version: int) -> None: + buffer_len_before_evict = len(self.buffer) self.buffer = [ trajectory for trajectory in self.buffer if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age ] + buffer_len_after_evict = len(self.buffer) + + logger.debug( + f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, " + f"{buffer_len_before_evict - buffer_len_after_evict} episodes expired, {buffer_len_after_evict} episodes left" + ) @endpoint async def _getitem(self, idx: int): @@ -106,6 +120,7 @@ async def _numel(self) -> int: async def clear(self) -> None: """Clear the replay buffer immediately - dropping all episodes.""" self.buffer.clear() + logger.debug("replay buffer cleared") @endpoint async def state_dict(self) -> dict[str, Any]: diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 3dbbd560e..df79c302e 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -85,8 +85,12 @@ async def generate(self, request: Observation) -> Action: @endpoint @abstractmethod - async def update_weights(self): - """Update the policy weights.""" + async def update_weights(self, policy_version: int): + """Update the policy weights. + + Args: + policy_version: The version number to update to. + """ pass