Skip to content

Commit 287d828

Browse files
authored
[Redo][Perf] make placement for logprob computation flexible (#287)
* initial commit * comment * rename * remove options from the grpo app + doc string * merge conflict * Re-trigger PR * ... * format
1 parent 8d84099 commit 287d828

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

apps/grpo/main.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -401,14 +401,14 @@ async def continuous_rollouts():
401401

402402
t.step("reward_evaluation")
403403

404-
# Calculate reference logprobs
405-
ref_logits = await ref_model.forward.route(input_ids)
406-
t.step("reference_model_forward")
404+
ref_logprobs = await ref_model.forward.route(
405+
input_ids, max_req_tokens, return_logprobs=True
406+
)
407+
t.step("reference_model_calculate_logprobs")
407408

408-
ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:])
409409
for i, episode in enumerate(group.episodes):
410410
episode.ref_logprobs = ref_logprobs[i]
411-
del ref_logits, ref_logprobs, input_ids
411+
del ref_logprobs, input_ids
412412
t.step("compute_logprobs")
413413

414414
# Calculate advantages and add to replay buffer

src/forge/actors/reference_model.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from forge.controller import ForgeActor
3030
from forge.observability.metrics import record_metric, Reduce
3131
from forge.observability.perf_tracker import Tracer
32+
from forge.util.ops import compute_logprobs
3233

3334
logger = logging.getLogger(__name__)
3435
logger.setLevel(logging.INFO)
@@ -90,8 +91,23 @@ async def setup(self):
9091
self.model.eval()
9192

9293
@endpoint
93-
async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
94-
94+
async def forward(
95+
self, input_ids: torch.Tensor, max_req_tokens: int, return_logprobs: bool
96+
) -> torch.Tensor:
97+
"""
98+
Args:
99+
input_ids (torch.Tensor): input token ids with shape [group_size, req + res length].
100+
max_req_tokens (int): maximum request length.
101+
return_logprobs (bool): whether to return og probabilities instead of raw logits.
102+
103+
return_logprobs flag significantly impacts the amount of data transferred to the caller:
104+
- When False: Returns logits with shape [group_size, req + res_length, vocab_size].
105+
This includes the full vocabulary distribution for each token position.
106+
107+
- When True: Returns log probabilities with shape [group_size, req_length].
108+
This only includes probabilities for the request tokens, significantly reducing memory
109+
usage and transfer overhead.
110+
"""
95111
# Record reference model metrics
96112
record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM)
97113
record_metric(
@@ -133,5 +149,12 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
133149
if isinstance(logits, DTensor):
134150
logits = logits.full_tensor()
135151
t.step("forward")
136-
t.stop()
137-
return logits
152+
153+
if not return_logprobs:
154+
t.stop()
155+
return logits
156+
else:
157+
logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:])
158+
t.step("compute_logprobs")
159+
t.stop()
160+
return logprobs

0 commit comments

Comments
 (0)