Skip to content

Commit 877b5a7

Browse files
authored
Add tensor parallelism on QLoRA (#2424)
1 parent 3a95dfb commit 877b5a7

File tree

10 files changed

+161
-13
lines changed

10 files changed

+161
-13
lines changed

.github/workflows/lint.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ jobs:
3636
-e no_proxy \
3737
-e python_version \
3838
-w /workspace ${docker_image}
39-
4039
- name: Download Code
4140
env:
4241
work_dir: ${{ github.workspace }}
@@ -64,18 +63,15 @@ jobs:
6463
echo "Not in a pull_request event. Skipping PR-specific operations."
6564
fi
6665
git log --pretty=oneline -10
67-
6866
if ! git show-ref --quiet refs/heads/develop; then \
6967
echo "local develop branch is missing, creating local develop branch that tracks remote develop branch"
7068
git fetch origin develop
7169
git branch develop --track origin/develop
7270
else
7371
echo "local develop branch exist, skipping"
7472
fi
75-
7673
unset http_proxy && unset https_proxy
7774
'
78-
7975
- name: Setup Environment
8076
run: |
8177
docker exec -t $container_name /bin/bash -c '

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,4 @@ repos:
6060
entry: python scripts/codestyle/check_dead_links.py
6161
language: python
6262
files: \.(md|markdown|rst)$
63-
pass_filenames: true
63+
pass_filenames: true

paddleformers/peft/lora/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from .lora_config import LoRAAutoConfig, LoRAConfig
1717
from .lora_layers import ColumnParallelLoRALinear, LoRALinear, RowParallelLoRALinear
1818
from .lora_model import LoRAModel
19+
from .lora_quantization_layers import QuantizationLoRABaseLinear

paddleformers/peft/lora/lora_quantization_layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def __init__(self, layer, lora_config):
4444
else:
4545
self.weight_scale = layer.weight_scale
4646
self.bias = layer.bias
47-
4847
# LoRA related parameters
4948
self.lora_config = lora_config
5049
if not isinstance(self.lora_config.r, int) or self.lora_config.r <= 0:

paddleformers/quantization/qlora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def qlora_weight_quantize(
3737
return quant_weight, (qweight_scale, double_weight_scale, quant_sacle_offset)
3838
qweight_scale_name = f"{linear_name}.qweight_scale" if linear_name else "qweight_scale"
3939
double_weight_scale_name = f"{linear_name}.double_weight_scale" if linear_name else "double_weight_scale"
40-
quant_sacle_offset_name = f"{linear_name}.quant_sacle_offset" if linear_name else "quant_sacle_offset"
40+
quant_sacle_offset_name = f"{linear_name}.weight_scale_offset" if linear_name else "weight_scale_offset"
4141
qlora_state_dict = {
4242
qweight_scale_name: qweight_scale,
4343
double_weight_scale_name: double_weight_scale,

paddleformers/quantization/quantization_linear.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def __init__(
357357
dtype="float32",
358358
is_bias=False,
359359
)
360+
self.weight_scale = None
360361
else:
361362
self.weight_scale = self.create_parameter(
362363
shape=[in_features * out_features // self.quantization_config.qlora_weight_blocksize],
@@ -496,6 +497,74 @@ def __init__(
496497
self.activation_scale.is_distributed = False
497498
self.activation_scale.stop_gradient = True
498499
self.group = get_activation_scale_group()
500+
elif self.weight_quantize_algo in ["nf4", "fp4"]:
501+
if qlora_weight_linear is None:
502+
raise ImportError(
503+
"Please run the following commands to install: qlora related package first\n"
504+
"1) git clone https://github.com/PaddlePaddle/PaddleSlim \n"
505+
"2) cd PaddleSlim && pip install -e .\n"
506+
"3) cd csrc && python ./setup_cuda.py install"
507+
)
508+
# print(self.output_size_per_partition, in_features)
509+
self.quant_weight = self.create_parameter(
510+
shape=[self.output_size_per_partition * in_features // 2, 1],
511+
attr=paddle.nn.initializer.Constant(value=0),
512+
dtype="uint8",
513+
is_bias=False,
514+
)
515+
self.quant_weight.is_distributed = True if self.is_mp else False
516+
if self.quant_weight.is_distributed:
517+
self.quant_weight.split_axis = 0
518+
if self.quantization_config.qlora_weight_double_quant:
519+
# quantized weight_scale
520+
self.qweight_scale = self.create_parameter(
521+
shape=[
522+
in_features * self.output_size_per_partition // self.quantization_config.qlora_weight_blocksize
523+
],
524+
dtype="uint8",
525+
is_bias=False,
526+
)
527+
# double weight_scale: weight_scale of quantized weight_scale
528+
self.qweight_scale.stop_gradient = True
529+
self.qweight_scale.is_distributed = True if self.is_mp else False
530+
if self.qweight_scale.is_distributed:
531+
self.qweight_scale.split_axis = 0
532+
self.double_weight_scale = self.create_parameter(
533+
shape=[
534+
in_features
535+
* self.output_size_per_partition
536+
// self.quantization_config.qlora_weight_blocksize
537+
// self.quantization_config.qlora_weight_double_quant_block_size
538+
],
539+
dtype="float32",
540+
is_bias=False,
541+
)
542+
self.double_weight_scale.stop_gradient = True
543+
self.double_weight_scale.is_distributed = True if self.is_mp else False
544+
if self.double_weight_scale.is_distributed:
545+
self.double_weight_scale.split_axis = 0
546+
self.weight_scale_offset = self.create_parameter(
547+
shape=[],
548+
dtype="float32",
549+
is_bias=False,
550+
)
551+
self.weight_scale_offset.stop_gradient = True
552+
self.weight_scale_offset.is_distributed = True if self.is_mp else False
553+
if self.weight_scale_offset.is_distributed:
554+
self.weight_scale_offset.split_axis = 0
555+
else:
556+
self.weight_scale = self.create_parameter(
557+
shape=[
558+
in_features * self.output_size_per_partition // self.quantization_config.qlora_weight_blocksize
559+
],
560+
dtype="float32",
561+
is_bias=False,
562+
)
563+
self.weight_scale.stop_gradient = True
564+
self.weight_scale.is_distributed = True if self.is_mp else False
565+
if self.weight_scale.is_distributed:
566+
self.weight_scale.split_axis = 0
567+
499568
else:
500569
raise NotImplementedError(f"Not yet support weight_quantize_algo: {self.weight_quantize_algo}")
501570
if bias_attr is False:
@@ -647,6 +716,74 @@ def __init__(
647716
self.activation_scale.is_distributed = False
648717
self.activation_scale.stop_gradient = True
649718
self.group = get_activation_scale_group(is_row=True)
719+
elif self.weight_quantize_algo in ["nf4", "fp4"]:
720+
if qlora_weight_linear is None:
721+
raise ImportError(
722+
"Please run the following commands to install: qlora related package first\n"
723+
"1) git clone https://github.com/PaddlePaddle/PaddleSlim \n"
724+
"2) cd PaddleSlim && pip install -e .\n"
725+
"3) cd csrc && python ./setup_cuda.py install"
726+
)
727+
self.quant_weight = self.create_parameter(
728+
shape=[out_features * self.input_size_per_partition // 2, 1],
729+
attr=paddle.nn.initializer.Constant(value=0),
730+
dtype="uint8",
731+
is_bias=False,
732+
)
733+
self.quant_weight.is_distributed = True if self.is_mp else False
734+
if self.quant_weight.is_distributed:
735+
self.quant_weight.split_axis = 1
736+
if self.quantization_config.qlora_weight_double_quant:
737+
# quantized weight_scale
738+
self.qweight_scale = self.create_parameter(
739+
shape=[
740+
self.input_size_per_partition * out_features // self.quantization_config.qlora_weight_blocksize
741+
],
742+
dtype="uint8",
743+
is_bias=False,
744+
)
745+
self.qweight_scale.stop_gradient = True
746+
self.qweight_scale.is_distributed = True if self.is_mp else False
747+
if self.qweight_scale.is_distributed:
748+
self.qweight_scale.split_axis = 0
749+
# double weight_scale: weight_scale of quantized weight_scale
750+
self.double_weight_scale = self.create_parameter(
751+
shape=[
752+
self.input_size_per_partition
753+
* out_features
754+
// self.quantization_config.qlora_weight_blocksize
755+
// self.quantization_config.qlora_weight_double_quant_block_size
756+
],
757+
dtype="float32",
758+
is_bias=False,
759+
)
760+
self.double_weight_scale.stop_gradient = True
761+
self.double_weight_scale.is_distributed = True if self.is_mp else False
762+
if self.double_weight_scale.is_distributed:
763+
self.double_weight_scale.split_axis = 1
764+
self.weight_scale_offset = self.create_parameter(
765+
shape=[],
766+
dtype="float32",
767+
is_bias=False,
768+
)
769+
self.weight_scale_offset.stop_gradient = True
770+
self.weight_scale_offset.is_distributed = True if self.is_mp else False
771+
if self.weight_scale_offset.is_distributed:
772+
self.weight_scale_offset.split_axis = 0
773+
else:
774+
self.weight_scale = self.create_parameter(
775+
shape=[
776+
self.input_size_per_partition * out_features // self.quantization_config.qlora_weight_blocksize
777+
],
778+
dtype="float32",
779+
is_bias=False,
780+
)
781+
782+
self.weight_scale.stop_gradient = True
783+
self.weight_scale.is_distributed = True if self.is_mp else False
784+
if self.weight_scale.is_distributed:
785+
self.weight_scale.split_axis = 0
786+
650787
else:
651788
raise NotImplementedError(f"Not yet support weight_quantize_algo: {self.weight_quantize_algo}")
652789

paddleformers/quantization/quantization_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def convert_to_qlora_state_dict(state_dict, name, quantization_config, dtype, we
197197
else:
198198
qweight_scale_name = name + ".qweight_scale"
199199
double_weight_scale_name = name + ".double_weight_scale"
200-
quant_sacle_offset_name = name + ".quant_sacle_offset"
200+
quant_sacle_offset_name = name + ".weight_scale_offset"
201201
quant_name_list += [qweight_scale_name, double_weight_scale_name, quant_sacle_offset_name]
202202

203203
if all(quant_name in state_dict for quant_name in quant_name_list):
@@ -252,7 +252,7 @@ def update_loaded_state_dict_keys(state_dict, quantization_linear_list, quantiza
252252
activation_scale_name = name + ".activation_scale"
253253
qweight_scale_name = name + ".qweight_scale"
254254
double_weight_scale_name = name + ".double_weight_scale"
255-
quant_sacle_offset_name = name + ".quant_sacle_offset"
255+
quant_sacle_offset_name = name + ".weight_scale_offset"
256256

257257
if quant_weight_name in state_dict and weight_scale_name in state_dict:
258258
continue

paddleformers/trainer/trainer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
init_dataloader_comm_group,
8585
)
8686
from ..peft import LoKrModel, LoRAModel, PrefixModelForCausalLM, ReFTModel, VeRAModel
87+
from ..peft.lora import QuantizationLoRABaseLinear
8788
from ..quantization.quantization_linear import (
8889
ColumnParallelQuantizationLinear,
8990
QuantizationLinear,
@@ -524,7 +525,12 @@ def _wrap_amp_model(self, args, model):
524525
models=model,
525526
level=self.args.fp16_opt_level,
526527
dtype=self.amp_dtype,
527-
excluded_layers=[QuantizationLinear, ColumnParallelQuantizationLinear, RowParallelQuantizationLinear]
528+
excluded_layers=[
529+
QuantizationLinear,
530+
ColumnParallelQuantizationLinear,
531+
RowParallelQuantizationLinear,
532+
QuantizationLoRABaseLinear,
533+
]
528534
+ self._decorate_exclude_layers(model),
529535
)
530536
# for pipeline mode and pure tensor parallel
@@ -2194,7 +2200,12 @@ def _wrap_model(self, model, training=True):
21942200
optimizers=self.optimizer,
21952201
level=self.args.fp16_opt_level,
21962202
dtype=self.amp_dtype,
2197-
excluded_layers=[QuantizationLinear, ColumnParallelQuantizationLinear, RowParallelQuantizationLinear]
2203+
excluded_layers=[
2204+
QuantizationLinear,
2205+
ColumnParallelQuantizationLinear,
2206+
RowParallelQuantizationLinear,
2207+
QuantizationLoRABaseLinear,
2208+
]
21982209
+ self._decorate_exclude_layers(model),
21992210
)
22002211

paddleformers/transformers/conversion_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,12 @@ def add_quant_mapping(name_action_mappings, quantization_config):
6767
post_quantize = quantization_config.weight_quantize_algo in [
6868
"weight_only_int4",
6969
"weight_only_int8",
70+
"nf4",
7071
]
7172
elif isinstance(quantization_config.weight_quantize_algo, dict):
7273
post_quantize = any(
73-
key in ["weight_only_int4", "weight_only_int8"] for key in quantization_config.weight_quantize_algo.keys()
74+
key in ["weight_only_int4", "weight_only_int8", "nf4"]
75+
for key in quantization_config.weight_quantize_algo.keys()
7476
)
7577
else:
7678
post_quantize = False

paddleformers/transformers/model_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2055,10 +2055,12 @@ def _load_pretrained_model(
20552055
post_quantize = config.quantization_config.weight_quantize_algo in [
20562056
"weight_only_int4",
20572057
"weight_only_int8",
2058+
"nf4",
2059+
"fp4",
20582060
]
20592061
elif isinstance(config.quantization_config.weight_quantize_algo, dict):
20602062
post_quantize = any(
2061-
key in ["weight_only_int4", "weight_only_int8"]
2063+
key in ["weight_only_int4", "weight_only_int8", "nf4", "fp4"]
20622064
for key in config.quantization_config.weight_quantize_algo.keys()
20632065
)
20642066
else:

0 commit comments

Comments
 (0)