Skip to content

Commit 2a1546b

Browse files
committed
Get current version
1 parent 5935260 commit 2a1546b

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

apps/grpo/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ async def main():
4949

5050
async def continuous_rollouts():
5151
while True:
52-
current_version = await policy.get_current_version()
53-
episode = await generate_rollout()
52+
version = await policy.get_current_version.choose()
53+
episode = await generate_rollout(version)
5454
await replay_buffer.add.call(episode)
5555

5656
rollout_task = asyncio.create_task(continuous_rollouts())

src/forge/actors/policy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class PolicyRouter(Actor):
4343
sampling_params: SamplingParams = None
4444
lora_request: LoRARequest = None
4545
tokenization_kwargs: dict = None
46+
version: int = 0
4647

4748
@endpoint
4849
async def setup(self):
@@ -78,6 +79,10 @@ async def setup(self):
7879
log_stats=None,
7980
)
8081

82+
@endpoint
83+
async def get_current_version(self) -> int:
84+
return self.version
85+
8186
@endpoint
8287
async def generate(self, prompt: str, priority: int = 0):
8388
self.request_id += 1 % sys.maxsize

0 commit comments

Comments
 (0)