Skip to content

Commit dd89c52

Browse files
authored
[model] support dots.ocr (#5333)
1 parent 3f37045 commit dd89c52

File tree

9 files changed

+117
-4
lines changed

9 files changed

+117
-4
lines changed

docs/source/Instruction/支持的模型和数据集.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@
848848
|[moonshotai/Kimi-VL-A3B-Thinking](https://modelscope.cn/models/moonshotai/Kimi-VL-A3B-Thinking)|kimi_vl|kimi_vl|transformers<4.49|&#x2718;|-|[moonshotai/Kimi-VL-A3B-Thinking](https://huggingface.co/moonshotai/Kimi-VL-A3B-Thinking)|
849849
|[moonshotai/Kimi-VL-A3B-Thinking-2506](https://modelscope.cn/models/moonshotai/Kimi-VL-A3B-Thinking-2506)|kimi_vl|kimi_vl|transformers<4.49|&#x2718;|-|[moonshotai/Kimi-VL-A3B-Thinking-2506](https://huggingface.co/moonshotai/Kimi-VL-A3B-Thinking-2506)|
850850
|[Kwai-Keye/Keye-VL-8B-Preview](https://modelscope.cn/models/Kwai-Keye/Keye-VL-8B-Preview)|keye_vl|keye_vl|keye_vl_utils|&#x2718;|vision|[Kwai-Keye/Keye-VL-8B-Preview](https://huggingface.co/Kwai-Keye/Keye-VL-8B-Preview)|
851+
|[rednote-hilab/dots.ocr](https://modelscope.cn/models/rednote-hilab/dots.ocr)|dots_ocr|dots_ocr|transformers>=4.51.0|&#x2718;|-|[rednote-hilab/dots.ocr](https://huggingface.co/rednote-hilab/dots.ocr)|
851852
|[LLM-Research/Phi-3-vision-128k-instruct](https://modelscope.cn/models/LLM-Research/Phi-3-vision-128k-instruct)|phi3_vision|phi3_vision|transformers>=4.36|&#x2718;|vision|[microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)|
852853
|[LLM-Research/Phi-3.5-vision-instruct](https://modelscope.cn/models/LLM-Research/Phi-3.5-vision-instruct)|phi3_vision|phi3_vision|transformers>=4.36|&#x2718;|vision|[microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)|
853854
|[LLM-Research/Phi-4-multimodal-instruct](https://modelscope.cn/models/LLM-Research/Phi-4-multimodal-instruct)|phi4_multimodal|phi4_multimodal|transformers>=4.36,<4.49, backoff, soundfile|&#x2718;|vision, audio|[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct)|

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@ The table below introduces the models integrated with ms-swift:
848848
|[moonshotai/Kimi-VL-A3B-Thinking](https://modelscope.cn/models/moonshotai/Kimi-VL-A3B-Thinking)|kimi_vl|kimi_vl|transformers<4.49|&#x2718;|-|[moonshotai/Kimi-VL-A3B-Thinking](https://huggingface.co/moonshotai/Kimi-VL-A3B-Thinking)|
849849
|[moonshotai/Kimi-VL-A3B-Thinking-2506](https://modelscope.cn/models/moonshotai/Kimi-VL-A3B-Thinking-2506)|kimi_vl|kimi_vl|transformers<4.49|&#x2718;|-|[moonshotai/Kimi-VL-A3B-Thinking-2506](https://huggingface.co/moonshotai/Kimi-VL-A3B-Thinking-2506)|
850850
|[Kwai-Keye/Keye-VL-8B-Preview](https://modelscope.cn/models/Kwai-Keye/Keye-VL-8B-Preview)|keye_vl|keye_vl|keye_vl_utils|&#x2718;|vision|[Kwai-Keye/Keye-VL-8B-Preview](https://huggingface.co/Kwai-Keye/Keye-VL-8B-Preview)|
851+
|[rednote-hilab/dots.ocr](https://modelscope.cn/models/rednote-hilab/dots.ocr)|dots_ocr|dots_ocr|transformers>=4.51.0|&#x2718;|-|[rednote-hilab/dots.ocr](https://huggingface.co/rednote-hilab/dots.ocr)|
851852
|[LLM-Research/Phi-3-vision-128k-instruct](https://modelscope.cn/models/LLM-Research/Phi-3-vision-128k-instruct)|phi3_vision|phi3_vision|transformers>=4.36|&#x2718;|vision|[microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)|
852853
|[LLM-Research/Phi-3.5-vision-instruct](https://modelscope.cn/models/LLM-Research/Phi-3.5-vision-instruct)|phi3_vision|phi3_vision|transformers>=4.36|&#x2718;|vision|[microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)|
853854
|[LLM-Research/Phi-4-multimodal-instruct](https://modelscope.cn/models/LLM-Research/Phi-4-multimodal-instruct)|phi4_multimodal|phi4_multimodal|transformers>=4.36,<4.49, backoff, soundfile|&#x2718;|vision, audio|[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct)|

swift/llm/model/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ class MLLMModelType:
229229
step_audio = 'step_audio'
230230
kimi_vl = 'kimi_vl'
231231
keye_vl = 'keye_vl'
232+
dots_ocr = 'dots_ocr'
232233

233234
phi3_vision = 'phi3_vision'
234235
phi4_multimodal = 'phi4_multimodal'

swift/llm/model/model/mllm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,24 @@ def get_model_tokenizer_keye_vl(model_dir: str, *args, **kwargs):
203203
tags=['vision'],
204204
requires=['keye_vl_utils'],
205205
))
206+
207+
208+
def get_model_tokenizer_dots_ocr(model_dir, *args, **kwargs):
209+
model_cls = get_class_from_dynamic_module('modeling_dots_vision.DotsVisionTransformer', model_dir)
210+
model_cls._no_split_modules = ['DotsVisionBlock']
211+
model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
212+
return model, processor
213+
214+
215+
register_model(
216+
ModelMeta(
217+
MLLMModelType.dots_ocr,
218+
[ModelGroup([
219+
Model('rednote-hilab/dots.ocr', 'rednote-hilab/dots.ocr'),
220+
])],
221+
TemplateType.dots_ocr,
222+
get_model_tokenizer_dots_ocr,
223+
model_arch=ModelArch.dots_ocr,
224+
architectures=['DotsOCRForCausalLM'],
225+
requires=['transformers>=4.51.0'],
226+
))

swift/llm/model/model_arch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class MLLMModelArch:
6666
idefics3 = 'idefics3'
6767

6868
got_ocr2 = 'got_ocr2'
69+
dots_ocr = 'dots_ocr'
6970

7071
ovis1_6 = 'ovis1_6'
7172
molmo = 'molmo'
@@ -640,6 +641,11 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
640641
vision_tower='visual',
641642
))
642643

644+
register_model_arch(MultiModelKeys(
645+
MLLMModelArch.dots_ocr,
646+
language_model='model',
647+
))
648+
643649

644650
def get_model_arch(arch_name: Optional[str]) -> Optional[MultiModelKeys]:
645651
return MODEL_ARCH_MAPPING.get(arch_name)

swift/llm/template/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class MLLMTemplateType:
185185
step_audio = 'step_audio'
186186
kimi_vl = 'kimi_vl'
187187
keye_vl = 'keye_vl'
188+
dots_ocr = 'dots_ocr'
188189

189190
idefics3 = 'idefics3'
190191
pixtral = 'pixtral'
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from . import (baidu, bert, deepseek, emu3, gemma, glm, idefics3, internlm, internvl, kwai, llama, llava, llm, megrez,
2-
microsoft, midashenglm, minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen,
3-
stepfun, valley, yi)
1+
from . import (baidu, bert, deepseek, dots, emu3, gemma, glm, idefics3, internlm, internvl, kwai, llama, llava, llm,
2+
megrez, microsoft, midashenglm, minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral,
3+
qwen, stepfun, valley, yi)

swift/llm/template/template/dots.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from typing import Any, Dict, List, Literal
3+
4+
from ..base import Template
5+
from ..constant import MLLMTemplateType
6+
from ..register import register_template
7+
from ..template_inputs import StdTemplateInputs
8+
from ..utils import Context, findall
9+
from .utils import TemplateMeta
10+
11+
12+
class DotsOCRTemplate(Template):
13+
image_token_id = 151665
14+
placeholder_tokens = ['<|imgpad|>']
15+
16+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
17+
inputs: StdTemplateInputs) -> List[Context]:
18+
from qwen_vl_utils import fetch_image
19+
assert media_type == 'image'
20+
inputs.images[index] = fetch_image({'image': inputs.images[index]})
21+
if self.mode == 'lmdeploy':
22+
return ['<|img|>', [-100], '<|endofimg|>']
23+
else:
24+
return ['<|img|><|imgpad|><|endofimg|>']
25+
26+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
27+
encoded = super()._encode(inputs)
28+
processor = self.processor
29+
input_ids = encoded['input_ids']
30+
labels = encoded['labels']
31+
loss_scale = encoded.get('loss_scale', None)
32+
33+
images = inputs.images
34+
media_token = self.image_token_id
35+
media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt', do_resize=False)
36+
media_grid_thw = media_inputs['image_grid_thw']
37+
idx_list = findall(input_ids, media_token)
38+
merge_length = processor.image_processor.merge_size**2
39+
40+
def _get_new_tokens(i):
41+
token_len = (media_grid_thw[i].prod() // merge_length)
42+
return [media_token] * token_len
43+
44+
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list, _get_new_tokens)
45+
encoded.update(media_inputs)
46+
47+
encoded['input_ids'] = input_ids
48+
encoded['labels'] = labels
49+
encoded['loss_scale'] = loss_scale
50+
return encoded
51+
52+
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
53+
res = super()._data_collator_mm_data(batch)
54+
grid_thw = self.concat_tensor(batch, 'image_grid_thw', 0)
55+
if grid_thw is not None:
56+
res['image_grid_thw'] = grid_thw
57+
return res
58+
59+
60+
register_template(
61+
TemplateMeta(
62+
MLLMTemplateType.dots_ocr,
63+
prefix=[''],
64+
prompt=['<|user|>{{QUERY}}<|endofuser|><|assistant|>'],
65+
chat_sep=['<|endofassistant|>'],
66+
suffix=['<|endofassistant|>'],
67+
system_prefix=['<|system|>{{SYSTEM}}<|endofsystem|>\n'],
68+
template_cls=DotsOCRTemplate,
69+
))

tests/test_align/test_template/test_vision.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,18 @@ def test_keye_vl():
601601
'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png',
602602
'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png'
603603
]
604+
pt_engine.default_template.template_backend = 'swift'
605+
response = _infer_model(pt_engine, messages=messages, images=images)
606+
pt_engine.default_template.template_backend = 'jinja'
607+
response2 = _infer_model(pt_engine, messages=messages, images=images)
608+
assert response == response2
609+
610+
611+
def test_dots_ocr():
612+
# https://github.com/modelscope/ms-swift/issues/2122
613+
pt_engine = PtEngine('rednote-hilab/dots.ocr')
614+
messages = [{'role': 'user', 'content': '<image>Extract the text content from this image.'}]
615+
images = ['https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/ocr.png']
604616
response = _infer_model(pt_engine, messages=messages, images=images)
605617
pt_engine.default_template.template_backend = 'jinja'
606618
response2 = _infer_model(pt_engine, messages=messages, images=images)
@@ -629,7 +641,7 @@ def test_keye_vl():
629641
# test_glm4v()
630642
# test_cogagent()
631643
# test_llava_onevision_hf()
632-
test_minicpmv()
644+
# test_minicpmv()
633645
# test_got_ocr()
634646
# test_got_ocr_hf()
635647
# test_paligemma()
@@ -664,3 +676,4 @@ def test_keye_vl():
664676
# test_glm4_1v()
665677
# test_gemma3n()
666678
# test_keye_vl()
679+
test_dots_ocr()

0 commit comments

Comments
 (0)