Skip to content

Commit 71a9e56

Browse files
committed
fix vit6b precision
1 parent 2c7c0ff commit 71a9e56

File tree

3 files changed

+30
-21
lines changed

3 files changed

+30
-21
lines changed

lightllm/models/vit/layer_infer/transformer_layer_infer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def tp_norm(self, input, weight):
4646
input_dtype = input.dtype
4747
input = input.to(torch.float32)
4848
tp_variance = input.pow(2).sum(-1, keepdim=True)
49-
# dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False)
49+
if self.world_size_ > 1:
50+
dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False)
5051
variance = tp_variance / self.embed_dim_
5152
input = input * torch.rsqrt(variance + self.eps_)
5253
out = weight * input.to(input_dtype)
@@ -79,8 +80,8 @@ def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
7980
)
8081

8182
def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
82-
q_norm = self.norm(q, layer_weight.q_norm_weight_.weight)
83-
k_norm = self.norm(k, layer_weight.k_norm_weight_.weight)
83+
q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight)
84+
k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight)
8485
return q_norm, k_norm
8586

8687
def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
@@ -89,9 +90,6 @@ def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tens
8990
qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False)
9091
qkv = qkv.view(batch_size, seq_len, 3, -1, self.head_dim_)
9192
q, k, v = qkv.unbind(2)
92-
q = q.contiguous()
93-
k = k.contiguous()
94-
v = v.contiguous()
9593
return q, k, v
9694

9795
def _context_attention_kernel(self, q, k, v) -> torch.Tensor:

lightllm/models/vit/layer_weights/transformer_layer_weight.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,5 @@ def load_hf_weights(self, weights):
163163
ls2 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls2"]
164164
self.ls2 = self._cuda(ls2)
165165
self.use_ls = True
166-
print(self.ls1)
167166

168167
return super().load_hf_weights(weights)

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ def _fwd_kernel(
2222
q_stride_s,
2323
q_stride_h,
2424
q_stride_d,
25+
k_stride_b,
26+
k_stride_s,
27+
k_stride_h,
28+
k_stride_d,
29+
v_stride_b,
30+
v_stride_s,
31+
v_stride_h,
32+
v_stride_d,
2533
o_stride_b,
2634
o_stride_s,
2735
o_stride_h,
@@ -30,9 +38,9 @@ def _fwd_kernel(
3038
BLOCK_DMODEL: tl.constexpr,
3139
BLOCK_N: tl.constexpr,
3240
):
33-
cur_batch = tl.program_id(0)
41+
cur_batch = tl.program_id(2)
3442
cur_head = tl.program_id(1)
35-
start_m = tl.program_id(2)
43+
start_m = tl.program_id(0)
3644

3745
# initialize offsets
3846
offs_n = tl.arange(0, BLOCK_N)
@@ -49,9 +57,9 @@ def _fwd_kernel(
4957
start_n = tl.multiple_of(start_n, BLOCK_N)
5058
# -- compute qk ----
5159
off_k = (
52-
cur_batch * q_stride_b
53-
+ (start_n + offs_n[None, :]) * q_stride_s
54-
+ cur_head * q_stride_h
60+
cur_batch * k_stride_b
61+
+ (start_n + offs_n[None, :]) * k_stride_s
62+
+ cur_head * k_stride_h
5563
+ offs_d[:, None]
5664
)
5765
k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < seq_len, other=0.0)
@@ -71,9 +79,9 @@ def _fwd_kernel(
7179

7280
# update acc
7381
off_v = (
74-
cur_batch * q_stride_b
75-
+ (start_n + offs_n[:, None]) * q_stride_s
76-
+ cur_head * q_stride_h
82+
cur_batch * v_stride_b
83+
+ (start_n + offs_n[:, None]) * v_stride_s
84+
+ cur_head * v_stride_h
7785
+ offs_d[None, :]
7886
)
7987
v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < seq_len, other=0.0)
@@ -104,8 +112,8 @@ def flash_attention_fwd(
104112
batch_size, seq_len, head_num, head_dim = q.shape
105113

106114
sm_scale = 1.0 / (head_dim ** 0.5) # 计算scale系数
107-
grid = (batch_size, head_num, triton.cdiv(seq_len, BLOCK)) # batch, head,
108-
# grid = (triton.cdiv(seq_len, BLOCK), batch_size, head_num) # batch, head,
115+
# grid = (batch_size, head_num, triton.cdiv(seq_len, BLOCK)) # batch, head,
116+
grid = (triton.cdiv(seq_len, BLOCK), head_num, batch_size) # batch, head,
109117
num_warps = 4
110118
_fwd_kernel[grid](
111119
q,
@@ -118,6 +126,14 @@ def flash_attention_fwd(
118126
q.stride(1),
119127
q.stride(2),
120128
q.stride(3),
129+
k.stride(0),
130+
k.stride(1),
131+
k.stride(2),
132+
k.stride(3),
133+
v.stride(0),
134+
v.stride(1),
135+
v.stride(2),
136+
v.stride(3),
121137
o.stride(0),
122138
o.stride(1),
123139
o.stride(2),
@@ -157,7 +173,6 @@ def test():
157173
k = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
158174
v = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
159175
o = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
160-
161176
torch_out = torch_att(q, k, v)
162177
import time
163178

@@ -174,6 +189,3 @@ def test():
174189
print("max ", torch.max(torch.abs(torch_out - o)))
175190
print("mean ", torch.mean(torch.abs(torch_out - o)))
176191
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
177-
178-
179-
# test()

0 commit comments

Comments
 (0)