Skip to content

Commit 40ea35f

Browse files
[FlashAttention] Remove XeTLA for fwd mode (#4524)
Since the results from XeTLA cannot be verified and we now have CUTLASS as a reference, which offers better performance, remove the XeTLA provider for flash attention forward mode. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 4db79f5 commit 40ea35f

File tree

3 files changed

+14
-71
lines changed

3 files changed

+14
-71
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ jobs:
274274
275275
source ../../scripts/capture-hw-details.sh
276276
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark flash-attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
277-
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark flash-attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
278277
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-cutlass-report.csv --benchmark flash-attn --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG
279278
280279
- name: Run Triton FA bwd kernel benchmark
@@ -300,7 +299,6 @@ jobs:
300299
301300
source ../../scripts/capture-hw-details.sh
302301
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-triton-report.csv --benchmark flash-attn-tensor-desc --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
303-
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-xetla-report.csv --benchmark flash-attn-tensor-desc --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
304302
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-cutlass-report.csv --benchmark flash-attn-tensor-desc --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG
305303
306304
- name: Run Triton FlexAttention Causal Mask fwd kernel benchmark

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 13 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import triton_kernels_benchmark as benchmark_suite
1111
from triton_kernels_benchmark import xetla_kernel
1212
from triton_kernels_benchmark import cutlass_kernel
13-
import numpy as np
1413

1514

1615
# pylint: disable=unused-argument
@@ -529,25 +528,10 @@ def backward(ctx, do):
529528
attention = _attention.apply
530529

531530

532-
def check_close(f_val, f_ref, atol, rtol):
533-
x = f_val()
534-
y = f_ref()
535-
x = x.cpu().detach().numpy()
536-
y = y.cpu().detach().numpy()
537-
close = np.isclose(x, y, atol=atol, rtol=rtol)
538-
num_close = np.count_nonzero(close)
539-
num_not_close = close.size - num_close
540-
num_perc = num_not_close / close.size * 100
541-
if num_not_close != 0:
542-
print(f'Warning: {num_not_close}, out of {close.size} elements do not match ({num_perc:.2f}%) in XeTLA impl')
543-
544-
545531
def get_benchmark(
546532
providers_filter: Optional[list[str]] = None,
547533
fa_kernel_mode='fwd',
548534
attn_fwd=_attn_fwd_with_block_pointers,
549-
xetla_assert_result=False,
550-
xetla_warn_mismatch=False,
551535
):
552536
"""
553537
Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values.
@@ -647,33 +631,6 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
647631
)
648632

649633
elif provider == 'xetla':
650-
xetla_fn = None
651-
if MODE == 'fwd':
652-
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
653-
func = getattr(xetla_kernel, module_name)
654-
out = torch.empty_like(q, device='xpu', dtype=dtype)
655-
size_score = Z * H * N_CTX * N_CTX
656-
size_attn_mask = Z * N_CTX * N_CTX
657-
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
658-
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
659-
size_ml = Z * H * N_CTX
660-
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
661-
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
662-
663-
def xetla_fwd_fn():
664-
func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
665-
return out
666-
667-
xetla_fn = xetla_fwd_fn
668-
669-
def check_xetla_fwd_result():
670-
if xetla_assert_result:
671-
benchmark_suite.assert_close(xetla_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='xetla to torch')
672-
elif xetla_warn_mismatch:
673-
check_close(xetla_fn, torch_fn, atol, 1e-3)
674-
675-
check_xetla_fwd_result()
676-
677634
if MODE == 'bwd':
678635
module_name = f'flash_attn_bwd_causal_{CAUSAL}'.lower()
679636
func = getattr(xetla_kernel, module_name)
@@ -701,18 +658,20 @@ def xetla_bwd_fn():
701658
bias_strideN, bias_strideF, attn_mask_padding)
702659
return out
703660

704-
xetla_fn = xetla_bwd_fn
661+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
662+
xetla_bwd_fn,
663+
n_warmup=10,
664+
n_repeat=10,
665+
quantiles=quantiles,
666+
)
705667

706-
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
707-
xetla_fn,
708-
n_warmup=10,
709-
n_repeat=10,
710-
quantiles=quantiles,
711-
)
668+
else:
669+
min_ms = float('nan')
670+
max_ms = float('nan')
671+
mean = float('nan')
672+
cv = float('nan')
712673

713674
elif provider == 'cutlass':
714-
cutlass_fn = None
715-
716675
if MODE == 'fwd':
717676
name = 'attention'
718677
func = getattr(cutlass_kernel, name)
@@ -723,17 +682,15 @@ def cutlass_fwd_fn():
723682
return out
724683

725684
benchmark_suite.assert_close(cutlass_fwd_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='cutlass to torch')
726-
cutlass_fn = cutlass_fwd_fn
727685

728686
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
729-
cutlass_fn,
687+
cutlass_fwd_fn,
730688
n_warmup=10,
731689
n_repeat=10,
732690
quantiles=quantiles,
733691
)
734692

735693
else:
736-
cutlass_fn = None
737694
min_ms = float('nan')
738695
max_ms = float('nan')
739696
mean = float('nan')
@@ -755,9 +712,5 @@ def cutlass_fwd_fn():
755712

756713

757714
if __name__ == '__main__':
758-
_benchmark = get_benchmark(
759-
fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'),
760-
xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'),
761-
xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '0') == '1'),
762-
)
715+
_benchmark = get_benchmark(fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), )
763716
_benchmark.run(show_plots=False, print_data=True)

benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,14 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
141141
def get_benchmark(
142142
providers_filter: Optional[list[str]] = None,
143143
fa_kernel_mode='fwd',
144-
xetla_assert_result=False,
145-
xetla_warn_mismatch=False,
146144
):
147145
return flash_attention_benchmark.get_benchmark(
148146
providers_filter=providers_filter,
149147
fa_kernel_mode=fa_kernel_mode,
150148
attn_fwd=_attn_fwd_with_tensor_desc,
151-
xetla_assert_result=xetla_assert_result,
152-
xetla_warn_mismatch=xetla_warn_mismatch,
153149
)
154150

155151

156152
if __name__ == '__main__':
157-
_benchmark = get_benchmark(
158-
fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'),
159-
xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'),
160-
xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '0') == '1'),
161-
)
153+
_benchmark = get_benchmark(fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'), )
162154
_benchmark.run(show_plots=False, print_data=True)

0 commit comments

Comments
 (0)