Skip to content

Commit 7e1119f

Browse files
committed
fix vit-6b
1 parent 28bf517 commit 7e1119f

File tree

2 files changed

+3
-10
lines changed

2 files changed

+3
-10
lines changed

lightllm/models/vit/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,8 @@ def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
8989
)
9090

9191
def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
92-
if self.tp_world_size_ > 1:
93-
q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight)
94-
k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight)
95-
else:
96-
q_norm = rms_norm(
97-
q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
98-
)
99-
k_norm = rms_norm(
100-
k, weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
101-
)
92+
q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight)
93+
k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight)
10294
return q_norm, k_norm
10395

10496
def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:

lightllm/models/vit/triton_kernel/rms_norm_vit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def rms_norm_kernel(
3535

3636
def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-5, use_custom_tensor_mananger: bool = False):
3737
"""Rms norm."""
38+
assert len(hidden_states.shape) == 2
3839
feat_size = weight.shape[0]
3940
seq_len = hidden_states.numel() // hidden_states.size(-1)
4041
input_stride = hidden_states.stride(-2)

0 commit comments

Comments
 (0)