Skip to content

Commit d5f0a4d

Browse files
Support max memory args (#1382)
1 parent 6c963d8 commit d5f0a4d

File tree

6 files changed

+29
-4
lines changed

6 files changed

+29
-4
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@
134134
- 匹配规则的应用优先级,从高到低为:query字段 > response特定字段 > 正则表达式匹配规则。
135135
- `--custom_register_path`: 默认为`None`. 传入`.py`文件, 用于注册模板、模型和数据集.
136136
- `--custom_dataset_info`: 默认为`None`, 传入外置dataset_info.json的路径、json字符串或者dict. 用于拓展数据集. 格式参考: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json
137-
- `--device_map_config_path`: 从本地文件中手动配置模型的device_map, 默认为None
137+
- `--device_map_config_path`: 从本地文件中手动配置模型的device_map, 默认为None.
138+
- `--device_max_memory`: 每个设备device_map的最大可用显存, `List`, 默认为`[]`, 传递的值数量必须和可见显卡数量相等. 比如`10GB 10GB`.
138139

139140
### Long Context
140141

@@ -252,7 +253,8 @@ RLHF参数继承了sft参数, 除此之外增加了以下参数:
252253
- `--load_args_from_ckpt_dir`: 是否从`ckpt_dir``sft_args.json`文件中读取模型配置信息. 默认是`True`.
253254
- `--load_dataset_config`: 该参数只有在`--load_args_from_ckpt_dir true`时才生效. 即是否从`ckpt_dir``sft_args.json`文件中读取数据集相关的配置信息. 默认为`False`.
254255
- `--eval_human`: 使用数据集中的验证集部分进行评估还是使用人工的方式评估. 默认值为`None`, 进行智能选择, 如果没有任何数据集(含自定义数据集)传入, 则会使用人工评估的方式. 如果有数据集传入, 则会使用数据集方式评估.
255-
- `--device_map_config_path`: 从本地文件中手动配置模型的device_map, 默认为None
256+
- `--device_map_config_path`: 从本地文件中手动配置模型的device_map, 默认为None.
257+
- `--device_max_memory`: 每个设备device_map的最大可用显存, `List`, 默认为`[]`, 传递的值数量必须和可见显卡数量相等. 比如`10GB 10GB`.
256258
- `--seed`: 默认值为`42`, 具体的参数介绍可以在`sft.sh命令行参数`中查看.
257259
- `--dtype`: 默认值为`'AUTO`, 具体的参数介绍可以在`sft.sh命令行参数`中查看.
258260
- `--dataset`: 默认值为`[]`, 具体的参数介绍可以在`sft.sh命令行参数`中查看.

docs/source_en/LLM/Command-line-parameters.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@
135135
- The application priority of matching rules is as follows, from highest to lowest: query fields > specific response fields > regular expression matching rules.
136136
- `--custom_register_path`: Default is `None`. Pass in a `.py` file used to register templates, models, and datasets.
137137
- `--custom_dataset_info`: Default is `None`. Pass in the path to an external `dataset_info.json`, a JSON string, or a dictionary. Used to register custom datasets. The format example: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json
138-
- `device_map_config_path`: Manually configure the model's device map from a local file, defaults to None.
138+
- `--device_map_config_path`: Manually configure the model's device map from a local file, defaults to None.
139+
- `--device_max_memory`: The max memory of each device can use for `device_map`, `List`, default is `[]`, The number of values must equal to the device count. Like `10GB 10GB`.
139140

140141
### Long Context
141142

@@ -253,7 +254,8 @@ RLHF parameters are an extension of the sft parameters, with the addition of the
253254
- `--load_args_from_ckpt_dir`: Whether to read model configuration info from `sft_args.json` file in `ckpt_dir`. Default is `True`.
254255
- `--load_dataset_config`: This parameter only takes effect when `--load_args_from_ckpt_dir true`. I.e. whether to read dataset related configuration from `sft_args.json` file in `ckpt_dir`. Default is `False`.
255256
- `--eval_human`: Whether to evaluate using validation set portion of dataset or manual evaluation. Default is `None`, for intelligent selection, if no datasets (including custom datasets) are passed, manual evaluation will be used. If datasets are passed, dataset evaluation will be used.
256-
- `device_map_config_path`: Manually configure the model's device map from a local file, defaults to None.
257+
- `--device_map_config_path`: Manually configure the model's device map from a local file, defaults to None.
258+
- `--device_max_memory`: The max memory of each device can use for `device_map`, `List`, default is `[]`, The number of values must equal to the device count. Like `10GB 10GB`.
257259
- `--seed`: Default is `42`, see `sft.sh command line arguments` for parameter details.
258260
- `--dtype`: Default is `'AUTO`, see `sft.sh command line arguments` for parameter details.
259261
- `--dataset`: Default is `[]`, see `sft.sh command line arguments` for parameter details.

swift/llm/infer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def prepare_model_template(args: InferArguments,
146146
if device_map == 'auto':
147147
model_kwargs['low_cpu_mem_usage'] = True
148148
model_kwargs['device_map'] = device_map
149+
if args.device_max_memory:
150+
assert len(args.device_max_memory) == torch.cuda.device_count()
151+
model_kwargs['max_memory'] = {i: mem for i, mem in enumerate(args.device_max_memory)}
149152

150153
# Loading Model and Tokenizer
151154
if hasattr(args, 'quant_config'):

swift/llm/rlhf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
6161
else:
6262
model_kwargs['device_map'] = 'auto'
6363

64+
if args.device_max_memory:
65+
n_gpu = torch.cuda.device_count()
66+
assert len(args.device_max_memory) == n_gpu / local_world_size
67+
model_kwargs['max_memory'] = {
68+
i: mem
69+
for i, mem in zip(list(range(local_rank, n_gpu, local_world_size)), args.device_max_memory)
70+
}
71+
6472
# quantization
6573
if args.quant_method == 'hqq':
6674
from transformers import HqqConfig

swift/llm/sft.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
6060
elif not use_torchacc():
6161
model_kwargs['device_map'] = 'auto'
6262

63+
if args.device_max_memory:
64+
n_gpu = torch.cuda.device_count()
65+
assert len(args.device_max_memory) == n_gpu / local_world_size
66+
model_kwargs['max_memory'] = {
67+
i: mem
68+
for i, mem in zip(list(range(local_rank, n_gpu, local_world_size)), args.device_max_memory)
69+
}
70+
6371
if args.quant_method == 'hqq':
6472
from transformers import HqqConfig
6573
if args.hqq_dynamic_config_path is not None:

swift/llm/utils/argument.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,7 @@ class SftArguments(ArgumentsBase):
635635
custom_dataset_info: Optional[str] = None # .json
636636

637637
device_map_config_path: Optional[str] = None
638+
device_max_memory: List[str] = field(default_factory=list)
638639

639640
# generation config
640641
max_new_tokens: int = 2048
@@ -1134,6 +1135,7 @@ class InferArguments(ArgumentsBase):
11341135
custom_register_path: Optional[str] = None # .py
11351136
custom_dataset_info: Optional[str] = None # .json
11361137
device_map_config_path: Optional[str] = None
1138+
device_max_memory: List[str] = field(default_factory=list)
11371139

11381140
# vllm
11391141
gpu_memory_utilization: float = 0.9

0 commit comments

Comments
 (0)