Skip to content

Commit 91fa553

Browse files
linsj20pre-commit-ci[bot]
authored andcommitted
[Feature] qlora support (#5586)
* [feature] qlora support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * qlora follow commit * migrate qutization folder to colossalai/ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8954a0c commit 91fa553

File tree

14 files changed

+640
-143
lines changed

14 files changed

+640
-143
lines changed

LICENSE

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,3 +552,18 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
552552
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
553553
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
554554
THE SOFTWARE.
555+
---------------- LICENSE FOR Hugging Face accelerate ----------------
556+
557+
Copyright 2021 The HuggingFace Team
558+
559+
Licensed under the Apache License, Version 2.0 (the "License");
560+
you may not use this file except in compliance with the License.
561+
You may obtain a copy of the License at
562+
563+
http://www.apache.org/licenses/LICENSE-2.0
564+
565+
Unless required by applicable law or agreed to in writing, software
566+
distributed under the License is distributed on an "AS IS" BASIS,
567+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
568+
See the License for the specific language governing permissions and
569+
limitations under the License.

applications/Colossal-LLaMA/colossal_llama/dataset/loader.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,19 @@ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch
8080

8181
# `List[torch.Tensor]`
8282
batch_input_ids = [
83-
torch.LongTensor(instance["input_ids"][: self.max_length])
84-
if len(instance["input_ids"]) > self.max_length
85-
else torch.LongTensor(instance["input_ids"])
83+
(
84+
torch.LongTensor(instance["input_ids"][: self.max_length])
85+
if len(instance["input_ids"]) > self.max_length
86+
else torch.LongTensor(instance["input_ids"])
87+
)
8688
for instance in instances
8789
]
8890
batch_labels = [
89-
torch.LongTensor(instance["labels"][: self.max_length])
90-
if len(instance["labels"]) > self.max_length
91-
else torch.LongTensor(instance["labels"])
91+
(
92+
torch.LongTensor(instance["labels"][: self.max_length])
93+
if len(instance["labels"]) > self.max_length
94+
else torch.LongTensor(instance["labels"])
95+
)
9296
for instance in instances
9397
]
9498

applications/Colossal-LLaMA/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,11 @@ def main() -> None:
253253
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
254254

255255
optimizer = HybridAdam(
256-
model_params=filter(lambda p: p.requires_grad, model.parameters())
257-
if args.freeze_non_embeds_params
258-
else model.parameters(),
256+
model_params=(
257+
filter(lambda p: p.requires_grad, model.parameters())
258+
if args.freeze_non_embeds_params
259+
else model.parameters()
260+
),
259261
lr=args.lr,
260262
betas=(0.9, 0.95),
261263
weight_decay=args.weight_decay,

colossalai/booster/booster.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import colossalai.interface.pretrained as pretrained_utils
2020
from colossalai.checkpoint_io import GeneralCheckpointIO
2121
from colossalai.interface import ModelWrapper, OptimizerWrapper
22+
from colossalai.quantization import BnbQuantizationConfig
2223

2324
from .accelerator import Accelerator
2425
from .mixed_precision import MixedPrecision, mixed_precision_factory
@@ -230,7 +231,12 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
230231
return self.plugin.no_sync(model, optimizer)
231232

232233
def enable_lora(
233-
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
234+
self,
235+
model: nn.Module,
236+
pretrained_dir: Optional[str] = None,
237+
lora_config: "peft.LoraConfig" = None,
238+
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
239+
quantize=False,
234240
) -> nn.Module:
235241
"""
236242
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
@@ -259,7 +265,20 @@ def enable_lora(
259265
assert (
260266
pretrained_dir is not None
261267
), "Please provide pretrained directory path if not passing in lora configuration."
262-
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
268+
if quantize is True:
269+
if bnb_quantization_config is not None:
270+
warnings.warn(
271+
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
272+
)
273+
else:
274+
bnb_quantization_config = BnbQuantizationConfig(
275+
load_in_4bit=True,
276+
bnb_4bit_compute_dtype=torch.bfloat16,
277+
bnb_4bit_use_double_quant=True,
278+
bnb_4bit_quant_type="nf4",
279+
)
280+
281+
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
263282

264283
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
265284
"""Load model from checkpoint.

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
sharded_optimizer_loading_epilogue,
2929
)
3030
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
31+
from colossalai.quantization import BnbQuantizationConfig, quantize_model
3132
from colossalai.zero import LowLevelZeroOptimizer
3233

3334
from .dp_plugin_base import DPPluginBase
@@ -338,14 +339,21 @@ def support_lora(self) -> bool:
338339
return True
339340

340341
def enable_lora(
341-
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
342+
self,
343+
model: nn.Module,
344+
pretrained_dir: Optional[str] = None,
345+
lora_config: Optional[Dict] = None,
346+
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
342347
) -> nn.Module:
343348
from peft import PeftModel, get_peft_model
344349

345350
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
346351
self.lora_enabled = True
347352
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
348353

354+
if bnb_quantization_config is not None:
355+
model = quantize_model(model, bnb_quantization_config)
356+
349357
if pretrained_dir is None:
350358
peft_model = get_peft_model(model, lora_config)
351359
else:

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
1010
from colossalai.cluster import DistCoordinator
1111
from colossalai.interface import ModelWrapper, OptimizerWrapper
12+
from colossalai.quantization import BnbQuantizationConfig, quantize_model
1213

1314
from .dp_plugin_base import DPPluginBase
1415

@@ -237,10 +238,17 @@ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[Non
237238
return model.module.no_sync()
238239

239240
def enable_lora(
240-
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
241+
self,
242+
model: nn.Module,
243+
pretrained_dir: Optional[str] = None,
244+
lora_config: Optional[Dict] = None,
245+
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
241246
) -> nn.Module:
242247
from peft import PeftModel, get_peft_model
243248

249+
if bnb_quantization_config is not None:
250+
model = quantize_model(model, bnb_quantization_config)
251+
244252
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
245253
if pretrained_dir is None:
246254
return get_peft_model(model, lora_config)

colossalai/inference/README.md

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc
165165
##### Llama
166166

167167
| batch_size | 8 | 16 | 32 |
168-
| :---------------------: | :----: | :----: | :----: |
168+
|:-----------------------:|:------:|:------:|:------:|
169169
| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 |
170170
| colossal-inference | 326.4 | 582.72 | 816.64 |
171171

@@ -174,7 +174,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc
174174
#### Bloom
175175

176176
| batch_size | 8 | 16 | 32 |
177-
| :---------------------: | :----: | :----: | :----: |
177+
|:-----------------------:|:------:|:------:|:------:|
178178
| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 |
179179
| colossal-inference | 323.28 | 538.52 | 611.64 |
180180

@@ -187,40 +187,40 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t
187187

188188
#### A10 7b, fp16
189189

190-
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
191-
| :-------------------------: | :---: | :---:| :---: | :---: | :---: | :---: |
192-
| Pipeline Inference | 40.35 | 77.10| 139.03| 232.70| 257.81| OOM |
193-
| Hugging Face | 41.43 | 65.30| 91.93 | 114.62| OOM | OOM |
190+
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) |
191+
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:|
192+
| Pipeline Inference | 40.35 | 77.10 | 139.03 | 232.70 | 257.81 | OOM |
193+
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM |
194194

195195

196196
![ppllama7b](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama7b.png)
197197

198198
#### A10 13b, fp16
199199

200-
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
201-
| :---: | :---: | :---: | :---: | :---: |
202-
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
203-
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
200+
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) |
201+
|:----------------------------:|:-----:|:-----:|:-----:|:-----:|
202+
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
203+
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
204204

205205
![ppllama13](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama13b.png)
206206

207207

208208
#### A800 7b, fp16
209209

210-
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
211-
| :---: | :---: | :---: | :---: | :---: | :---: |
212-
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
213-
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
210+
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
211+
|:----------------------------:|:-----:|:------:|:------:|:------:|:------:|
212+
| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
213+
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
214214

215215
![ppllama7b_a800](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a800-llama7b.png)
216216

217217
### Quantization LLama
218218

219-
| batch_size | 8 | 16 | 32 |
220-
| :---------------------: | :----: | :----: | :----: |
221-
| auto-gptq | 199.20 | 232.56 | 253.26 |
222-
| smooth-quant | 142.28 | 222.96 | 300.59 |
223-
| colossal-gptq | 231.98 | 388.87 | 573.03 |
219+
| batch_size | 8 | 16 | 32 |
220+
|:-------------:|:------:|:------:|:------:|
221+
| auto-gptq | 199.20 | 232.56 | 253.26 |
222+
| smooth-quant | 142.28 | 222.96 | 300.59 |
223+
| colossal-gptq | 231.98 | 388.87 | 573.03 |
224224

225225
![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-quant.png)
226226

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .bnb import quantize_model
2+
from .bnb_config import BnbQuantizationConfig
3+
4+
__all__ = [
5+
"BnbQuantizationConfig",
6+
"quantize_model",
7+
]

0 commit comments

Comments
 (0)