|
1 | 1 | import torch |
| 2 | +from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb |
| 3 | +from lightllm.distributed.communication_op import all_reduce |
2 | 4 | from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer |
| 5 | +from lightllm.server.embed_cache.utils import bytes2tensor, get_shm_name_embed, read_shm |
3 | 6 |
|
4 | 7 |
|
5 | 8 | class Gemma3PreLayerInfer(LlamaMultimodalPreLayerInfer): |
6 | 9 | def __init__(self, network_config, mode): |
7 | 10 | super().__init__(network_config, mode) |
8 | | - self.embed_scale = torch.tensor(network_config['hidden_size']**0.5, dtype=torch.float32) |
| 11 | + self.embed_scale = torch.tensor(network_config["hidden_size"] ** 0.5, dtype=torch.float32) |
| 12 | + self.boi_token_index: int = 255_999 |
| 13 | + self.eoi_token_index: int = 256_000 |
9 | 14 | return |
10 | 15 |
|
11 | 16 | def context_forward(self, input_ids, infer_state, layer_weight): |
12 | | - input_embedding = super().context_forward(input_ids, infer_state, layer_weight) |
13 | | - input_dtype = input_embedding.dtype |
14 | | - return (input_embedding.float() * self.embed_scale.to(input_embedding.device).float()).to(input_dtype) |
| 17 | + img_weight = [] |
| 18 | + img_start_token_ids = [] |
| 19 | + img_token_lens = [] |
| 20 | + img_start_loc = 0 |
| 21 | + img_start_locs = [] |
| 22 | + device = layer_weight.wte_weight_.device |
| 23 | + dtype = layer_weight.wte_weight_.dtype |
| 24 | + hidden_size = layer_weight.wte_weight_.shape[1] |
| 25 | + weight_mask = torch.zeros((len(input_ids)), dtype=torch.float32, device=device) |
| 26 | + |
| 27 | + scale = self.embed_scale |
| 28 | + for idx, input_id in enumerate(input_ids): |
| 29 | + if input_id == self.boi_token_index: |
| 30 | + weight_mask[idx] = scale |
| 31 | + scale = 1.0 |
| 32 | + elif input_id == self.eoi_token_index: |
| 33 | + scale = self.embed_scale |
| 34 | + weight_mask[idx] = scale |
| 35 | + else: |
| 36 | + weight_mask[idx] = scale |
| 37 | + |
| 38 | + for batch_id, p in enumerate(infer_state.multimodal_params): |
| 39 | + for img in p["images"]: |
| 40 | + # skip the same image |
| 41 | + if img["token_id"] in img_start_token_ids: |
| 42 | + continue |
| 43 | + # pull the img_embeds by uid from shm |
| 44 | + data = read_shm(get_shm_name_embed(img["uuid"])) |
| 45 | + img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) |
| 46 | + img_start_token_ids.append(img["token_id"]) |
| 47 | + img_token_lens.append(img["token_num"]) |
| 48 | + img_start_locs.append(img_start_loc) |
| 49 | + img_start_loc += img["token_num"] |
| 50 | + out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) |
| 51 | + if len(img_weight) > 0: |
| 52 | + img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype) |
| 53 | + else: |
| 54 | + img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype) |
| 55 | + assert img_weight.shape[1] == hidden_size, ( |
| 56 | + f"Dimension mismatch: text weight dimension is {hidden_size}, " |
| 57 | + f"but image weight dimension is {img_weight.shape[1]}" |
| 58 | + ) |
| 59 | + # each tp will fill the img embeds, should divide by world_size |
| 60 | + img_weight = img_weight / self.tp_world_size_ |
| 61 | + img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long) |
| 62 | + img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long) |
| 63 | + img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long) |
| 64 | + |
| 65 | + multimodal_emb( |
| 66 | + out, |
| 67 | + input_ids, |
| 68 | + layer_weight.wte_weight_, |
| 69 | + img_weight, |
| 70 | + img_token_lens, |
| 71 | + img_start_token_ids, |
| 72 | + img_start_locs, |
| 73 | + self.vob_start_id_, |
| 74 | + self.vob_end_id_, |
| 75 | + ) |
| 76 | + input_dtype = out.dtype |
| 77 | + if self.tp_world_size_ > 1: |
| 78 | + all_reduce(out, group=infer_state.dist_group, op=torch.dist.ReduceOp.SUM, async_op=False) |
| 79 | + return (out.float() * weight_mask.unsqueeze(1).float()).to(input_dtype) |
15 | 80 |
|
16 | 81 | def token_forward(self, input_ids, infer_state, layer_weight): |
17 | 82 | input_embedding = super().token_forward(input_ids, infer_state, layer_weight) |
18 | 83 | input_dtype = input_embedding.dtype |
19 | 84 | return (input_embedding.float() * self.embed_scale.to(input_embedding.device).float()).to(input_dtype) |
20 | | - |
|
0 commit comments