Skip to content

Commit a4ebea1

Browse files
support qa_lora (#104)
1 parent 7846f3a commit a4ebea1

File tree

1 file changed

+46
-7
lines changed

1 file changed

+46
-7
lines changed

swift/tuners/lora.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,32 @@ def __init__(
8484
r: int = 0,
8585
lora_alpha: int = 1,
8686
lora_dropout: float = 0.0,
87+
use_qa_lora: bool = False,
8788
**kwargs,
8889
):
89-
super(ActivationMixin,
90-
self).__init__(adapter_name, quant_linear_module, r,
91-
lora_alpha, lora_dropout, **kwargs)
90+
from peft.tuners.lora import LoraLayer
91+
torch.nn.Module.__init__(self)
92+
self.group_size = kwargs.get('group_size', None)
93+
self.use_qa_lora = use_qa_lora
94+
if self.use_qa_lora:
95+
assert self.group_size is not None, 'To use qa_lora you need to pass in the `group_size` param.'
96+
LoraLayer.__init__(
97+
self,
98+
in_features=quant_linear_module.infeatures
99+
if not self.use_qa_lora else quant_linear_module.infeatures
100+
// self.group_size,
101+
out_features=quant_linear_module.outfeatures)
102+
self.quant_linear_module = quant_linear_module
103+
self.weight = quant_linear_module.qweight
104+
init_lora_weights = kwargs.pop('init_lora_weights', True)
105+
self.update_layer(adapter_name, r, lora_alpha, lora_dropout,
106+
init_lora_weights)
107+
self.active_adapter = adapter_name
92108
super(QuantLinear, self).__init__()
109+
if self.use_qa_lora:
110+
self.qa_pool = torch.nn.AvgPool1d(
111+
self.group_size
112+
) # using pooling layer to conduct sum operation
93113

94114
def call_quant_linear_module(*args, **kwargs):
95115
return quant_linear_module.forward_origin(*args, **kwargs)
@@ -108,12 +128,16 @@ def forward(self, x: torch.Tensor):
108128
if not torch.is_autocast_enabled():
109129
expected_dtype = result.dtype
110130
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
131+
if self.use_qa_lora:
132+
x = self.qa_pool(x) * self.group_size
111133
output = (
112134
self.lora_B[self.active_adapter](
113135
self.lora_A[self.active_adapter](self.lora_dropout[
114136
self.active_adapter](x))).to(expected_dtype)
115137
* self.scaling[self.active_adapter])
116138
else:
139+
if self.use_qa_lora:
140+
x = self.qa_pool(x) * self.group_size
117141
output = (
118142
self.lora_B[self.active_adapter](
119143
self.lora_A[self.active_adapter](
@@ -179,6 +203,13 @@ class LoRAConfig(SwiftConfig):
179203
'help': 'Bias type. Values ca be "none", "all" or "lora_only"'
180204
})
181205

206+
use_qa_lora: bool = field(
207+
default=False,
208+
metadata={
209+
'help':
210+
'Use [qa-lora](https://github.com/yuhuixu1993/qa-lora) or not'
211+
})
212+
182213
def __post_init__(self):
183214
from .mapping import SwiftTuners
184215
self.swift_type = SwiftTuners.LORA
@@ -199,7 +230,8 @@ def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str):
199230
merge_weights=config.merge_weights,
200231
use_merged_linear=config.use_merged_linear,
201232
enable_lora=config.enable_lora,
202-
fan_in_fan_out=config.fan_in_fan_out)
233+
fan_in_fan_out=config.fan_in_fan_out,
234+
use_qa_lora=config.use_qa_lora)
203235

204236
def state_dict_callback(state_dict, adapter_name):
205237
return lora_state_dict(state_dict, adapter_name, config.bias)
@@ -237,8 +269,9 @@ def _dynamic_patch_lora(model: torch.nn.Module,
237269
modules = {}
238270
module_keys = [key for key, _ in model.named_modules()]
239271
assert isinstance(target_modules, (str, list))
240-
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(
241-
get_quantization_config(model, method='gptq'))
272+
auto_gptq_config = get_quantization_config(model, method='gptq')
273+
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(auto_gptq_config)
274+
use_qa_lora = kwargs.pop('use_qa_lora', False)
242275

243276
for module_key in module_keys:
244277
if isinstance(target_modules, str):
@@ -292,7 +325,13 @@ def _dynamic_patch_lora(model: torch.nn.Module,
292325
**four_bit_kwargs)
293326
elif AutoGPTQQuantLinear is not None and isinstance(
294327
sub_module, AutoGPTQQuantLinear):
295-
lora_module = QuantLinear('default', sub_module, **kwargs)
328+
lora_module = QuantLinear(
329+
'default',
330+
sub_module,
331+
use_qa_lora=use_qa_lora,
332+
group_size=getattr(auto_gptq_config, 'group_size',
333+
None),
334+
**kwargs)
296335
sub_module.weight = sub_module.qweight
297336
elif isinstance(sub_module, torch.nn.Linear):
298337
if use_merged_linear:

0 commit comments

Comments
 (0)