diff --git a/paddlenlp/peft/lora/lora_quantization_layers.py b/paddlenlp/peft/lora/lora_quantization_layers.py index 35715b272bdc..ea9fd60e60c6 100644 --- a/paddlenlp/peft/lora/lora_quantization_layers.py +++ b/paddlenlp/peft/lora/lora_quantization_layers.py @@ -23,7 +23,7 @@ mark_as_sequence_parallel_parameter, ) -from ...quantization.quantization_linear import quant_weight_linear +from ...quantization.quantization_linear import quant_weight_linear, get_act_scale_group from ...utils.log import logger from .utils import rng_ctx @@ -44,7 +44,17 @@ def __init__(self, layer, lora_config): else: self.quant_scale = layer.quant_scale self.bias = layer.bias - + self.state = 0 + if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]: + self.act_scale = self.create_parameter( + shape=[1], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=0.0), + ) + self.act_scale.is_distributed = False + self.act_scale.stop_gradient = True + self.group = get_act_scale_group(is_row=True) # LoRA related parameters self.lora_config = lora_config if not isinstance(self.lora_config.r, int) or self.lora_config.r <= 0: @@ -77,7 +87,10 @@ def forward(self, x, add_bias=True): if (self.weight_quantize_algo in ["fp4", "nf4"] and self.quantization_config.qlora_weight_double_quant) else None, bias=self.bias if add_bias else None, + act_state=(self.state, self.training, self.act_scale, self.group) ) + if self.training: + self.state += 1 return output def merge(self):