|
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 | import torch |
| 8 | +import torch.distributed as dist |
8 | 9 | import torch.nn.functional as F |
9 | 10 |
|
| 11 | +from torch.distributed.tensor import DTensor |
| 12 | + |
10 | 13 |
|
11 | 14 | def compute_logprobs( |
12 | 15 | logits: torch.Tensor, |
@@ -95,3 +98,124 @@ def compute_logprobs( |
95 | 98 | ) |
96 | 99 |
|
97 | 100 | return logprobs.reshape(batch_size, seq_len) |
| 101 | + |
| 102 | + |
| 103 | +def compute_logprobs_parallel( |
| 104 | + logits: DTensor, |
| 105 | + target_ids: torch.Tensor, |
| 106 | + temperature: float = 1.0, |
| 107 | + align: bool = True, |
| 108 | +) -> torch.Tensor: |
| 109 | + """ |
| 110 | + Compute log probabilities for target tokens from vocab-sharded DTensor logits. |
| 111 | +
|
| 112 | + This function computes log_softmax(logits)[target_ids] distributedly, |
| 113 | + without ever gathering the full vocabulary dimension. |
| 114 | +
|
| 115 | + IMPORTANT: Only use this when logits is a DTensor sharded on vocab dimension. |
| 116 | + For regular tensors or non-vocab-sharded DTensors, use compute_logprobs instead. |
| 117 | +
|
| 118 | + Args: |
| 119 | + logits: DTensor of shape [batch_size, seq_len, vocab_size], sharded on dim=-1. |
| 120 | + target_ids: Tensor of shape [batch_size, target_len] with target token IDs. |
| 121 | + temperature: Temperature for scaling logits (default 1.0). |
| 122 | + align: If True, slice logits to align with target_ids (default True). |
| 123 | +
|
| 124 | + Returns: |
| 125 | + Tensor of shape [batch_size, target_len] with log probabilities. |
| 126 | + """ |
| 127 | + # Get sharding info using helper |
| 128 | + tp_group, tp_rank, tp_size, vocab_start, vocab_end = get_vocab_shard_info(logits) |
| 129 | + |
| 130 | + if tp_group is None: |
| 131 | + # DTensor but not sharded on vocab (Replicate or other dim sharding) |
| 132 | + return compute_logprobs(logits.full_tensor(), target_ids, temperature, align) |
| 133 | + |
| 134 | + # Get the local shard |
| 135 | + local_logits = logits._local_tensor # [batch, seq_len, vocab_size / tp_size] |
| 136 | + |
| 137 | + # Align logits with target if needed |
| 138 | + if align: |
| 139 | + # Slice to match target length: logits[:, -target_len-1:-1, :] |
| 140 | + target_len = target_ids.size(1) |
| 141 | + local_logits = local_logits[:, -target_len - 1 : -1, :] |
| 142 | + |
| 143 | + # Scale by temperature |
| 144 | + local_logits = local_logits / temperature |
| 145 | + |
| 146 | + batch_size, seq_len, local_vocab_size = local_logits.shape |
| 147 | + |
| 148 | + # Move target_ids to the same device as local_logits |
| 149 | + target_ids = target_ids.to(local_logits.device) |
| 150 | + |
| 151 | + # Cast to float32 for numerical stability |
| 152 | + local_logits_fp32 = local_logits.float() |
| 153 | + |
| 154 | + # Compute global max across all shards for numerical stability |
| 155 | + local_max = local_logits_fp32.max(dim=-1, keepdim=True).values |
| 156 | + global_max = local_max.clone() |
| 157 | + dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=tp_group) |
| 158 | + |
| 159 | + # Compute global sum(exp(x - max)) for the log-sum-exp trick |
| 160 | + local_exp = torch.exp(local_logits_fp32 - global_max) |
| 161 | + local_sum_exp = local_exp.sum(dim=-1, keepdim=True) |
| 162 | + global_sum_exp = local_sum_exp.clone() |
| 163 | + dist.all_reduce(global_sum_exp, op=dist.ReduceOp.SUM, group=tp_group) |
| 164 | + |
| 165 | + # log_normalizer = global_max + log(global_sum_exp) |
| 166 | + log_normalizer = global_max + torch.log(global_sum_exp) # [batch, seq, 1] |
| 167 | + log_normalizer = log_normalizer.squeeze(-1) # [batch, seq] |
| 168 | + |
| 169 | + # Extract logits at target positions - each rank only has part of the vocab |
| 170 | + is_local = (target_ids >= vocab_start) & (target_ids < vocab_end) |
| 171 | + |
| 172 | + # Convert global indices to local indices (only valid where is_local=True) |
| 173 | + local_indices = target_ids - vocab_start |
| 174 | + local_indices = local_indices.clamp(0, local_vocab_size - 1) # Clamp for safety |
| 175 | + |
| 176 | + target_logits = torch.gather( |
| 177 | + local_logits_fp32, |
| 178 | + dim=-1, |
| 179 | + index=local_indices.unsqueeze(-1).long(), |
| 180 | + ).squeeze(-1) |
| 181 | + |
| 182 | + # Zero out where this rank doesn't own the token, then reduce |
| 183 | + target_logits = target_logits * is_local.float() |
| 184 | + dist.all_reduce(target_logits, op=dist.ReduceOp.SUM, group=tp_group) |
| 185 | + |
| 186 | + logprobs = target_logits - log_normalizer |
| 187 | + |
| 188 | + return logprobs |
| 189 | + |
| 190 | + |
| 191 | +def get_vocab_shard_info( |
| 192 | + logits: DTensor, |
| 193 | +) -> tuple[dist.ProcessGroup | None, int, int, int, int]: |
| 194 | + """ |
| 195 | + Get vocabulary sharding information from a DTensor. |
| 196 | +
|
| 197 | + Args: |
| 198 | + logits: DTensor with shape [..., vocab_size], potentially sharded on vocab dim. |
| 199 | +
|
| 200 | + Returns: |
| 201 | + Tuple of (tp_group, tp_rank, tp_size, vocab_start, vocab_end). |
| 202 | + If not sharded, returns (None, 0, 1, 0, vocab_size). |
| 203 | + """ |
| 204 | + from torch.distributed.tensor.placement_types import Shard |
| 205 | + |
| 206 | + local_logits = logits._local_tensor |
| 207 | + placements = logits.placements |
| 208 | + device_mesh = logits.device_mesh |
| 209 | + |
| 210 | + for i, p in enumerate(placements): |
| 211 | + if isinstance(p, Shard) and p.dim == 2: # vocab dimension |
| 212 | + tp_group = device_mesh.get_group(mesh_dim=i) |
| 213 | + tp_size = dist.get_world_size(tp_group) |
| 214 | + tp_rank = dist.get_rank(tp_group) |
| 215 | + local_vocab_size = local_logits.shape[-1] |
| 216 | + vocab_start = tp_rank * local_vocab_size |
| 217 | + vocab_end = vocab_start + local_vocab_size |
| 218 | + return tp_group, tp_rank, tp_size, vocab_start, vocab_end |
| 219 | + |
| 220 | + # Not sharded |
| 221 | + return None, 0, 1, 0, local_logits.shape[-1] |
0 commit comments