Skip to content

Commit 4dd1025

Browse files
committed
Alias Generator as Policy
1 parent cadaef1 commit 4dd1025

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

apps/grpo/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def response_tensor(self) -> torch.Tensor:
7979
# Represents the group (G) of episodes in GRPO
8080
Group = list[Episode]
8181

82+
# Represents the Policy Model to collect data from
83+
Policy = Generator
84+
8285

8386
def collate(
8487
batches: list[Group],
@@ -317,7 +320,7 @@ async def main(cfg: DictConfig):
317320
reward_actor,
318321
) = await asyncio.gather(
319322
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
320-
Generator.options(**cfg.services.policy).as_service(**cfg.policy),
323+
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
321324
RLTrainer.options(**cfg.actors.trainer).as_actor(
322325
**cfg.trainer, loss=simple_grpo_loss
323326
),

0 commit comments

Comments
 (0)