Skip to content

Commit e29cf5a

Browse files
authored
Support hd_num (#1801)
1 parent 089234c commit e29cf5a

File tree

10 files changed

+89
-28
lines changed

10 files changed

+89
-28
lines changed

docs/source/LLM/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
## LLM文档
22

3+
[English Documentation](https://swift.readthedocs.io/en/latest/)
4+
35
### 📚教程
46

57
1. [LLM推理文档](LLM推理文档.md)

docs/source/LLM/命令行参数.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
- `--resume_from_checkpoint`: 用于断点续训, 默认为`None`. 你可以将其设置为checkpoint的路径, 例如: `--resume_from_checkpoint output/qwen-7b-chat/vx-xxx/checkpoint-xxx`, 来进行断点续训. 支持调节`--resume_only_model`在断点续训时只读取模型文件.
3838
- `--resume_only_model`: 默认为`False`, 即为严格的断点续训, 这会读取模型、优化器和lr_scheduler的权重和各个设备存储的随机种子, 并将从上次训练暂停的stpes后继续计数进行训练. 如果设置为`True`, 则只读取模型的权重.
3939
- `--dtype`: 基模型载入时的torch_dtype, 默认为`'AUTO'`, 即智能选择dtype: 如果机器不支持bf16, 则使用fp16, 如果`MODEL_MAPPING`中对应模型有指定torch_dtype, 则使用其对应dtype, 否则使用bf16. 你可以选择的值包括: 'bf16', 'fp16', 'fp32'.
40+
- `--model_kwargs`: 用于传入多模态模型中针对于模型的额外参数, 例如: `'{"hd_num": 16}'`. 你可以传入json字符串或者直接传入字典. 默认为`None`. 除了使用该参数,你也可以通过环境变量传入, 例如: `HD_NUM=16`.
4041
- `--dataset`: 用于选择训练的数据集, 默认为`[]`. 可以选择的数据集可以查看[支持的数据集](支持的模型和数据集.md#数据集). 如果需要使用多个数据集进行训练, 你可以使用','或者' '进行分割, 例如: `--dataset alpaca-en,alpaca-zh` or `--dataset alpaca-en alpaca-zh`. 支持Modelscope Hub/HuggingFace Hub/本地路径、subsets选择与数据集采样, 每个数据集指定格式如下: `[HF or MS::]{dataset_name} or {dataset_id} or {dataset_path}[:subset1/subset2/...][#dataset_sample]`, 最简只需要指定dataset_name、dataset_id或者dataset_path即可. 自定义数据集可以查看[数据集的自定义与拓展文档](自定义与拓展.md#自定义数据集).
4142
- 支持MS和HF hub, 以及dataset_sample的支持. e.g. 'MS::alpaca-zh#2000', 'HF::jd-sentiment-zh#2000' (默认使用的hub, 由`USE_UF`环境变量控制, 默认MS).
4243
- 对subsets更细粒度的控制: 默认使用注册时指定的subsets(注册时未指定则使用'default'). e.g. 'sharegpt-gpt4'. 如果指定subsets则使用对应子集的数据集. e.g. 'sharegpt-gpt4:default/V3_format#2000'. 这里使用`default``V3_format`子数据集, 使用'/'进行分隔, 并取2000条.
@@ -299,6 +300,7 @@ RLHF参数继承了sft参数, 除此之外增加了以下参数:
299300
- `--device_max_memory`: 每个设备device_map的最大可用显存, `List`, 默认为`[]`, 传递的值数量必须和可见显卡数量相等. 比如`10GB 10GB`.
300301
- `--seed`: 默认值为`42`, 具体的参数介绍可以在`sft命令行参数`中查看.
301302
- `--dtype`: 默认值为`'AUTO`, 具体的参数介绍可以在`sft命令行参数`中查看.
303+
- `--model_kwargs`: 默认值为`'None`, 具体的参数介绍可以在`sft命令行参数`中查看.
302304
- `--dataset`: 默认值为`[]`, 具体的参数介绍可以在`sft命令行参数`中查看.
303305
- `--val_dataset`: 默认为`[]`, 具体的参数介绍可以在`sft命令行参数`中查看.
304306
- `--dataset_seed`: 默认值为`None`, 具体的参数介绍可以在`sft命令行参数`中查看.

docs/source_en/LLM/Command-line-parameters.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
- `--resume_from_checkpoint`: Used for resuming training from a checkpoint, default is `None`. You can set it to the path of the checkpoint, for example: `--resume_from_checkpoint output/qwen-7b-chat/vx-xxx/checkpoint-xxx`, to resume training from that point. Supports adjusting `--resume_only_model` to only read the model file during checkpoint continuation.
3737
- `--resume_only_model`: Default is `False`, which means strict checkpoint continuation, this will read the weights of the model, optimizer, lr_scheduler, and the random seeds stored on each device, and continue training from the last paused steps. If set to `True`, it will only read the weights of the model.
3838
- `--dtype`: torch_dtype when loading base model, default is `'AUTO'`, i.e. intelligently select dtype: if machine does not support bf16, use fp16; if `MODEL_MAPPING` specifies torch_dtype for corresponding model, use its dtype; otherwise use bf16. Options include: 'bf16', 'fp16', 'fp32'.
39+
- `--model_kwargs`: Used for passing additional parameters to the multimodal model, for example: `'{"hd_num": 16}'`. You can either pass a JSON string or directly pass a dictionary. The default is `None`. In addition to using this parameter, you can also pass it through environment variables, for example: `HD_NUM=16`.
3940
- `--dataset`: Used to select the training dataset, default is `[]`. You can see the list of available datasets [here](Supported-models-datasets.md#Datasets). If you need to train with multiple datasets, you can use ',' or ' ' to separate them, for example: `--dataset alpaca-en,alpaca-zh` or `--dataset alpaca-en alpaca-zh`. It supports Modelscope Hub/HuggingFace Hub/local paths, subset selection, and dataset sampling. The specified format for each dataset is as follows: `[HF or MS::]{dataset_name} or {dataset_id} or {dataset_path}[:subset1/subset2/...][#dataset_sample]`. The simplest case requires specifying only dataset_name, dataset_id, or dataset_path. Customizing datasets can be found in the [Customizing and Extending Datasets document](Customization.md#custom-dataset)
4041
- Supports MS and HF hub, as well as dataset_sample. For example, 'MS::alpaca-zh#2000', 'HF::jd-sentiment-zh#2000' (the default hub used is controlled by the `USE_UF` environment variable, default is MS).
4142
- More fine-grained control over subsets: It uses the subsets specified during registration by default (if not specified during registration, it uses 'default'). For example, 'sharegpt-gpt4'. If subsets are specified, it uses the corresponding subset of the dataset. For example, 'sharegpt-gpt4:default/V3_format#2000'. Here, the `default` and `V3_format` sub-datasets are used, separated by '/', and 2000 entries are selected.
@@ -301,6 +302,7 @@ RLHF parameters are an extension of the sft parameters, with the addition of the
301302
- `--device_max_memory`: The max memory of each device can use for `device_map`, `List`, default is `[]`, The number of values must equal to the device count. Like `10GB 10GB`.
302303
- `--seed`: Default is `42`, see `sft command line arguments` for parameter details.
303304
- `--dtype`: Default is `'AUTO`, see `sft command line arguments` for parameter details.
305+
- `--model_kwargs`: Default is `None`, see `sft command line arguments` for parameter details.
304306
- `--dataset`: Default is `[]`, see `sft command line arguments` for parameter details.
305307
- `--val_dataset`: Default is `[]`, see `sft command line arguments` for parameter details.
306308
- `--dataset_seed`: Default is `None`, see `sft command line arguments` for parameter details.

docs/source_en/LLM/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
## LLM Documentation
22

3+
[中文文档](https://swift.readthedocs.io/zh-cn/latest/LLM/index.html)
4+
35
### 📚Tutorials!
46

57
1. [LLM Inference](LLM-inference.md)

swift/llm/megatron/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def init_megatron_env() -> None:
1818
if 'MEGATRON_LM_PATH' not in os.environ:
1919
megatron_path = git_clone_github(
2020
'https://github.com/NVIDIA/Megatron-LM', commit_hash='6dbe4cf699880038b1e5cd90b23ee71053c7f2ee')
21-
os.environ['MEGATRON_LM_PATH'] = megatron_path
2221
else:
2322
megatron_path = os.environ['MEGATRON_LM_PATH']
2423
if not is_megatron_available():
@@ -28,10 +27,9 @@ def init_megatron_env() -> None:
2827
if 'PAI_MEGATRON_PATCH_PATH' not in os.environ:
2928
megatron_patch_path = git_clone_github(
3029
'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='6fd5d050b240fd959f0ba69f1e9cd9a053e5a81d')
31-
os.environ['PAI_MEGATRON_PATCH_PATH'] = megatron_patch_path
3230
else:
3331
megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH']
34-
sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH'])
32+
sys.path.append(megatron_patch_path)
3533

3634
# rename qwen1.5->qwen1_5 files
3735
qwen1_5_folders = ['toolkits/model_checkpoints_convertor/qwen']

swift/llm/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, Function, Model,
2222
ModelList, UsageInfo, XRequestConfig, random_uuid)
2323
from .template import (DEFAULT_SYSTEM, TEMPLATE_MAPPING, History, Prompt, StopWords, Template, TemplateType,
24-
get_template, register_template)
24+
get_env_args, get_template, register_template)
2525
from .utils import (LazyLLMDataset, LLMDataset, dataset_map, download_dataset, find_all_linears, find_embedding,
2626
find_ln, get_max_model_len, get_time_info, history_to_messages, inference, inference_stream,
2727
is_lmdeploy_available, is_megatron_available, is_quant_model, is_vllm_available,

swift/llm/utils/argument.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@ def is_adapter(sft_type: str) -> bool:
4646

4747
class ArgumentsBase:
4848

49+
def __post_init__(self) -> None:
50+
if self.max_length == -1:
51+
self.max_length = None
52+
model_kwargs = self.model_kwargs
53+
if model_kwargs is None:
54+
model_kwargs = {}
55+
if isinstance(model_kwargs, str):
56+
model_kwargs = json.loads(model_kwargs)
57+
for k, v in model_kwargs.items():
58+
k = k.upper()
59+
os.environ[k] = str(v)
60+
4961
@classmethod
5062
def _check_path(cls,
5163
value: Union[str, List[str]],
@@ -592,6 +604,9 @@ class SftArguments(ArgumentsBase):
592604
min_lr: Optional[float] = None
593605
sequence_parallel: bool = False
594606

607+
# multimodal
608+
model_kwargs: Optional[str] = None
609+
595610
# dataset_id or dataset_name or dataset_path or ...
596611
dataset: List[str] = field(
597612
default_factory=list, metadata={'help': f'dataset choices: {list(DATASET_MAPPING.keys())}'})
@@ -889,6 +904,7 @@ def _prepare_modules_to_save(self, modules_to_save) -> List[str]:
889904
return modules_to_save
890905

891906
def __post_init__(self) -> None:
907+
super().__post_init__()
892908
self.handle_compatibility()
893909
if len(self.val_dataset) > 0:
894910
self.dataset_test_ratio = 0.0
@@ -1040,8 +1056,6 @@ def __post_init__(self) -> None:
10401056
self.eval_batch_size = self.batch_size
10411057
if self.save_total_limit == -1:
10421058
self.save_total_limit = None
1043-
if self.max_length == -1:
1044-
self.max_length = None
10451059

10461060
if self.deepspeed is not None:
10471061
if is_mp():
@@ -1276,6 +1290,9 @@ class InferArguments(ArgumentsBase):
12761290
seed: int = 42
12771291
dtype: Literal['bf16', 'fp16', 'fp32', 'AUTO'] = 'AUTO'
12781292

1293+
# multimodal
1294+
model_kwargs: Optional[str] = None
1295+
12791296
# dataset_id or dataset_name or dataset_path or ...
12801297
dataset: List[str] = field(
12811298
default_factory=list, metadata={'help': f'dataset choices: {list(DATASET_MAPPING.keys())}'})
@@ -1363,6 +1380,7 @@ class InferArguments(ArgumentsBase):
13631380
vllm_lora_modules: List[str] = None
13641381

13651382
def __post_init__(self) -> None:
1383+
super().__post_init__()
13661384
if self.ckpt_dir is not None and not self.check_ckpt_dir_correct(self.ckpt_dir):
13671385
logger.warning(f'The checkpoint dir {self.ckpt_dir} passed in is invalid, please make sure'
13681386
'the dir contains a `configuration.json` file.')
@@ -1419,8 +1437,6 @@ def __post_init__(self) -> None:
14191437

14201438
self.bnb_4bit_compute_dtype, self.load_in_4bit, self.load_in_8bit = self.select_bnb()
14211439

1422-
if self.max_length == -1:
1423-
self.max_length = None
14241440
if self.overwrite_generation_config is None:
14251441
if self.ckpt_dir is None:
14261442
self.overwrite_generation_config = False
@@ -1518,9 +1534,6 @@ class DeployArguments(InferArguments):
15181534
verbose: bool = True # Whether to log request_info
15191535
log_interval: int = 10 # Interval for printing global statistics
15201536

1521-
def __post_init__(self):
1522-
super().__post_init__()
1523-
15241537

15251538
@dataclass
15261539
class EvalArguments(InferArguments):

swift/llm/utils/model.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.nn.functional as F
1414
import torch.utils.checkpoint
1515
import transformers
16+
from accelerate.utils import find_device
1617
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
1718
GenerationConfig, GPTQConfig, snapshot_download)
1819
from modelscope.hub.utils.utils import get_cache_dir
@@ -28,8 +29,8 @@
2829
from swift import get_logger
2930
from swift.utils import get_dist_setting, safe_ddp_context, subprocess_run, use_torchacc
3031
from swift.utils.module_mapping import get_regex_for_mm_default_lora
31-
from .template import TemplateType
32-
from .utils import get_max_model_len, get_rope_scaling, is_unsloth_available, set_rope_scaling
32+
from .template import TemplateType, get_env_args
33+
from .utils import get_max_model_len, get_rope_scaling, is_unsloth_available, set_rope_scaling, to_device
3334

3435
logger = get_logger()
3536

@@ -1293,7 +1294,7 @@ def get_model_tokenizer_phi3_vision(model_dir: str,
12931294
**kwargs):
12941295
processor_kwargs = {}
12951296
if 'num_crops' in kwargs:
1296-
processor_kwargs['num_crops'] = kwargs['num_crops']
1297+
processor_kwargs['num_crops'] = get_env_args('num_crops', int, kwargs['num_crops'])
12971298
from transformers import AutoProcessor
12981299
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True, **processor_kwargs)
12991300
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
@@ -4282,19 +4283,27 @@ def _use_submodel_func(model, submodel_name: str, func_list: List[str]) -> None:
42824283
submodel = getattr(model, submodel_name)
42834284

42844285
def _get_new_func(func_name: str):
4285-
_old_func = getattr(submodel, func_name)
4286+
_old_func = getattr(submodel.__class__, func_name)
42864287

42874288
@wraps(_old_func)
4288-
def _new_func(*args, **kwargs):
4289-
return _old_func(*args, **kwargs)
4289+
def _new_func(self, *args, **kwargs):
4290+
res = _old_func(self, *args, **kwargs)
4291+
if func_name == 'forward':
4292+
device = find_device(args)
4293+
if device is None:
4294+
device = find_device(kwargs)
4295+
res = res.__class__(**to_device(res, device))
4296+
return res
42904297

42914298
return _new_func
42924299

42934300
for key in func_list:
4294-
model_key = key
4295-
if key == 'forward' and hasattr(model, '_old_forward'): # device_map
4296-
model_key = '_old_forward'
4297-
setattr(model, model_key, _get_new_func(key))
4301+
value = MethodType(_get_new_func(key), submodel)
4302+
setattr(model, key, value)
4303+
if key == 'generate' and model.device != submodel.device:
4304+
submodel.__class__.device = model.device
4305+
if key == 'forward' and 'generate' in func_list:
4306+
setattr(submodel, key, value)
42984307

42994308

43004309
@register_model(

swift/llm/utils/template.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import inspect
3+
import os
34
import re
45
from contextlib import contextmanager
56
from copy import deepcopy
67
from functools import partial, wraps
78
from types import MethodType
8-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
9+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union
910

1011
import json
1112
import torch
@@ -1539,6 +1540,28 @@ class Llama3Template(Llama3TemplateMixin, Template):
15391540
Template(['<s>'], ['<|User|>:{{QUERY}}\n<|Bot|>:'], ['<eoa>\n'], ['<eoa>'], INTERNLM_SYSTEM,
15401541
['<s><|System|>:{{SYSTEM}}\n']))
15411542

1543+
_T = TypeVar('_T')
1544+
1545+
_log_set = set() # log once
1546+
1547+
1548+
def get_env_args(args_name: str,
1549+
type_func: Callable[[str], _T] = int,
1550+
default_value: Optional[_T] = None) -> Optional[_T]:
1551+
args_name_upper = args_name.upper()
1552+
value = os.getenv(args_name_upper)
1553+
if value is None:
1554+
value = default_value
1555+
log_info = (f'Setting {args_name}: {default_value}. '
1556+
f'You can adjust this hyperparameter through the environment variable: `{args_name_upper}`.')
1557+
else:
1558+
value = type_func(value)
1559+
log_info = f'Using environment variable `{args_name_upper}`, Setting {args_name}: {value}.'
1560+
if log_info not in _log_set:
1561+
_log_set.add(log_info)
1562+
logger.info(log_info)
1563+
return value
1564+
15421565

15431566
class Internlm2Template(ChatmlTemplate):
15441567
system = INTERNLM_SYSTEM
@@ -1595,12 +1618,14 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
15951618

15961619
if self.version == 'v2.5':
15971620
hd_num = 24
1598-
Image_transform = get_class_from_dynamic_module('ixc_utils.Image_transform', self.tokenizer.model_dir)
15991621
if len(images) > 1:
16001622
hd_num = 6
1623+
hd_num = get_env_args('hd_num', int, hd_num)
1624+
Image_transform = get_class_from_dynamic_module('ixc_utils.Image_transform', self.tokenizer.model_dir)
16011625
images = [Image_transform(image, hd_num=hd_num) for image in images]
16021626
elif self.version == 'v2-4khd':
16031627
hd_num = 55
1628+
hd_num = get_env_args('hd_num', int, hd_num)
16041629
HD_transform = get_class_from_dynamic_module('ixc_utils.HD_transform', self.tokenizer.model_dir)
16051630
images = [HD_transform(image, hd_num=hd_num) for image in images]
16061631
images = [self.model.vis_processor(image).to(dtype) for image in images]
@@ -1723,7 +1748,9 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
17231748
images = example.get('images')
17241749
if images:
17251750
labels = inputs.get('labels')
1726-
pixel_values_images = [transform_image(image) for image in images]
1751+
input_size = get_env_args('input_size', int, 448)
1752+
max_num = get_env_args('max_num', int, 12)
1753+
pixel_values_images = [transform_image(image, input_size, max_num) for image in images]
17271754
pixel_values = torch.cat(pixel_values_images, dim=0).to(self.model.dtype)
17281755
image_bs = pixel_values.shape[0]
17291756

@@ -1784,7 +1811,8 @@ def replace_tag(self, media_type, index, example) -> List[Context]:
17841811
if media_type == 'image':
17851812
return image_context
17861813
elif media_type == 'video':
1787-
load_video = partial(load_video_internvl, num_segments=self.video_segments)
1814+
video_segments = get_env_args('video_segments', int, self.video_segments)
1815+
load_video = partial(load_video_internvl, num_segments=video_segments)
17881816
return _replace_video2image(load_video, example, lambda i: [f'Frame{i + 1}: '] + image_context)
17891817

17901818
def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]:
@@ -1816,7 +1844,9 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
18161844
images = example.get('images')
18171845
if images:
18181846
has_video = bool(example.get('videos'))
1819-
pixel_values = [transform_image(image, max_num=1 if has_video else 12) for image in images]
1847+
input_size = get_env_args('input_size', int, 448)
1848+
max_num = get_env_args('max_num', int, 1 if has_video else 12)
1849+
pixel_values = [transform_image(image, input_size, max_num) for image in images]
18201850
num_patches = [pv.shape[0] for pv in pixel_values]
18211851
pixel_values = torch.cat(pixel_values).to(self.model.dtype)
18221852
else:
@@ -1924,7 +1954,9 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
19241954
processor = self.tokenizer.processor
19251955
images = example.get('images') or []
19261956
assert len(images) == 1, 'Florence series models only supports input with a single image.'
1927-
image_tensors = transform_image(images[0])
1957+
input_size = get_env_args('input_size', int, 448)
1958+
max_num = get_env_args('max_num', int, 12)
1959+
image_tensors = transform_image(images[0], input_size, max_num)
19281960
example['_image'] = image_tensors
19291961

19301962
# process bbox
@@ -2789,6 +2821,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
27892821
use_image_id = False
27902822
max_slice_nums = 1 # or 2
27912823

2824+
max_slice_nums = get_env_args('max_slice_nums', int, max_slice_nums)
27922825
input_ids = inputs['input_ids']
27932826
labels = inputs['labels']
27942827
idx_list = _findall(input_ids, -100)

0 commit comments

Comments
 (0)