|
6 | 6 |
|
7 | 7 | import pytest |
8 | 8 | import torch |
| 9 | +import torch.distributed as dist |
9 | 10 | import torch.nn.functional as F |
10 | | -from forge.util.ops import compute_logprobs |
| 11 | + |
| 12 | +from forge.util.ops import ( |
| 13 | + compute_logprobs, |
| 14 | + compute_logprobs_parallel, |
| 15 | + get_vocab_shard_info, |
| 16 | +) |
| 17 | + |
| 18 | +from tests.test_utils import gpu_test |
| 19 | +from torch.distributed.device_mesh import init_device_mesh |
| 20 | +from torch.distributed.tensor import DTensor, Shard |
| 21 | +from torch.testing._internal.common_fsdp import FSDPTest |
11 | 22 |
|
12 | 23 |
|
13 | 24 | def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor): |
@@ -162,3 +173,331 @@ def test_align_comparison(self): |
162 | 173 |
|
163 | 174 | # Both should give the same result |
164 | 175 | assert torch.allclose(result_aligned, result_manual, atol=1e-5) |
| 176 | + |
| 177 | + |
| 178 | +class TestParallelLogprobs(FSDPTest): |
| 179 | + """Test parallel logprobs against reference implementation.""" |
| 180 | + |
| 181 | + @property |
| 182 | + def world_size(self) -> int: |
| 183 | + return 2 |
| 184 | + |
| 185 | + @gpu_test(gpu_count=2) |
| 186 | + def test_parallel_logprobs_matches_sequential(self): |
| 187 | + """Verify parallel logprobs produces same results as sequential version.""" |
| 188 | + torch.manual_seed(42) |
| 189 | + |
| 190 | + batch_size = 4 |
| 191 | + seq_len = 16 |
| 192 | + vocab_size = 1000 # Must be divisible by world_size |
| 193 | + target_len = 8 |
| 194 | + |
| 195 | + rank = dist.get_rank() |
| 196 | + device = torch.device(f"cuda:{rank}") |
| 197 | + |
| 198 | + # Create test data on rank 0 and broadcast to ensure consistency |
| 199 | + if rank == 0: |
| 200 | + # Full logits tensor (what we'd have without sharding) |
| 201 | + full_logits = torch.randn( |
| 202 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 203 | + ) |
| 204 | + # Target tokens for logprob computation |
| 205 | + target_ids = torch.randint( |
| 206 | + 0, vocab_size, (batch_size, target_len), device=device |
| 207 | + ) |
| 208 | + else: |
| 209 | + full_logits = torch.empty( |
| 210 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 211 | + ) |
| 212 | + target_ids = torch.empty( |
| 213 | + batch_size, target_len, dtype=torch.int64, device=device |
| 214 | + ) |
| 215 | + |
| 216 | + # Broadcast to all ranks |
| 217 | + dist.broadcast(full_logits, src=0) |
| 218 | + dist.broadcast(target_ids, src=0) |
| 219 | + |
| 220 | + # Compute reference result using sequential version |
| 221 | + expected = compute_logprobs(full_logits, target_ids, align=True) |
| 222 | + |
| 223 | + # Create device mesh for tensor parallel |
| 224 | + mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",)) |
| 225 | + |
| 226 | + # Create DTensor sharded on vocab dimension (dim=2) |
| 227 | + # Each rank gets vocab_size // world_size columns |
| 228 | + dtensor_logits = DTensor.from_local( |
| 229 | + full_logits[ |
| 230 | + :, :, rank * (vocab_size // 2) : (rank + 1) * (vocab_size // 2) |
| 231 | + ], |
| 232 | + mesh, |
| 233 | + placements=[Shard(2)], # Shard on vocab dimension |
| 234 | + ) |
| 235 | + |
| 236 | + # Compute parallel result |
| 237 | + result = compute_logprobs_parallel(dtensor_logits, target_ids, align=True) |
| 238 | + |
| 239 | + # Verify results match |
| 240 | + torch.testing.assert_close( |
| 241 | + result, |
| 242 | + expected, |
| 243 | + atol=1e-5, |
| 244 | + rtol=1e-5, |
| 245 | + msg="Parallel logprobs should match sequential version", |
| 246 | + ) |
| 247 | + |
| 248 | + @gpu_test(gpu_count=2) |
| 249 | + def test_parallel_logprobs_with_temperature(self): |
| 250 | + """Test parallel logprobs with different temperature values.""" |
| 251 | + torch.manual_seed(123) |
| 252 | + |
| 253 | + batch_size = 2 |
| 254 | + seq_len = 10 |
| 255 | + vocab_size = 500 |
| 256 | + target_len = 5 |
| 257 | + |
| 258 | + rank = dist.get_rank() |
| 259 | + device = torch.device(f"cuda:{rank}") |
| 260 | + |
| 261 | + if rank == 0: |
| 262 | + full_logits = torch.randn( |
| 263 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 264 | + ) |
| 265 | + target_ids = torch.randint( |
| 266 | + 0, vocab_size, (batch_size, target_len), device=device |
| 267 | + ) |
| 268 | + else: |
| 269 | + full_logits = torch.empty( |
| 270 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 271 | + ) |
| 272 | + target_ids = torch.empty( |
| 273 | + batch_size, target_len, dtype=torch.int64, device=device |
| 274 | + ) |
| 275 | + |
| 276 | + dist.broadcast(full_logits, src=0) |
| 277 | + dist.broadcast(target_ids, src=0) |
| 278 | + |
| 279 | + mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",)) |
| 280 | + local_vocab = vocab_size // self.world_size |
| 281 | + dtensor_logits = DTensor.from_local( |
| 282 | + full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab], |
| 283 | + mesh, |
| 284 | + placements=[Shard(2)], |
| 285 | + ) |
| 286 | + |
| 287 | + for temperature in [0.5, 1.0, 2.0]: |
| 288 | + expected = compute_logprobs( |
| 289 | + full_logits, target_ids, temperature=temperature, align=True |
| 290 | + ) |
| 291 | + result = compute_logprobs_parallel( |
| 292 | + dtensor_logits, target_ids, temperature=temperature, align=True |
| 293 | + ) |
| 294 | + torch.testing.assert_close( |
| 295 | + result, |
| 296 | + expected, |
| 297 | + atol=1e-5, |
| 298 | + rtol=1e-5, |
| 299 | + msg=f"Failed with temperature={temperature}", |
| 300 | + ) |
| 301 | + |
| 302 | + @gpu_test(gpu_count=2) |
| 303 | + def test_parallel_logprobs_align_false(self): |
| 304 | + """Test parallel logprobs with align=False.""" |
| 305 | + torch.manual_seed(456) |
| 306 | + |
| 307 | + batch_size = 3 |
| 308 | + seq_len = 8 |
| 309 | + vocab_size = 200 |
| 310 | + |
| 311 | + rank = dist.get_rank() |
| 312 | + device = torch.device(f"cuda:{rank}") |
| 313 | + |
| 314 | + if rank == 0: |
| 315 | + full_logits = torch.randn( |
| 316 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 317 | + ) |
| 318 | + # With align=False, target_ids same length as seq_len |
| 319 | + target_ids = torch.randint( |
| 320 | + 0, vocab_size, (batch_size, seq_len), device=device |
| 321 | + ) |
| 322 | + else: |
| 323 | + full_logits = torch.empty( |
| 324 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 325 | + ) |
| 326 | + target_ids = torch.empty( |
| 327 | + batch_size, seq_len, dtype=torch.int64, device=device |
| 328 | + ) |
| 329 | + |
| 330 | + dist.broadcast(full_logits, src=0) |
| 331 | + dist.broadcast(target_ids, src=0) |
| 332 | + |
| 333 | + expected = compute_logprobs(full_logits, target_ids, align=False) |
| 334 | + |
| 335 | + mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",)) |
| 336 | + local_vocab = vocab_size // self.world_size |
| 337 | + dtensor_logits = DTensor.from_local( |
| 338 | + full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab], |
| 339 | + mesh, |
| 340 | + placements=[Shard(2)], |
| 341 | + ) |
| 342 | + |
| 343 | + result = compute_logprobs_parallel(dtensor_logits, target_ids, align=False) |
| 344 | + |
| 345 | + torch.testing.assert_close( |
| 346 | + result, |
| 347 | + expected, |
| 348 | + atol=1e-5, |
| 349 | + rtol=1e-5, |
| 350 | + msg="Parallel logprobs with align=False should match", |
| 351 | + ) |
| 352 | + |
| 353 | + @gpu_test(gpu_count=2) |
| 354 | + def test_parallel_logprobs_numerical_stability(self): |
| 355 | + """Test parallel logprobs handles extreme values correctly.""" |
| 356 | + torch.manual_seed(789) |
| 357 | + |
| 358 | + batch_size = 2 |
| 359 | + seq_len = 4 |
| 360 | + vocab_size = 100 |
| 361 | + target_len = 2 |
| 362 | + |
| 363 | + rank = dist.get_rank() |
| 364 | + device = torch.device(f"cuda:{rank}") |
| 365 | + |
| 366 | + # Test with large values |
| 367 | + if rank == 0: |
| 368 | + full_logits = torch.randn( |
| 369 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 370 | + ) |
| 371 | + # Add some extreme values |
| 372 | + full_logits[:, :, 0] = 1000.0 |
| 373 | + full_logits[:, :, 50] = -1000.0 |
| 374 | + target_ids = torch.randint( |
| 375 | + 0, vocab_size, (batch_size, target_len), device=device |
| 376 | + ) |
| 377 | + else: |
| 378 | + full_logits = torch.empty( |
| 379 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 380 | + ) |
| 381 | + target_ids = torch.empty( |
| 382 | + batch_size, target_len, dtype=torch.int64, device=device |
| 383 | + ) |
| 384 | + |
| 385 | + dist.broadcast(full_logits, src=0) |
| 386 | + dist.broadcast(target_ids, src=0) |
| 387 | + |
| 388 | + expected = compute_logprobs(full_logits, target_ids, align=True) |
| 389 | + |
| 390 | + mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",)) |
| 391 | + local_vocab = vocab_size // self.world_size |
| 392 | + dtensor_logits = DTensor.from_local( |
| 393 | + full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab], |
| 394 | + mesh, |
| 395 | + placements=[Shard(2)], |
| 396 | + ) |
| 397 | + |
| 398 | + result = compute_logprobs_parallel(dtensor_logits, target_ids, align=True) |
| 399 | + |
| 400 | + # Should not have NaN or Inf |
| 401 | + assert torch.isfinite(result).all(), "Result contains NaN or Inf" |
| 402 | + assert torch.isfinite(expected).all(), "Expected contains NaN or Inf" |
| 403 | + |
| 404 | + torch.testing.assert_close( |
| 405 | + result, |
| 406 | + expected, |
| 407 | + atol=1e-4, # Slightly relaxed for extreme values |
| 408 | + rtol=1e-4, |
| 409 | + msg="Parallel logprobs should be numerically stable", |
| 410 | + ) |
| 411 | + |
| 412 | + @gpu_test(gpu_count=2) |
| 413 | + def test_get_vocab_shard_info(self): |
| 414 | + """Test vocab shard info extraction.""" |
| 415 | + torch.manual_seed(111) |
| 416 | + |
| 417 | + batch_size = 2 |
| 418 | + seq_len = 4 |
| 419 | + vocab_size = 100 |
| 420 | + |
| 421 | + rank = dist.get_rank() |
| 422 | + device = torch.device(f"cuda:{rank}") |
| 423 | + |
| 424 | + full_logits = torch.randn( |
| 425 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 426 | + ) |
| 427 | + |
| 428 | + mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",)) |
| 429 | + local_vocab = vocab_size // self.world_size |
| 430 | + dtensor_logits = DTensor.from_local( |
| 431 | + full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab], |
| 432 | + mesh, |
| 433 | + placements=[Shard(2)], |
| 434 | + ) |
| 435 | + |
| 436 | + tp_group, tp_rank, tp_size, vocab_start, vocab_end = get_vocab_shard_info( |
| 437 | + dtensor_logits |
| 438 | + ) |
| 439 | + |
| 440 | + assert tp_group is not None, "Should have TP group for sharded tensor" |
| 441 | + assert tp_rank == rank, f"TP rank should be {rank}, got {tp_rank}" |
| 442 | + assert tp_size == self.world_size, f"TP size should be {self.world_size}" |
| 443 | + assert vocab_start == rank * local_vocab, "Vocab start incorrect" |
| 444 | + assert vocab_end == (rank + 1) * local_vocab, "Vocab end incorrect" |
| 445 | + |
| 446 | + |
| 447 | +class TestParallelLogprobs4GPU(FSDPTest): |
| 448 | + """Test parallel logprobs with 4 GPUs.""" |
| 449 | + |
| 450 | + @property |
| 451 | + def world_size(self) -> int: |
| 452 | + return 4 |
| 453 | + |
| 454 | + @gpu_test(gpu_count=4) |
| 455 | + def test_parallel_logprobs_4_way_sharding(self): |
| 456 | + """Test with 4-way vocab sharding.""" |
| 457 | + torch.manual_seed(999) |
| 458 | + |
| 459 | + batch_size = 8 |
| 460 | + seq_len = 32 |
| 461 | + vocab_size = 1000 # Divisible by 4 |
| 462 | + target_len = 16 |
| 463 | + |
| 464 | + rank = dist.get_rank() |
| 465 | + device = torch.device(f"cuda:{rank}") |
| 466 | + |
| 467 | + if rank == 0: |
| 468 | + full_logits = torch.randn( |
| 469 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 470 | + ) |
| 471 | + target_ids = torch.randint( |
| 472 | + 0, vocab_size, (batch_size, target_len), device=device |
| 473 | + ) |
| 474 | + else: |
| 475 | + full_logits = torch.empty( |
| 476 | + batch_size, seq_len, vocab_size, dtype=torch.float32, device=device |
| 477 | + ) |
| 478 | + target_ids = torch.empty( |
| 479 | + batch_size, target_len, dtype=torch.int64, device=device |
| 480 | + ) |
| 481 | + |
| 482 | + dist.broadcast(full_logits, src=0) |
| 483 | + dist.broadcast(target_ids, src=0) |
| 484 | + |
| 485 | + expected = compute_logprobs(full_logits, target_ids, align=True) |
| 486 | + |
| 487 | + mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",)) |
| 488 | + local_vocab = vocab_size // self.world_size |
| 489 | + dtensor_logits = DTensor.from_local( |
| 490 | + full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab], |
| 491 | + mesh, |
| 492 | + placements=[Shard(2)], |
| 493 | + ) |
| 494 | + |
| 495 | + result = compute_logprobs_parallel(dtensor_logits, target_ids, align=True) |
| 496 | + |
| 497 | + torch.testing.assert_close( |
| 498 | + result, |
| 499 | + expected, |
| 500 | + atol=1e-5, |
| 501 | + rtol=1e-5, |
| 502 | + msg="4-way parallel logprobs should match sequential", |
| 503 | + ) |
0 commit comments