@@ -84,12 +84,32 @@ def __init__(
8484 r : int = 0 ,
8585 lora_alpha : int = 1 ,
8686 lora_dropout : float = 0.0 ,
87+ use_qa_lora : bool = False ,
8788 ** kwargs ,
8889 ):
89- super (ActivationMixin ,
90- self ).__init__ (adapter_name , quant_linear_module , r ,
91- lora_alpha , lora_dropout , ** kwargs )
90+ from peft .tuners .lora import LoraLayer
91+ torch .nn .Module .__init__ (self )
92+ self .group_size = kwargs .get ('group_size' , None )
93+ self .use_qa_lora = use_qa_lora
94+ if self .use_qa_lora :
95+ assert self .group_size is not None , 'To use qa_lora you need to pass in the `group_size` param.'
96+ LoraLayer .__init__ (
97+ self ,
98+ in_features = quant_linear_module .infeatures
99+ if not self .use_qa_lora else quant_linear_module .infeatures
100+ // self .group_size ,
101+ out_features = quant_linear_module .outfeatures )
102+ self .quant_linear_module = quant_linear_module
103+ self .weight = quant_linear_module .qweight
104+ init_lora_weights = kwargs .pop ('init_lora_weights' , True )
105+ self .update_layer (adapter_name , r , lora_alpha , lora_dropout ,
106+ init_lora_weights )
107+ self .active_adapter = adapter_name
92108 super (QuantLinear , self ).__init__ ()
109+ if self .use_qa_lora :
110+ self .qa_pool = torch .nn .AvgPool1d (
111+ self .group_size
112+ ) # using pooling layer to conduct sum operation
93113
94114 def call_quant_linear_module (* args , ** kwargs ):
95115 return quant_linear_module .forward_origin (* args , ** kwargs )
@@ -108,12 +128,16 @@ def forward(self, x: torch.Tensor):
108128 if not torch .is_autocast_enabled ():
109129 expected_dtype = result .dtype
110130 x = x .to (self .lora_A [self .active_adapter ].weight .dtype )
131+ if self .use_qa_lora :
132+ x = self .qa_pool (x ) * self .group_size
111133 output = (
112134 self .lora_B [self .active_adapter ](
113135 self .lora_A [self .active_adapter ](self .lora_dropout [
114136 self .active_adapter ](x ))).to (expected_dtype )
115137 * self .scaling [self .active_adapter ])
116138 else :
139+ if self .use_qa_lora :
140+ x = self .qa_pool (x ) * self .group_size
117141 output = (
118142 self .lora_B [self .active_adapter ](
119143 self .lora_A [self .active_adapter ](
@@ -179,6 +203,13 @@ class LoRAConfig(SwiftConfig):
179203 'help' : 'Bias type. Values ca be "none", "all" or "lora_only"'
180204 })
181205
206+ use_qa_lora : bool = field (
207+ default = False ,
208+ metadata = {
209+ 'help' :
210+ 'Use [qa-lora](https://github.com/yuhuixu1993/qa-lora) or not'
211+ })
212+
182213 def __post_init__ (self ):
183214 from .mapping import SwiftTuners
184215 self .swift_type = SwiftTuners .LORA
@@ -199,7 +230,8 @@ def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str):
199230 merge_weights = config .merge_weights ,
200231 use_merged_linear = config .use_merged_linear ,
201232 enable_lora = config .enable_lora ,
202- fan_in_fan_out = config .fan_in_fan_out )
233+ fan_in_fan_out = config .fan_in_fan_out ,
234+ use_qa_lora = config .use_qa_lora )
203235
204236 def state_dict_callback (state_dict , adapter_name ):
205237 return lora_state_dict (state_dict , adapter_name , config .bias )
@@ -237,8 +269,9 @@ def _dynamic_patch_lora(model: torch.nn.Module,
237269 modules = {}
238270 module_keys = [key for key , _ in model .named_modules ()]
239271 assert isinstance (target_modules , (str , list ))
240- AutoGPTQQuantLinear = get_auto_gptq_quant_linear (
241- get_quantization_config (model , method = 'gptq' ))
272+ auto_gptq_config = get_quantization_config (model , method = 'gptq' )
273+ AutoGPTQQuantLinear = get_auto_gptq_quant_linear (auto_gptq_config )
274+ use_qa_lora = kwargs .pop ('use_qa_lora' , False )
242275
243276 for module_key in module_keys :
244277 if isinstance (target_modules , str ):
@@ -292,7 +325,13 @@ def _dynamic_patch_lora(model: torch.nn.Module,
292325 ** four_bit_kwargs )
293326 elif AutoGPTQQuantLinear is not None and isinstance (
294327 sub_module , AutoGPTQQuantLinear ):
295- lora_module = QuantLinear ('default' , sub_module , ** kwargs )
328+ lora_module = QuantLinear (
329+ 'default' ,
330+ sub_module ,
331+ use_qa_lora = use_qa_lora ,
332+ group_size = getattr (auto_gptq_config , 'group_size' ,
333+ None ),
334+ ** kwargs )
296335 sub_module .weight = sub_module .qweight
297336 elif isinstance (sub_module , torch .nn .Linear ):
298337 if use_merged_linear :
0 commit comments