Skip to content

Commit 3e32264

Browse files
committed
fix typo
1 parent 3ba0df6 commit 3e32264

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

apps/grpo/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class Trainer(ForgeActor):
136136
epsilon: float = 0.1
137137
device: torch.device | None = None
138138

139+
@endpoint
139140
def setup(self):
140141
# Set device
141142
if self.device is None:
@@ -255,7 +256,7 @@ class ComputeAdvantages(ForgeActor):
255256
async def compute(self, group: Group) -> list[float]:
256257
# TODO: add batch processing
257258
rewards = torch.Tensor([[e.reward for e in group.episodes]])
258-
advantages = (rewards - rewards.me / an(1, keepdim=True)) / (
259+
advantages = (rewards - rewards.mean(1, keepdim=True)) / (
259260
rewards.std(1, keepdim=True) + 1e-4
260261
)
261262
return advantages.squeeze(0)

0 commit comments

Comments
 (0)