Skip to content

Commit 28bf517

Browse files
flyinglandlordgushiqiaohiworldwzj
authored
[Model] Add support for gemma3 model (#825)
Co-authored-by: gushiqiao <[email protected]> Co-authored-by: hiworldwzj <[email protected]>
1 parent 1e01498 commit 28bf517

File tree

17 files changed

+810
-4
lines changed

17 files changed

+810
-4
lines changed

docs/CN/source/models/supported_models.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ lightllm 支持大多数的主流的开源大语言模型以及多模态模型
7979
- :code:`--enable_multimodal`
8080
* - `Qwen2-VL <https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct>`_
8181
- :code:`--enable_multimodal`
82+
* - `Google Gemma3 <https://huggingface.co/google/gemma-3-12b-it>`_
83+
- :code:`--enable_multimodal`
8284

8385

8486
Reward模型

docs/EN/source/models/supported_models.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ VLM
7676
- :code:`--enable_multimodal`
7777
* - `Llava-13b <https://huggingface.co/liuhaotian/llava-v1.5-13b>`_
7878
- :code:`--enable_multimodal`
79+
* - `Google Gemma3 <https://huggingface.co/google/gemma-3-12b-it>`_
80+
- :code:`--enable_multimodal`
7981

8082

8183
Reward Model

lightllm/models/gemma3/__init__.py

Whitespace-only changes.
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import torch
2+
import torch.nn.functional as F
3+
import torch.nn as nn
4+
import json
5+
import os
6+
from PIL import Image
7+
from typing import List, Union
8+
from safetensors import safe_open
9+
from io import BytesIO
10+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
11+
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
12+
from lightllm.utils.log_utils import init_logger
13+
14+
15+
logger = init_logger(__name__)
16+
17+
18+
class Gemma3VisionModel:
19+
def __init__(self):
20+
pass
21+
22+
def load_model(self, weight_dir):
23+
config_file = os.path.join(weight_dir, "config.json")
24+
config = json.load(open(config_file))
25+
26+
# for llava-v1.5-7b-hf model, should load config from transformers
27+
if "text_config" in config:
28+
self.load_hf_model(config, weight_dir)
29+
else:
30+
assert False, "only hf format model is supported for Gemma3"
31+
32+
self.patches_per_image = int(config["vision_config"]["image_size"] // config["vision_config"]["patch_size"])
33+
self.tokens_per_side = int(config["mm_tokens_per_image"] ** 0.5)
34+
self.kernel_size = self.patches_per_image // self.tokens_per_side
35+
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
36+
37+
self.vision_tower.requires_grad_(False)
38+
self.device = torch.device("cpu")
39+
40+
assert "model.mm_projector.linear" in self.projector_weights
41+
assert "model.mm_projector.norm" in self.projector_weights
42+
43+
def load_hf_model(self, config, weight_dir):
44+
from transformers import AutoConfig, AutoProcessor, Gemma3ForConditionalGeneration
45+
46+
# config = AutoConfig.from_pretrained(weight_dir, trust_remote_code=True)
47+
processor = AutoProcessor.from_pretrained(weight_dir)
48+
self.image_processor = processor.image_processor
49+
50+
model = Gemma3ForConditionalGeneration.from_pretrained(
51+
weight_dir,
52+
torch_dtype=torch.float16,
53+
)
54+
self.vision_tower = model.vision_tower
55+
model.multi_modal_projector = None
56+
model.language_model = None
57+
58+
# load projector weights
59+
self.projector_weights = {}
60+
for f in os.listdir(weight_dir):
61+
if f.endswith(".safetensors"):
62+
d = safe_open(os.path.join(weight_dir, f), "pt", "cpu")
63+
for k in d.keys():
64+
if "multi_modal_projector.mm_input_projection_weight" in k:
65+
self.projector_weights[
66+
k.replace("multi_modal_projector.mm_input_projection_weight", "model.mm_projector.linear")
67+
] = d.get_tensor(k).to(torch.bfloat16)
68+
if "multi_modal_projector.mm_soft_emb_norm.weight" in k:
69+
self.projector_weights[
70+
k.replace("multi_modal_projector.mm_soft_emb_norm.weight", "model.mm_projector.norm")
71+
] = d.get_tensor(k).to(torch.bfloat16)
72+
73+
def cuda(self):
74+
self.vision_tower = self.vision_tower.cuda()
75+
for k, v in self.projector_weights.items():
76+
self.projector_weights[k] = v.cuda()
77+
return self
78+
79+
def gemma3_rms_norm(self, input, weight, eps: float = 1e-6):
80+
def _norm(x):
81+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
82+
83+
output = _norm(input.float())
84+
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
85+
# See https://github.com/huggingface/transformers/pull/29402
86+
output = output * (1.0 + weight.float())
87+
return output.type_as(input)
88+
89+
# batch images infer
90+
def forward(self, x):
91+
x = x.to(torch.bfloat16).cuda()
92+
x = self.vision_tower(x, output_hidden_states=True).last_hidden_state
93+
94+
batch_size, _, seq_length = x.shape
95+
96+
reshaped_vision_outputs = x.transpose(1, 2)
97+
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
98+
batch_size, seq_length, self.patches_per_image, self.patches_per_image
99+
)
100+
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
101+
102+
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
103+
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
104+
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
105+
106+
normed_vision_outputs = self.gemma3_rms_norm(
107+
pooled_vision_outputs.float(), self.projector_weights["model.mm_projector.norm"]
108+
).to(torch.bfloat16)
109+
110+
projected_vision_outputs = torch.matmul(
111+
normed_vision_outputs, self.projector_weights["model.mm_projector.linear"]
112+
)
113+
114+
return projected_vision_outputs.type_as(x)
115+
116+
def encode(self, images: List[ImageItem]):
117+
img_tensors = []
118+
uuids = []
119+
valid_id = 0
120+
valid_ids = []
121+
122+
for i, img in enumerate(images):
123+
if isinstance(img, ImageItem):
124+
uuids.append(img.uuid)
125+
image_data = read_shm(get_shm_name_data(img.uuid))
126+
image_data = Image.open(BytesIO(image_data))
127+
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
128+
img_tensors.append(t)
129+
else:
130+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
131+
132+
cur_num = img_tensors[-1].shape[0]
133+
valid_ids.append([valid_id, valid_id + cur_num])
134+
valid_id += cur_num
135+
136+
if len(img_tensors) <= 0:
137+
return None
138+
139+
img = torch.cat(img_tensors, dim=0)
140+
all_img_embeds = self.forward(img)
141+
142+
return all_img_embeds, uuids, valid_ids
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.common.basemodel import InferStateInfo
4+
from lightllm.common.req_manager import ReqManager
5+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6+
7+
8+
class Gemma3InferStateInfo(LlamaInferStateInfo):
9+
def __init__(self):
10+
super().__init__()
11+
self.position_cos_global = None
12+
self.position_sin_global = None
13+
self.position_sin_local = None
14+
self.position_cos_local = None
15+
16+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
17+
if self.is_prefill:
18+
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
19+
self.max_seq_len = b_seq_len_numpy.max()
20+
b_ready_cache_len_numpy = self.b_ready_cache_len.cpu().numpy()
21+
position_ids = torch.from_numpy(
22+
np.concatenate(
23+
[np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))],
24+
axis=0,
25+
)
26+
).cuda()
27+
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
28+
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
29+
30+
self.position_cos_local = torch.index_select(model._cos_cached_local, 0, position_ids).view(
31+
position_ids.shape[0], -1
32+
)
33+
self.position_sin_local = torch.index_select(model._sin_cached_local, 0, position_ids).view(
34+
position_ids.shape[0], -1
35+
)
36+
37+
self.position_cos_global = torch.index_select(model._cos_cached_global, 0, position_ids).view(
38+
position_ids.shape[0], -1
39+
)
40+
self.position_sin_global = torch.index_select(model._sin_cached_global, 0, position_ids).view(
41+
position_ids.shape[0], -1
42+
)
43+
position_ids = None
44+
else:
45+
position_ids = self.b_seq_len - 1
46+
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1)
47+
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1)
48+
49+
self.position_cos_local = torch.index_select(model._cos_cached_local, 0, position_ids).view(
50+
self.b_seq_len.shape[0], -1
51+
)
52+
self.position_sin_local = torch.index_select(model._sin_cached_local, 0, position_ids).view(
53+
self.b_seq_len.shape[0], -1
54+
)
55+
56+
self.position_cos_global = torch.index_select(model._cos_cached_global, 0, position_ids).view(
57+
self.b_seq_len.shape[0], -1
58+
)
59+
self.position_sin_global = torch.index_select(model._sin_cached_global, 0, position_ids).view(
60+
self.b_seq_len.shape[0], -1
61+
)
62+
return

lightllm/models/gemma3/layer_infer/__init__.py

Whitespace-only changes.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
import torch
3+
4+
from lightllm.distributed.communication_op import all_gather
5+
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
6+
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
7+
8+
9+
class Gemma3PostLayerInfer(LlamaPostLayerInfer):
10+
""" """
11+
12+
def __init__(self, network_config, mode):
13+
super().__init__(network_config, mode)
14+
self.eps_ = 1e-6
15+
return
16+
17+
def gemma3_rmsnorm(self, input, weight, eps: float = 1e-6, out=None):
18+
def _inner_norm(x):
19+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
20+
21+
output = _inner_norm(input.float())
22+
output = output * (1.0 + weight.float())
23+
if out is not None:
24+
out = output.to(out.dtype)
25+
return output
26+
27+
def _norm(self, input, infer_state, layer_weight) -> torch.Tensor:
28+
return self.gemma3_rmsnorm(input, layer_weight.final_norm_weight_, eps=self.eps_)
29+
30+
def token_forward(self, input_embdings, infer_state, layer_weight):
31+
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state)
32+
input_embdings_dtype = input_embdings.dtype
33+
last_input = self._norm(last_input.float(), infer_state, layer_weight).to(torch.bfloat16)
34+
last_input = last_input.permute(1, 0).view(-1, token_num)
35+
logic_batch = self.alloc_tensor(
36+
(layer_weight.lm_head_weight_.shape[0], last_input.shape[1]), dtype=last_input.dtype
37+
)
38+
torch.mm(layer_weight.lm_head_weight_.to(last_input.dtype), last_input, out=logic_batch)
39+
last_input = None
40+
if self.tp_world_size_ == 1:
41+
gather_data = logic_batch
42+
else:
43+
gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype)
44+
split_indexes = np.linspace(0, self.vocab_size_, self.tp_world_size_ + 1, dtype=np.int64)
45+
all_gather(
46+
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)],
47+
logic_batch,
48+
group=infer_state.dist_group,
49+
async_op=False,
50+
)
51+
logic_batch = None
52+
ans_logics = self.alloc_tensor(
53+
(token_num, self.vocab_size_),
54+
dtype=torch.float32,
55+
is_graph_out=True,
56+
microbatch_index=infer_state.microbatch_index,
57+
)
58+
ans_logics[:, :] = gather_data.permute(1, 0)
59+
gather_data = None
60+
return ans_logics
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import torch
2+
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
3+
from lightllm.distributed.communication_op import all_reduce
4+
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
5+
from lightllm.server.embed_cache.utils import bytes2tensor, get_shm_name_embed, read_shm
6+
7+
8+
class Gemma3PreLayerInfer(LlamaMultimodalPreLayerInfer):
9+
def __init__(self, network_config, mode):
10+
super().__init__(network_config, mode)
11+
self.embed_scale = torch.tensor(network_config["hidden_size"] ** 0.5, dtype=torch.float32)
12+
self.boi_token_index: int = 255_999
13+
self.eoi_token_index: int = 256_000
14+
return
15+
16+
def context_forward(self, input_ids, infer_state, layer_weight):
17+
img_weight = []
18+
img_start_token_ids = []
19+
img_token_lens = []
20+
img_start_loc = 0
21+
img_start_locs = []
22+
device = layer_weight.wte_weight_.device
23+
dtype = layer_weight.wte_weight_.dtype
24+
hidden_size = layer_weight.wte_weight_.shape[1]
25+
weight_mask = torch.zeros((len(input_ids)), dtype=torch.float32, device=device)
26+
27+
scale = self.embed_scale
28+
for idx, input_id in enumerate(input_ids):
29+
if input_id == self.boi_token_index:
30+
weight_mask[idx] = scale
31+
scale = 1.0
32+
elif input_id == self.eoi_token_index:
33+
scale = self.embed_scale
34+
weight_mask[idx] = scale
35+
else:
36+
weight_mask[idx] = scale
37+
38+
for batch_id, p in enumerate(infer_state.multimodal_params):
39+
for img in p["images"]:
40+
# skip the same image
41+
if img["token_id"] in img_start_token_ids:
42+
continue
43+
# pull the img_embeds by uid from shm
44+
data = read_shm(get_shm_name_embed(img["uuid"]))
45+
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
46+
img_start_token_ids.append(img["token_id"])
47+
img_token_lens.append(img["token_num"])
48+
img_start_locs.append(img_start_loc)
49+
img_start_loc += img["token_num"]
50+
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
51+
if len(img_weight) > 0:
52+
img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype)
53+
else:
54+
img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype)
55+
assert img_weight.shape[1] == hidden_size, (
56+
f"Dimension mismatch: text weight dimension is {hidden_size}, "
57+
f"but image weight dimension is {img_weight.shape[1]}"
58+
)
59+
# each tp will fill the img embeds, should divide by world_size
60+
img_weight = img_weight / self.tp_world_size_
61+
img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long)
62+
img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long)
63+
img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long)
64+
65+
multimodal_emb(
66+
out,
67+
input_ids,
68+
layer_weight.wte_weight_,
69+
img_weight,
70+
img_token_lens,
71+
img_start_token_ids,
72+
img_start_locs,
73+
self.vob_start_id_,
74+
self.vob_end_id_,
75+
)
76+
input_dtype = out.dtype
77+
if self.tp_world_size_ > 1:
78+
all_reduce(out, group=infer_state.dist_group, op=torch.dist.ReduceOp.SUM, async_op=False)
79+
return (out.float() * weight_mask.unsqueeze(1).float()).to(input_dtype)
80+
81+
def token_forward(self, input_ids, infer_state, layer_weight):
82+
input_embedding = super().token_forward(input_ids, infer_state, layer_weight)
83+
input_dtype = input_embedding.dtype
84+
return (input_embedding.float() * self.embed_scale.to(input_embedding.device).float()).to(input_dtype)

0 commit comments

Comments
 (0)