We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3ba0df6 commit 3e32264Copy full SHA for 3e32264
apps/grpo/main.py
@@ -136,6 +136,7 @@ class Trainer(ForgeActor):
136
epsilon: float = 0.1
137
device: torch.device | None = None
138
139
+ @endpoint
140
def setup(self):
141
# Set device
142
if self.device is None:
@@ -255,7 +256,7 @@ class ComputeAdvantages(ForgeActor):
255
256
async def compute(self, group: Group) -> list[float]:
257
# TODO: add batch processing
258
rewards = torch.Tensor([[e.reward for e in group.episodes]])
- advantages = (rewards - rewards.me / an(1, keepdim=True)) / (
259
+ advantages = (rewards - rewards.mean(1, keepdim=True)) / (
260
rewards.std(1, keepdim=True) + 1e-4
261
)
262
return advantages.squeeze(0)
0 commit comments