Skip to content

Commit 696be85

Browse files
committed
fix
1 parent 4356d2e commit 696be85

File tree

3 files changed

+9
-19
lines changed

3 files changed

+9
-19
lines changed

lightllm/common/basemodel/layer_infer/cache_tensor_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def alloc_tensor(
135135
# shape 类型转换
136136
if isinstance(shape, list):
137137
shape = torch.Size(shape)
138-
138+
139139
# cache manager 没有被正常使用时
140140
if not self.cache_env_ok:
141141
return torch.empty(shape, dtype=data_type, device=device, requires_grad=False)

lightllm/models/vit/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import torch
2-
import torch.functional as F
32
import torch.distributed as dist
4-
import numpy as np
5-
from typing import Tuple
6-
from functools import partial
7-
import triton
3+
84

95
from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight
10-
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm
116
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
127
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
138
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
@@ -108,7 +103,7 @@ def _context_attention_kernel(self, q, k, v) -> torch.Tensor:
108103
reshape = lambda t: t.view(total_len, head_num, head_dim)
109104
q, k, v, out = map(reshape, (q, k, v, out))
110105
cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) * seq_len
111-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
106+
max_seqlen = seq_len
112107
flash_attention_fwd(q, k, v, out, cu_seqlens, max_seqlen)
113108
return out.reshape(batch_size, seq_len, -1)
114109

unit_tests/models/vit/test_flash_attention_forward.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import torch
2-
import triton
3-
import triton.language as tl
42
import math
53
import time
6-
import torch.nn.functional as F
7-
from typing import Optional, Tuple
4+
import pytest
85
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
96

107

@@ -34,7 +31,8 @@ def reference_attention_varlen(q, k, v, cu):
3431
return out
3532

3633

37-
def test_varlen(batch=4, heads=8, d=80, dtype=torch.bfloat16, atol=1e-2, device="cuda:0"):
34+
@pytest.mark.parametrize("dtype,atol", [(torch.float16, 1e-2), (torch.bfloat16, 2e-2)])
35+
def test_varlen(dtype, atol, batch=4, heads=8, d=80, device="cuda:0"):
3836
torch.manual_seed(0)
3937
lengths = torch.randint(1, 257, (batch,))
4038
max_len = int(lengths.max().item())
@@ -49,10 +47,10 @@ def test_varlen(batch=4, heads=8, d=80, dtype=torch.bfloat16, atol=1e-2, device=
4947
out_tri = torch.randn_like(q)
5048
flash_attention_fwd(q, k, v, out_tri, cu, max_len)
5149
a = time.time()
52-
for _ in range(1000):
50+
for _ in range(100):
5351
flash_attention_fwd(q, k, v, out_tri, cu, max_len)
5452
b = time.time()
55-
print(f"flash_attention_fwd time: {(b - a) / 1000 * 1000:.2f} ms")
53+
print(f"flash_attention_fwd time: {(b - a) / 100 * 1000:.2f} ms")
5654
out_ref = reference_attention_varlen(q, k, v, cu)
5755

5856
max_err = (out_ref - out_tri).abs().max().item()
@@ -62,7 +60,4 @@ def test_varlen(batch=4, heads=8, d=80, dtype=torch.bfloat16, atol=1e-2, device=
6260

6361

6462
if __name__ == "__main__":
65-
tests = [(torch.float16, 1e-2), (torch.bfloat16, 2e-2)]
66-
for dt, tol in tests:
67-
test_varlen(dtype=dt, atol=tol)
68-
print("✓ variable-length Flash-Attention all dtypes pass")
63+
pytest.main()

0 commit comments

Comments
 (0)