Skip to content

Commit 82b89e6

Browse files
chore: make routing explicit
1 parent 0c2dd54 commit 82b89e6

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/forge/actors/reference_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ async def forward(
191191
return logits
192192
else:
193193
# Compute logprobs in parallel without gathering full vocab tensor
194+
# Use parallel version when TP is enabled (vocab sharded across GPUs)
194195
response_tokens = input_ids[:, max_req_tokens:]
195-
if isinstance(logits, DTensor):
196-
# Use parallel logprobs - avoids materializing full vocab on each GPU
196+
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
197197
logprobs = compute_logprobs_parallel(
198198
logits, response_tokens, align=True
199199
)

src/forge/util/parallel_logprobs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,14 @@ def compute_logprobs_parallel(
4646
align: bool = True,
4747
) -> torch.Tensor:
4848
"""
49-
Compute log probabilities for target tokens from vocab-sharded logits.
49+
Compute log probabilities for target tokens from vocab-sharded DTensor logits.
5050
5151
This function computes log_softmax(logits)[target_ids] distributedly,
5252
without ever gathering the full vocabulary dimension.
5353
54+
IMPORTANT: Only use this when logits is a DTensor sharded on vocab dimension.
55+
For regular tensors or non-vocab-sharded DTensors, use compute_logprobs instead.
56+
5457
Args:
5558
logits: DTensor of shape [batch_size, seq_len, vocab_size], sharded on dim=-1.
5659
target_ids: Tensor of shape [batch_size, target_len] with target token IDs.
@@ -64,7 +67,7 @@ def compute_logprobs_parallel(
6467
tp_group, tp_rank, tp_size, vocab_start, vocab_end = get_vocab_shard_info(logits)
6568

6669
if tp_group is None:
67-
# Not sharded on vocab (TP=1 or Replicate), use regular computation
70+
# DTensor but not sharded on vocab (Replicate or other dim sharding)
6871
return compute_logprobs(logits.full_tensor(), target_ids, temperature, align)
6972

7073
# Get the local shard

0 commit comments

Comments
 (0)