Skip to content

Commit 69a5161

Browse files
authored
[model] support hunyuan_ocr (#7038)
* support hunyuanOCR * fix bug * fix typos * fix position_ids for multi images
1 parent e4e5807 commit 69a5161

File tree

10 files changed

+160
-4
lines changed

10 files changed

+160
-4
lines changed

docs/source/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,7 @@
10531053
|[mistralai/Ministral-3-14B-Reasoning-2512](https://modelscope.cn/models/mistralai/Ministral-3-14B-Reasoning-2512)|mistral_2512_thinking|mistral_2512_thinking|transformers>=5.0.0.dev0, mistral-common>=1.8.6|✘|vision|[mistralai/Ministral-3-14B-Reasoning-2512](https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512)|
10541054
|[PaddlePaddle/PaddleOCR-VL](https://modelscope.cn/models/PaddlePaddle/PaddleOCR-VL)|paddle_ocr|paddle_ocr|-|✘|-|[PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)|
10551055
|[JinaAI/jina-reranker-m0](https://modelscope.cn/models/JinaAI/jina-reranker-m0)|jina_reranker_m0|jina_reranker_m0|-|✘|reranker, vision|[JinaAI/jina-reranker-m0](https://huggingface.co/JinaAI/jina-reranker-m0)|
1056+
|[Tencent-Hunyuan/HunyuanOCR](https://modelscope.cn/models/Tencent-Hunyuan/HunyuanOCR)|hunyuan_ocr|hunyuan_ocr|transformers>=4.49.0|✘|vision|[tencent/HunyuanOCR](https://huggingface.co/tencent/HunyuanOCR)|
10561057

10571058

10581059
## 数据集

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,7 @@ The table below introduces the models integrated with ms-swift:
10531053
|[mistralai/Ministral-3-14B-Reasoning-2512](https://modelscope.cn/models/mistralai/Ministral-3-14B-Reasoning-2512)|mistral_2512_thinking|mistral_2512_thinking|transformers>=5.0.0.dev0, mistral-common>=1.8.6|✘|vision|[mistralai/Ministral-3-14B-Reasoning-2512](https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512)|
10541054
|[PaddlePaddle/PaddleOCR-VL](https://modelscope.cn/models/PaddlePaddle/PaddleOCR-VL)|paddle_ocr|paddle_ocr|-|✘|-|[PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)|
10551055
|[JinaAI/jina-reranker-m0](https://modelscope.cn/models/JinaAI/jina-reranker-m0)|jina_reranker_m0|jina_reranker_m0|-|✘|reranker, vision|[JinaAI/jina-reranker-m0](https://huggingface.co/JinaAI/jina-reranker-m0)|
1056+
|[Tencent-Hunyuan/HunyuanOCR](https://modelscope.cn/models/Tencent-Hunyuan/HunyuanOCR)|hunyuan_ocr|hunyuan_ocr|transformers>=4.49.0|✘|vision|[tencent/HunyuanOCR](https://huggingface.co/tencent/HunyuanOCR)|
10561057

10571058

10581059
## Datasets

swift/llm/model/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ class MLLMModelType:
279279
mistral_2512 = 'mistral_2512'
280280
mistral_2512_thinking = 'mistral_2512_thinking'
281281
paddle_ocr = 'paddle_ocr'
282+
hunyuan_ocr = 'hunyuan_ocr'
282283

283284

284285
class RerankerModelType:

swift/llm/model/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from . import (baai, baichuan, baidu, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba,
22
microsoft, minicpm, minimax, mistral, mllm, moonshot, mplug, openbuddy, qwen, seed, skywork, stepfun,
3-
telechat, valley, yi)
3+
telechat, tencent, valley, yi)

swift/llm/model/model/tencent.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Any, Dict
2+
3+
from swift.llm import TemplateType
4+
from ..constant import MLLMModelType
5+
from ..model_arch import ModelArch
6+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, register_model
7+
from ..utils import ModelInfo
8+
9+
10+
def get_model_tokenizer_hunyuan_vl(model_dir: str,
11+
model_info: ModelInfo,
12+
model_kwargs: Dict[str, Any],
13+
load_model: bool = True,
14+
**kwargs):
15+
from transformers import HunYuanVLForConditionalGeneration
16+
kwargs['automodel_class'] = kwargs['automodel_class'] or HunYuanVLForConditionalGeneration
17+
kwargs['attn_impl'] = kwargs['attn_impl'] or 'eager'
18+
model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
19+
return model, processor
20+
21+
22+
register_model(
23+
ModelMeta(
24+
MLLMModelType.hunyuan_ocr,
25+
[
26+
ModelGroup([
27+
Model('Tencent-Hunyuan/HunyuanOCR', 'tencent/HunyuanOCR'),
28+
]),
29+
],
30+
TemplateType.hunyuan_ocr,
31+
get_model_tokenizer_hunyuan_vl,
32+
architectures=['HunYuanVLForConditionalGeneration'],
33+
model_arch=ModelArch.hunyuan_vl,
34+
requires=['transformers>=4.49.0'],
35+
))

swift/llm/model/model_arch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class MLLMModelArch:
8585

8686
midashenglm = 'midashenglm'
8787
step_audio2_mini = 'step_audio2_mini'
88+
hunyuan_vl = 'hunyuan_vl'
8889

8990

9091
class ModelArch(LLMModelArch, MLLMModelArch):
@@ -722,6 +723,14 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
722723
vision_tower='model.visual',
723724
))
724725

726+
register_model_arch(
727+
MultiModelKeys(
728+
MLLMModelArch.hunyuan_vl,
729+
language_model='model',
730+
aligner='vit.perceive',
731+
vision_tower='vit',
732+
))
733+
725734

726735
def get_model_arch(arch_name: Optional[str]) -> Optional[MultiModelKeys]:
727736
return MODEL_ARCH_MAPPING.get(arch_name)

swift/llm/template/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ class MLLMTemplateType:
233233
mistral_2512 = 'mistral_2512'
234234
mistral_2512_thinking = 'mistral_2512_thinking'
235235
paddle_ocr = 'paddle_ocr'
236+
hunyuan_ocr = 'hunyuan_ocr'
236237

237238

238239
class TemplateType(LLMTemplateType, MLLMTemplateType, RMTemplateType):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from . import (baai, baidu, bert, deepseek, dots, gemma, glm, idefics3, internlm, internvl, kwai, llama, llava, llm,
22
megrez, microsoft, midashenglm, minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral,
3-
qwen, seed, stepfun, valley, yi)
3+
qwen, seed, stepfun, tencent, valley, yi)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any, Dict, List, Literal, Optional
3+
4+
import torch
5+
6+
from ..base import Template
7+
from ..constant import MLLMTemplateType
8+
from ..register import TemplateMeta, register_template
9+
from ..template_inputs import StdTemplateInputs
10+
from ..utils import Context, Prompt, findall
11+
12+
13+
@dataclass
14+
class HunYuanVLTemplateMeta(TemplateMeta):
15+
prefix: Prompt = field(default_factory=lambda: ['<|hy_begin▁of▁sentence|>'])
16+
prompt: Prompt = field(default_factory=lambda: ['{{QUERY}}<|hy_User|>'])
17+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|hy_Assistant|><|hy_begin▁of▁sentence|>'])
18+
suffix: Prompt = field(default_factory=lambda: ['<|hy_Assistant|>'])
19+
system_prefix: Optional[Prompt] = field(
20+
default_factory=lambda: ['<|hy_begin▁of▁sentence|>{{SYSTEM}}<|hy_place▁holder▁no▁3|>'])
21+
22+
23+
class HunYuanVLTemplate(Template):
24+
image_token_id = 120120
25+
image_token = '<|hy_place▁holder▁no▁102|>'
26+
image_placeholder = ['<|hy_place▁holder▁no▁102|>']
27+
28+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
29+
inputs: StdTemplateInputs) -> List[Context]:
30+
assert media_type == 'image'
31+
if self.mode == 'vllm':
32+
return ['<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>']
33+
return [[-100]]
34+
35+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
36+
encoded = super()._encode(inputs)
37+
input_ids = encoded['input_ids']
38+
labels = encoded['labels']
39+
loss_scale = encoded.get('loss_scale', None)
40+
idx_list = findall(input_ids, -100)
41+
processor = self.processor
42+
images = inputs.images
43+
if images:
44+
image_inputs = processor.image_processor(images=images, return_tensors='pt')
45+
image_grid_thw = image_inputs['image_grid_thw']
46+
merge_size = processor.image_processor.merge_size
47+
48+
def _get_new_tokens(i):
49+
grid_h, grid_w = image_grid_thw[i][-2:]
50+
patch_h = grid_h // merge_size
51+
patch_w = grid_w // merge_size
52+
img_tokens: List[int] = [self.image_token_id] * (patch_h * (patch_w + 1) + 2)
53+
return img_tokens
54+
55+
encoded['input_ids'], encoded['labels'], encoded['loss_scale'] = self._extend_tokens(
56+
input_ids, labels, loss_scale, idx_list, _get_new_tokens)
57+
encoded['pixel_values'] = image_inputs['pixel_values']
58+
encoded['image_grid_thw'] = image_grid_thw
59+
60+
input_ids = encoded['input_ids']
61+
position_ids = torch.arange(len(input_ids))
62+
position_ids_w = torch.arange(len(input_ids))
63+
position_ids_h = torch.arange(len(input_ids))
64+
position_ids_t = torch.arange(len(input_ids))
65+
image_tokens_cumsum = [0]
66+
for i in range(len(image_grid_thw)):
67+
grid_h, grid_w = image_grid_thw[i][-2:]
68+
patch_h = grid_h // merge_size
69+
patch_w = grid_w // merge_size
70+
num_image_tokens = patch_h * (patch_w + 1) + 2
71+
image_tokens_cumsum.append(image_tokens_cumsum[-1] + int(num_image_tokens))
72+
image_token_pos_indices = torch.where(torch.tensor(input_ids) == self.image_token_id)
73+
start_pos = image_token_pos_indices[0][image_tokens_cumsum[i]] + 1
74+
replace_num = (patch_w + 1) * patch_h
75+
position_ids_w[start_pos:start_pos + replace_num] = torch.tensor(
76+
list(range(patch_w + 1)) * patch_h, dtype=torch.int64)
77+
patch_h_list = []
78+
for h in range(patch_h):
79+
patch_h_list += [h] * (patch_w + 1)
80+
position_ids_h[start_pos:start_pos + replace_num] = torch.tensor(patch_h_list, dtype=torch.int64)
81+
position_ids_t[start_pos:start_pos + replace_num] = 0
82+
position_ids = torch.stack([position_ids, position_ids_w, position_ids_h, position_ids_t]).unsqueeze(0)
83+
encoded['position_ids'] = position_ids
84+
attention_mask = torch.tensor(input_ids).ne(processor.pad_id)
85+
encoded['attention_mask'] = attention_mask
86+
return encoded
87+
88+
89+
register_template(HunYuanVLTemplateMeta(MLLMTemplateType.hunyuan_ocr, template_cls=HunYuanVLTemplate))

tests/test_align/test_template/test_vision.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,24 @@ def test_mistral_2512_thinking():
11271127
assert response1[:256] == response2[:256]
11281128

11291129

1130+
def test_hunyuan_ocr():
1131+
pt_engine = PtEngine('Tencent-Hunyuan/HunyuanOCR')
1132+
images = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/ocr.png']
1133+
messages = [{
1134+
'role':
1135+
'user',
1136+
'content': ('Extract all information from the main body of the document image '
1137+
'and represent it in markdown format, ignoring headers and footers. '
1138+
'Tables should be expressed in HTML format, formulas in the document '
1139+
'should be represented using LaTeX format, and the parsing should be '
1140+
'organized according to the reading order.')
1141+
}]
1142+
response1 = _infer_model(pt_engine, messages=messages, images=images)
1143+
pt_engine.default_template.template_backend = 'jinja'
1144+
response2 = _infer_model(pt_engine, messages=messages, images=images)
1145+
assert response1 == response2
1146+
1147+
11301148
if __name__ == '__main__':
11311149
from swift.llm import PtEngine, RequestConfig
11321150
from swift.utils import get_logger, seed_everything
@@ -1206,5 +1224,6 @@ def test_mistral_2512_thinking():
12061224
# test_ernie_vl_thinking()
12071225
# test_mistral_2506()
12081226
# test_sensenova_si()
1209-
test_mistral_2512()
1210-
test_mistral_2512_thinking()
1227+
# test_mistral_2512()
1228+
# test_mistral_2512_thinking()
1229+
test_hunyuan_ocr()

0 commit comments

Comments
 (0)