Skip to content

Commit 65ea69d

Browse files
authored
Fix internvl2 template (#1308)
1 parent a76c206 commit 65ea69d

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

swift/llm/utils/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3648,7 +3648,7 @@ def _new_forward(*args, **kwargs):
36483648
ModelType.internvl2_2b,
36493649
'OpenGVLab/InternVL2-2B',
36503650
LoRATM.internlm2,
3651-
TemplateType.internvl,
3651+
TemplateType.internvl2,
36523652
requires=['transformers>=4.35', 'timm'],
36533653
support_flash_attn=True,
36543654
placeholder_tokens=['<IMG_CONTEXT>'],
@@ -3658,7 +3658,7 @@ def _new_forward(*args, **kwargs):
36583658
ModelType.internvl2_4b,
36593659
'OpenGVLab/InternVL2-4B',
36603660
LoRATM.internlm2,
3661-
TemplateType.internvl,
3661+
TemplateType.internvl2,
36623662
requires=['transformers>=4.35', 'timm'],
36633663
support_flash_attn=True,
36643664
placeholder_tokens=['<IMG_CONTEXT>'],
@@ -3668,7 +3668,7 @@ def _new_forward(*args, **kwargs):
36683668
ModelType.internvl2_8b,
36693669
'OpenGVLab/InternVL2-8B',
36703670
LoRATM.internlm2,
3671-
TemplateType.internvl,
3671+
TemplateType.internvl2,
36723672
requires=['transformers>=4.35', 'timm'],
36733673
support_flash_attn=True,
36743674
placeholder_tokens=['<IMG_CONTEXT>'],
@@ -3678,7 +3678,7 @@ def _new_forward(*args, **kwargs):
36783678
ModelType.internvl2_26b,
36793679
'OpenGVLab/InternVL2-26B',
36803680
LoRATM.internlm2,
3681-
TemplateType.internvl,
3681+
TemplateType.internvl2,
36823682
requires=['transformers>=4.35', 'timm'],
36833683
support_flash_attn=True,
36843684
placeholder_tokens=['<IMG_CONTEXT>'],

swift/llm/utils/template.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class TemplateType:
5757
internlm2 = 'internlm2'
5858
internlm_xcomposer2 = 'internlm-xcomposer2'
5959
internvl = 'internvl'
60+
internvl2 = 'internvl2'
6061
internvl_phi3 = 'internvl-phi3'
6162
florence = 'florence'
6263
yi = 'yi'
@@ -1339,6 +1340,14 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
13391340
return generate_ids[0].tolist()
13401341

13411342

1343+
class Internvl2Template(InternvlTemplate):
1344+
1345+
def __init__(self):
1346+
self.system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
1347+
Template.__init__(self, [], ['<|im_start|>user\n{{QUERY}}<|im_end|><|im_start|>assistant\n'], ['<|im_end|>'],
1348+
['<|im_end|>'], self.system, ['<|im_start|>system\n{{SYSTEM}}<|im_end|>'])
1349+
1350+
13421351
class InternvlPhi3Template(InternvlTemplate):
13431352
system = 'You are an AI assistant whose name is Phi-3.'
13441353

@@ -1365,6 +1374,15 @@ def __init__(self):
13651374
dataloader_num_workers=0,
13661375
dataloader_pin_memory=False)
13671376

1377+
register_template(
1378+
TemplateType.internvl2,
1379+
Internvl2Template(),
1380+
use_model=True,
1381+
lazy_tokenize=True,
1382+
infer_media_type='dialogue',
1383+
dataloader_num_workers=0,
1384+
dataloader_pin_memory=False)
1385+
13681386

13691387
class FlorenceTemplate(Template):
13701388

0 commit comments

Comments
 (0)