Skip to content

Commit 591fe97

Browse files
authored
Align the creation of the q, k and v to the 06 tutorial to solve the flaky in accuracy for backward kernel. (#5176)
Align the creation of the q, k and v to the 06 tutorial to solve the flaky in accuracy for backward kernel. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent e67ac5d commit 591fe97

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,10 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
605605
raise AssertionError(f'Unknown {MODE}, supported modes are {modes}')
606606
dtype = torch.float16
607607
torch.xpu.empty_cache()
608-
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
609-
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
610-
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
608+
torch.manual_seed(20)
609+
q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='xpu').normal_(mean=0.0, std=0.5).requires_grad_())
610+
k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='xpu').normal_(mean=0.0, std=0.5).requires_grad_())
611+
v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='xpu').normal_(mean=0.0, std=0.5).requires_grad_())
611612
sm_scale = 0.125
612613
quantiles = [0.5, 0.0, 1.0]
613614
atol = 1e-1 if N_CTX == 16384 else 1e-2

0 commit comments

Comments
 (0)