Skip to content

Commit 46c70ff

Browse files
Improve internvl for turbomind engine (#3769)
* support internvl using moe model as LLM part * improvement * update * update * add interns1 template * update * support pt interns1 hf * support interns1 in turbomind * fix embedding data type mapping * add escaping for regex * add escaping for regex * fix linting --------- Co-authored-by: RunningLeon <[email protected]>
1 parent a731a7f commit 46c70ff

File tree

13 files changed

+129
-151
lines changed

13 files changed

+129
-151
lines changed

lmdeploy/turbomind/deploy/source_model/deepseek_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class DeepSeekVLReader(LlamaReader):
1111
"""DeepSeekVL model reader."""
1212

1313
attn_layer_prefix = 'language_model.model.layers'
14-
attn_layer_patten = r'language_model.model.layers.([0-9]+).'
14+
attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
1515
tok_embeddings_key = 'language_model.model.embed_tokens.weight'
1616
norm_weight_key = 'language_model.model.norm.weight'
1717
output_weight_key = 'language_model.lm_head.weight'

lmdeploy/turbomind/deploy/source_model/glm4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class Glm4Reader(LlamaReader):
1313
"""Glm4Reader."""
1414

15-
attn_layer_patten = r'transformer.encoder.layers.([0-9]+).'
15+
attn_layer_patten = r'transformer\.encoder\.layers\.([0-9]+).'
1616
tok_embeddings_key = 'transformer.embedding.word_embeddings.weight'
1717
norm_weight_key = 'transformer.encoder.final_layernorm.weight'
1818
output_weight_key = 'transformer.output_layer.weight'

lmdeploy/turbomind/deploy/source_model/internlm2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class InternLM2Reader(LlamaReader):
1010
"""InternLM2 model reader."""
1111

1212
attn_layer_prefix = 'model.layers'
13-
attn_layer_patten = r'model.layers.([0-9]+).'
13+
attn_layer_patten = r'model\.layers\.([0-9]+).'
1414
tok_embeddings_key = 'model.tok_embeddings.weight'
1515
norm_weight_key = 'model.norm.weight'
1616
output_weight_key = 'output.weight'
Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
import json
3-
import os.path as osp
4-
5-
from ..config import RopeParam
62
from .base import INPUT_MODELS
73
from .internlm2 import InternLM2Reader
84
from .llama import LlamaModel, LlamaReader
5+
from .qwen import Qwen3MoeReader
96

107

118
class InternVLReader(LlamaReader):
129
"""InternVLReader for llama model."""
1310

1411
attn_layer_prefix = 'language_model.model.layers'
15-
attn_layer_patten = r'language_model.model.layers.([0-9]+).'
12+
attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
1613
tok_embeddings_key = 'language_model.model.embed_tokens.weight'
1714
norm_weight_key = 'language_model.model.norm.weight'
1815
output_weight_key = 'language_model.lm_head.weight'
@@ -27,7 +24,7 @@ class InternVL2Reader(InternLM2Reader):
2724
"""InternVLReader for InternLM2 model."""
2825

2926
attn_layer_prefix = 'language_model.model.layers'
30-
attn_layer_patten = r'language_model.model.layers.([0-9]+).'
27+
attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
3128
tok_embeddings_key = 'language_model.model.tok_embeddings.weight'
3229
norm_weight_key = 'language_model.model.norm.weight'
3330
output_weight_key = 'language_model.output.weight'
@@ -37,6 +34,22 @@ def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_
3734
super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)
3835

3936

37+
class InternS1Reader(Qwen3MoeReader):
38+
"""InternVL3Reader for InternVL+Qwen3MoE model."""
39+
40+
attn_layer_prefix = 'model.language_model.layers'
41+
attn_layer_patten = r'model\.language_model\.layers\.([0-9]+).'
42+
tok_embeddings_key = 'model.language_model.embed_tokens.weight'
43+
norm_weight_key = 'model.language_model.norm.weight'
44+
output_weight_key = 'lm_head.weight'
45+
46+
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, **kwargs):
47+
model_cfg = model_cfg.get('text_config')
48+
if model_cfg is None:
49+
raise ValueError(f'Miss "text_config" in model config: {model_cfg}')
50+
super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)
51+
52+
4053
@INPUT_MODELS.register_module(name='internvl')
4154
class InternVLModel(LlamaModel):
4255
"""InternVL model in hf format."""
@@ -45,53 +58,18 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
4558
super().__init__(model_path, tokenizer_path, **kwargs)
4659
from transformers import AutoConfig
4760
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
48-
llm_config = getattr(config, 'llm_config', None) or getattr(config, 'text_config', None)
49-
arch = llm_config.architectures[0]
50-
_readers = dict(InternLM2ForCausalLM=InternVL2Reader,
51-
LlamaForCausalLM=InternVLReader,
52-
Qwen2ForCausalLM=InternVLReader)
53-
self.Reader = _readers[arch]
61+
self.llm_config = getattr(config, 'llm_config', None) or getattr(config, 'text_config', None)
62+
arch = self.llm_config.architectures[0]
63+
relations = dict(
64+
InternLM2ForCausalLM=('internlm2', InternVL2Reader),
65+
LlamaForCausalLM=('llama', InternVLReader),
66+
Qwen2ForCausalLM=('qwen2', InternVLReader),
67+
Qwen3MoeForCausalLM=('qwen3-moe', InternS1Reader),
68+
)
69+
llm_model, self.Reader = relations[arch]
70+
self.llm_model = INPUT_MODELS.get(llm_model)(model_path=model_path, tokenizer_path=tokenizer_path, **kwargs)
5471

5572
def model_info(self):
5673
"""Read model info."""
57-
params_path = osp.join(self.model_path, 'config.json')
58-
with open(params_path) as f:
59-
file_content = json.load(f)
60-
model_arg = file_content.get('llm_config') or file_content.get('text_config')
61-
num_layer = model_arg['num_hidden_layers']
62-
norm_eps = model_arg['rms_norm_eps']
63-
hidden_units = model_arg['hidden_size']
64-
attn_head_num = model_arg['num_attention_heads']
65-
vocab_size = model_arg['vocab_size']
66-
inter_size = model_arg['intermediate_size']
67-
if 'num_key_value_heads' in model_arg:
68-
kv_head_num = model_arg['num_key_value_heads']
69-
else:
70-
kv_head_num = model_arg['num_attention_heads']
71-
rope_theta = float(model_arg.get('rope_theta', 10000.0))
72-
max_position_embeddings = int(model_arg.get('max_position_embeddings', 0))
73-
rope_scaling = model_arg.get('rope_scaling', None)
74-
scaling_factor = 0.0
75-
scaling_type = 'default'
76-
if isinstance(rope_scaling, dict):
77-
scaling_type = model_arg['rope_scaling'].get('type', 'default')
78-
scaling_factor = model_arg['rope_scaling'].get('factor', '')
79-
attn_bias = 1 if model_arg['architectures'][0] == 'Qwen2ForCausalLM' else 0
80-
rotary_embedding = hidden_units // attn_head_num
81-
rope_param = RopeParam(type=scaling_type,
82-
base=rope_theta,
83-
dim=rotary_embedding,
84-
max_position_embeddings=max_position_embeddings,
85-
factor=scaling_factor)
86-
87-
return dict(num_layer=num_layer,
88-
size_per_head=hidden_units // attn_head_num,
89-
attn_bias=attn_bias,
90-
norm_eps=norm_eps,
91-
hidden_units=hidden_units,
92-
inter_size=inter_size,
93-
vocab_size=vocab_size,
94-
head_num=attn_head_num,
95-
kv_head_num=kv_head_num,
96-
max_position_embeddings=max_position_embeddings,
97-
rope_param=rope_param)
74+
self.llm_model.model_config = self.llm_config.to_dict()
75+
return self.llm_model.model_info()

lmdeploy/turbomind/deploy/source_model/llama.py

Lines changed: 70 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
import json
32
import math
4-
import os.path as osp
53
import re
64

75
import torch
@@ -17,7 +15,7 @@ class LlamaReader(BaseReader):
1715
"""LlamaReader."""
1816

1917
attn_layer_prefix = 'model.layers'
20-
attn_layer_patten = r'model.layers.([0-9]+).'
18+
attn_layer_patten = r'model\.layers\.([0-9]+).'
2119
tok_embeddings_key = 'model.embed_tokens.weight'
2220
norm_weight_key = 'model.norm.weight'
2321
output_weight_key = 'lm_head.weight'
@@ -118,79 +116,76 @@ def readers(self):
118116

119117
def model_info(self):
120118
"""Read model info."""
121-
params_path = osp.join(self.model_path, 'config.json')
122-
with open(params_path) as f:
123-
model_arg = json.load(f)
124-
num_layer = model_arg['num_hidden_layers']
125-
norm_eps = model_arg['rms_norm_eps']
126-
attn_head_num = model_arg['num_attention_heads']
127-
vocab_size = model_arg['vocab_size']
128-
inter_size = model_arg['intermediate_size']
129-
if 'num_key_value_heads' in model_arg:
130-
kv_head_num = model_arg['num_key_value_heads']
131-
else:
132-
kv_head_num = model_arg['num_attention_heads']
133-
hidden_units = model_arg['hidden_size']
134-
head_dim = model_arg.get('head_dim', hidden_units // attn_head_num)
135-
# compute rope param
136-
rope_theta = float(model_arg.get('rope_theta', 10000.0))
137-
max_position_embeddings = int(model_arg.get('max_position_embeddings', 0))
138-
rope_param = RopeParam(type='default', base=rope_theta, dim=head_dim)
139-
rope_scaling = model_arg.get('rope_scaling', None)
140-
if isinstance(rope_scaling, dict):
141-
llama2_scaling_type = rope_scaling.get('type', '')
142-
llama3_scaling_type = rope_scaling.get('rope_type', '')
143-
if llama2_scaling_type and llama3_scaling_type \
144-
and llama2_scaling_type != llama3_scaling_type:
145-
raise ValueError(f'Ambiguous rope_scaling in config: {model_arg}')
146-
scaling_type = llama2_scaling_type if llama2_scaling_type \
147-
else llama3_scaling_type
148-
if rope_scaling.get('mrope_section') is not None:
149-
# TODO: treat mrope as an option to the common rope functions
150-
scaling_type = 'mrope'
151-
scaling_factor = rope_scaling.get('factor', 0.0)
152-
if scaling_type == 'default':
153-
pass
154-
elif scaling_type == 'dynamic':
155-
rope_param.type = 'dynamic'
156-
rope_param.factor = scaling_factor
157-
rope_param.max_position_embeddings = max_position_embeddings
158-
elif scaling_type == 'linear':
159-
rope_param.type = 'linear'
160-
rope_param.factor = scaling_factor
161-
elif scaling_type == 'llama3':
162-
low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
163-
high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
164-
original_max_position_embeddings = model_arg['rope_scaling'].get(
165-
'original_max_position_embeddings', 0)
166-
rope_param.type = 'llama3'
167-
rope_param.factor = scaling_factor
168-
rope_param.low_freq_factor = low_freq_factor
169-
rope_param.high_freq_factor = high_freq_factor
170-
rope_param.original_max_position_embeddings = original_max_position_embeddings
171-
elif scaling_type == 'yarn':
172-
attention_factor = rope_scaling.get('attention_factor', None)
173-
if attention_factor is None:
174-
attention_factor = 0.1 * math.log(scaling_factor) + 1.0
175-
beta_fast = rope_scaling.get('beta_fast', 32.0)
176-
beta_slow = rope_scaling.get('beta_slow', 1.0)
177-
rope_param.type = 'yarn'
178-
if 'original_max_position_embeddings' in rope_scaling:
179-
original_max_position_embeddings = rope_scaling['original_max_position_embeddings']
180-
scaling_factor = max_position_embeddings / original_max_position_embeddings
181-
else:
182-
original_max_position_embeddings = max_position_embeddings
183-
rope_param.factor = scaling_factor
184-
rope_param.max_position_embeddings = original_max_position_embeddings
185-
rope_param.attention_factor = attention_factor
186-
rope_param.beta_fast = beta_fast
187-
rope_param.beta_slow = beta_slow
188-
elif scaling_type == 'mrope':
189-
mrope_section = rope_scaling.get('mrope_section')
190-
rope_param.type = 'mrope'
191-
rope_param.mrope_section = mrope_section
119+
model_arg = self.model_config
120+
num_layer = model_arg['num_hidden_layers']
121+
norm_eps = model_arg['rms_norm_eps']
122+
attn_head_num = model_arg['num_attention_heads']
123+
vocab_size = model_arg['vocab_size']
124+
inter_size = model_arg['intermediate_size']
125+
if 'num_key_value_heads' in model_arg:
126+
kv_head_num = model_arg['num_key_value_heads']
127+
else:
128+
kv_head_num = model_arg['num_attention_heads']
129+
hidden_units = model_arg['hidden_size']
130+
head_dim = model_arg.get('head_dim', hidden_units // attn_head_num)
131+
# compute rope param
132+
rope_theta = float(model_arg.get('rope_theta', 10000.0))
133+
max_position_embeddings = int(model_arg.get('max_position_embeddings', 0))
134+
rope_param = RopeParam(type='default', base=rope_theta, dim=head_dim)
135+
rope_scaling = model_arg.get('rope_scaling', None)
136+
if isinstance(rope_scaling, dict):
137+
llama2_scaling_type = rope_scaling.get('type', '')
138+
llama3_scaling_type = rope_scaling.get('rope_type', '')
139+
if llama2_scaling_type and llama3_scaling_type \
140+
and llama2_scaling_type != llama3_scaling_type:
141+
raise ValueError(f'Ambiguous rope_scaling in config: {model_arg}')
142+
scaling_type = llama2_scaling_type if llama2_scaling_type \
143+
else llama3_scaling_type
144+
if rope_scaling.get('mrope_section') is not None:
145+
# TODO: treat mrope as an option to the common rope functions
146+
scaling_type = 'mrope'
147+
scaling_factor = rope_scaling.get('factor', 0.0)
148+
if scaling_type == 'default':
149+
pass
150+
elif scaling_type == 'dynamic':
151+
rope_param.type = 'dynamic'
152+
rope_param.factor = scaling_factor
153+
rope_param.max_position_embeddings = max_position_embeddings
154+
elif scaling_type == 'linear':
155+
rope_param.type = 'linear'
156+
rope_param.factor = scaling_factor
157+
elif scaling_type == 'llama3':
158+
low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
159+
high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
160+
original_max_position_embeddings = model_arg['rope_scaling'].get('original_max_position_embeddings', 0)
161+
rope_param.type = 'llama3'
162+
rope_param.factor = scaling_factor
163+
rope_param.low_freq_factor = low_freq_factor
164+
rope_param.high_freq_factor = high_freq_factor
165+
rope_param.original_max_position_embeddings = original_max_position_embeddings
166+
elif scaling_type == 'yarn':
167+
attention_factor = rope_scaling.get('attention_factor', None)
168+
if attention_factor is None:
169+
attention_factor = 0.1 * math.log(scaling_factor) + 1.0
170+
beta_fast = rope_scaling.get('beta_fast', 32.0)
171+
beta_slow = rope_scaling.get('beta_slow', 1.0)
172+
rope_param.type = 'yarn'
173+
if 'original_max_position_embeddings' in rope_scaling:
174+
original_max_position_embeddings = rope_scaling['original_max_position_embeddings']
175+
scaling_factor = max_position_embeddings / original_max_position_embeddings
192176
else:
193-
raise RuntimeError(f'Unsupported rope type: {scaling_type}')
177+
original_max_position_embeddings = max_position_embeddings
178+
rope_param.factor = scaling_factor
179+
rope_param.max_position_embeddings = original_max_position_embeddings
180+
rope_param.attention_factor = attention_factor
181+
rope_param.beta_fast = beta_fast
182+
rope_param.beta_slow = beta_slow
183+
elif scaling_type == 'mrope':
184+
mrope_section = rope_scaling.get('mrope_section')
185+
rope_param.type = 'mrope'
186+
rope_param.mrope_section = mrope_section
187+
else:
188+
raise RuntimeError(f'Unsupported rope type: {scaling_type}')
194189

195190
return dict(size_per_head=head_dim,
196191
num_layer=num_layer,

lmdeploy/turbomind/deploy/source_model/llava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class LlavaReader(LlamaReader):
1111
"""LlavaReader for llama model."""
1212

1313
attn_layer_prefix = 'language_model.model.layers'
14-
attn_layer_patten = r'language_model.model.layers.([0-9]+).'
14+
attn_layer_patten = r'language_model\.model\.layers\.([0-9]+).'
1515
tok_embeddings_key = 'language_model.model.embed_tokens.weight'
1616
norm_weight_key = 'language_model.model.norm.weight'
1717
output_weight_key = 'language_model.lm_head.weight'

lmdeploy/turbomind/deploy/source_model/minicpmv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class MiniCPMVReader(LlamaReader):
1111
"""MiniCPMVReader for llama model."""
1212

1313
attn_layer_prefix = 'llm.model.layers'
14-
attn_layer_patten = r'llm.model.layers.([0-9]+).'
14+
attn_layer_patten = r'llm\.model\.layers\.([0-9]+).'
1515
tok_embeddings_key = 'llm.model.embed_tokens.weight'
1616
norm_weight_key = 'llm.model.norm.weight'
1717
output_weight_key = 'llm.lm_head.weight'

lmdeploy/turbomind/deploy/source_model/molmo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class MolmoReader(LlamaReader):
1313
attn_layer_prefix = 'model.transformer.blocks'
14-
attn_layer_patten = r'model.transformer.blocks.([0-9]+).'
14+
attn_layer_patten = r'model\.transformer\.blocks\.([0-9]+).'
1515
norm_weight_key = 'model.transformer.ln_f.weight'
1616
output_weight_key = 'model.transformer.ff_out.weight'
1717

lmdeploy/turbomind/deploy/source_model/qwen.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class QwenReader(LlamaReader):
1313
"""QwenReader."""
1414

15-
attn_layer_patten = r'transformer.h.([0-9]+).'
15+
attn_layer_patten = r'transformer\.h\.([0-9]+).'
1616
tok_embeddings_key = 'transformer.wte.weight'
1717
norm_weight_key = 'transformer.ln_f.weight'
1818
output_weight_key = 'lm_head.weight'
@@ -124,28 +124,28 @@ def moe_ffn_expert(self, e=None, i=None, kind=None):
124124
return self.filter(r'experts')
125125
result = []
126126
for key in ['gate', 'down', 'up']:
127-
name = f'model.layers.{i}.mlp.experts.{e}.{key}_proj.{kind}'
127+
name = f'{self.attn_layer_prefix}.{i}.mlp.experts.{e}.{key}_proj.{kind}'
128128
tensor = self.params.get(name)
129129
tensor = self.transform(tensor, kind)
130130
result.append(tensor)
131131
return (*result, )
132132

133133
def moe_ffn_gate(self, i):
134-
return self.transform(self.params.get(f'model.layers.{i}.mlp.gate.weight'), 'weight')
134+
return self.transform(self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.gate.weight'), 'weight')
135135

136136
def _ffn(self, i: int, kind: str):
137137
"""Get ffn kind for layer i."""
138138
if not kind:
139139
return self.filter(self.ffn_pattern)
140140
result = []
141141
for key in ['gate', 'down', 'up']:
142-
tensor = self.params[f'model.layers.{i}.mlp.shared_expert.{key}_proj.{kind}']
142+
tensor = self.params[f'{self.attn_layer_prefix}.{i}.mlp.shared_expert.{key}_proj.{kind}']
143143
tensor = self.transform(tensor, kind)
144144
result.append(tensor)
145145
return (*result, )
146146

147147
def moe_ffn_shared_gate(self, i):
148-
return self.params.get(f'model.layers.{i}.mlp.shared_expert_gate.weight')
148+
return self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.shared_expert_gate.weight')
149149

150150

151151
@INPUT_MODELS.register_module(name='qwen2-moe')

lmdeploy/turbomind/supported_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,12 @@
4242
InternVLChatModel='internvl',
4343
# internvl3
4444
InternVLForConditionalGeneration='internvl',
45+
InternS1ForConditionalGeneration='internvl',
4546
# deepseek-vl
4647
MultiModalityCausalLM='deepseekvl',
4748
DeepseekV2ForCausalLM='deepseek2',
4849
# MiniCPMV
4950
MiniCPMV='minicpmv',
50-
# mini gemini
51-
MGMLlamaForCausalLM='llama',
5251
# chatglm2/3, glm4
5352
ChatGLMModel='glm4',
5453
ChatGLMForConditionalGeneration='glm4',

0 commit comments

Comments
 (0)