Skip to content

Commit a8dd61f

Browse files
authored
[bugfix] fix mulitmodal cached_dataset (#5671)
1 parent 8f22319 commit a8dd61f

File tree

7 files changed

+88
-12
lines changed

7 files changed

+88
-12
lines changed

examples/export/cached_dataset/mcore.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Note: cached_dataset does not support CP temporarily.
12
swift export \
23
--model Qwen/Qwen3-30B-A3B-Base \
34
--dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT' \
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
OMP_NUM_THREADS=14 \
2+
MAX_PIXELS=1003520 \
3+
VIDEO_MAX_PIXELS=50176 \
4+
FPS_MAX_FRAMES=12 \
5+
swift export \
6+
--model Qwen/Qwen2.5-Omni-7B \
7+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#10000' \
8+
'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \
9+
'speech_asr/speech_asr_aishell1_trainsets:validation#5000' \
10+
--max_length 4096 \
11+
--split_dataset_ratio 0.01 \
12+
--dataset_num_proc 16 \
13+
--to_cached_dataset true \
14+
--lazy_tokenize false \
15+
--output_dir ./qwen2_5_omni_cached_dataset
16+
17+
# 4 * 70GiB
18+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
19+
MAX_PIXELS=1003520 \
20+
VIDEO_MAX_PIXELS=50176 \
21+
FPS_MAX_FRAMES=12 \
22+
NPROC_PER_NODE=4 \
23+
ENABLE_AUDIO_OUTPUT=0 \
24+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
25+
swift sft \
26+
--model Qwen/Qwen2.5-Omni-7B \
27+
--train_type full \
28+
--cached_dataset './qwen2_5_omni_cached_dataset' \
29+
--num_train_epochs 1 \
30+
--split_dataset_ratio 0.01 \
31+
--torch_dtype bfloat16 \
32+
--per_device_train_batch_size 1 \
33+
--per_device_eval_batch_size 1 \
34+
--learning_rate 1e-5 \
35+
--gradient_accumulation_steps 1 \
36+
--packing true \
37+
--freeze_llm false \
38+
--freeze_vit true \
39+
--freeze_aligner true \
40+
--eval_steps 200 \
41+
--save_steps 200 \
42+
--logging_steps 5 \
43+
--max_length 4096 \
44+
--warmup_ratio 0.05 \
45+
--dataloader_num_workers 8 \
46+
--dataset_num_proc 8 \
47+
--save_total_limit 2 \
48+
--save_only_model true \
49+
--output_dir output/Qwen2.5-Omni-7B \
50+
--deepspeed zero2 \
51+
--use_liger_kernel true \
52+
--attn_impl flash_attn
53+
54+
# Use the validation set
55+
CUDA_VISIBLE_DEVICES=0 \
56+
MAX_PIXELS=1003520 \
57+
VIDEO_MAX_PIXELS=50176 \
58+
FPS_MAX_FRAMES=12 \
59+
ENABLE_AUDIO_OUTPUT=0 \
60+
swift infer \
61+
--model output/Qwen2.5-Omni-7B/vx-xxx/checkpoint-xxx \
62+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#10000' \
63+
'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \
64+
'speech_asr/speech_asr_aishell1_trainsets:validation#5000' \
65+
--max_length 4096 \
66+
--split_dataset_ratio 0.01 \
67+
--attn_impl flash_attn \
68+
--stream true \
69+
--temperature 0 \
70+
--max_new_tokens 512

examples/train/multimodal/omni/sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
# A demo for four modalities that can be run directly
33
nproc_per_node=4
44

5+
# If using zero3, please set `ENABLE_AUDIO_OUTPUT=0`.
56
CUDA_VISIBLE_DEVICES=0,1,2,3 \
67
ENABLE_AUDIO_OUTPUT=1 \
78
NPROC_PER_NODE=$nproc_per_node \
89
VIDEO_MAX_PIXELS=50176 \
910
FPS_MAX_FRAMES=12 \
1011
MAX_PIXELS=1003520 \
11-
ENABLE_AUDIO_OUTPUT=0 \
1212
swift sft \
1313
--model Qwen/Qwen2.5-Omni-7B \
1414
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#2000' \

swift/llm/argument/base_args/data_args.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ class DataArguments:
2727
custom_dataset_info (Optional[str]): Path to custom dataset_info.json file. Default is None.
2828
"""
2929
# dataset_id or dataset_dir or dataset_path
30-
dataset: List[str] = field(
31-
default_factory=list, metadata={'help': f'dataset choices: {list(DATASET_MAPPING.keys())}'})
32-
val_dataset: List[str] = field(
33-
default_factory=list, metadata={'help': f'dataset choices: {list(DATASET_MAPPING.keys())}'})
30+
dataset: List[str] = field(default_factory=list)
31+
val_dataset: List[str] = field(default_factory=list)
3432
split_dataset_ratio: float = 0.
3533

3634
data_seed: int = 42

swift/llm/argument/export_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def __post_init__(self):
119119
if self.quant_method in {'gptq', 'awq'} and len(self.dataset) == 0:
120120
raise ValueError(f'self.dataset: {self.dataset}, Please input the quant dataset.')
121121
if self.to_cached_dataset:
122+
self.lazy_tokenize = False
122123
if self.packing:
123124
raise ValueError('Packing will be handled during training; here we only perform tokenization '
124125
'in advance, so you do not need to set up packing separately.')
125-
assert not self.streaming and not self.lazy_tokenize, 'not supported'
126+
assert not self.streaming, 'not supported'

swift/llm/export/cached_dataset.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import os
33
from typing import List, Optional, Union
44

5-
from swift.llm import ExportArguments
5+
import torch
6+
7+
from swift.llm import TEMPLATE_MAPPING, ExportArguments
68
from swift.llm.train import SwiftSft
79
from swift.utils import get_logger
810

@@ -16,10 +18,14 @@ class ExportCachedDataset(SwiftSft):
1618
def __init__(self, args: Optional[Union[List[str], ExportArguments]] = None) -> None:
1719
super(SwiftSft, self).__init__(args)
1820
self.train_msg = {} # dummy
19-
self.processor = None
21+
template_cls = TEMPLATE_MAPPING[args.template].template_cls
22+
if template_cls and template_cls.use_model:
23+
kwargs = {'return_dummy_model': True}
24+
else:
25+
kwargs = {'load_model': False}
26+
with torch.device('meta'):
27+
self._prepare_model_tokenizer(**kwargs)
2028
self._prepare_template()
21-
self._prepare_model_tokenizer(load_model=self.template.use_model)
22-
self.template.init_processor(self.processor)
2329

2430
def main(self):
2531
train_dataset, val_dataset = self._get_dataset()

swift/llm/train/sft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ def _prepare_generation_config(self):
4242
args.get_request_config(), self.tokenizer)
4343
logger.info(f'model.generation_config: {self.model.generation_config}')
4444

45-
def _prepare_model_tokenizer(self, load_model=True):
45+
def _prepare_model_tokenizer(self, **kwargs):
4646
args = self.args
4747
if args.sequence_parallel_size > 1:
4848
from swift.trainers.sequence_parallel import sequence_parallel
4949
sequence_parallel.init_sequence_parallel(args.sequence_parallel_size)
50-
self.model, self.processor = args.get_model_processor(load_model=load_model)
50+
self.model, self.processor = args.get_model_processor(**kwargs)
5151
if self.model is None:
5252
return
5353
if hasattr(self.model, 'hf_device_map'):

0 commit comments

Comments
 (0)