Skip to content

Commit 603a655

Browse files
authored
Support Latex-OCR dataset (#1810)
1 parent ebc0a90 commit 603a655

File tree

10 files changed

+72
-18
lines changed

10 files changed

+72
-18
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ You can contact us and communicate with us by adding our group:
5656

5757
## 🎉 News
5858
- 🔥2024.08.22: Support `reft` tuner from [ReFT](https://github.com/stanfordnlp/pyreft) to achieve 15×–65× more parameter-efficient than LoRA, use `--sft_type reft` to begin!
59-
- 2024.08.21: Support for phi3_5-mini-instruct, phi3_5-moe-instruct, and phi3_5-vision-instruct.
59+
- 🔥2024.08.21: Support for phi3_5-mini-instruct, phi3_5-moe-instruct, and phi3_5-vision-instruct. The best practices for fine-tuning Latex OCR using phi3_5-vision-instruct can be found [here](https://github.com/modelscope/ms-swift/issues/1809).
6060
- 2024.08.21: Support for idefics3-8b-llama3, llava-onevision-qwen2-0_5b-ov, llava-onevision-qwen2-7b-ov, and llava-onevision-qwen2-72b-ov.
6161
- 🔥2024.08.20: Support fine-tuning of multimodal large models using DeepSpeed-Zero3.
6262
- 2024.08.20: Supported models: longwriter-glm4-9b, longwriter-llama3_1-8b. Supported dataset: longwriter-6k.

README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站:
5757

5858
## 🎉 新闻
5959
- 🔥2024.08.22: 支持[ReFT](https://github.com/stanfordnlp/pyreft), 该tuner可以以LoRA的1/15~1/65的参数量达到和LoRA匹配或更好的效果, 使用`--sft_type reft`开始训练!
60-
- 2024.08.21: 支持phi3_5-mini-instruct, phi3_5-moe-instruct, phi3_5-vision-instruct.
60+
- 🔥2024.08.21: 支持phi3_5-mini-instruct, phi3_5-moe-instruct, phi3_5-vision-instruct. 使用phi3_5-vision-instruct进行Latex OCR微调的最佳实践可以查看[这里](https://github.com/modelscope/ms-swift/issues/1809).
6161
- 2024.08.21: 支持idefics3-8b-llama3, llava-onevision-qwen2-0_5b-ov, llava-onevision-qwen2-7b-ov, llava-onevision-qwen2-72b-ov.
6262
- 🔥2024.08.20: 支持使用deepspeed-zero3对多模态大模型进行微调.
6363
- 2024.08.20: 支持模型: longwriter-glm4-9b, longwriter-llama3_1-8b. 支持数据集: longwriter-6k.

docs/source/LLM/支持的模型和数据集.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,9 +510,11 @@
510510
|coco-en-2|[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption/summary)|coco_2014_caption|454617|36.8±2.8, min=32, max=89|chat, multi-modal, vision|-|
511511
|🔥coco-en-2-mini|[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption/summary)|coco_2014_caption|40504|36.8±2.6, min=32, max=75|chat, multi-modal, vision|-|
512512
|capcha-images|[AI-ModelScope/captcha-images](https://modelscope.cn/datasets/AI-ModelScope/captcha-images/summary)||8000|31.0±0.0, min=31, max=31|chat, multi-modal, vision|-|
513+
|latex-ocr-print|[AI-ModelScope/LaTeX_OCR](https://modelscope.cn/datasets/AI-ModelScope/LaTeX_OCR/summary)|full|17918|362.7±34.8, min=294, max=528|chat, ocr, multi-modal, vision|[linxy/LaTeX_OCR](https://huggingface.co/datasets/linxy/LaTeX_OCR)|
514+
|latex-ocr-handwrite|[AI-ModelScope/LaTeX_OCR](https://modelscope.cn/datasets/AI-ModelScope/LaTeX_OCR/summary)|synthetic_handwrite|95424|375.1±59.4, min=292, max=2115|chat, ocr, multi-modal, vision|[linxy/LaTeX_OCR](https://huggingface.co/datasets/linxy/LaTeX_OCR)|
513515
|aishell1-zh|[speech_asr/speech_asr_aishell1_trainsets](https://modelscope.cn/datasets/speech_asr/speech_asr_aishell1_trainsets/summary)||141600|152.2±36.8, min=63, max=419|chat, multi-modal, audio|-|
514516
|🔥aishell1-zh-mini|[speech_asr/speech_asr_aishell1_trainsets](https://modelscope.cn/datasets/speech_asr/speech_asr_aishell1_trainsets/summary)||14526|152.2±35.6, min=74, max=359|chat, multi-modal, audio|-|
515-
|🔥video-chatgpt|[swift/VideoChatGPT](https://modelscope.cn/datasets/swift/VideoChatGPT/summary)|Generic<br>Temporal<br>Consistency|3206|88.4±48.3, min=32, max=399|chat, multi-modal, video|-|
517+
|🔥video-chatgpt|[swift/VideoChatGPT](https://modelscope.cn/datasets/swift/VideoChatGPT/summary)|Generic<br>Temporal<br>Consistency|3206|88.4±48.3, min=32, max=399|chat, multi-modal, video|[lmms-lab/VideoChatGPT](https://huggingface.co/datasets/lmms-lab/VideoChatGPT)|
516518
|hh-rlhf|[AI-ModelScope/hh-rlhf](https://modelscope.cn/datasets/AI-ModelScope/hh-rlhf/summary)|harmless-base<br>helpful-base<br>helpful-online<br>helpful-rejection-sampled|127459|245.4±190.7, min=22, max=1999|rlhf, dpo, pairwise|-|
517519
|🔥hh-rlhf-cn|[AI-ModelScope/hh_rlhf_cn](https://modelscope.cn/datasets/AI-ModelScope/hh_rlhf_cn/summary)|hh_rlhf<br>harmless_base_cn<br>harmless_base_en<br>helpful_base_cn<br>helpful_base_en|355920|171.2±122.7, min=22, max=3078|rlhf, dpo, pairwise|-|
518520
|orpo-dpo-mix-40k|[AI-ModelScope/orpo-dpo-mix-40k](https://modelscope.cn/datasets/AI-ModelScope/orpo-dpo-mix-40k/summary)|default|43666|548.3±397.4, min=28, max=8483|dpo, orpo, en, quality|[mlabonne/orpo-dpo-mix-40k](https://huggingface.co/datasets/mlabonne/orpo-dpo-mix-40k)|

docs/source/Multi-Modal/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
4. [InternVL系列最佳实践](internvl最佳实践.md)
1717
5. [Deepseek-VL最佳实践](deepseek-vl最佳实践.md)
1818
6. [Internlm2-Xcomposers最佳实践](internlm-xcomposer2最佳实践.md)
19-
7. [Phi3-Vision最佳实践](phi3-vision最佳实践.md)
19+
7. [Phi3-Vision最佳实践](phi3-vision最佳实践.md), [Phi3.5-Vision最佳实践](https://github.com/modelscope/ms-swift/issues/1809).
2020

2121

2222
一轮对话只能包含一张图片(可能可以不含图片):

docs/source_en/LLM/Supported-models-datasets.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,9 +510,11 @@ The table below introduces the datasets supported by SWIFT:
510510
|coco-en-2|[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption/summary)|coco_2014_caption|454617|36.8±2.8, min=32, max=89|chat, multi-modal, vision|-|
511511
|🔥coco-en-2-mini|[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption/summary)|coco_2014_caption|40504|36.8±2.6, min=32, max=75|chat, multi-modal, vision|-|
512512
|capcha-images|[AI-ModelScope/captcha-images](https://modelscope.cn/datasets/AI-ModelScope/captcha-images/summary)||8000|31.0±0.0, min=31, max=31|chat, multi-modal, vision|-|
513+
|latex-ocr-print|[AI-ModelScope/LaTeX_OCR](https://modelscope.cn/datasets/AI-ModelScope/LaTeX_OCR/summary)|full|17918|362.7±34.8, min=294, max=528|chat, ocr, multi-modal, vision|[linxy/LaTeX_OCR](https://huggingface.co/datasets/linxy/LaTeX_OCR)|
514+
|latex-ocr-handwrite|[AI-ModelScope/LaTeX_OCR](https://modelscope.cn/datasets/AI-ModelScope/LaTeX_OCR/summary)|synthetic_handwrite|95424|375.1±59.4, min=292, max=2115|chat, ocr, multi-modal, vision|[linxy/LaTeX_OCR](https://huggingface.co/datasets/linxy/LaTeX_OCR)|
513515
|aishell1-zh|[speech_asr/speech_asr_aishell1_trainsets](https://modelscope.cn/datasets/speech_asr/speech_asr_aishell1_trainsets/summary)||141600|152.2±36.8, min=63, max=419|chat, multi-modal, audio|-|
514516
|🔥aishell1-zh-mini|[speech_asr/speech_asr_aishell1_trainsets](https://modelscope.cn/datasets/speech_asr/speech_asr_aishell1_trainsets/summary)||14526|152.2±35.6, min=74, max=359|chat, multi-modal, audio|-|
515-
|🔥video-chatgpt|[swift/VideoChatGPT](https://modelscope.cn/datasets/swift/VideoChatGPT/summary)|Generic<br>Temporal<br>Consistency|3206|88.4±48.3, min=32, max=399|chat, multi-modal, video|-|
517+
|🔥video-chatgpt|[swift/VideoChatGPT](https://modelscope.cn/datasets/swift/VideoChatGPT/summary)|Generic<br>Temporal<br>Consistency|3206|88.4±48.3, min=32, max=399|chat, multi-modal, video|[lmms-lab/VideoChatGPT](https://huggingface.co/datasets/lmms-lab/VideoChatGPT)|
516518
|hh-rlhf|[AI-ModelScope/hh-rlhf](https://modelscope.cn/datasets/AI-ModelScope/hh-rlhf/summary)|harmless-base<br>helpful-base<br>helpful-online<br>helpful-rejection-sampled|127459|245.4±190.7, min=22, max=1999|rlhf, dpo, pairwise|-|
517519
|🔥hh-rlhf-cn|[AI-ModelScope/hh_rlhf_cn](https://modelscope.cn/datasets/AI-ModelScope/hh_rlhf_cn/summary)|hh_rlhf<br>harmless_base_cn<br>harmless_base_en<br>helpful_base_cn<br>helpful_base_en|355920|171.2±122.7, min=22, max=3078|rlhf, dpo, pairwise|-|
518520
|orpo-dpo-mix-40k|[AI-ModelScope/orpo-dpo-mix-40k](https://modelscope.cn/datasets/AI-ModelScope/orpo-dpo-mix-40k/summary)|default|43666|548.3±397.4, min=28, max=8483|dpo, orpo, en, quality|[mlabonne/orpo-dpo-mix-40k](https://huggingface.co/datasets/mlabonne/orpo-dpo-mix-40k)|

docs/source_en/Multi-Modal/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ A single round of dialogue can contain multiple images (or no images):
1616
4. [InternVL Series Best Practice](internvl-best-practice.md)
1717
5. [Deepseek-VL Best Practice](deepseek-vl-best-practice.md)
1818
6. [Internlm2-Xcomposers Best Practice](internlm-xcomposer2-best-practice.md)
19-
7. [Phi3-Vision Best Practice](phi3-vision-best-practice.md)
19+
7. [Phi3-Vision Best Practice](phi3-vision-best-practice.md), [Phi3.5-Vision Best Practice](https://github.com/modelscope/ms-swift/issues/1809).
2020

2121

2222
A single round of dialogue can only contain one image:

swift/llm/utils/client_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def _from_base64(img_base64: Union[str, 'PIL.Image.Image'], tmp_dir: str = 'tmp'
100100
sha256_hash = hashlib.sha256(img_base64.encode('utf-8')).hexdigest()
101101
img_path = os.path.join(tmp_dir, f'{sha256_hash}.png')
102102
image = Image.open(BytesIO(base64.b64decode(img_base64)))
103-
image.save(img_path)
103+
if not os.path.exists(img_path):
104+
image.save(img_path)
104105
return img_path
105106

106107

swift/llm/utils/dataset.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ class DatasetName:
156156
coco_en_2 = 'coco-en-2'
157157
coco_en_2_mini = 'coco-en-2-mini'
158158
capcha_images = 'capcha-images'
159+
latex_ocr_print = 'latex-ocr-print'
160+
latex_ocr_handwrite = 'latex-ocr-handwrite'
159161
# for qwen-audio
160162
aishell1_zh = 'aishell1-zh'
161163
aishell1_zh_mini = 'aishell1-zh-mini'
@@ -747,7 +749,10 @@ def _process(d):
747749
response = d[response_key]
748750
return {'query': query * len(response), 'response': response, 'images': images}
749751

750-
return dataset.map(_process)
752+
kwargs = {}
753+
if not isinstance(dataset, HfIterableDataset):
754+
kwargs['load_from_cache_file'] = dataset_enable_cache
755+
return dataset.map(_process, **kwargs)
751756

752757

753758
register_dataset(
@@ -861,6 +866,7 @@ def _process(d):
861866
_preprocess_video_chatgpt,
862867
get_dataset_from_repo,
863868
split=['test'],
869+
hf_dataset_id='lmms-lab/VideoChatGPT',
864870
tags=['chat', 'multi-modal', 'video', '🔥'])
865871

866872

@@ -1784,7 +1790,7 @@ def preprocess_row(row):
17841790
query = row['question']
17851791
response = row['choices'][row['answer']]
17861792
solution = row['solution']
1787-
return {'query': query, 'response': f'{solution}\nSo the final answer is:{response}'}
1793+
return {'query': query, 'response': f'{solution}\nSo the final answer is: {response}'}
17881794

17891795
kwargs = {}
17901796
if not isinstance(dataset, HfIterableDataset):
@@ -2028,16 +2034,48 @@ def preprocess(row):
20282034
tags=['chat', 'general', 'multi-round'])
20292035

20302036

2037+
def _preprocess_latex_ocr_dataset(dataset: DATASET_TYPE) -> DATASET_TYPE:
2038+
from datasets import Image
2039+
prompt = 'Using LaTeX to perform OCR on the image.'
2040+
2041+
def _process(d):
2042+
return {'query': prompt, 'response': d['text']}
2043+
2044+
kwargs = {}
2045+
if not isinstance(dataset, HfIterableDataset):
2046+
kwargs['load_from_cache_file'] = dataset_enable_cache
2047+
return dataset.map(_process, **kwargs).rename_column('image', 'images')
2048+
2049+
2050+
register_dataset(
2051+
DatasetName.latex_ocr_print,
2052+
'AI-ModelScope/LaTeX_OCR',
2053+
['full'],
2054+
_preprocess_latex_ocr_dataset,
2055+
get_dataset_from_repo,
2056+
split=['validation', 'test'], # There are some problems in the training dataset.
2057+
hf_dataset_id='linxy/LaTeX_OCR',
2058+
tags=['chat', 'ocr', 'multi-modal', 'vision'])
2059+
2060+
register_dataset(
2061+
DatasetName.latex_ocr_handwrite,
2062+
'AI-ModelScope/LaTeX_OCR', ['synthetic_handwrite'],
2063+
_preprocess_latex_ocr_dataset,
2064+
get_dataset_from_repo,
2065+
split=['train', 'validation', 'test'],
2066+
hf_dataset_id='linxy/LaTeX_OCR',
2067+
tags=['chat', 'ocr', 'multi-modal', 'vision'])
2068+
2069+
20312070
def _preprocess_capcha_images(dataset: DATASET_TYPE) -> DATASET_TYPE:
20322071
from datasets import Image
20332072
query = 'recognize the content.'
2034-
image_key = 'image'
20352073
response_key = 'solution'
20362074

20372075
def _process(d):
2038-
return {'query': query * len(d[response_key]), 'response': d[response_key], 'images': [d[image_key]]}
2076+
return {'query': query * len(d[response_key]), 'response': d[response_key]}
20392077

2040-
return dataset.map(_process).cast_column('image', Image(decode=True))
2078+
return dataset.map(_process).rename_column('image', 'images')
20412079

20422080

20432081
register_dataset(

swift/llm/utils/model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4287,7 +4287,7 @@ def _get_new_func(func_name: str):
42874287

42884288
@wraps(_old_func)
42894289
def _new_func(self, *args, **kwargs):
4290-
res = _old_func(self, *args, **kwargs)
4290+
res = _old_func(getattr(self, submodel_name), *args, **kwargs)
42914291
if func_name == 'forward':
42924292
device = find_device(args)
42934293
if device is None:
@@ -4298,12 +4298,9 @@ def _new_func(self, *args, **kwargs):
42984298
return _new_func
42994299

43004300
for key in func_list:
4301-
value = MethodType(_get_new_func(key), submodel)
4302-
setattr(model, key, value)
4301+
setattr(model, key, MethodType(_get_new_func(key), model))
43034302
if key == 'generate' and model.device != submodel.device:
43044303
submodel.__class__.device = model.device
4305-
if key == 'forward' and 'generate' in func_list:
4306-
setattr(submodel, key, value)
43074304

43084305

43094306
@register_model(

tests/custom/test_main.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_pt():
2020

2121

2222
def test_vlm_sft():
23+
# lora full
2324
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
2425
from swift.llm import sft_main, SftArguments, infer_main, InferArguments
2526
model_type = 'phi3_5-vision-instruct'
@@ -45,9 +46,22 @@ def test_llm_sft():
4546
InferArguments(ckpt_dir=last_model_checkpoint, load_dataset_config=True, merge_lora=True, infer_backend='pt'))
4647

4748

49+
def test_vlm_dpo():
50+
# lora, full, stream
51+
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
52+
from swift.llm import rlhf_main, RLHFArguments, infer_main, InferArguments
53+
model_type = 'internvl2-2b'
54+
dataset = 'rlaif-v#100'
55+
56+
output = rlhf_main(RLHFArguments(model_type=model_type, dataset=dataset, max_length=8192, sft_type='full'))
57+
last_model_checkpoint = output['last_model_checkpoint']
58+
infer_main(InferArguments(ckpt_dir=last_model_checkpoint, load_dataset_config=True))
59+
60+
4861
if __name__ == '__main__':
4962
# test_eval_llm()
5063
# test_eval_vlm()
5164
# test_pt()
52-
test_vlm_sft()
65+
# test_vlm_sft()
5366
# test_llm_sft()
67+
test_vlm_dpo()

0 commit comments

Comments
 (0)