|
8 | 8 | from torch import Tensor
|
9 | 9 | from torch._inductor import config, utils
|
10 | 10 | from torch._inductor.test_case import run_tests, TestCase
|
| 11 | +from torch._inductor.utils import run_and_get_code |
11 | 12 | from torch.testing._internal.common_cuda import (
|
12 | 13 | PLATFORM_SUPPORTS_FP8,
|
13 | 14 | PLATFORM_SUPPORTS_MX_GEMM,
|
|
24 | 25 | HAS_CPU,
|
25 | 26 | HAS_CUDA_AND_TRITON,
|
26 | 27 | )
|
| 28 | +from torch.testing._internal.jit_utils import FileCheck |
27 | 29 | from torch.utils._triton import has_triton_tma_device
|
28 | 30 |
|
29 | 31 |
|
@@ -465,6 +467,86 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
|
465 | 467 | # setting a small absolute tolerance in these tests
|
466 | 468 | torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
467 | 469 |
|
| 470 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| 471 | + @unittest.skipIf( |
| 472 | + not has_triton_tma_device(), "Need device-side TMA support in Triton" |
| 473 | + ) |
| 474 | + @parametrize("dtype", (torch.bfloat16, torch.float32)) |
| 475 | + @parametrize("shape", ("16,32,32", "1024,1024,512")) |
| 476 | + @parametrize("use_fast_accum", (False, True)) |
| 477 | + def test_tensorwise_scaling_tma_template( |
| 478 | + self, |
| 479 | + dtype: torch.dtype, |
| 480 | + shape: str, |
| 481 | + use_fast_accum: bool, |
| 482 | + ): |
| 483 | + device = "cuda" |
| 484 | + dtype_float8 = torch.float8_e4m3fn |
| 485 | + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) |
| 486 | + |
| 487 | + shape = [int(dim) for dim in shape.split(",")] |
| 488 | + M, K, N = shape # Matmul Y = X [M, K] x W [N, K] |
| 489 | + # input and output dtypes of _scaled_mm do not need to be the same, but |
| 490 | + # typically in a model they are |
| 491 | + x = torch.randn(M, K, dtype=dtype, device=device) |
| 492 | + w = torch.randn(N, K, dtype=dtype, device=device) |
| 493 | + bias = None |
| 494 | + |
| 495 | + # quantize weight (prior to inference) |
| 496 | + w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8) |
| 497 | + w_t_fp8 = w_fp8.t() |
| 498 | + |
| 499 | + # quantize input x |
| 500 | + x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8) |
| 501 | + |
| 502 | + def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): |
| 503 | + y = torch._scaled_mm( |
| 504 | + x_fp8, |
| 505 | + w_t_fp8, |
| 506 | + x_inverse_scale, |
| 507 | + w_inverse_scale, |
| 508 | + bias, |
| 509 | + out_dtype=dtype, |
| 510 | + use_fast_accum=use_fast_accum, |
| 511 | + ) |
| 512 | + return y |
| 513 | + |
| 514 | + y_eager = linear( |
| 515 | + x_fp8, |
| 516 | + x_inverse_scale, |
| 517 | + w_t_fp8, |
| 518 | + w_inverse_scale, |
| 519 | + bias, |
| 520 | + ) |
| 521 | + with config.patch( |
| 522 | + { |
| 523 | + "triton.enable_persistent_tma_matmul": True, |
| 524 | + "test_configs.autotune_choice_name_regex": "triton_scaled_mm_device_tma", |
| 525 | + "max_autotune_gemm_backends": "TRITON", |
| 526 | + "max_autotune": True, |
| 527 | + } |
| 528 | + ): |
| 529 | + linear_compiled = torch.compile( |
| 530 | + linear, backend="inductor", mode="max-autotune" |
| 531 | + ) |
| 532 | + y_compiled, code = run_and_get_code( |
| 533 | + linear_compiled, |
| 534 | + x_fp8, |
| 535 | + x_inverse_scale, |
| 536 | + w_t_fp8, |
| 537 | + w_inverse_scale, |
| 538 | + bias, |
| 539 | + ) |
| 540 | + |
| 541 | + FileCheck().check("SCALING_ROWWISE : tl.constexpr = False").run(code[0]) |
| 542 | + self.assertEqual(y_eager.dtype, dtype) |
| 543 | + self.assertEqual(y_compiled.dtype, dtype) |
| 544 | + # depending on the kernel config (BLOCK_M size, etc) selected during Inductor |
| 545 | + # autotuning for the compiled case, the results can be different because of |
| 546 | + # the way blocks of results are accumulated (float addition not associative), so |
| 547 | + # setting a small absolute tolerance in these tests |
| 548 | + torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) |
| 549 | + |
468 | 550 | @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
469 | 551 | @parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512"))
|
470 | 552 | @parametrize("has_bias", (False, True))
|
@@ -531,6 +613,81 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
|
531 | 613 | self.assertEqual(y_compiled.dtype, dtype)
|
532 | 614 | torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
533 | 615 |
|
| 616 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| 617 | + @unittest.skipIf( |
| 618 | + not has_triton_tma_device(), "Need device-side TMA support in Triton" |
| 619 | + ) |
| 620 | + @parametrize("shape", ("16,32,32", "1024,1024,512")) |
| 621 | + @parametrize("use_fast_accum", (False, True)) |
| 622 | + def test_rowwise_scaling_tma_template( |
| 623 | + self, |
| 624 | + shape: str, |
| 625 | + use_fast_accum: bool, |
| 626 | + ): |
| 627 | + # Only bf16 output type is supported for row-wise scaling, not fp32 |
| 628 | + dtype: torch.dtype = torch.bfloat16 |
| 629 | + device = "cuda" |
| 630 | + dtype_float8 = torch.float8_e4m3fn |
| 631 | + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) |
| 632 | + |
| 633 | + shape = [int(dim) for dim in shape.split(",")] |
| 634 | + M, K, N = shape # Matmul Y = X [M, K] x W [N, K] |
| 635 | + x = torch.randn(M, K, dtype=dtype, device=device) |
| 636 | + w = torch.randn(N, K, dtype=dtype, device=device) |
| 637 | + bias = None |
| 638 | + |
| 639 | + # quantize weight (prior to inference) |
| 640 | + w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8) |
| 641 | + w_t_fp8 = w_fp8.t() |
| 642 | + w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N) |
| 643 | + |
| 644 | + # quantize input x |
| 645 | + x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8) |
| 646 | + |
| 647 | + def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): |
| 648 | + y = torch._scaled_mm( |
| 649 | + x_fp8, |
| 650 | + w_t_fp8, |
| 651 | + x_inverse_scale, |
| 652 | + w_inverse_scale, |
| 653 | + bias, |
| 654 | + out_dtype=dtype, |
| 655 | + use_fast_accum=use_fast_accum, |
| 656 | + ) |
| 657 | + return y |
| 658 | + |
| 659 | + y_eager = linear( |
| 660 | + x_fp8, |
| 661 | + x_inverse_scale, |
| 662 | + w_t_fp8, |
| 663 | + w_inverse_scale, |
| 664 | + bias, |
| 665 | + ) |
| 666 | + with config.patch( |
| 667 | + { |
| 668 | + "triton.enable_persistent_tma_matmul": True, |
| 669 | + "test_configs.autotune_choice_name_regex": "triton_scaled_mm_device_tma", |
| 670 | + "max_autotune_gemm_backends": "TRITON", |
| 671 | + "max_autotune": True, |
| 672 | + } |
| 673 | + ): |
| 674 | + linear_compiled = torch.compile( |
| 675 | + linear, backend="inductor", mode="max-autotune" |
| 676 | + ) |
| 677 | + y_compiled, code = run_and_get_code( |
| 678 | + linear_compiled, |
| 679 | + x_fp8, |
| 680 | + x_inverse_scale, |
| 681 | + w_t_fp8, |
| 682 | + w_inverse_scale, |
| 683 | + bias, |
| 684 | + ) |
| 685 | + |
| 686 | + FileCheck().check("SCALING_ROWWISE : tl.constexpr = True").run(code[0]) |
| 687 | + self.assertEqual(y_eager.dtype, dtype) |
| 688 | + self.assertEqual(y_compiled.dtype, dtype) |
| 689 | + torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) |
| 690 | + |
534 | 691 | @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
535 | 692 | @parametrize("M", (1, 3, 33, 257, 1024))
|
536 | 693 | @parametrize("K", (16, 32, 1024))
|
|
0 commit comments