Skip to content

Commit 2c7c0ff

Browse files
committed
fix precision(ongoing)
1 parent 3fbf429 commit 2c7c0ff

File tree

4 files changed

+14
-4
lines changed

4 files changed

+14
-4
lines changed

lightllm/models/vit/layer_infer/transformer_layer_infer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,22 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
3131

3232
def norm(self, input, weight):
3333
input_dtype = input.dtype
34+
input_shape = input.shape
35+
input = input.view(-1, self.tp_padding_head_num * self.head_dim_)
3436
input = input.to(torch.float32)
3537
variance = input.pow(2).mean(-1, keepdim=True)
3638
input = input * torch.rsqrt(variance + self.eps_)
37-
return weight * input.to(input_dtype)
39+
out = weight * input.to(input_dtype)
40+
out = out.reshape(input_shape)
41+
return out
3842

3943
def tp_norm(self, input, weight):
4044
input_shape = input.shape
4145
input = input.view(-1, self.tp_padding_head_num * self.head_dim_)
4246
input_dtype = input.dtype
4347
input = input.to(torch.float32)
4448
tp_variance = input.pow(2).sum(-1, keepdim=True)
45-
dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False)
49+
# dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False)
4650
variance = tp_variance / self.embed_dim_
4751
input = input * torch.rsqrt(variance + self.eps_)
4852
out = weight * input.to(input_dtype)
@@ -75,8 +79,8 @@ def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
7579
)
7680

7781
def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
78-
q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight)
79-
k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight)
82+
q_norm = self.norm(q, layer_weight.q_norm_weight_.weight)
83+
k_norm = self.norm(k, layer_weight.k_norm_weight_.weight)
8084
return q_norm, k_norm
8185

8286
def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
@@ -85,6 +89,9 @@ def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tens
8589
qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False)
8690
qkv = qkv.view(batch_size, seq_len, 3, -1, self.head_dim_)
8791
q, k, v = qkv.unbind(2)
92+
q = q.contiguous()
93+
k = k.contiguous()
94+
v = v.contiguous()
8895
return q, k, v
8996

9097
def _context_attention_kernel(self, q, k, v) -> torch.Tensor:

lightllm/models/vit/layer_weights/transformer_layer_weight.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,5 +163,6 @@ def load_hf_weights(self, weights):
163163
ls2 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls2"]
164164
self.ls2 = self._cuda(ls2)
165165
self.use_ls = True
166+
print(self.ls1)
166167

167168
return super().load_hf_weights(weights)

lightllm/models/vit/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def encode(self, image_uuids: List):
163163

164164
imgs = torch.cat(img_tensors, dim=0)
165165
pixel_values = imgs.cuda().to(dtype=self.data_type)
166+
print(pixel_values.shape, pixel_values.dtype)
166167
all_img_embeds = self.forward(pixel_values)
167168
return all_img_embeds, uuids, valid_ids
168169

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def exposed_init_model(self, kvargs):
6363
"quant_cfg": kvargs["quant_cfg"],
6464
}
6565
self.model = VisionTransformer(kvargs)
66+
# self.model = InternVLVisionModel()
6667
else:
6768
raise Exception(f"can not support {self.model_type} now")
6869

0 commit comments

Comments
 (0)