Skip to content

Commit b584f4e

Browse files
refactor: move compute_logprobs_parallel into ops alongside compute_logprobs
1 parent a2a9567 commit b584f4e

File tree

3 files changed

+129
-162
lines changed

3 files changed

+129
-162
lines changed

src/forge/util/ops.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
import torch.distributed as dist
89
import torch.nn.functional as F
910

11+
from torch.distributed.tensor import DTensor
12+
1013

1114
def compute_logprobs(
1215
logits: torch.Tensor,
@@ -95,3 +98,124 @@ def compute_logprobs(
9598
)
9699

97100
return logprobs.reshape(batch_size, seq_len)
101+
102+
103+
def compute_logprobs_parallel(
104+
logits: DTensor,
105+
target_ids: torch.Tensor,
106+
temperature: float = 1.0,
107+
align: bool = True,
108+
) -> torch.Tensor:
109+
"""
110+
Compute log probabilities for target tokens from vocab-sharded DTensor logits.
111+
112+
This function computes log_softmax(logits)[target_ids] distributedly,
113+
without ever gathering the full vocabulary dimension.
114+
115+
IMPORTANT: Only use this when logits is a DTensor sharded on vocab dimension.
116+
For regular tensors or non-vocab-sharded DTensors, use compute_logprobs instead.
117+
118+
Args:
119+
logits: DTensor of shape [batch_size, seq_len, vocab_size], sharded on dim=-1.
120+
target_ids: Tensor of shape [batch_size, target_len] with target token IDs.
121+
temperature: Temperature for scaling logits (default 1.0).
122+
align: If True, slice logits to align with target_ids (default True).
123+
124+
Returns:
125+
Tensor of shape [batch_size, target_len] with log probabilities.
126+
"""
127+
# Get sharding info using helper
128+
tp_group, tp_rank, tp_size, vocab_start, vocab_end = get_vocab_shard_info(logits)
129+
130+
if tp_group is None:
131+
# DTensor but not sharded on vocab (Replicate or other dim sharding)
132+
return compute_logprobs(logits.full_tensor(), target_ids, temperature, align)
133+
134+
# Get the local shard
135+
local_logits = logits._local_tensor # [batch, seq_len, vocab_size / tp_size]
136+
137+
# Align logits with target if needed
138+
if align:
139+
# Slice to match target length: logits[:, -target_len-1:-1, :]
140+
target_len = target_ids.size(1)
141+
local_logits = local_logits[:, -target_len - 1 : -1, :]
142+
143+
# Scale by temperature
144+
local_logits = local_logits / temperature
145+
146+
batch_size, seq_len, local_vocab_size = local_logits.shape
147+
148+
# Move target_ids to the same device as local_logits
149+
target_ids = target_ids.to(local_logits.device)
150+
151+
# Cast to float32 for numerical stability
152+
local_logits_fp32 = local_logits.float()
153+
154+
# Compute global max across all shards for numerical stability
155+
local_max = local_logits_fp32.max(dim=-1, keepdim=True).values
156+
global_max = local_max.clone()
157+
dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=tp_group)
158+
159+
# Compute global sum(exp(x - max)) for the log-sum-exp trick
160+
local_exp = torch.exp(local_logits_fp32 - global_max)
161+
local_sum_exp = local_exp.sum(dim=-1, keepdim=True)
162+
global_sum_exp = local_sum_exp.clone()
163+
dist.all_reduce(global_sum_exp, op=dist.ReduceOp.SUM, group=tp_group)
164+
165+
# log_normalizer = global_max + log(global_sum_exp)
166+
log_normalizer = global_max + torch.log(global_sum_exp) # [batch, seq, 1]
167+
log_normalizer = log_normalizer.squeeze(-1) # [batch, seq]
168+
169+
# Extract logits at target positions - each rank only has part of the vocab
170+
is_local = (target_ids >= vocab_start) & (target_ids < vocab_end)
171+
172+
# Convert global indices to local indices (only valid where is_local=True)
173+
local_indices = target_ids - vocab_start
174+
local_indices = local_indices.clamp(0, local_vocab_size - 1) # Clamp for safety
175+
176+
target_logits = torch.gather(
177+
local_logits_fp32,
178+
dim=-1,
179+
index=local_indices.unsqueeze(-1).long(),
180+
).squeeze(-1)
181+
182+
# Zero out where this rank doesn't own the token, then reduce
183+
target_logits = target_logits * is_local.float()
184+
dist.all_reduce(target_logits, op=dist.ReduceOp.SUM, group=tp_group)
185+
186+
logprobs = target_logits - log_normalizer
187+
188+
return logprobs
189+
190+
191+
def get_vocab_shard_info(
192+
logits: DTensor,
193+
) -> tuple[dist.ProcessGroup | None, int, int, int, int]:
194+
"""
195+
Get vocabulary sharding information from a DTensor.
196+
197+
Args:
198+
logits: DTensor with shape [..., vocab_size], potentially sharded on vocab dim.
199+
200+
Returns:
201+
Tuple of (tp_group, tp_rank, tp_size, vocab_start, vocab_end).
202+
If not sharded, returns (None, 0, 1, 0, vocab_size).
203+
"""
204+
from torch.distributed.tensor.placement_types import Shard
205+
206+
local_logits = logits._local_tensor
207+
placements = logits.placements
208+
device_mesh = logits.device_mesh
209+
210+
for i, p in enumerate(placements):
211+
if isinstance(p, Shard) and p.dim == 2: # vocab dimension
212+
tp_group = device_mesh.get_group(mesh_dim=i)
213+
tp_size = dist.get_world_size(tp_group)
214+
tp_rank = dist.get_rank(tp_group)
215+
local_vocab_size = local_logits.shape[-1]
216+
vocab_start = tp_rank * local_vocab_size
217+
vocab_end = vocab_start + local_vocab_size
218+
return tp_group, tp_rank, tp_size, vocab_start, vocab_end
219+
220+
# Not sharded
221+
return None, 0, 1, 0, local_logits.shape[-1]

src/forge/util/parallel_logprobs.py

Lines changed: 0 additions & 160 deletions
This file was deleted.

tests/unit_tests/util/test_parallel_logprobs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
import torch
1515
import torch.distributed as dist
1616

17-
from forge.util.ops import compute_logprobs
18-
from forge.util.parallel_logprobs import compute_logprobs_parallel, get_vocab_shard_info
17+
from forge.util.ops import (
18+
compute_logprobs,
19+
compute_logprobs_parallel,
20+
get_vocab_shard_info,
21+
)
1922
from tests.test_utils import gpu_test
2023
from torch.distributed.device_mesh import init_device_mesh
2124
from torch.distributed.tensor import DTensor, Shard

0 commit comments

Comments
 (0)