Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ async def continuous_rollouts():
return
prompt, target = sample["request"], sample["target"]
responses = await policy.generate.choose(prompt)
# TODO: this shall be part of the responses metadata instead of a separate call
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

version = await policy.get_version.choose()
group = Group.new_group(
group_id=rollout_count,
Expand Down Expand Up @@ -343,21 +344,17 @@ async def continuous_rollouts():

async def continuous_training():
training_step = 0
policy_version = 0
while True:
batch = await replay_buffer.sample.choose(
curr_policy_version=policy_version
)
batch = await replay_buffer.sample.choose(curr_policy_version=training_step)
if batch is None:
await asyncio.sleep(0.1)
else:
inputs, targets = batch
loss = await trainer.train_step.choose(inputs, targets)
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.call(policy_version)
policy_version += 1
await policy.update_weights.call()
await trainer.push_weights.call(training_step)
await policy.update_weights.call(training_step)

print("Starting GRPO training loops...")
# TODO: Start multiple rollouts once all serivces support it
Expand Down
13 changes: 5 additions & 8 deletions apps/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,21 +464,18 @@ async def continuous_rollouts():

async def continuous_training():
training_step = 0
policy_version = 0
while True:
batch = await replay_buffer.sample.choose(
curr_policy_version=policy_version
)
batch = await replay_buffer.sample.choose(curr_policy_version=training_step)
if batch is None:
await asyncio.sleep(0.1)
else:
loss = await trainer.train_step.choose(batch[0])
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
print(f"loss/training_step: {loss} at {training_step}")
await trainer.push_weights.call(policy_version)
policy_version += 1
await policy.update_weights.call()
print(f"loss/training_step: {loss} at training step {training_step}")
await trainer.push_weights.call(training_step)
await policy.update_weights.call(training_step)
# NOTE: hard-coded to be on-policy for faster convergence
await replay_buffer.clear.call()

print("Starting training loop.")
Expand Down
13 changes: 7 additions & 6 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ class Policy(PolicyInterface):
lora_request: LoRARequest | None = None
tokenization_kwargs: dict = field(default_factory=dict)
policy_worker: "PolicyWorker" = None
policy_version: int | None = None

def __post_init__(self):
self._run_task: asyncio.Task | None = None
self._policy_proc: ProcMesh | None = None
self._worker_procs: ProcMesh | None = None
self.weights_version: int = 0
self.running = False
if isinstance(self.engine_config, Mapping):
self.engine_config = EngineConfig.from_dict(self.engine_config)
Expand Down Expand Up @@ -212,6 +212,7 @@ async def setup(self):
await self.policy_worker.setup.call()

self.request_id = 0
self.policy_version = 0
self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
self.vllm_config: VllmConfig = self.engine_config.create_vllm_config()

Expand Down Expand Up @@ -364,17 +365,17 @@ async def run(self):
fut.set_result(request_output)

@endpoint
async def update_weights(self):
async def update_weights(self, policy_version: int):
# TODO: If generating long sequences, this might be long and will block policy weight updates
curr_requests = [fut for _, fut in self.requests.values()]
if curr_requests:
self.logger.debug(f"Waiting for {len(curr_requests)} pending requests")
await asyncio.gather(*curr_requests)

self.logger.debug(f"Starting weight update on {self.__class__.__name__}")
await self.policy_worker.update.call(version=self.weights_version)
self.weights_version += 1
self.logger.info(f"Weight update completed (now v{self.weights_version})")
await self.policy_worker.update.call(version=policy_version)
self.policy_version = policy_version
self.logger.info(f"Weight update completed (now v{self.policy_version})")

@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
Expand All @@ -388,7 +389,7 @@ async def _get_model_params(self) -> dict[str, torch.Tensor]:
@endpoint
async def get_version(self) -> int:
"""Get the current policy version."""
return self.weights_version
return self.policy_version

@endpoint
async def stop(self):
Expand Down
15 changes: 15 additions & 0 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import random
from dataclasses import dataclass
from typing import Any, Callable
Expand All @@ -12,6 +13,9 @@

from forge.controller import ForgeActor

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@dataclass
class ReplayBuffer(ForgeActor):
Expand All @@ -23,6 +27,9 @@ class ReplayBuffer(ForgeActor):
seed: int | None = None
collate: Callable = lambda batch: batch

def __post_init__(self):
super().__init__()

@endpoint
async def setup(self) -> None:
self.buffer: list = []
Expand Down Expand Up @@ -87,11 +94,18 @@ async def evict(self, curr_policy_version: int) -> None:
self._evict(curr_policy_version)

def _evict(self, curr_policy_version: int) -> None:
buffer_len_before_evict = len(self.buffer)
self.buffer = [
trajectory
for trajectory in self.buffer
if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age
]
buffer_len_after_evict = len(self.buffer)

logger.debug(
f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, "
f"{buffer_len_before_evict - buffer_len_after_evict} episodes expired, {buffer_len_after_evict} episodes left"
)

@endpoint
async def _getitem(self, idx: int):
Expand All @@ -106,6 +120,7 @@ async def _numel(self) -> int:
async def clear(self) -> None:
"""Clear the replay buffer immediately - dropping all episodes."""
self.buffer.clear()
logger.debug("replay buffer cleared")

@endpoint
async def state_dict(self) -> dict[str, Any]:
Expand Down
8 changes: 6 additions & 2 deletions src/forge/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,12 @@ async def generate(self, request: Observation) -> Action:

@endpoint
@abstractmethod
async def update_weights(self):
"""Update the policy weights."""
async def update_weights(self, policy_version: int):
"""Update the policy weights.

Args:
policy_version: The version number to update to.
"""
pass


Expand Down
Loading