Skip to content

Commit 20d8644

Browse files
committed
Merge branch 'main' into language-reward-feature
2 parents 866dff7 + c30ddf5 commit 20d8644

File tree

3 files changed

+19
-21
lines changed

3 files changed

+19
-21
lines changed

apps/grpo/main.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -439,23 +439,21 @@ async def continuous_rollouts():
439439
input_ids[i, :max_req_tokens] = episode.request_tensor
440440
input_ids[i, max_req_tokens:] = episode.response_tensor
441441

442-
# drop episodes if
443-
# 1> reward std-dev is very small (including all 0s and all 1s)
444-
# 2> response is potentially truncated (response_len >= max_res_tokens)
445-
rewards = [e.reward for e in episodes]
446-
rewards_std = torch.std(torch.tensor(rewards))
447-
max_response_len = max(
448-
e.completion.token_ids.shape[0] for e in episodes
449-
)
450-
drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens
451-
record_metric(
452-
"main/continuous_rollouts/dropped_episodes",
453-
1 if drop else 0,
454-
Reduce.SUM,
455-
)
456-
if drop:
457-
del input_ids, episodes
458-
continue
442+
# drop episodes if
443+
# 1> reward std-dev is very small (including all 0s and all 1s)
444+
# 2> response is potentially truncated (response_len >= max_res_tokens)
445+
rewards = [e.reward for e in episodes]
446+
rewards_std = torch.std(torch.tensor(rewards))
447+
max_response_len = max(e.completion.token_ids.shape[0] for e in episodes)
448+
drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens
449+
record_metric(
450+
"main/continuous_rollouts/dropped_episodes",
451+
1 if drop else 0,
452+
Reduce.SUM,
453+
)
454+
if drop:
455+
del input_ids, episodes
456+
continue
459457

460458
t.step("reward_evaluation")
461459

src/forge/controller/launcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ async def launch_mast_job(self):
248248
scheduler_args={
249249
"hpcIdentity": "hyper_monarch",
250250
"hpcJobOncall": "monarch",
251-
"hpcClusterUuid": "MastProdCluster",
252-
"rmAttribution": "pytorch4all_clients_approved",
251+
"hpcClusterUuid": "MastGenAICluster",
252+
"rmAttribution": "msl_infra_hw_enab_agentrl",
253253
},
254254
appdef=self.build_appdef(),
255255
workspace=Workspace(

src/forge/controller/service/service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ async def _migrate_remaining_requests(self, failed_replica: Replica):
285285
return
286286

287287
# Distribute requests among healthy replicas
288-
for i, request in enumerate(migrated_requests):
289-
target_replica = healthy_replicas[i % len(healthy_replicas)]
288+
for request in migrated_requests:
289+
target_replica = self._default_router.get_replica(healthy_replicas)
290290
await target_replica.enqueue_request(request)
291291

292292
# Update session mapping if needed

0 commit comments

Comments
 (0)