Skip to content

Commit a76ec74

Browse files
fix
1 parent 5595dd4 commit a76ec74

File tree

8 files changed

+160
-100
lines changed

8 files changed

+160
-100
lines changed

lightllm/models/gemma3/gemma3_visual.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def load_model(self, weight_dir):
2929
else:
3030
assert False, "only hf format model is supported for Gemma3"
3131

32-
self.patches_per_image = int(config['vision_config']['image_size'] // config['vision_config']['patch_size'])
33-
self.tokens_per_side = int(config['mm_tokens_per_image']**0.5)
32+
self.patches_per_image = int(config["vision_config"]["image_size"] // config["vision_config"]["patch_size"])
33+
self.tokens_per_side = int(config["mm_tokens_per_image"] ** 0.5)
3434
self.kernel_size = self.patches_per_image // self.tokens_per_side
3535
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
3636

@@ -43,7 +43,7 @@ def load_model(self, weight_dir):
4343
def load_hf_model(self, config, weight_dir):
4444
from transformers import AutoConfig, AutoProcessor, Gemma3ForConditionalGeneration
4545

46-
config = AutoConfig.from_pretrained(weight_dir, trust_remote_code=True)
46+
# config = AutoConfig.from_pretrained(weight_dir, trust_remote_code=True)
4747
processor = AutoProcessor.from_pretrained(weight_dir)
4848
self.image_processor = processor.image_processor
4949

@@ -79,6 +79,7 @@ def cuda(self):
7979
def gemma3_rms_norm(self, input, weight, eps: float = 1e-6):
8080
def _norm(x):
8181
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
82+
8283
output = _norm(input.float())
8384
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
8485
# See https://github.com/huggingface/transformers/pull/29402
@@ -89,7 +90,7 @@ def _norm(x):
8990
def forward(self, x):
9091
x = x.to(torch.bfloat16).cuda()
9192
x = self.vision_tower(x, output_hidden_states=True).last_hidden_state
92-
93+
9394
batch_size, _, seq_length = x.shape
9495

9596
reshaped_vision_outputs = x.transpose(1, 2)
@@ -102,10 +103,14 @@ def forward(self, x):
102103
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
103104
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
104105

105-
normed_vision_outputs = self.gemma3_rms_norm(pooled_vision_outputs.float(), self.projector_weights['model.mm_projector.norm']).to(torch.float32)
106+
normed_vision_outputs = self.gemma3_rms_norm(
107+
pooled_vision_outputs.float(), self.projector_weights["model.mm_projector.norm"]
108+
).to(torch.bfloat16)
109+
110+
projected_vision_outputs = torch.matmul(
111+
normed_vision_outputs, self.projector_weights["model.mm_projector.linear"]
112+
)
106113

107-
projected_vision_outputs = torch.matmul(normed_vision_outputs, self.projector_weights['model.mm_projector.linear'])
108-
#print(projected_vision_outputs.type_as(x))
109114
return projected_vision_outputs.type_as(x)
110115

111116
def encode(self, images: List[ImageItem]):
@@ -120,7 +125,6 @@ def encode(self, images: List[ImageItem]):
120125
image_data = read_shm(get_shm_name_data(img.uuid))
121126
image_data = Image.open(BytesIO(image_data))
122127
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
123-
#print(t)
124128
img_tensors.append(t)
125129
else:
126130
raise Exception("Unsupport input types: {} for {}".format(type(img), img))

lightllm/models/gemma3/infer_struct.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,36 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2727
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
2828
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
2929

30-
self.position_cos_local = torch.index_select(model._cos_cached_local, 0, position_ids).view(position_ids.shape[0], -1)
31-
self.position_sin_local = torch.index_select(model._sin_cached_local, 0, position_ids).view(position_ids.shape[0], -1)
30+
self.position_cos_local = torch.index_select(model._cos_cached_local, 0, position_ids).view(
31+
position_ids.shape[0], -1
32+
)
33+
self.position_sin_local = torch.index_select(model._sin_cached_local, 0, position_ids).view(
34+
position_ids.shape[0], -1
35+
)
3236

33-
self.position_cos_global = torch.index_select(model._cos_cached_global, 0, position_ids).view(position_ids.shape[0], -1)
34-
self.position_sin_global = torch.index_select(model._sin_cached_global, 0, position_ids).view(position_ids.shape[0], -1)
37+
self.position_cos_global = torch.index_select(model._cos_cached_global, 0, position_ids).view(
38+
position_ids.shape[0], -1
39+
)
40+
self.position_sin_global = torch.index_select(model._sin_cached_global, 0, position_ids).view(
41+
position_ids.shape[0], -1
42+
)
3543
position_ids = None
3644
else:
3745
position_ids = self.b_seq_len - 1
3846
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1)
3947
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1)
4048

41-
self.position_cos_local = torch.index_select(model._cos_cached_local, 0, position_ids).view(self.b_seq_len.shape[0], -1)
42-
self.position_sin_local = torch.index_select(model._sin_cached_local, 0, position_ids).view(self.b_seq_len.shape[0], -1)
49+
self.position_cos_local = torch.index_select(model._cos_cached_local, 0, position_ids).view(
50+
self.b_seq_len.shape[0], -1
51+
)
52+
self.position_sin_local = torch.index_select(model._sin_cached_local, 0, position_ids).view(
53+
self.b_seq_len.shape[0], -1
54+
)
4355

44-
self.position_cos_global = torch.index_select(model._cos_cached_global, 0, position_ids).view(self.b_seq_len.shape[0], -1)
45-
self.position_sin_global = torch.index_select(model._sin_cached_global, 0, position_ids).view(self.b_seq_len.shape[0], -1)
46-
return
56+
self.position_cos_global = torch.index_select(model._cos_cached_global, 0, position_ids).view(
57+
self.b_seq_len.shape[0], -1
58+
)
59+
self.position_sin_global = torch.index_select(model._sin_cached_global, 0, position_ids).view(
60+
self.b_seq_len.shape[0], -1
61+
)
62+
return

lightllm/models/gemma3/layer_infer/post_layer_infer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
77

88

9-
109
class Gemma3PostLayerInfer(LlamaPostLayerInfer):
1110
""" """
1211

@@ -15,9 +14,10 @@ def __init__(self, network_config, mode):
1514
self.eps_ = 1e-6
1615
return
1716

18-
def gemma3_rmsnorm(self, input, weight, eps: float = 1e-6, out = None):
17+
def gemma3_rmsnorm(self, input, weight, eps: float = 1e-6, out=None):
1918
def _inner_norm(x):
2019
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
20+
2121
output = _inner_norm(input.float())
2222
output = output * (1.0 + weight.float())
2323
if out is not None:
@@ -26,7 +26,7 @@ def _inner_norm(x):
2626

2727
def _norm(self, input, infer_state, layer_weight) -> torch.Tensor:
2828
return self.gemma3_rmsnorm(input, layer_weight.final_norm_weight_, eps=self.eps_)
29-
29+
3030
def token_forward(self, input_embdings, infer_state, layer_weight):
3131
# print('last_hidden_before_norm', input_embdings)
3232
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state)
@@ -58,4 +58,4 @@ def token_forward(self, input_embdings, infer_state, layer_weight):
5858
)
5959
ans_logics[:, :] = gather_data.permute(1, 0)
6060
gather_data = None
61-
return ans_logics
61+
return ans_logics

lightllm/models/gemma3/layer_infer/pre_layer_infer.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,84 @@
11
import torch
2+
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
3+
from lightllm.distributed.communication_op import all_reduce
24
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
5+
from lightllm.server.embed_cache.utils import bytes2tensor, get_shm_name_embed, read_shm
36

47

58
class Gemma3PreLayerInfer(LlamaMultimodalPreLayerInfer):
69
def __init__(self, network_config, mode):
710
super().__init__(network_config, mode)
8-
self.embed_scale = torch.tensor(network_config['hidden_size']**0.5, dtype=torch.float32)
11+
self.embed_scale = torch.tensor(network_config["hidden_size"] ** 0.5, dtype=torch.float32)
12+
self.boi_token_index: int = 255_999
13+
self.eoi_token_index: int = 256_000
914
return
1015

1116
def context_forward(self, input_ids, infer_state, layer_weight):
12-
input_embedding = super().context_forward(input_ids, infer_state, layer_weight)
13-
input_dtype = input_embedding.dtype
14-
return (input_embedding.float() * self.embed_scale.to(input_embedding.device).float()).to(input_dtype)
17+
img_weight = []
18+
img_start_token_ids = []
19+
img_token_lens = []
20+
img_start_loc = 0
21+
img_start_locs = []
22+
device = layer_weight.wte_weight_.device
23+
dtype = layer_weight.wte_weight_.dtype
24+
hidden_size = layer_weight.wte_weight_.shape[1]
25+
weight_mask = torch.zeros((len(input_ids)), dtype=torch.float32, device=device)
26+
27+
scale = self.embed_scale
28+
for idx, input_id in enumerate(input_ids):
29+
if input_id == self.boi_token_index:
30+
weight_mask[idx] = scale
31+
scale = 1.0
32+
elif input_id == self.eoi_token_index:
33+
scale = self.embed_scale
34+
weight_mask[idx] = scale
35+
else:
36+
weight_mask[idx] = scale
37+
38+
for batch_id, p in enumerate(infer_state.multimodal_params):
39+
for img in p["images"]:
40+
# skip the same image
41+
if img["token_id"] in img_start_token_ids:
42+
continue
43+
# pull the img_embeds by uid from shm
44+
data = read_shm(get_shm_name_embed(img["uuid"]))
45+
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
46+
img_start_token_ids.append(img["token_id"])
47+
img_token_lens.append(img["token_num"])
48+
img_start_locs.append(img_start_loc)
49+
img_start_loc += img["token_num"]
50+
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
51+
if len(img_weight) > 0:
52+
img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype)
53+
else:
54+
img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype)
55+
assert img_weight.shape[1] == hidden_size, (
56+
f"Dimension mismatch: text weight dimension is {hidden_size}, "
57+
f"but image weight dimension is {img_weight.shape[1]}"
58+
)
59+
# each tp will fill the img embeds, should divide by world_size
60+
img_weight = img_weight / self.tp_world_size_
61+
img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long)
62+
img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long)
63+
img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long)
64+
65+
multimodal_emb(
66+
out,
67+
input_ids,
68+
layer_weight.wte_weight_,
69+
img_weight,
70+
img_token_lens,
71+
img_start_token_ids,
72+
img_start_locs,
73+
self.vob_start_id_,
74+
self.vob_end_id_,
75+
)
76+
input_dtype = out.dtype
77+
if self.tp_world_size_ > 1:
78+
all_reduce(out, group=infer_state.dist_group, op=torch.dist.ReduceOp.SUM, async_op=False)
79+
return (out.float() * weight_mask.unsqueeze(1).float()).to(input_dtype)
1580

1681
def token_forward(self, input_ids, infer_state, layer_weight):
1782
input_embedding = super().token_forward(input_ids, infer_state, layer_weight)
1883
input_dtype = input_embedding.dtype
1984
return (input_embedding.float() * self.embed_scale.to(input_embedding.device).float()).to(input_dtype)
20-

0 commit comments

Comments
 (0)