11import torch
2- import triton
3- import triton .language as tl
42import math
53import time
6- import torch .nn .functional as F
7- from typing import Optional , Tuple
4+ import pytest
85from lightllm .models .vit .triton_kernel .flashattention_nopad import flash_attention_fwd
96
107
@@ -34,7 +31,8 @@ def reference_attention_varlen(q, k, v, cu):
3431 return out
3532
3633
37- def test_varlen (batch = 4 , heads = 8 , d = 80 , dtype = torch .bfloat16 , atol = 1e-2 , device = "cuda:0" ):
34+ @pytest .mark .parametrize ("dtype,atol" , [(torch .float16 , 1e-2 ), (torch .bfloat16 , 2e-2 )])
35+ def test_varlen (dtype , atol , batch = 4 , heads = 8 , d = 80 , device = "cuda:0" ):
3836 torch .manual_seed (0 )
3937 lengths = torch .randint (1 , 257 , (batch ,))
4038 max_len = int (lengths .max ().item ())
@@ -49,10 +47,10 @@ def test_varlen(batch=4, heads=8, d=80, dtype=torch.bfloat16, atol=1e-2, device=
4947 out_tri = torch .randn_like (q )
5048 flash_attention_fwd (q , k , v , out_tri , cu , max_len )
5149 a = time .time ()
52- for _ in range (1000 ):
50+ for _ in range (100 ):
5351 flash_attention_fwd (q , k , v , out_tri , cu , max_len )
5452 b = time .time ()
55- print (f"flash_attention_fwd time: { (b - a ) / 1000 * 1000 :.2f} ms" )
53+ print (f"flash_attention_fwd time: { (b - a ) / 100 * 1000 :.2f} ms" )
5654 out_ref = reference_attention_varlen (q , k , v , cu )
5755
5856 max_err = (out_ref - out_tri ).abs ().max ().item ()
@@ -62,7 +60,4 @@ def test_varlen(batch=4, heads=8, d=80, dtype=torch.bfloat16, atol=1e-2, device=
6260
6361
6462if __name__ == "__main__" :
65- tests = [(torch .float16 , 1e-2 ), (torch .bfloat16 , 2e-2 )]
66- for dt , tol in tests :
67- test_varlen (dtype = dt , atol = tol )
68- print ("✓ variable-length Flash-Attention all dtypes pass" )
63+ pytest .main ()
0 commit comments