|
29 | 29 | from forge.controller import ForgeActor |
30 | 30 | from forge.observability.metrics import record_metric, Reduce |
31 | 31 | from forge.observability.perf_tracker import Tracer |
| 32 | +from forge.util.ops import compute_logprobs |
32 | 33 |
|
33 | 34 | logger = logging.getLogger(__name__) |
34 | 35 | logger.setLevel(logging.INFO) |
@@ -90,8 +91,23 @@ async def setup(self): |
90 | 91 | self.model.eval() |
91 | 92 |
|
92 | 93 | @endpoint |
93 | | - async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
94 | | - |
| 94 | + async def forward( |
| 95 | + self, input_ids: torch.Tensor, max_req_tokens: int, return_logprobs: bool |
| 96 | + ) -> torch.Tensor: |
| 97 | + """ |
| 98 | + Args: |
| 99 | + input_ids (torch.Tensor): input token ids with shape [group_size, req + res length]. |
| 100 | + max_req_tokens (int): maximum request length. |
| 101 | + return_logprobs (bool): whether to return og probabilities instead of raw logits. |
| 102 | +
|
| 103 | + return_logprobs flag significantly impacts the amount of data transferred to the caller: |
| 104 | + - When False: Returns logits with shape [group_size, req + res_length, vocab_size]. |
| 105 | + This includes the full vocabulary distribution for each token position. |
| 106 | +
|
| 107 | + - When True: Returns log probabilities with shape [group_size, req_length]. |
| 108 | + This only includes probabilities for the request tokens, significantly reducing memory |
| 109 | + usage and transfer overhead. |
| 110 | + """ |
95 | 111 | # Record reference model metrics |
96 | 112 | record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM) |
97 | 113 | record_metric( |
@@ -133,5 +149,12 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
133 | 149 | if isinstance(logits, DTensor): |
134 | 150 | logits = logits.full_tensor() |
135 | 151 | t.step("forward") |
136 | | - t.stop() |
137 | | - return logits |
| 152 | + |
| 153 | + if not return_logprobs: |
| 154 | + t.stop() |
| 155 | + return logits |
| 156 | + else: |
| 157 | + logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:]) |
| 158 | + t.step("compute_logprobs") |
| 159 | + t.stop() |
| 160 | + return logprobs |
0 commit comments