Skip to content

Commit 5abab36

Browse files
authored
Support IterableDataset (#1596)
1 parent aca5a7c commit 5abab36

File tree

16 files changed

+723
-360
lines changed

16 files changed

+723
-360
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
- LLAVA模型: `https://github.com/haotian-liu/LLaVA.git`
2424
- `--sft_type`: 表示微调的方式, 默认是`'lora'`. 你可以选择的值包括: 'lora', 'full', 'longlora', 'adalora', 'ia3', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft'. 如果你要使用qlora, 你需设置`--sft_type lora --quantization_bit 4`.
2525
- `--packing`: pack数据集到`max-length`, 默认值`False`.
26+
- `--streaming`: 是否使用流式数据处理, 默认值`False`.
2627
- `--freeze_parameters`: 当sft_type指定为'full'时, 将模型最底部的参数进行freeze. 指定范围为0. ~ 1., 默认为`0.`. 该参数提供了lora与全参数微调的折中方案.
2728
- `--additional_trainable_parameters`: 作为freeze_parameters的补充, 只有在sft_type指定为'full'才允许被使用, 默认为`[]`. 例如你如果想训练50%的参数的情况下想额外训练embedding层, 你可以设置`--freeze_parameters 0.5 --additional_trainable_parameters transformer.wte`, 所有以`transformer.wte`开头的parameters都会被激活. 你也可以设置`--freeze_parameters 1 --additional_trainable_parameters xxx`来自定义可以训练的层.
2829
- `--tuner_backend`: 表示lora, qlora的后端支持, 默认是`'peft'`. 你可以选择的值包括: 'swift', 'peft', 'unsloth'.

docs/source_en/LLM/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
- LLAVA model: `https://github.com/haotian-liu/LLaVA.git`
2323
- `--sft_type`: Fine-tuning method, default is `'lora'`. Options include: 'lora', 'full', 'longlora', 'adalora', 'ia3', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft'. If using qlora, you need to set `--sft_type lora --quantization_bit 4`.
2424
- `--packing`: pack the dataset length to `max-length`, default `False`.
25+
- `--streaming`: Whether to use iterable dataset, Default `False`.
2526
- `--freeze_parameters`: When sft_type is set to 'full', freeze the bottommost parameters of the model. Range is 0. ~ 1., default is `0.`. This provides a compromise between lora and full fine-tuning.
2627
- `--additional_trainable_parameters`: In addition to freeze_parameters, only allowed when sft_type is 'full', default is `[]`. For example, if you want to train embedding layer in addition to 50% of parameters, you can set `--freeze_parameters 0.5 --additional_trainable_parameters transformer.wte`, all parameters starting with `transformer.wte` will be activated. You can also set `--freeze_parameters 1 --additional_trainable_parameters xxx` to customize the trainable layers.
2728
- `--tuner_backend`: Backend support for lora, qlora, default is `'peft'`. Options include: 'swift', 'peft', 'unsloth'.

swift/llm/rlhf.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
2424
logger.info(f'args: {args}')
2525
seed_everything(args.seed)
2626
training_args = args.training_args
27+
streaming = args.streaming
2728
if is_torch_npu_available():
2829
print(f'device_count: {torch.npu.device_count()}')
2930
else:
@@ -169,6 +170,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
169170
if val_dataset is None:
170171
training_args.evaluation_strategy = IntervalStrategy.NO
171172
training_args.do_eval = False
173+
training_args.eval_strategy = IntervalStrategy.NO
172174

173175
template_kwargs = {}
174176
template_info = TEMPLATE_MAPPING[args.template_type]
@@ -183,10 +185,10 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
183185

184186
template: Template = get_template(
185187
args.template_type, tokenizer, args.system, args.max_length, args.truncation_strategy, model=model)
186-
if not template.support_multi_round and 'history' in train_dataset[0]:
188+
if not template.support_multi_round and 'history' in next(iter(train_dataset)):
187189
logger.info(
188190
'The current template does not support multi-turn dialogue. The chatml template is used by default. \
189-
You can also use the --model_type parameter to specify the template.')
191+
You can also use the --model_type parameter to specify the template.')
190192
template: Template = get_template(
191193
'chatml', tokenizer, args.system, args.max_length, args.truncation_strategy, model=model)
192194
args.system = template.default_system
@@ -206,6 +208,8 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
206208
trainer_kwargs['is_vision'] = args.is_vision
207209
model.config.model_type += '_' # add suffix to avoid checks in hfDPOTrainer
208210

211+
trainer_kwargs['streaming'] = streaming
212+
209213
trainer = trainer_cls(
210214
model=model,
211215
train_dataset=train_dataset,
@@ -227,7 +231,8 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
227231
last_model_checkpoint = getattr(trainer.state, 'last_model_checkpoint', None)
228232
logger.info(f'last_model_checkpoint: {last_model_checkpoint}')
229233
logger.info(f'best_model_checkpoint: {trainer.state.best_model_checkpoint}')
230-
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
234+
if not streaming:
235+
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
231236
# Visualization
232237
if is_master():
233238
if 'tensorboard' in args.training_args.report_to:
@@ -239,15 +244,16 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
239244
trainer.push_to_hub()
240245
run_info = {
241246
'memory': trainer.perf['memory'],
242-
'train_time': train_time,
243247
'last_model_checkpoint': last_model_checkpoint,
244248
'best_model_checkpoint': trainer.state.best_model_checkpoint,
245249
'best_metric': trainer.state.best_metric,
246250
'global_step': trainer.state.global_step,
247251
'log_history': trainer.state.log_history,
248-
'model_info': model_info,
249-
'dataset_info': trainer.dataset_info,
252+
'model_info': model_info
250253
}
254+
if not streaming:
255+
run_info.update({'train_time': train_time})
256+
run_info.update({'dataset_info': trainer.dataset_info})
251257
if is_master():
252258
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
253259
append_to_jsonl(jsonl_path, run_info)

swift/llm/sft.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ def _get_train_val_dataset(args: SftArguments) -> Tuple[HfDataset, Optional[HfDa
3636
args.dataset_seed,
3737
check_dataset_strategy=args.check_dataset_strategy,
3838
model_name=args.model_name,
39-
model_author=args.model_author)
39+
model_author=args.model_author,
40+
streaming=args.streaming,
41+
streaming_val_size=args.streaming_val_size,
42+
streaming_buffer_size=args.streaming_buffer_size)
4043
if len(args.val_dataset) > 0:
4144
# Loading val dataset
4245
_, val_dataset = get_dataset(
@@ -45,7 +48,10 @@ def _get_train_val_dataset(args: SftArguments) -> Tuple[HfDataset, Optional[HfDa
4548
args.dataset_seed,
4649
check_dataset_strategy=args.check_dataset_strategy,
4750
model_name=args.model_name,
48-
model_author=args.model_author)
51+
model_author=args.model_author,
52+
streaming=args.streaming,
53+
streaming_val_size=args.streaming_val_size,
54+
streaming_buffer_size=args.streaming_buffer_size)
4955

5056
train_dataset, val_dataset = args._handle_dataset_compat(train_dataset, val_dataset)
5157
logger.info(f'train_dataset: {train_dataset}')
@@ -111,6 +117,7 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]:
111117
def llm_sft(args: SftArguments) -> Dict[str, Any]:
112118
logger.info(f'args: {args}')
113119
is_generation = TEMPLATE_MAPPING[args.template_type].get('is_generation', False)
120+
streaming = args.streaming
114121
if is_generation and type(args) is SftArguments:
115122
logger.warning(f"Please check if args.template_type: '{args.template_type}' is correct. "
116123
'Currently, SFT is in progress, but the template is used for PT.')
@@ -267,7 +274,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
267274
fsdp_flatten_parameters=False)
268275

269276
train_dataset, val_dataset = _get_train_val_dataset(args)
270-
training_args.train_dataset_sample = train_dataset.shape[0] if train_dataset is not None else 0 # torchacc
277+
if use_torchacc():
278+
training_args.train_dataset_sample = train_dataset.shape[0] if train_dataset is not None else 0
271279
template_kwargs = {}
272280
template_kwargs['use_loss_scale'] = args.use_loss_scale
273281
if args.loss_scale_config_path is not None:
@@ -288,6 +296,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
288296
args.truncation_strategy,
289297
model=model,
290298
**template_kwargs)
299+
if streaming:
300+
template.encode = partial(template.encode, streaming=streaming)
291301
args.system = template.default_system
292302
logger.info(f'system: {args.system}')
293303
logger.info(f'args.lazy_tokenize: {args.lazy_tokenize}')
@@ -307,10 +317,11 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
307317
dataset_info['val_dataset'] = stat_dataset(val_dataset)
308318
elif not args.lazy_tokenize:
309319
dataset_info = {}
310-
logger.info(f'Using num_proc: {args.preprocess_num_proc}')
311-
train_dataset = dataset_map(train_dataset, template.encode, args.preprocess_num_proc)
320+
if not streaming:
321+
logger.info(f'Using num_proc: {args.preprocess_num_proc}')
322+
train_dataset = dataset_map(train_dataset, template.encode, args.preprocess_num_proc, streaming=streaming)
312323
if val_dataset is not None:
313-
val_dataset = dataset_map(val_dataset, template.encode, args.preprocess_num_proc)
324+
val_dataset = dataset_map(val_dataset, template.encode, args.preprocess_num_proc, streaming=streaming)
314325
if args.test_oom_error:
315326
train_dataset = sort_by_max_length(train_dataset, 20000)
316327
# Data analysis
@@ -321,11 +332,11 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
321332
raise AttributeError('Failed to access dataset attributes,train_dataset is None. This might be because:\n'
322333
'(1) The dataset contains None for input or labels;\n'
323334
"(2) The 'max_length' setting is too short causing data truncation.")
324-
td0, tkwargs0 = train_dataset.data[0]
335+
td0, tkwargs0 = train_dataset.data[0] if not streaming else (next(iter(train_dataset)), {})
325336
print_example(td0, tokenizer, tkwargs0)
326-
dataset_info['train_dataset'] = stat_dataset(train_dataset)
337+
dataset_info['train_dataset'] = stat_dataset(train_dataset) if not streaming else None
327338
if val_dataset is not None:
328-
dataset_info['val_dataset'] = stat_dataset(val_dataset)
339+
dataset_info['val_dataset'] = stat_dataset(val_dataset) if not streaming else None
329340
else:
330341
dataset_info = None
331342
td0, tkwargs0 = template.encode(train_dataset[0])
@@ -395,7 +406,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
395406
last_model_checkpoint = getattr(trainer.state, 'last_model_checkpoint', None)
396407
logger.info(f'last_model_checkpoint: {last_model_checkpoint}')
397408
logger.info(f'best_model_checkpoint: {trainer.state.best_model_checkpoint}')
398-
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
409+
if not streaming:
410+
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
399411
# Visualization
400412
if is_master() and not use_torchacc():
401413
if 'tensorboard' in args.training_args.report_to:
@@ -407,7 +419,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
407419
trainer.push_to_hub()
408420
run_info = {
409421
'memory': trainer.perf['memory'],
410-
'train_time': train_time,
411422
'last_model_checkpoint': last_model_checkpoint,
412423
'best_model_checkpoint': trainer.state.best_model_checkpoint,
413424
'best_metric': trainer.state.best_metric,
@@ -416,6 +427,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
416427
'model_info': model_info,
417428
'dataset_info': dataset_info,
418429
}
430+
if not streaming:
431+
run_info.update({'train_time': train_time})
419432
for key in ['gen_time', 'gen_len']:
420433
if trainer.perf[key] != 0:
421434
run_info[key] = trainer.perf[key]

swift/llm/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
get_model_list_client_async, inference_client, inference_client_async)
66
from .dataset import (DATASET_MAPPING, DatasetName, HfDataset, get_dataset, get_dataset_from_repo,
77
load_dataset_from_local, load_ms_dataset, register_dataset, register_dataset_info,
8-
register_local_dataset, sample_dataset)
8+
register_local_dataset, sample_dataset, standard_keys)
99
from .media import MediaCache, MediaTag
1010
from .model import (MODEL_MAPPING, GetModelTokenizerFunction, LoRATM, ModelType, get_additional_saved_files,
1111
get_default_lora_target_modules, get_default_template_type, get_model_tokenizer,

swift/llm/utils/argument.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.distributed as dist
1414
import transformers
1515
from datasets import Dataset as HfDataset
16+
from datasets import IterableDataset as HfIterableDataset
1617
from datasets import concatenate_datasets
1718
from packaging import version
1819
from torch import dtype as Dtype
@@ -34,6 +35,7 @@
3435
from .utils import is_lmdeploy_available, is_quant_model, is_vllm_available
3536

3637
logger = get_logger()
38+
DATASET_TYPE = Union[HfDataset, HfIterableDataset]
3739

3840

3941
def is_adapter(sft_type: str) -> bool:
@@ -374,11 +376,14 @@ def _register_self_cognition(self: Union['SftArguments', 'InferArguments']) -> N
374376
'Representing the model name and model author in Chinese and English.')
375377
setattr(self, k, v)
376378

377-
def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_dataset: Optional[HfDataset],
378-
val_dataset: Optional[HfDataset]) -> Tuple[Optional[HfDataset], Optional[HfDataset]]:
379+
def _handle_dataset_compat(
380+
self: Union['SftArguments', 'InferArguments'], train_dataset: Optional[DATASET_TYPE],
381+
val_dataset: Optional[DATASET_TYPE]) -> Tuple[Optional[DATASET_TYPE], Optional[DATASET_TYPE]]:
379382
# compatibility. (Deprecated)
383+
streaming = getattr(self, 'streaming', False)
380384
random_state = np.random.RandomState(self.dataset_seed)
381385
val_dataset_sample = self.val_dataset_sample
386+
382387
if train_dataset is not None and self.train_dataset_sample >= 0:
383388
train_dataset_sample = min(self.train_dataset_sample, train_dataset.shape[0])
384389
if train_dataset.shape[0] > train_dataset_sample:
@@ -388,10 +393,13 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_
388393
if val_dataset_sample is None:
389394
val_dataset_sample = max(int(train_dataset_sample * self.dataset_test_ratio), 1)
390395
if val_dataset is not None and val_dataset_sample is not None and val_dataset_sample >= 0:
391-
if val_dataset.shape[0] > val_dataset_sample:
396+
if not streaming and val_dataset.shape[0] > val_dataset_sample:
392397
logger.info(f'val_dataset_sample: {val_dataset_sample}')
393398
val_idxs = random_state.permutation(val_dataset_sample)
394399
val_dataset = val_dataset.select(val_idxs)
400+
elif streaming:
401+
val_dataset = val_dataset.shuffle(
402+
seed=self.dataset_seed, buffer_size=self.streaming_buffer_size).take(val_dataset_sample)
395403

396404
if (train_dataset is None or not hasattr(self, 'train_dataset_mix_ratio') or self.train_dataset_mix_ratio <= 0
397405
or len(self.train_dataset_mix_ds) == 0):
@@ -401,7 +409,11 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_
401409
logger.info(f'train_dataset_mix_ds: {self.train_dataset_mix_ds}')
402410
logger.info(f'len(train_dataset): {len(train_dataset)}, mix_dataset_sample: {mix_dataset_sample}')
403411
mixed_dataset = get_dataset(
404-
self.train_dataset_mix_ds, 0.0, random_state, check_dataset_strategy=self.check_dataset_strategy)[0]
412+
self.train_dataset_mix_ds,
413+
0.0,
414+
random_state,
415+
check_dataset_strategy=self.check_dataset_strategy,
416+
streaming=streaming)[0]
405417
if len(mixed_dataset) < mix_dataset_sample:
406418
logger.warn(f'The length of dataset used for mixin: {self.train_dataset_mix_ds} are '
407419
'lesser than the ratio required by the `train_dataset_mix_ratio` '
@@ -590,7 +602,10 @@ class SftArguments(ArgumentsBase):
590602
max_length: int = 2048 # -1: no limit
591603
truncation_strategy: Literal['delete', 'truncation_left'] = 'delete'
592604
check_dataset_strategy: Literal['none', 'discard', 'error', 'warning'] = 'none'
593-
605+
# streaming dataset
606+
streaming: bool = False
607+
streaming_val_size: int = 0
608+
streaming_buffer_size: int = 16384
594609
# Chinese name and English name
595610
model_name: List[str] = field(default_factory=lambda: [None, None], metadata={'help': "e.g. ['小黄', 'Xiao Huang']"})
596611
model_author: List[str] = field(
@@ -1025,7 +1040,8 @@ def __post_init__(self) -> None:
10251040
if self.gradient_accumulation_steps is None:
10261041
self.gradient_accumulation_steps = math.ceil(16 / self.batch_size / self.world_size)
10271042
template_info = TEMPLATE_MAPPING[self.template_type]
1028-
if self.lazy_tokenize is None:
1043+
self._handle_streaming_args()
1044+
if self.lazy_tokenize is None and not self.streaming:
10291045
self.lazy_tokenize = template_info.get('lazy_tokenize', False)
10301046
logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}')
10311047
if self.dataloader_num_workers is None:
@@ -1095,6 +1111,9 @@ def _init_training_args(self) -> None:
10951111
else:
10961112
kwargs['evaluation_strategy'] = self.evaluation_strategy
10971113

1114+
if 'accelerator_config' in parameters:
1115+
kwargs['accelerator_config'] = {'dispatch_batches': False}
1116+
10981117
training_args = Seq2SeqTrainingArguments(
10991118
output_dir=self.output_dir,
11001119
logging_dir=self.logging_dir,
@@ -1181,6 +1200,42 @@ def _handle_pai_compat(self) -> None:
11811200
self.add_output_dir_suffix = False
11821201
logger.info(f'Setting args.add_output_dir_suffix: {self.add_output_dir_suffix}')
11831202

1203+
def _handle_streaming_args(self) -> None:
1204+
if not self.streaming:
1205+
return
1206+
if self.max_steps == -1:
1207+
raise ValueError('Please specify `max_steps` in streaming mode.')
1208+
1209+
if self.packing:
1210+
self.packing = False
1211+
logger.warning('Packing is not supported for streaming dataset, set to False')
1212+
1213+
if self.test_oom_error:
1214+
self.test_oom_error = False
1215+
logger.warning('test_oom_error is not supported for streaming dataset, set to False')
1216+
1217+
if self.lazy_tokenize:
1218+
self.lazy_tokenize = False
1219+
logger.info('lazy_tokenize set to False in streaming dataset')
1220+
1221+
if self.train_dataset_mix_ratio > 0:
1222+
logger.warning('train_dataset_mix_ratio is not supported for streaming dataset, set to 0')
1223+
self.train_dataset_mix_ratio = 0
1224+
1225+
if self.dataset_test_ratio > 0:
1226+
logger.info('Set dataset_test_ratio to 0 in streaming mode.'
1227+
'You can manually set val_dataset and val_dataset_sample.'
1228+
'or set streaming_val_size instead to split from train dataset')
1229+
self.dataset_test_ratio = 0
1230+
1231+
if self.train_dataset_sample > 0:
1232+
logger.warning('train_dataset_sample is not supported for streaming dataset, set to -1')
1233+
self.train_dataset_sample = -1
1234+
1235+
if self.dataloader_num_workers is None or self.dataloader_num_workers > 0:
1236+
logger.info('Set dataloader_num_workers to 0 in streaming mode')
1237+
self.dataloader_num_workers = 0
1238+
11841239

11851240
@dataclass
11861241
class InferArguments(ArgumentsBase):

0 commit comments

Comments
 (0)