Skip to content

Commit ca955ba

Browse files
Support restuning/side/lora-embedding/qlora, etc (#48)
## New algorithms: * Add bnb 4bit & 8bit and autogptq lora support * Add lora support for torch.nn.Embedding * Add sidetuner * Add restuner-bypass * Fix some bugs ## New features: * llm_sft support cross-validation with model.generate * llm_sft support perf recording * All tuners support activate and deactivate * Add more unit tests * Fix some bugs
1 parent 70d956a commit ca955ba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+3077
-238
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,17 @@ Currently supported approches (and counting):
2121
1. LoRA: [LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/abs/2106.09685)
2222
2. Adapter: [Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751)
2323
3. Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119)
24-
4. All tuners offered on [Peft](https://github.com/huggingface/peft).
24+
4. Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503)
25+
5. ResTuning-Bypass
26+
7. All tuners offered on [Peft](https://github.com/huggingface/peft)
2527

2628
Key features:
2729

2830
1. By integrating the ModelScope library, models can be readily obatined via a model-id.
2931
2. Tuners provided by SWIFT be combined together to allow exploration of multiple tuners on a model for best result.
32+
3. Support calling `activate_adapter``deactivate_adapter` to activate/deactivate a single tuner. User can use one model with multiple tuners in different threads.
33+
34+
Users can check the [documentation of Swift](./docs/Get Started/1.Introduction.md) to get detail tutorials.
3035

3136
## LLM SFT Example
3237
[code link](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm)

README_CN.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,16 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
2020
1. LoRA:[LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/abs/2106.09685)
2121
2. Adapter:[Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751)
2222
3. Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119)
23-
4. 所有在[Peft](https://github.com/huggingface/peft)上提供的tuners。
23+
4. Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503)
24+
5. ResTuning-Bypass
25+
6. 所有在[Peft](https://github.com/huggingface/peft)上提供的tuners
2426

2527
关键特点:
2628
1. 通过集成ModelScope库,可以通过model id轻松获取模型。
2729
2. SWIFT提供的tuners可以组合在一起,以便在模型上探索多个tuners,以获得最佳结果。
30+
3. 支持调用`activate_adapter``deactivate_adapter`来使tuner激活或失活,用户可以在推理时用一个模型在不同线程中使用多种tuners而互不干扰。
31+
32+
用户可以查看 [Swift官方文档](./docs/Get Started/1.Introduction.md) 来了解详细信息。
2833

2934
## 大模型微调的例子
3035
[code link](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm)

docs/Get Started/1.Introduction.md

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# 介绍
2+
3+
Swift是一个提供LLM模型轻量级训练和推理的开源框架。Swift提供的主要能力是`efficient tuners`,tuners是运行时动态加载到模型上的额外结构,在训练时将原模型的参数冻结,只训练tuner部分,这样可以达到快速训练、降低显存使用的目的。比如,最常用的tuner是LoRA。
4+
5+
总之,在这个框架中提供了以下特性:
6+
7+
- **具备SOTA特性的Efficient Tuners**:用于结合大模型实现轻量级(在商业级显卡上)训练和推理,并取得较好效果
8+
- **使用ModelScope Hub的Trainer**:基于`transformers trainer`提供,支持LLM模型的训练,并支持将训练后的模型上传到[ModelScope Hub](https://www.modelscope.cn/models)
9+
- **可运行的模型Examples**:针对热门大模型提供的训练脚本和推理脚本,并针对热门开源数据集提供了预处理逻辑,可直接运行使用
10+
11+
# 快速开始
12+
13+
在本章节会介绍如何快速安装swift并设定好运行环境,并跑通一个用例。
14+
15+
安装swift的方式非常简单,用户只需要在python>=3.8环境中运行:
16+
17+
```shell
18+
pip install ms-swift
19+
```
20+
21+
下面的代码使用LoRA在分类任务上训练了`bert-base-uncased`模型:
22+
23+
**运行下面的代码前请额外安装modelscope: **
24+
25+
```shell
26+
pip install modelscope>=1.9.0
27+
```
28+
29+
```python
30+
import os
31+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
32+
33+
from modelscope import AutoModelForSequenceClassification, AutoTokenizer, MsDataset
34+
from transformers import default_data_collator
35+
36+
from swift import Trainer, LoRAConfig, Swift, TrainingArguments
37+
38+
39+
model = AutoModelForSequenceClassification.from_pretrained(
40+
'AI-ModelScope/bert-base-uncased', revision='v1.0.0')
41+
tokenizer = AutoTokenizer.from_pretrained(
42+
'AI-ModelScope/bert-base-uncased', revision='v1.0.0')
43+
lora_config = LoRAConfig(target_modules=['query', 'key', 'value'])
44+
model = Swift.prepare_model(model, config=lora_config)
45+
46+
train_dataset = MsDataset.load('clue', subset_name='afqmc', split='train').to_hf_dataset().select(range(100))
47+
val_dataset = MsDataset.load('clue', subset_name='afqmc', split='validation').to_hf_dataset().select(range(100))
48+
49+
50+
def tokenize_function(examples):
51+
return tokenizer(examples["sentence1"], examples["sentence2"],
52+
padding="max_length", truncation=True, max_length=128)
53+
54+
55+
train_dataset = train_dataset.map(tokenize_function)
56+
val_dataset = val_dataset.map(tokenize_function)
57+
58+
arguments = TrainingArguments(
59+
output_dir='./outputs',
60+
per_device_train_batch_size=16,
61+
)
62+
63+
trainer = Trainer(model, arguments, train_dataset=train_dataset,
64+
eval_dataset=val_dataset,
65+
data_collator=default_data_collator,)
66+
67+
trainer.train()
68+
```
69+
70+
在上面的例子中,我们使用了`bert-base-uncased`作为基模型,将LoRA模块patch到了['query', 'key', 'value']三个Linear上,进行了一次训练。
71+
72+
训练结束后可以看到outputs文件夹,它的文件结构如下:
73+
74+
> outputs
75+
>
76+
> ​ |-- checkpoint-xx
77+
>
78+
> ​ |-- configuration.json
79+
>
80+
> ​ |-- default
81+
>
82+
> ​ |-- adapter_config.json
83+
>
84+
> ​ |-- adapter_model.bin
85+
>
86+
> ​ |-- ...
87+
88+
可以使用该文件夹执行推理:
89+
90+
```python
91+
from modelscope import AutoModelForSequenceClassification, AutoTokenizer
92+
from swift import Trainer, LoRAConfig, Swift
93+
94+
95+
model = AutoModelForSequenceClassification.from_pretrained(
96+
'AI-ModelScope/bert-base-uncased', revision='v1.0.0')
97+
tokenizer = AutoTokenizer.from_pretrained(
98+
'AI-ModelScope/bert-base-uncased', revision='v1.0.0')
99+
lora_config = LoRAConfig(target_modules=['query', 'key', 'value'])
100+
model = Swift.from_pretrained(model, model_id='./outputs/checkpoint-21')
101+
102+
print(model(**tokenizer('this is a test', return_tensors='pt')))
103+
```

docs/Get Started/2.Installation.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# 安装和使用
2+
3+
## Wheel包安装
4+
5+
可以使用pip进行安装:
6+
7+
```shell
8+
pip install ms-swift
9+
```
10+
11+
## 源代码安装
12+
13+
```shell
14+
git clone https://github.com/modelscope/swift.git
15+
cd swift
16+
pip install -e .
17+
```
18+
19+
## Notebook环境
20+
21+
Swift支持训练的绝大多数模型都可以在`A10`显卡上使用,用户可以使用ModelScope官方提供的免费显卡资源:
22+
23+
1. 进入[ModelScope](https://www.modelscope.cn)官方网站并登录
24+
2. 点击左侧的`我的Notebook`并开启一个免费GPU实例
25+
3. 愉快地薅A10显卡羊毛
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Swift API
2+
3+
## 在训练中使用Swift
4+
5+
调用`Swift.prepare_model()`来将tuners添加到模型上:
6+
7+
```python
8+
from modelscope import Model
9+
from swift import Swift, LoRAConfig
10+
import torch
11+
model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto')
12+
lora_config = LoRAConfig(
13+
r=16,
14+
target_modules=['query_key_value'],
15+
lora_alpha=32,
16+
lora_dropout=0.)
17+
model = Swift.prepare_model(model, lora_config)
18+
# use model to do other things
19+
```
20+
21+
也可以同时使用多个tuners:
22+
23+
```python
24+
from modelscope import Model
25+
from swift import Swift, LoRAConfig, AdapterConfig
26+
import torch
27+
model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto')
28+
lora_config = LoRAConfig(
29+
r=16,
30+
target_modules=['query_key_value'],
31+
lora_alpha=32,
32+
lora_dropout=0.)
33+
adapter_config = AdapterConfig(
34+
dim=model.config.hidden_size,
35+
target_modules=['mlp'],
36+
method_name='forward',
37+
hidden_pos=0,
38+
adapter_length=32,
39+
)
40+
model = Swift.prepare_model(model, {'first_tuner': lora_config, 'second_tuner': adapter_config})
41+
# use model to do other things
42+
```
43+
44+
在使用多个tuners时,传入的第二个参数需要是Dict,key是tuner名字,value是tuner配置。
45+
46+
训练后可以调用:
47+
48+
```python
49+
model.save_pretrained(save_directory='./output')
50+
```
51+
52+
来存储模型checkpoint。模型的checkpoint文件只会包括tuners的权重,不会包含模型本身的权重。存储后的结构如下:
53+
54+
> outputs
55+
>
56+
> ​ |-- configuration.json
57+
>
58+
> ​ |-- first_tuner
59+
>
60+
> ​ |-- adapter_config.json
61+
>
62+
> ​ |-- adapter_model.bin
63+
>
64+
> ​ |-- second_tuner
65+
>
66+
> ​ |-- adapter_config.json
67+
>
68+
> ​ |-- adapter_model.bin
69+
>
70+
> ​ |-- ...
71+
72+
如果只传入单独的config,则会使用默认的名称`default`
73+
74+
> outputs
75+
>
76+
> ​ |-- configuration.json
77+
>
78+
> ​ |-- default
79+
>
80+
> ​ |-- adapter_config.json
81+
>
82+
> ​ |-- adapter_model.bin
83+
>
84+
> ​ |-- ...
85+
86+
## 在推理时使用Swift
87+
88+
使用`Swift.from_pretrained()`来拉起训练后存储的checkpoint:
89+
90+
```python
91+
from modelscope import Model
92+
from swift import Swift
93+
import torch
94+
model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto')
95+
model = Swift.from_pretrained(model, './output')
96+
```
97+
98+
## 加载多个tuners并在不同线程中并行使用
99+
100+
在模型提供服务时,很可能出现一个模型同时服务多个http线程的情况,其中每个线程代表了一类用户请求。Swift支持在不同线程中激活不同tuners:
101+
102+
```python
103+
from modelscope import Model
104+
from swift import Swift
105+
import torch
106+
model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto')
107+
# 假设output中存在训练完成的a、b、c、d是个tuners
108+
model = Swift.from_pretrained(model, './output')
109+
110+
# 假设两类请求,一类使用a、b两个tuner,一类使用c、d两个tuner
111+
type_1 = ['a', 'b', 'c']
112+
type_2 = ['a', 'c', 'd']
113+
114+
def request(_input, _type):
115+
if _type == 'type_1':
116+
model.set_active_adapters(type_1)
117+
elif _type == 'type_2':
118+
model.set_active_adapters(type_2)
119+
return model(**_input)
120+
121+
```
122+
123+
在不同线程中使用同一个tuner是安全的。

docs/Get Started/4.examples.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# LLM训练方案
2+
3+
Swift提供了完整的LLM训练方案,可以查看[Examples的README](../../examples/pytorch/llm/README_CN.md).

docs/Modules/1.swift.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# 接口介绍
2+
3+
## Swift
4+
5+
##### Swift.prepare_model(model: Union[nn.Module, 'SwiftModel'], config: Union[SwiftConfig, PeftConfig, Dict[str, SwiftConfig]], **kwargs)
6+
7+
>该静态方法随机初始化指定类型的tuners
8+
>
9+
>model: 需要加载tuner的模型,可以是SwiftModel,后添加的tuners会和前面SwiftModel中的一起生效
10+
>
11+
>config:加载的tuner的config,可以是SwiftConfig或PeftConfig,或者带有名称的config的dict。如果不传递名称则名称默认为`default`
12+
>
13+
>kwargs:
14+
>
15+
>​ extra_state_keys: List[str] 需要被额外存储到文件的原始模型weights的key
16+
>
17+
>​ inference_mode: bool 是否以推理模式启动
18+
19+
SwiftConfig的具体参数可以查看每个tuner的文档。
20+
21+
##### Swift.from_pretrained(model: Union[nn.Module, 'SwiftModel'], model_id: str = None, adapter_name: Union[str, List[str]] = None, revision: str = None, **kwargs)
22+
23+
> 该静态方法拉起之前存储过的tuners的checkpoint
24+
>
25+
> model: 需要加载tuner的模型,可以是SwiftModel,后添加的tuners会和前面SwiftModel中的一起生效
26+
>
27+
> model_id:已存储的tuners的本地目录或modelscope hub id。
28+
>
29+
> adapter_name:需要被拉起的adapter名称,默认为None代表全部拉起
30+
>
31+
> kwargs:
32+
>
33+
> ​ inference_mode: bool 是否以推理模式启动
34+
>
35+
> ​ revision: model_id的revision
36+
>
37+
> ​ extra_state_keys: 下次save_pretrained时额外存储的weights
38+
39+
## SwiftModel
40+
41+
`Swift.prepare_model``Swift.from_pretrained`拉起后,都会返回一个`SwiftModel`类型的实例。该实例包装了实际传入的模型。
42+
43+
##### save_pretrained(self, save_directory: str, safe_serialization: bool = False, adapter_name: Union[str, List[str]] = None, **kwargs)
44+
45+
> 实例方法,将模型存储到本地磁盘中,可直接被Swift.from_pretrained拉起
46+
>
47+
> save_directory:存储的目录
48+
>
49+
> safe_serialization: 是否存储safe_tensors
50+
>
51+
> adapter_name:待存储的adapter名称,默认为None代表全部存储
52+
53+
##### set_active_adapters(self, adapter_names: List[str])
54+
55+
> 实例方法,设置模型在当前线程中生效的所有adapter。如果将环境变量`USE_UNIQUE_THREAD`设置为'0',则设置对所有线程同时生效。
56+
>
57+
> adapter_names:adapter名称列表
58+
59+
##### activate_adapter(self, adapter_name)
60+
61+
> 实例方法,在当前线程中单独激活某个adapter,如果将环境变量`USE_UNIQUE_THREAD`设置为'0',则设置对所有线程同时生效。
62+
>
63+
> adapter_name:adapter名称
64+
65+
##### deactivate_adapter(self, adapter_name)
66+
67+
> 实例方法,在当前线程中单独激活某个adapter,如果将环境变量`USE_UNIQUE_THREAD`设置为'0',则设置对所有线程同时生效。
68+
>
69+
> adapter_name:adapter名称

docs/Modules/2.lora.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# LoRA
2+
3+
LoRA是[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) 论文提供的轻量级训练组件。LoRA可以添加到Linear、Embedding、Conv2d等算子上生效。
4+
5+
>```python
6+
>LoRAConfig (
7+
> r: int LoRA结构的秩
8+
> target_modules: Union[List[str], str] MLP结构的module_key,如果是str类型则进行full_match统配查找,如果是List,则进行末尾匹配
9+
> lora_alpha: int LoRA结构的权重比例,lora_alpha/r的值是lora结构的权重
10+
> lora_dropout: float LoRA结构的dropout比例
11+
> merge_weights: bool 在推理时是否将loRA权重合并到原始weights上
12+
> use_merged_linear: bool 是否是merged linear结构
13+
> enable_lora: List[bool]: 如果是use_merged_linear,哪些module需要添加LoRA结构
14+
> bias: str 偏置是否参与训练和存储,可以为`none`:所有偏置不参与训练, `all`:所有模块的偏置均参与训练, `lora_only`:仅loRA结构的偏置参与训练
15+
>)
16+
>```
17+
18+
一个使用LoRA的例子如下:
19+
20+
```python
21+
from modelscope import Model
22+
from swift import Swift, LoRAConfig
23+
import torch
24+
model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto')
25+
lora_config = LoRAConfig(
26+
r=16,
27+
target_modules=['query_key_value'],
28+
lora_alpha=32,
29+
lora_dropout=0.)
30+
model = Swift.prepare_model(model, lora_config)
31+
# use model to do other things
32+
```

0 commit comments

Comments
 (0)