Skip to content

Commit 4044087

Browse files
committed
Last updates
1 parent 7eedc91 commit 4044087

File tree

2 files changed

+2
-5
lines changed

2 files changed

+2
-5
lines changed

apps/grpo/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ async def train_step(self, batch: list[list[Episode]]):
196196

197197
mask = response != pad_id
198198
loss = self.loss(logprobs, ref_logprobs, advantages, mask)
199-
self.optimizer.zero_grad()
200199
loss.backward()
201200
self.optimizer.step()
201+
self.optimizer.zero_grad(set_to_none=True)
202202

203203
return loss.item()
204204

@@ -447,7 +447,7 @@ async def continuous_training():
447447
if batch is None:
448448
await asyncio.sleep(0.1)
449449
else:
450-
loss = sum(await trainer.train_step.call(batch))
450+
loss = await trainer.train_step.choose(batch)
451451
training_step += 1
452452
mlogger.log("loss/training_step", loss, training_step)
453453
await trainer.push_weights.call(policy_version)

src/forge/actors/policy.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,6 @@ async def generate(self, prompt: str, priority: int = 0) -> RequestOutput:
243243
Returns:
244244
RequestOutput: vLLM class with the generated response.
245245
"""
246-
return await self._generate(prompt, priority)
247-
248-
async def _generate(self, prompt: str, priority: int = 0) -> RequestOutput:
249246
self.request_id += 1 % sys.maxsize
250247
request_id = str(self.request_id) # implement from a counter
251248

0 commit comments

Comments
 (0)