We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cadaef1 commit 4dd1025Copy full SHA for 4dd1025
apps/grpo/main.py
@@ -79,6 +79,9 @@ def response_tensor(self) -> torch.Tensor:
79
# Represents the group (G) of episodes in GRPO
80
Group = list[Episode]
81
82
+# Represents the Policy Model to collect data from
83
+Policy = Generator
84
+
85
86
def collate(
87
batches: list[Group],
@@ -317,7 +320,7 @@ async def main(cfg: DictConfig):
317
320
reward_actor,
318
321
) = await asyncio.gather(
319
322
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
- Generator.options(**cfg.services.policy).as_service(**cfg.policy),
323
+ Policy.options(**cfg.services.policy).as_service(**cfg.policy),
324
RLTrainer.options(**cfg.actors.trainer).as_actor(
325
**cfg.trainer, loss=simple_grpo_loss
326
),
0 commit comments