|
1 | 1 | import os
|
2 | 2 | import contextlib
|
| 3 | +from typing import Optional |
3 | 4 |
|
4 | 5 | import torch
|
5 | 6 | from torch.profiler import record_function
|
6 | 7 | import triton
|
7 | 8 | import triton.language as tl
|
8 | 9 |
|
9 |
| -import triton_kernels_benchmark as benchmark_suit |
| 10 | +import triton_kernels_benchmark as benchmark_suite |
10 | 11 | from triton_kernels_benchmark import xetla_kernel
|
11 | 12 | import numpy as np
|
12 | 13 |
|
@@ -482,7 +483,7 @@ def backward(ctx, do):
|
482 | 483 | # https://github.com/pytorch/pytorch/issues/144778 has more details.
|
483 | 484 | with record_function(
|
484 | 485 | '__profile_kernel_of_func_bwd_fa'
|
485 |
| - ) if benchmark_suit.BENCHMARKING_METHOD == 'UPSTREAM_PYTORCH_PROFILER' else contextlib.nullcontext(): |
| 486 | + ) if benchmark_suite.BENCHMARKING_METHOD == 'UPSTREAM_PYTORCH_PROFILER' else contextlib.nullcontext(): |
486 | 487 | q, k, v, o, M = ctx.saved_tensors
|
487 | 488 | assert do.is_contiguous()
|
488 | 489 | assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
|
@@ -540,135 +541,182 @@ def check_close(f_val, f_ref, atol, rtol):
|
540 | 541 | print(f'Warning: {num_not_close}, out of {close.size} elements do not match ({num_perc:.2f}%) in XeTLA impl')
|
541 | 542 |
|
542 | 543 |
|
543 |
| -@benchmark_suit.perf_report( |
544 |
| - benchmark_suit.Benchmark( |
545 |
| - # argument names to use as an x-axis for the plot |
546 |
| - x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'], |
547 |
| - x_vals=[[z, h, 16384 // z, dhead, causal, mode] |
548 |
| - for z in [1, 2, 4, 8, 16, 32] |
549 |
| - for (h, dhead) in [(16, 128), (32, 64)] |
550 |
| - for causal in [False, True] |
551 |
| - for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] # |
552 |
| - + [[4, 48, 1024, 64, causal, mode] |
553 |
| - for causal in [False, True] |
554 |
| - for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]], |
555 |
| - line_arg='provider', |
556 |
| - # argument name whose value corresponds to a different line in the plot |
557 |
| - # possible values for `line_arg`` |
558 |
| - line_vals=['triton', 'xetla'], |
559 |
| - # label name for the lines |
560 |
| - line_names=['Triton', 'XeTLA'], |
561 |
| - # line styles |
562 |
| - styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], |
563 |
| - ylabel=['GB/s', 'TFlops'], # label name for the y-axis |
564 |
| - plot_name='attn-performance', |
565 |
| - # name for the plot. Used also as a file name for saving the plot. |
566 |
| - args={}, |
567 |
| - )) |
568 |
| -def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider): |
569 |
| - assert MODE in ['fwd', 'bwd'] |
570 |
| - dtype = torch.float16 |
571 |
| - q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) |
572 |
| - k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) |
573 |
| - v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) |
574 |
| - sm_scale = 0.125 |
575 |
| - if MODE == 'bwd': |
576 |
| - sm_scale = 1.3 |
577 |
| - quantiles = [0.5, 0.0, 1.0] |
578 |
| - atol = 1e-1 if N_CTX == 16384 else 1e-2 |
579 |
| - # FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed |
580 |
| - torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu( |
581 |
| - ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) |
582 |
| - if MODE == 'bwd': |
583 |
| - torch_o = torch_fn() |
584 |
| - torch_do = torch.randn_like(torch_o) |
585 |
| - torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True) |
586 |
| - if provider == 'onednn': |
587 |
| - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) |
588 |
| - |
589 |
| - elif provider == 'triton': |
590 |
| - triton_fn = lambda: attention(q, k, v, CAUSAL, sm_scale) |
| 544 | +def get_benchmark( |
| 545 | + providers_filter: Optional[list[str]] = None, |
| 546 | + fa_kernel_mode='fwd', |
| 547 | + xetla_assert_result=False, |
| 548 | + xetla_warn_mismatch=True, |
| 549 | +): |
| 550 | + """ |
| 551 | + Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values. |
| 552 | + The benchmark can then be executed by calling the :code:`.run` method on the return value. |
| 553 | + """ |
| 554 | + |
| 555 | + supported_providers = { |
| 556 | + 'triton': 'Triton', |
| 557 | + 'xetla': 'XeTLA', |
| 558 | + } |
| 559 | + providers = benchmark_suite.filter_providers(supported_providers, providers_filter) |
| 560 | + |
| 561 | + @benchmark_suite.perf_report( |
| 562 | + benchmark_suite.Benchmark( |
| 563 | + # argument names to use as an x-axis for the plot |
| 564 | + x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'], |
| 565 | + x_vals=[[z, h, 16384 // z, dhead, causal, mode] |
| 566 | + for z in [1, 2, 4, 8, 16, 32] |
| 567 | + for (h, dhead) in [(16, 128), (32, 64)] |
| 568 | + for causal in [False, True] |
| 569 | + for mode in [fa_kernel_mode]] # |
| 570 | + + [[4, 48, 1024, 64, causal, mode] for causal in [False, True] for mode in [fa_kernel_mode]], |
| 571 | + line_arg='provider', |
| 572 | + # argument name whose value corresponds to a different line in the plot |
| 573 | + # possible values for `line_arg`` |
| 574 | + line_vals=list(providers.keys()), |
| 575 | + # label name for the lines |
| 576 | + line_names=list(providers.values()), |
| 577 | + # line styles |
| 578 | + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], |
| 579 | + ylabel=['GB/s', 'TFlops'], # label name for the y-axis |
| 580 | + plot_name='attn-performance', |
| 581 | + # name for the plot. Used also as a file name for saving the plot. |
| 582 | + args={}, |
| 583 | + )) |
| 584 | + # pylint: disable=too-many-branches |
| 585 | + def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider): |
| 586 | + modes = ['fwd', 'bwd'] |
| 587 | + if MODE not in modes: |
| 588 | + raise AssertionError(f'Unknown {MODE}, supported modes are {modes}') |
| 589 | + dtype = torch.float16 |
| 590 | + q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) |
| 591 | + k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) |
| 592 | + v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) |
| 593 | + sm_scale = 0.125 |
591 | 594 | if MODE == 'bwd':
|
592 |
| - triton_o = triton_fn() |
593 |
| - triton_do = torch.randn_like(triton_o) |
594 |
| - triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True) |
595 |
| - if MODE == 'fwd': |
596 |
| - benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch') |
597 |
| - else: |
598 |
| - benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch') |
599 |
| - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) |
600 |
| - |
601 |
| - elif provider == 'xetla': |
602 |
| - xetla_fn = None |
603 |
| - if MODE == 'fwd': |
604 |
| - module_name = f'flash_attn_causal_{CAUSAL}'.lower() |
605 |
| - func = getattr(xetla_kernel, module_name) |
606 |
| - out = torch.empty_like(q, device='xpu', dtype=dtype) |
607 |
| - size_score = Z * H * N_CTX * N_CTX |
608 |
| - size_attn_mask = Z * N_CTX * N_CTX |
609 |
| - dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8) |
610 |
| - bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype) |
611 |
| - size_ml = Z * H * N_CTX |
612 |
| - m = torch.empty((size_ml, ), device='xpu', dtype=torch.float) |
613 |
| - l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) |
614 |
| - |
615 |
| - def xetla_fwd_fn(): |
616 |
| - func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) |
617 |
| - return out |
618 |
| - |
619 |
| - xetla_fn = xetla_fwd_fn |
620 |
| - |
621 |
| - def check_xetla_fwd_result(): |
622 |
| - if os.getenv('XETLA_ASSERT_RESULT', '0') == '1': |
623 |
| - benchmark_suit.assert_close(xetla_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='xetla to torch') |
624 |
| - elif os.getenv('XETLA_WARN_MISMATCH', '1') == '1': |
625 |
| - check_close(xetla_fn, torch_fn, atol, 1e-3) |
626 |
| - |
627 |
| - check_xetla_fwd_result() |
628 |
| - |
| 595 | + sm_scale = 1.3 |
| 596 | + quantiles = [0.5, 0.0, 1.0] |
| 597 | + atol = 1e-1 if N_CTX == 16384 else 1e-2 |
| 598 | + # FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed |
| 599 | + torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu( |
| 600 | + ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) |
629 | 601 | if MODE == 'bwd':
|
630 |
| - module_name = f'flash_attn_bwd_causal_{CAUSAL}'.lower() |
631 |
| - func = getattr(xetla_kernel, module_name) |
632 |
| - grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
633 |
| - bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
634 |
| - dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8) |
635 |
| - out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
636 |
| - log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) |
637 |
| - workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) |
638 |
| - grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) |
639 |
| - alpha = sm_scale |
640 |
| - dropout_prob = 0 |
641 |
| - grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
642 |
| - grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True) |
643 |
| - grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True) |
644 |
| - grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True) |
645 |
| - bias_strideB = -1 |
646 |
| - bias_strideN = -1 |
647 |
| - bias_strideF = -1 |
648 |
| - attn_mask_padding = 0 |
649 |
| - |
650 |
| - def xetla_bwd_fn(): |
651 |
| - func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha, dropout_prob, |
652 |
| - grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX, N_CTX, bias_strideB, |
653 |
| - bias_strideN, bias_strideF, attn_mask_padding) |
654 |
| - return out |
655 |
| - |
656 |
| - xetla_fn = xetla_bwd_fn |
657 |
| - |
658 |
| - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) |
| 602 | + torch_o = torch_fn() |
| 603 | + torch_do = torch.randn_like(torch_o) |
| 604 | + torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True) |
| 605 | + |
| 606 | + if provider == 'onednn': |
| 607 | + _, min_ms, max_ms, mean, cv = benchmark_suite.do_bench( |
| 608 | + torch_fn, |
| 609 | + n_warmup=10, |
| 610 | + n_repeat=10, |
| 611 | + quantiles=quantiles, |
| 612 | + ) |
659 | 613 |
|
660 |
| - else: |
661 |
| - raise NotImplementedError(f'Unsupported provider {provider}') |
| 614 | + elif provider == 'triton': |
| 615 | + triton_fn = lambda: attention(q, k, v, CAUSAL, sm_scale) |
| 616 | + if MODE == 'bwd': |
| 617 | + triton_o = triton_fn() |
| 618 | + triton_do = torch.randn_like(triton_o) |
| 619 | + triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True) |
| 620 | + if MODE == 'fwd': |
| 621 | + benchmark_suite.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch') |
| 622 | + else: |
| 623 | + benchmark_suite.assert_close( |
| 624 | + lambda: triton_o, |
| 625 | + lambda: torch_o, |
| 626 | + atol=1e-2, |
| 627 | + rtol=0, |
| 628 | + err_msg='triton to torch', |
| 629 | + ) |
| 630 | + _, min_ms, max_ms, mean, cv = benchmark_suite.do_bench( |
| 631 | + triton_fn, |
| 632 | + n_warmup=10, |
| 633 | + n_repeat=10, |
| 634 | + quantiles=quantiles, |
| 635 | + ) |
| 636 | + |
| 637 | + elif provider == 'xetla': |
| 638 | + xetla_fn = None |
| 639 | + if MODE == 'fwd': |
| 640 | + module_name = f'flash_attn_causal_{CAUSAL}'.lower() |
| 641 | + func = getattr(xetla_kernel, module_name) |
| 642 | + out = torch.empty_like(q, device='xpu', dtype=dtype) |
| 643 | + size_score = Z * H * N_CTX * N_CTX |
| 644 | + size_attn_mask = Z * N_CTX * N_CTX |
| 645 | + dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8) |
| 646 | + bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype) |
| 647 | + size_ml = Z * H * N_CTX |
| 648 | + m = torch.empty((size_ml, ), device='xpu', dtype=torch.float) |
| 649 | + l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) |
| 650 | + |
| 651 | + def xetla_fwd_fn(): |
| 652 | + func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) |
| 653 | + return out |
| 654 | + |
| 655 | + xetla_fn = xetla_fwd_fn |
| 656 | + |
| 657 | + def check_xetla_fwd_result(): |
| 658 | + if xetla_assert_result: |
| 659 | + benchmark_suite.assert_close(xetla_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='xetla to torch') |
| 660 | + elif xetla_warn_mismatch: |
| 661 | + check_close(xetla_fn, torch_fn, atol, 1e-3) |
| 662 | + |
| 663 | + check_xetla_fwd_result() |
| 664 | + |
| 665 | + if MODE == 'bwd': |
| 666 | + module_name = f'flash_attn_bwd_causal_{CAUSAL}'.lower() |
| 667 | + func = getattr(xetla_kernel, module_name) |
| 668 | + grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
| 669 | + bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
| 670 | + dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8) |
| 671 | + out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
| 672 | + log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) |
| 673 | + workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) |
| 674 | + grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) |
| 675 | + alpha = sm_scale |
| 676 | + dropout_prob = 0 |
| 677 | + grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) |
| 678 | + grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True) |
| 679 | + grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True) |
| 680 | + grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True) |
| 681 | + bias_strideB = -1 |
| 682 | + bias_strideN = -1 |
| 683 | + bias_strideF = -1 |
| 684 | + attn_mask_padding = 0 |
| 685 | + |
| 686 | + def xetla_bwd_fn(): |
| 687 | + func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha, dropout_prob, |
| 688 | + grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX, N_CTX, bias_strideB, |
| 689 | + bias_strideN, bias_strideF, attn_mask_padding) |
| 690 | + return out |
| 691 | + |
| 692 | + xetla_fn = xetla_bwd_fn |
| 693 | + |
| 694 | + _, min_ms, max_ms, mean, cv = benchmark_suite.do_bench( |
| 695 | + xetla_fn, |
| 696 | + n_warmup=10, |
| 697 | + n_repeat=10, |
| 698 | + quantiles=quantiles, |
| 699 | + ) |
| 700 | + |
| 701 | + else: |
| 702 | + raise NotImplementedError(f'Unsupported provider {provider}') |
662 | 703 |
|
663 |
| - tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3) |
664 |
| - gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3) |
| 704 | + tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3) |
| 705 | + gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3) |
665 | 706 |
|
666 |
| - if MODE == 'bwd': |
667 |
| - tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3) |
668 |
| - gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3) |
| 707 | + if MODE == 'bwd': |
| 708 | + tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3) |
| 709 | + gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3) |
| 710 | + |
| 711 | + return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv |
669 | 712 |
|
670 |
| - return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv |
| 713 | + return benchmark |
671 | 714 |
|
672 | 715 |
|
673 | 716 | if __name__ == '__main__':
|
674 |
| - benchmark.run(show_plots=False, print_data=True) |
| 717 | + _benchmark = get_benchmark( |
| 718 | + fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), |
| 719 | + xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'), |
| 720 | + xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '1') == '1'), |
| 721 | + ) |
| 722 | + _benchmark.run(show_plots=False, print_data=True) |
0 commit comments