Skip to content

Commit bb5cbe6

Browse files
authored
support iic/DocOwl2 (#2728)
1 parent c4e7bef commit bb5cbe6

File tree

10 files changed

+113
-7
lines changed

10 files changed

+113
-7
lines changed

docs/source/Instruction/支持的模型和数据集.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@
587587
|[iic/mPLUG-Owl3-2B-241014](https://modelscope.cn/models/iic/mPLUG-Owl3-2B-241014)|mplug_owl3|mplug_owl3|transformers>=4.36, icecream, decord|vision, video|[mPLUG/mPLUG-Owl3-2B-241014](https://huggingface.co/mPLUG/mPLUG-Owl3-2B-241014)|
588588
|[iic/mPLUG-Owl3-7B-240728](https://modelscope.cn/models/iic/mPLUG-Owl3-7B-240728)|mplug_owl3|mplug_owl3|transformers>=4.36, icecream, decord|vision, video|[mPLUG/mPLUG-Owl3-7B-240728](https://huggingface.co/mPLUG/mPLUG-Owl3-7B-240728)|
589589
|[iic/mPLUG-Owl3-7B-241101](https://modelscope.cn/models/iic/mPLUG-Owl3-7B-241101)|mplug_owl3_241101|mplug_owl3_241101|transformers>=4.36, icecream|vision, video|[mPLUG/mPLUG-Owl3-7B-241101](https://huggingface.co/mPLUG/mPLUG-Owl3-7B-241101)|
590+
|[iic/DocOwl2](https://modelscope.cn/models/iic/DocOwl2)|doc_owl2|doc_owl2|transformers>=4.36, icecream|vision|[mPLUG/DocOwl2](https://huggingface.co/mPLUG/DocOwl2)|
590591
|[BAAI/Emu3-Gen](https://modelscope.cn/models/BAAI/Emu3-Gen)|emu3_gen|emu3_gen|-|t2i|[BAAI/Emu3-Gen](https://huggingface.co/BAAI/Emu3-Gen)|
591592
|[BAAI/Emu3-Chat](https://modelscope.cn/models/BAAI/Emu3-Chat)|emu3_chat|emu3_chat|transformers>=4.44.0|vision|[BAAI/Emu3-Chat](https://huggingface.co/BAAI/Emu3-Chat)|
592593
|[stepfun-ai/GOT-OCR2_0](https://modelscope.cn/models/stepfun-ai/GOT-OCR2_0)|got_ocr2|got_ocr2|-|vision|[stepfun-ai/GOT-OCR2_0](https://huggingface.co/stepfun-ai/GOT-OCR2_0)|

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ The table below introduces the models integrated with ms-swift:
587587
|[iic/mPLUG-Owl3-2B-241014](https://modelscope.cn/models/iic/mPLUG-Owl3-2B-241014)|mplug_owl3|mplug_owl3|transformers>=4.36, icecream, decord|vision, video|[mPLUG/mPLUG-Owl3-2B-241014](https://huggingface.co/mPLUG/mPLUG-Owl3-2B-241014)|
588588
|[iic/mPLUG-Owl3-7B-240728](https://modelscope.cn/models/iic/mPLUG-Owl3-7B-240728)|mplug_owl3|mplug_owl3|transformers>=4.36, icecream, decord|vision, video|[mPLUG/mPLUG-Owl3-7B-240728](https://huggingface.co/mPLUG/mPLUG-Owl3-7B-240728)|
589589
|[iic/mPLUG-Owl3-7B-241101](https://modelscope.cn/models/iic/mPLUG-Owl3-7B-241101)|mplug_owl3_241101|mplug_owl3_241101|transformers>=4.36, icecream|vision, video|[mPLUG/mPLUG-Owl3-7B-241101](https://huggingface.co/mPLUG/mPLUG-Owl3-7B-241101)|
590+
|[iic/DocOwl2](https://modelscope.cn/models/iic/DocOwl2)|doc_owl2|doc_owl2|transformers>=4.36, icecream|vision|[mPLUG/DocOwl2](https://huggingface.co/mPLUG/DocOwl2)|
590591
|[BAAI/Emu3-Gen](https://modelscope.cn/models/BAAI/Emu3-Gen)|emu3_gen|emu3_gen|-|t2i|[BAAI/Emu3-Gen](https://huggingface.co/BAAI/Emu3-Gen)|
591592
|[BAAI/Emu3-Chat](https://modelscope.cn/models/BAAI/Emu3-Chat)|emu3_chat|emu3_chat|transformers>=4.44.0|vision|[BAAI/Emu3-Chat](https://huggingface.co/BAAI/Emu3-Chat)|
592593
|[stepfun-ai/GOT-OCR2_0](https://modelscope.cn/models/stepfun-ai/GOT-OCR2_0)|got_ocr2|got_ocr2|-|vision|[stepfun-ai/GOT-OCR2_0](https://huggingface.co/stepfun-ai/GOT-OCR2_0)|

swift/llm/model/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ class MLLMModelType:
153153
mplug_owl2_1 = 'mplug_owl2_1'
154154
mplug_owl3 = 'mplug_owl3'
155155
mplug_owl3_241101 = 'mplug_owl3_241101'
156+
doc_owl2 = 'doc_owl2'
156157

157158
emu3_gen = 'emu3_gen'
158159
emu3_chat = 'emu3_chat'

swift/llm/model/model/mplug.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def get_model_tokenizer_mplug_owl3(model_dir: str,
8282
processor = model.init_processor(tokenizer)
8383
if model is not None:
8484
func_list = ['generate', 'forward']
85-
_use_submodel_func(model, 'language_model', func_list)
85+
use_submodel_func(model, 'language_model', func_list)
8686
return model, processor
8787

8888

@@ -115,3 +115,28 @@ def get_model_tokenizer_mplug_owl3(model_dir: str,
115115
model_arch=ModelArch.mplug_owl3,
116116
requires=['transformers>=4.36', 'icecream'],
117117
tags=['vision', 'video']))
118+
119+
120+
def get_model_tokenizer_doc_owl2(model_dir: str,
121+
model_info: ModelInfo,
122+
model_kwargs: Dict[str, Any],
123+
load_model: bool = True,
124+
**kwargs):
125+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
126+
processor = model.init_processor(tokenizer, basic_image_size=504, crop_anchors='grid_12')
127+
return model, processor
128+
129+
130+
register_model(
131+
ModelMeta(
132+
MLLMModelType.doc_owl2, [
133+
ModelGroup([
134+
Model('iic/DocOwl2', 'mPLUG/DocOwl2'),
135+
]),
136+
],
137+
TemplateType.doc_owl2,
138+
get_model_tokenizer_doc_owl2,
139+
architectures=['mPLUGDocOwl2'],
140+
model_arch=ModelArch.doc_owl2,
141+
requires=['transformers>=4.36', 'icecream'],
142+
tags=['vision']))

swift/llm/model/model_arch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class MLLMModelArch:
4949
mplug_owl2 = 'mplug_owl2'
5050
mplug_owl2_1 = 'mplug_owl2_1'
5151
mplug_owl3 = 'mplug_owl3'
52+
doc_owl2 = 'doc_owl2'
5253

5354
phi3v = 'phi3v'
5455
florence = 'florence'
@@ -354,6 +355,14 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
354355
vision_tower='vision_model',
355356
))
356357

358+
register_model_arch(
359+
MultiModelKeys(
360+
MLLMModelArch.doc_owl2,
361+
language_model='model.layers',
362+
aligner=['model.vision2text', 'model.hr_compressor'],
363+
vision_tower='model.vision_model',
364+
))
365+
357366
register_model_arch(
358367
MultiModelKeys(
359368
MLLMModelArch.deepseek_vl,

swift/llm/template/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,12 @@ def _pad_sequence(self, sequences: List[torch.Tensor], padding_value: float = 0.
974974
return torch.stack(padded_sequences)
975975

976976
def safe_decode(self, input_ids: List[int], **tokenizer_kwargs) -> str:
977-
placeholder_tokens = self.template_meta.placeholder_tokens
977+
if isinstance(self, Template):
978+
tokenizer = self.tokenizer
979+
placeholder_tokens = self.template_meta.placeholder_tokens
980+
else:
981+
tokenizer = self
982+
placeholder_tokens = []
978983

979984
def _is_special(token: int) -> bool:
980985
if isinstance(token, float) or token < 0:
@@ -995,12 +1000,12 @@ def _is_special(token: int) -> bool:
9951000
continue
9961001
if _is_special(input_ids[i]) and not _is_special(input_ids[i - 1]):
9971002
s = i
998-
result_str += self.tokenizer.decode(input_ids[e:s], **tokenizer_kwargs)
1003+
result_str += tokenizer.decode(input_ids[e:s], **tokenizer_kwargs)
9991004
if not _is_special(input_ids[i]) and _is_special(input_ids[i - 1]):
10001005
e = i
10011006
result_str += f'[{input_ids[i - 1]} * {e - s}]'
10021007
if _is_special(input_ids[i]):
10031008
result_str += f'[{input_ids[i]} * {len(input_ids) - s}]'
10041009
else:
1005-
result_str += self.tokenizer.decode(input_ids[e:], **tokenizer_kwargs)
1010+
result_str += tokenizer.decode(input_ids[e:], **tokenizer_kwargs)
10061011
return result_str

swift/llm/template/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class MLLMTemplateType:
127127
mplug_owl2 = 'mplug_owl2'
128128
mplug_owl3 = 'mplug_owl3'
129129
mplug_owl3_241101 = 'mplug_owl3_241101'
130+
doc_owl2 = 'doc_owl2'
130131

131132
emu3_chat = 'emu3_chat'
132133
emu3_gen = 'emu3_gen'

swift/llm/template/template/mplug.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,40 @@ class mPlugOwl3TemplateMeta(QwenTemplateMeta):
174174
register_template(mPlugOwl3TemplateMeta(MLLMTemplateType.mplug_owl3, template_cls=mPlugOwl3Template))
175175

176176
register_template(mPlugOwl3TemplateMeta(MLLMTemplateType.mplug_owl3_241101, template_cls=mPlugOwl3_241101Template))
177+
178+
179+
class DocOwl2Template(Template):
180+
181+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
182+
inputs: StdTemplateInputs) -> List[Context]:
183+
if media_type == 'image':
184+
return [f'<img {index + 1}>', [-200]]
185+
186+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
187+
encoded = super()._encode(inputs)
188+
if inputs.images:
189+
image_tensor, patch_positions, _ = self.processor._process_image(inputs.images)
190+
image_tensor = image_tensor.to(self.config.torch_dtype)
191+
encoded.update({'images': image_tensor, 'patch_positions': patch_positions})
192+
return encoded
193+
194+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
195+
keys = ['images', 'patch_positions']
196+
res = self.fetch_inputs(batch, keys)
197+
for key in keys:
198+
val = res.get(key)
199+
if val:
200+
res[key] = torch.concat([v for v in val if v is not None])
201+
res.update(super()._data_collator(batch, padding_to=padding_to))
202+
return res
203+
204+
205+
register_template(
206+
TemplateMeta(
207+
MLLMTemplateType.doc_owl2,
208+
prefix=['<s>'],
209+
prompt=[' USER: {{QUERY}} ASSISTANT:'],
210+
chat_sep=['</s>'],
211+
suffix=['</s>'],
212+
template_cls=DocOwl2Template,
213+
))

tests/llm/test_custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
4242
],
4343
template='custom',
4444
get_function=get_model_tokenizer_with_flash_attn,
45-
ignore_file_pattern=['nemo']))
45+
ignore_patterns=['nemo']))
4646

4747

4848
class TestCustom(unittest.TestCase):

tests/test_align/test_template/test_vision.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,32 @@ def test_molmoe():
255255
"effect that emphasizes the young feline's charm.")
256256

257257

258+
def test_doc_owl2():
259+
pt_engine = PtEngine('iic/DocOwl2', torch_dtype=torch.float16)
260+
response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': '你是谁'}], images=[])
261+
images = [
262+
'https://modelscope.cn/models/iic/DocOwl2/resolve/master/examples/docowl2_page0.png',
263+
'https://modelscope.cn/models/iic/DocOwl2/resolve/master/examples/docowl2_page1.png',
264+
'https://modelscope.cn/models/iic/DocOwl2/resolve/master/examples/docowl2_page2.png',
265+
'https://modelscope.cn/models/iic/DocOwl2/resolve/master/examples/docowl2_page3.png',
266+
'https://modelscope.cn/models/iic/DocOwl2/resolve/master/examples/docowl2_page4.png',
267+
'https://modelscope.cn/models/iic/DocOwl2/resolve/master/examples/docowl2_page5.png',
268+
]
269+
response = _infer_model(
270+
pt_engine,
271+
messages=[{
272+
'role': 'user',
273+
'content': '<image>' * len(images) + 'what is this paper about? provide detailed information.'
274+
}],
275+
images=images)
276+
assert response == (
277+
'This paper is about multimodal Language Models(MLMs) achieving promising OCR-free '
278+
'Document Understanding by performing understanding by the cost of generating thorough sands of visual '
279+
'tokens for a single document image, leading to excessive GPU computation time. The paper also discusses '
280+
'the challenges and limitations of existing multimodal OCR approaches and proposes a new framework for '
281+
'more efficient and accurate OCR-free document understanding.')
282+
283+
258284
if __name__ == '__main__':
259285
from swift.llm import PtEngine, RequestConfig, get_template
260286
from swift.utils import get_logger, seed_everything
@@ -278,14 +304,14 @@ def test_molmoe():
278304
# test_llava_hf()
279305
# test_florence()
280306
# test_glm_edge_v()
281-
#
282307
# test_phi3_vision()
283308
# test_internvl2_5()
284-
test_internvl2_5_mpo()
309+
# test_internvl2_5_mpo()
285310
# test_mplug_owl3()
286311
# test_xcomposer2_5()
287312
# test_megrez_omni()
288313
# test_qvq()
289314
# test_mplug_owl2()
290315
# test_molmo()
291316
# test_molmoe()
317+
test_doc_owl2()

0 commit comments

Comments
 (0)