Skip to content

Commit 833a6b6

Browse files
committed
Remove extraneous 'calculations'
1 parent a13a1ac commit 833a6b6

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,7 @@ def setup(self):
174174

175175
@endpoint
176176
async def train_step(self, batch: list[Episode]):
177-
total_loss = 0.0
178-
num_episodes_processed = 0
179177
pad_id = batch[0].pad_id
180-
bsz = len(batch)
181178

182179
# prepare batch
183180
request = [e.request_tensor for e in batch]
@@ -212,10 +209,7 @@ async def train_step(self, batch: list[Episode]):
212209

213210
self.optimizer.step()
214211

215-
total_loss += loss.item()
216-
avg_loss = total_loss / bsz
217-
218-
return {"loss": avg_loss, "episodes_processed": num_episodes_processed}
212+
return {"loss": loss.item()}
219213

220214
@endpoint
221215
async def update_weights(self, policy_actor):

0 commit comments

Comments
 (0)