|
1 | 1 | from collections.abc import Callable |
2 | 2 | import os |
| 3 | +import functools |
3 | 4 |
|
4 | 5 | import numpy as np |
5 | 6 | import pytest |
6 | 7 | import torch |
7 | 8 | from torch.testing import assert_close |
8 | 9 |
|
| 10 | +if hasattr(torch.nn.functional, "scaled_mm"): |
| 11 | + from torch.nn.functional import ScalingType, SwizzleType |
| 12 | + |
9 | 13 | import thunder |
10 | 14 | import thunder.core.devices as devices |
11 | 15 | import thunder.core.dtypes as dtypes |
@@ -419,6 +423,302 @@ def fn(a): |
419 | 423 | assert_close(b, b_ref) |
420 | 424 |
|
421 | 425 |
|
| 426 | +def _cuda_version_tuple() -> tuple[int, int] | None: |
| 427 | + if torch.version.cuda is None: |
| 428 | + return None |
| 429 | + parts = torch.version.cuda.split(".") |
| 430 | + try: |
| 431 | + major = int(parts[0]) |
| 432 | + minor = int(parts[1]) if len(parts) > 1 else 0 |
| 433 | + return major, minor |
| 434 | + except ValueError: |
| 435 | + return None |
| 436 | + |
| 437 | + |
| 438 | +def _require_scaled_mm(fn): |
| 439 | + @functools.wraps(fn) |
| 440 | + def wrapper(*args, **kwargs): |
| 441 | + if not hasattr(torch.nn.functional, "scaled_mm"): |
| 442 | + pytest.skip("torch.nn.functional.scaled_mm is not found in this PyTorch") |
| 443 | + return fn(*args, **kwargs) |
| 444 | + |
| 445 | + return wrapper |
| 446 | + |
| 447 | + |
| 448 | +def _ensure_fp8_tensorwise(device: torch.device) -> None: |
| 449 | + if torch.cuda.get_device_capability(device) < (8, 9): |
| 450 | + pytest.skip("scaled_mm tensor-wise support requires SM89 or newer") |
| 451 | + |
| 452 | + |
| 453 | +def _require_fp8_tensorwise(fn): |
| 454 | + @functools.wraps(fn) |
| 455 | + def wrapper(*args, **kwargs): |
| 456 | + device = torch.device("cuda") |
| 457 | + _ensure_fp8_tensorwise(device) |
| 458 | + return fn(*args, **kwargs) |
| 459 | + |
| 460 | + return wrapper |
| 461 | + |
| 462 | + |
| 463 | +def _require_fp8_rowwise(device: torch.device) -> None: |
| 464 | + _ensure_fp8_tensorwise(device) |
| 465 | + if torch.cuda.get_device_capability(device) < (9, 0): |
| 466 | + pytest.skip("row-wise scaled_mm requires SM90 or newer") |
| 467 | + cuda_version = _cuda_version_tuple() |
| 468 | + if cuda_version is not None and cuda_version < (12, 9): |
| 469 | + pytest.skip("row-wise scaled_mm requires CUDA 12.9 or newer") |
| 470 | + |
| 471 | + |
| 472 | +def _require_fp8_blockwise(device: torch.device) -> None: |
| 473 | + _require_fp8_rowwise(device) |
| 474 | + |
| 475 | + |
| 476 | +# Adapted from https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py#L645-L659 |
| 477 | +@requiresCUDA |
| 478 | +@_require_fp8_tensorwise |
| 479 | +@_require_scaled_mm |
| 480 | +def test_scaled_mm_tensorwise_matches_torch(): |
| 481 | + device = torch.device("cuda") |
| 482 | + |
| 483 | + def reference_fn(mat_a, mat_b, scale_a, scale_b): |
| 484 | + return torch.nn.functional.scaled_mm( |
| 485 | + mat_a, |
| 486 | + mat_b, |
| 487 | + scale_a, |
| 488 | + ScalingType.TensorWise, |
| 489 | + scale_b, |
| 490 | + ScalingType.TensorWise, |
| 491 | + swizzle_a=SwizzleType.NO_SWIZZLE, |
| 492 | + swizzle_b=SwizzleType.NO_SWIZZLE, |
| 493 | + output_dtype=torch.bfloat16, |
| 494 | + ) |
| 495 | + |
| 496 | + M, K, N = 16, 32, 16 |
| 497 | + mat_a = torch.randn(M, K, device=device, dtype=torch.float32) |
| 498 | + mat_b = torch.randn(K, N, device=device, dtype=torch.float32) |
| 499 | + mat_a_lp = mat_a.to(torch.float8_e4m3fn) |
| 500 | + mat_b_lp = mat_b.to(torch.float8_e4m3fn) |
| 501 | + scale_a = torch.tensor(1.0, device=device, dtype=torch.float32) |
| 502 | + scale_b = torch.tensor(1.0, device=device, dtype=torch.float32) |
| 503 | + |
| 504 | + try: |
| 505 | + expected = reference_fn(mat_a_lp, mat_b_lp, scale_a, scale_b) |
| 506 | + except (NotImplementedError, RuntimeError) as exc: |
| 507 | + pytest.skip(str(exc)) |
| 508 | + |
| 509 | + jf = thunder.jit(reference_fn) |
| 510 | + result = jf(mat_a_lp, mat_b_lp, scale_a, scale_b) |
| 511 | + assert_close(result, expected) |
| 512 | + |
| 513 | + |
| 514 | +# Adapted from https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py#L862-L910 |
| 515 | +@requiresCUDA |
| 516 | +@_require_fp8_tensorwise |
| 517 | +@_require_scaled_mm |
| 518 | +def test_scaled_mm_matches_scaled_data(): |
| 519 | + device = torch.device("cuda") |
| 520 | + |
| 521 | + def quantize_to_fp8(tensor): |
| 522 | + dtype = torch.float8_e4m3fn |
| 523 | + max_val = torch.finfo(dtype).max |
| 524 | + amax = tensor.abs().max() |
| 525 | + encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32) |
| 526 | + quant = torch.clamp(tensor * encode, min=-max_val, max=max_val).to(dtype) |
| 527 | + decode = encode.reciprocal() |
| 528 | + return quant, decode, encode |
| 529 | + |
| 530 | + def scaled_mm_fp8(mat_a, mat_b, scale_a, scale_b, *, out_dtype): |
| 531 | + return torch.nn.functional.scaled_mm( |
| 532 | + mat_a, |
| 533 | + mat_b, |
| 534 | + scale_a, |
| 535 | + ScalingType.TensorWise, |
| 536 | + scale_b, |
| 537 | + ScalingType.TensorWise, |
| 538 | + swizzle_a=SwizzleType.NO_SWIZZLE, |
| 539 | + swizzle_b=SwizzleType.NO_SWIZZLE, |
| 540 | + output_dtype=out_dtype, |
| 541 | + ) |
| 542 | + |
| 543 | + M, K, N = 32, 64, 32 |
| 544 | + mat_a = torch.randn(M, K, device=device, dtype=torch.float32) |
| 545 | + mat_b_base = torch.randn(N, K, device=device, dtype=torch.float32) |
| 546 | + |
| 547 | + mat_a_lp, decode_a, encode_a = quantize_to_fp8(mat_a) |
| 548 | + mat_b_lp_pre, decode_b, encode_b = quantize_to_fp8(mat_b_base) |
| 549 | + # To use cublaslt, the second matrix needs to be column-major. |
| 550 | + mat_b_lp = mat_b_lp_pre.t() |
| 551 | + |
| 552 | + try: |
| 553 | + reference = scaled_mm_fp8(mat_a_lp, mat_b_lp, decode_a, decode_b, out_dtype=torch.float32) |
| 554 | + except (NotImplementedError, RuntimeError) as exc: |
| 555 | + pytest.skip(str(exc)) |
| 556 | + |
| 557 | + jf = thunder.jit(lambda a, b, sa, sb: scaled_mm_fp8(a, b, sa, sb, out_dtype=torch.float32)) |
| 558 | + thunder_out = jf(mat_a_lp, mat_b_lp, decode_a, decode_b) |
| 559 | + |
| 560 | + assert_close(thunder_out, reference) |
| 561 | + |
| 562 | + |
| 563 | +@requiresCUDA |
| 564 | +@_require_scaled_mm |
| 565 | +def test_scaled_mm_rowwise_matches_torch(): |
| 566 | + device = torch.device("cuda") |
| 567 | + _require_fp8_rowwise(device) |
| 568 | + |
| 569 | + def reference_fn(mat_a, mat_b, scale_a, scale_b): |
| 570 | + return torch.nn.functional.scaled_mm( |
| 571 | + mat_a, |
| 572 | + mat_b, |
| 573 | + scale_a, |
| 574 | + ScalingType.RowWise, |
| 575 | + scale_b, |
| 576 | + ScalingType.RowWise, |
| 577 | + swizzle_a=SwizzleType.NO_SWIZZLE, |
| 578 | + swizzle_b=SwizzleType.NO_SWIZZLE, |
| 579 | + output_dtype=torch.bfloat16, |
| 580 | + ) |
| 581 | + |
| 582 | + M, K, N = 16, 32, 16 |
| 583 | + mat_a = torch.randn(M, K, device=device, dtype=torch.float32) |
| 584 | + mat_b_base = torch.randn(N, K, device=device, dtype=torch.float32) |
| 585 | + mat_a_lp = mat_a.to(torch.float8_e4m3fn) |
| 586 | + # To use cublaslt, the second matrix needs to be column-major. |
| 587 | + mat_b_lp = mat_b_base.to(torch.float8_e4m3fn).t() |
| 588 | + scale_a = torch.ones((M, 1), device=device, dtype=torch.float32) |
| 589 | + scale_b = torch.ones((1, N), device=device, dtype=torch.float32) |
| 590 | + |
| 591 | + try: |
| 592 | + expected = reference_fn(mat_a_lp, mat_b_lp, scale_a, scale_b) |
| 593 | + except (NotImplementedError, RuntimeError) as exc: |
| 594 | + pytest.skip(str(exc)) |
| 595 | + |
| 596 | + jf = thunder.jit(reference_fn) |
| 597 | + result = jf(mat_a_lp, mat_b_lp, scale_a, scale_b) |
| 598 | + assert_close(result, expected) |
| 599 | + |
| 600 | + |
| 601 | +@requiresCUDA |
| 602 | +@_require_scaled_mm |
| 603 | +def test_scaled_mm_rowwise_matches_scaled_data(): |
| 604 | + device = torch.device("cuda") |
| 605 | + _require_fp8_rowwise(device) |
| 606 | + |
| 607 | + dtype_fp8 = torch.float8_e4m3fn |
| 608 | + max_val = torch.finfo(dtype_fp8).max |
| 609 | + |
| 610 | + def rowwise_quantize(tensor, *, dim): |
| 611 | + amax = tensor.abs().amax(dim=dim, keepdim=True) |
| 612 | + encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32) |
| 613 | + quant = torch.clamp(tensor * encode, min=-max_val, max=max_val).to(dtype_fp8) |
| 614 | + decode = encode.reciprocal() |
| 615 | + return quant, decode, encode |
| 616 | + |
| 617 | + def scaled_mm_rowwise(mat_a, mat_b, scale_a, scale_b, *, out_dtype): |
| 618 | + return torch.nn.functional.scaled_mm( |
| 619 | + mat_a, |
| 620 | + mat_b, |
| 621 | + scale_a, |
| 622 | + ScalingType.RowWise, |
| 623 | + scale_b, |
| 624 | + ScalingType.RowWise, |
| 625 | + swizzle_a=SwizzleType.NO_SWIZZLE, |
| 626 | + swizzle_b=SwizzleType.NO_SWIZZLE, |
| 627 | + output_dtype=out_dtype, |
| 628 | + ) |
| 629 | + |
| 630 | + M, K, N = 32, 64, 32 |
| 631 | + mat_a = torch.randn(M, K, device=device, dtype=torch.bfloat16) |
| 632 | + mat_b = torch.randn(K, N, device=device, dtype=torch.bfloat16) |
| 633 | + |
| 634 | + mat_a_lp, decode_a, encode_a = rowwise_quantize(mat_a.to(torch.float32), dim=1) |
| 635 | + mat_b_lp, decode_b, encode_b = rowwise_quantize(mat_b.to(torch.float32), dim=0) |
| 636 | + |
| 637 | + try: |
| 638 | + reference = scaled_mm_rowwise(mat_a_lp, mat_b_lp, decode_a, decode_b, out_dtype=torch.bfloat16) |
| 639 | + except (NotImplementedError, RuntimeError) as exc: |
| 640 | + pytest.skip(str(exc)) |
| 641 | + |
| 642 | + jf = thunder.jit(lambda a, b, sa, sb: scaled_mm_rowwise(a, b, sa, sb, out_dtype=torch.bfloat16)) |
| 643 | + thunder_out = jf(mat_a_lp, mat_b_lp, decode_a, decode_b) |
| 644 | + |
| 645 | + reference_f32 = reference.to(torch.float32) |
| 646 | + thunder_out_f32 = thunder_out.to(torch.float32) |
| 647 | + |
| 648 | + assert_close(thunder_out_f32, reference_f32, atol=3e-2, rtol=3e-2) |
| 649 | + |
| 650 | + |
| 651 | +def _blockwise_quantize(tensor: torch.Tensor, block_rows: int, block_cols: int) -> tuple[torch.Tensor, torch.Tensor]: |
| 652 | + dtype_fp8 = torch.float8_e4m3fn |
| 653 | + max_val = torch.finfo(dtype_fp8).max |
| 654 | + |
| 655 | + M, K = tensor.shape |
| 656 | + assert M % block_rows == 0 and K % block_cols == 0 |
| 657 | + |
| 658 | + reshaped = tensor.reshape(M // block_rows, block_rows, K // block_cols, block_cols) |
| 659 | + amax = reshaped.abs().amax(dim=(1, 3), keepdim=True) |
| 660 | + encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32) |
| 661 | + quant = torch.clamp(reshaped * encode, min=-max_val, max=max_val).to(dtype_fp8) |
| 662 | + |
| 663 | + return quant.reshape(M, K), encode.reshape(M // block_rows, K // block_cols).to(tensor.device) |
| 664 | + |
| 665 | + |
| 666 | +@requiresCUDA |
| 667 | +@_require_scaled_mm |
| 668 | +@pytest.mark.parametrize("output_dtype", [torch.bfloat16]) |
| 669 | +@pytest.mark.parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)]) |
| 670 | +def test_scaled_mm_blockwise_matches_torch(output_dtype, lhs_block, rhs_block): |
| 671 | + device = torch.device("cuda") |
| 672 | + _require_fp8_blockwise(device) |
| 673 | + |
| 674 | + M, K, N = 256, 256, 256 |
| 675 | + mat_a = torch.randn(M, K, device=device, dtype=output_dtype).pow(3) |
| 676 | + mat_b_rows = torch.randn(N, K, device=device, dtype=output_dtype).pow(3) |
| 677 | + |
| 678 | + mat_a_lp, encode_a = _blockwise_quantize(mat_a.to(torch.float32), lhs_block, 128) |
| 679 | + mat_b_lp_rows, encode_b = _blockwise_quantize(mat_b_rows.to(torch.float32), rhs_block, 128) |
| 680 | + mat_b_lp = mat_b_lp_rows.t().contiguous() |
| 681 | + |
| 682 | + scale_a = encode_a.reciprocal().contiguous() |
| 683 | + scale_b = encode_b.reciprocal().t().contiguous() |
| 684 | + |
| 685 | + recipe_map = { |
| 686 | + 1: ScalingType.BlockWise1x128, |
| 687 | + 128: ScalingType.BlockWise128x128, |
| 688 | + } |
| 689 | + |
| 690 | + try: |
| 691 | + expected = torch.nn.functional.scaled_mm( |
| 692 | + mat_a_lp, |
| 693 | + mat_b_lp, |
| 694 | + scale_a, |
| 695 | + recipe_map[lhs_block], |
| 696 | + scale_b, |
| 697 | + recipe_map[rhs_block], |
| 698 | + swizzle_a=SwizzleType.NO_SWIZZLE, |
| 699 | + swizzle_b=SwizzleType.NO_SWIZZLE, |
| 700 | + output_dtype=output_dtype, |
| 701 | + ) |
| 702 | + except (RuntimeError, NotImplementedError, ValueError) as exc: |
| 703 | + pytest.skip(str(exc)) |
| 704 | + |
| 705 | + fn = thunder.jit( |
| 706 | + lambda a, b, sa, sb: torch.nn.functional.scaled_mm( |
| 707 | + a, |
| 708 | + b, |
| 709 | + sa, |
| 710 | + recipe_map[lhs_block], |
| 711 | + sb, |
| 712 | + recipe_map[rhs_block], |
| 713 | + swizzle_a=SwizzleType.NO_SWIZZLE, |
| 714 | + swizzle_b=SwizzleType.NO_SWIZZLE, |
| 715 | + output_dtype=output_dtype, |
| 716 | + ) |
| 717 | + ) |
| 718 | + thunder_out = fn(mat_a_lp, mat_b_lp, scale_a, scale_b) |
| 719 | + assert_close(thunder_out, expected) |
| 720 | + |
| 721 | + |
422 | 722 | # https://github.com/Lightning-AI/lightning-thunder/issues/1857 |
423 | 723 | def test_max_with_int(): |
424 | 724 | def f(x, ids): |
|
0 commit comments