Skip to content

Commit a2a9567

Browse files
style: clean up inline comments in parallel_logprobs
1 parent 82b89e6 commit a2a9567

File tree

1 file changed

+10
-31
lines changed

1 file changed

+10
-31
lines changed

src/forge/util/parallel_logprobs.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -83,66 +83,45 @@ def compute_logprobs_parallel(
8383
local_logits = local_logits / temperature
8484

8585
batch_size, seq_len, local_vocab_size = local_logits.shape
86-
device = local_logits.device
8786

88-
# Move target_ids to the same device
89-
target_ids = target_ids.to(device)
87+
# Move target_ids to the same device as local_logits
88+
target_ids = target_ids.to(local_logits.device)
9089

9190
# Cast to float32 for numerical stability
9291
local_logits_fp32 = local_logits.float()
9392

94-
# ============================================================
95-
# Step 1: Compute global max for numerical stability
96-
# ============================================================
97-
local_max = local_logits_fp32.max(dim=-1, keepdim=True).values # [batch, seq, 1]
93+
# Compute global max across all shards for numerical stability
94+
local_max = local_logits_fp32.max(dim=-1, keepdim=True).values
9895
global_max = local_max.clone()
9996
dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=tp_group)
10097

101-
# ============================================================
102-
# Step 2: Compute global sum(exp(x - max))
103-
# ============================================================
104-
local_exp = torch.exp(local_logits_fp32 - global_max) # [batch, seq, local_vocab]
105-
local_sum_exp = local_exp.sum(dim=-1, keepdim=True) # [batch, seq, 1]
98+
# Compute global sum(exp(x - max)) for the log-sum-exp trick
99+
local_exp = torch.exp(local_logits_fp32 - global_max)
100+
local_sum_exp = local_exp.sum(dim=-1, keepdim=True)
106101
global_sum_exp = local_sum_exp.clone()
107102
dist.all_reduce(global_sum_exp, op=dist.ReduceOp.SUM, group=tp_group)
108103

109104
# log_normalizer = global_max + log(global_sum_exp)
110105
log_normalizer = global_max + torch.log(global_sum_exp) # [batch, seq, 1]
111106
log_normalizer = log_normalizer.squeeze(-1) # [batch, seq]
112107

113-
# ============================================================
114-
# Step 3: Extract logits at target positions (only on owning rank)
115-
# ============================================================
116-
# Create mask for tokens owned by this rank (vocab_start/vocab_end from helper)
108+
# Extract logits at target positions - each rank only has part of the vocab
117109
is_local = (target_ids >= vocab_start) & (target_ids < vocab_end)
118110

119111
# Convert global indices to local indices (only valid where is_local=True)
120112
local_indices = target_ids - vocab_start
121113
local_indices = local_indices.clamp(0, local_vocab_size - 1) # Clamp for safety
122114

123-
# Gather logits at target positions
124-
# local_logits_fp32: [batch, seq, local_vocab]
125-
# local_indices: [batch, seq]
126-
# We need logits_fp32[b, s, local_indices[b, s]]
127115
target_logits = torch.gather(
128116
local_logits_fp32,
129117
dim=-1,
130118
index=local_indices.unsqueeze(-1).long(),
131-
).squeeze(
132-
-1
133-
) # [batch, seq]
119+
).squeeze(-1)
134120

135-
# Zero out logits where this rank doesn't own the token
121+
# Zero out where this rank doesn't own the token, then reduce
136122
target_logits = target_logits * is_local.float()
137-
138-
# ============================================================
139-
# Step 4: All-reduce to combine (only one rank has non-zero value)
140-
# ============================================================
141123
dist.all_reduce(target_logits, op=dist.ReduceOp.SUM, group=tp_group)
142124

143-
# ============================================================
144-
# Step 5: Compute final log probability
145-
# ============================================================
146125
logprobs = target_logits - log_normalizer
147126

148127
return logprobs

0 commit comments

Comments
 (0)