@@ -83,66 +83,45 @@ def compute_logprobs_parallel(
8383 local_logits = local_logits / temperature
8484
8585 batch_size , seq_len , local_vocab_size = local_logits .shape
86- device = local_logits .device
8786
88- # Move target_ids to the same device
89- target_ids = target_ids .to (device )
87+ # Move target_ids to the same device as local_logits
88+ target_ids = target_ids .to (local_logits . device )
9089
9190 # Cast to float32 for numerical stability
9291 local_logits_fp32 = local_logits .float ()
9392
94- # ============================================================
95- # Step 1: Compute global max for numerical stability
96- # ============================================================
97- local_max = local_logits_fp32 .max (dim = - 1 , keepdim = True ).values # [batch, seq, 1]
93+ # Compute global max across all shards for numerical stability
94+ local_max = local_logits_fp32 .max (dim = - 1 , keepdim = True ).values
9895 global_max = local_max .clone ()
9996 dist .all_reduce (global_max , op = dist .ReduceOp .MAX , group = tp_group )
10097
101- # ============================================================
102- # Step 2: Compute global sum(exp(x - max))
103- # ============================================================
104- local_exp = torch .exp (local_logits_fp32 - global_max ) # [batch, seq, local_vocab]
105- local_sum_exp = local_exp .sum (dim = - 1 , keepdim = True ) # [batch, seq, 1]
98+ # Compute global sum(exp(x - max)) for the log-sum-exp trick
99+ local_exp = torch .exp (local_logits_fp32 - global_max )
100+ local_sum_exp = local_exp .sum (dim = - 1 , keepdim = True )
106101 global_sum_exp = local_sum_exp .clone ()
107102 dist .all_reduce (global_sum_exp , op = dist .ReduceOp .SUM , group = tp_group )
108103
109104 # log_normalizer = global_max + log(global_sum_exp)
110105 log_normalizer = global_max + torch .log (global_sum_exp ) # [batch, seq, 1]
111106 log_normalizer = log_normalizer .squeeze (- 1 ) # [batch, seq]
112107
113- # ============================================================
114- # Step 3: Extract logits at target positions (only on owning rank)
115- # ============================================================
116- # Create mask for tokens owned by this rank (vocab_start/vocab_end from helper)
108+ # Extract logits at target positions - each rank only has part of the vocab
117109 is_local = (target_ids >= vocab_start ) & (target_ids < vocab_end )
118110
119111 # Convert global indices to local indices (only valid where is_local=True)
120112 local_indices = target_ids - vocab_start
121113 local_indices = local_indices .clamp (0 , local_vocab_size - 1 ) # Clamp for safety
122114
123- # Gather logits at target positions
124- # local_logits_fp32: [batch, seq, local_vocab]
125- # local_indices: [batch, seq]
126- # We need logits_fp32[b, s, local_indices[b, s]]
127115 target_logits = torch .gather (
128116 local_logits_fp32 ,
129117 dim = - 1 ,
130118 index = local_indices .unsqueeze (- 1 ).long (),
131- ).squeeze (
132- - 1
133- ) # [batch, seq]
119+ ).squeeze (- 1 )
134120
135- # Zero out logits where this rank doesn't own the token
121+ # Zero out where this rank doesn't own the token, then reduce
136122 target_logits = target_logits * is_local .float ()
137-
138- # ============================================================
139- # Step 4: All-reduce to combine (only one rank has non-zero value)
140- # ============================================================
141123 dist .all_reduce (target_logits , op = dist .ReduceOp .SUM , group = tp_group )
142124
143- # ============================================================
144- # Step 5: Compute final log probability
145- # ============================================================
146125 logprobs = target_logits - log_normalizer
147126
148127 return logprobs
0 commit comments