1515# See the License for the specific language governing permissions and
1616# limitations under the License.
1717
18+ import os
1819import torch
1920from ..utils import DTYPE_BITS_MAPPING
2021from functools import reduce
2324from peft .tuners .lora import LoraLayer , LoraModel
2425from peft .utils .other import transpose
2526from intel_extension_for_transformers .transformers .llm .quantization .autograd import (
26- matmul_kbit ,
27- )
27+ matmul_kbit , )
2828import intel_extension_for_transformers .qbits as qbits # pylint: disable=E0611, E0401
2929
3030
3131class DropoutQBits_ (torch .autograd .Function ):
32+
3233 @staticmethod
3334 def forward (ctx , input , probability ):
3435 mask = qbits .dropout_fwd (input , probability )
3536 if any (ctx .needs_input_grad [:1 ]):
36- ctx .tensors = (mask ,)
37+ ctx .tensors = (mask , )
3738 else :
38- ctx .tensors = (None ,)
39+ ctx .tensors = (None , )
3940 return input
4041
4142 @staticmethod
@@ -51,6 +52,7 @@ def backward(ctx, grad_output):
5152
5253
5354class DropoutQBits (torch .nn .Module ):
55+
5456 def __init__ (self , p = 0.0 ):
5557 super ().__init__ ()
5658 self .p = p
@@ -63,6 +65,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
6365
6466
6567class ParamsQBits (torch .nn .Parameter ):
68+
6669 def __new__ (
6770 cls ,
6871 data = None ,
@@ -87,6 +90,7 @@ def __new__(
8790
8891
8992class QuantizedLinearQBits (torch .nn .Linear ):
93+
9094 def __init__ (
9195 self ,
9296 input_features ,
@@ -156,6 +160,9 @@ def forward(self, x: torch.Tensor):
156160 shape [- 1 ] = self .out_features
157161 out = out .view (shape )
158162
163+ if os .environ .get ("backend" , None ) == "use_vllm" :
164+ return out , None
165+
159166 return out
160167
161168 def set_fp_weights_bias (self , weight_data , bias = None ):
@@ -264,33 +271,24 @@ def quant_weight_w_scale(self, weight, scale, zp, group_size=-1):
264271 if zp is not None :
265272 zp = zp .to (device )
266273 if group_size == - 1 :
267- return (
268- weight .div_ (scale ).round_ ()
269- if zp is None
270- else weight .div_ (scale ).add_ (zp ).round_ ()
271- )
274+ return (weight .div_ (scale ).round_ () if zp is None else weight .div_ (scale ).add_ (zp ).round_ ())
272275 int_weight = torch .zeros (weight .shape ).to (device )
273276 leng = weight .shape [1 ] // group_size
274277 tail_flag = False if weight .shape [1 ] % group_size == 0 else True
275278 for i in range (leng ):
276- int_weight_tmp = weight [:, i * group_size : (i + 1 ) * group_size ].div_ (
277- scale [:, i ].unsqueeze (1 )
278- )
279+ int_weight_tmp = weight [:, i * group_size :(i + 1 ) * group_size ].div_ (scale [:, i ].unsqueeze (1 ))
279280 if zp is not None :
280281 int_weight_tmp .add_ (zp [:, i ].unsqueeze (1 ))
281- int_weight [:, i * group_size : (i + 1 ) * group_size ].copy_ (
282- int_weight_tmp .round_ ()
283- )
282+ int_weight [:, i * group_size :(i + 1 ) * group_size ].copy_ (int_weight_tmp .round_ ())
284283 if tail_flag :
285- int_weight_tmp = weight [:, leng * group_size :].div_ (
286- scale [:, - 1 ].unsqueeze (1 )
287- )
284+ int_weight_tmp = weight [:, leng * group_size :].div_ (scale [:, - 1 ].unsqueeze (1 ))
288285 if zp is not None :
289286 int_weight_tmp .add_ (zp [:, - 1 ].unsqueeze (1 ))
290- int_weight [:, leng * group_size :].copy_ (int_weight_tmp .round_ ())
287+ int_weight [:, leng * group_size :].copy_ (int_weight_tmp .round_ ())
291288 return int_weight
292289
293290 def recover_qparms (self ):
291+
294292 def recover_idx (ret_idx , k , blocksize ):
295293 g_idx = torch .zeros (k , dtype = int )
296294 value_range = (k + blocksize - 1 ) // blocksize
@@ -328,18 +326,12 @@ def recover_int_weight(g_idx, int_weight):
328326 else :
329327 g_idx = None
330328 weight_dtype_ascii = qbits .acquire_packed_weight_info (self .weight , 6 )
331- weight_dtype = "" .join (
332- chr (ascii_code ) for ascii_code in weight_dtype_ascii .tolist ()
333- )
329+ weight_dtype = "" .join (chr (ascii_code ) for ascii_code in weight_dtype_ascii .tolist ())
334330 bits = 4 if weight_dtype in ["nf4" , "int4_clip" , "fp4" , "int4_fullrange" ] else 8
335331 compute_dtype_ascii = qbits .acquire_packed_weight_info (self .weight , 7 )
336- compute_dtype = "" .join (
337- chr (ascii_code ) for ascii_code in compute_dtype_ascii .tolist ()
338- )
332+ compute_dtype = "" .join (chr (ascii_code ) for ascii_code in compute_dtype_ascii .tolist ())
339333 scales_dtype_ascii = qbits .acquire_packed_weight_info (self .weight , 8 )
340- scales_dtype = "" .join (
341- chr (ascii_code ) for ascii_code in scales_dtype_ascii .tolist ()
342- )
334+ scales_dtype = "" .join (chr (ascii_code ) for ascii_code in scales_dtype_ascii .tolist ())
343335 if scales_dtype is None :
344336 assert False , "scales dtype only support fp32."
345337 scales = qbits .acquire_packed_weight_info (self .weight , 9 )
@@ -356,9 +348,7 @@ def recover_int_weight(g_idx, int_weight):
356348
357349 revert_wei = torch .zeros (in_features , out_features , dtype = torch .float )
358350
359- qbits .dequantize_packed_weight (
360- self .weight , revert_wei , False , compute_dtype , weight_dtype , scales_dtype
361- )
351+ qbits .dequantize_packed_weight (self .weight , revert_wei , False , compute_dtype , weight_dtype , scales_dtype )
362352
363353 int_weight = self .quant_weight_w_scale (
364354 revert_wei .t (),
@@ -426,9 +416,7 @@ def __init__(
426416 except :
427417 qbits_customop_available = False
428418 if lora_dropout > 0 and qbits_customop_available :
429- self .lora_dropout = torch .nn .ModuleDict (
430- {adapter_name : DropoutQBits (p = lora_dropout )}
431- )
419+ self .lora_dropout = torch .nn .ModuleDict ({adapter_name : DropoutQBits (p = lora_dropout )})
432420
433421 def merge (self , safe_merge : bool = False ) -> None :
434422 """Merge the active adapter weights into the base weights.
@@ -440,10 +428,8 @@ def merge(self, safe_merge: bool = False) -> None:
440428 NaNs. Defaults to `False`.
441429 """
442430 if self .merged :
443- print (
444- f"Already following adapters were merged { ',' .join (self .merged_adapters )} . "
445- f"You are now additionally merging { ',' .join (self .active_adapters )} ."
446- )
431+ print (f"Already following adapters were merged { ',' .join (self .merged_adapters )} . "
432+ f"You are now additionally merging { ',' .join (self .active_adapters )} ." )
447433 w_dequant = torch .zeros (
448434 self .out_features ,
449435 self .in_features ,
@@ -468,8 +454,7 @@ def merge(self, safe_merge: bool = False) -> None:
468454
469455 if not torch .isfinite (orig_weights ).all ():
470456 raise ValueError (
471- f"NaNs detected in the merged weights. The adapter { active_adapter } seems to be broken"
472- )
457+ f"NaNs detected in the merged weights. The adapter { active_adapter } seems to be broken" )
473458
474459 w_data = orig_weights
475460 else :
@@ -541,13 +526,10 @@ def unmerge(self) -> None:
541526 )
542527
543528 def get_delta_weight (self , adapter ) -> torch .Tensor :
544- return (
545- transpose (
546- self .lora_B [adapter ].weight @ self .lora_A [adapter ].weight ,
547- False ,
548- )
549- * self .scaling [adapter ]
550- )
529+ return (transpose (
530+ self .lora_B [adapter ].weight @ self .lora_A [adapter ].weight ,
531+ False ,
532+ ) * self .scaling [adapter ])
551533
552534 def forward (self , x : torch .Tensor ) -> torch .Tensor :
553535 if self .disable_adapters :
@@ -602,24 +584,18 @@ def _create_new_module(self, lora_config, adapter_name, target, **kwargs):
602584 bias = kwargs .pop ("bias" , False )
603585 in_features , out_features = target .in_features , target .out_features
604586 if kwargs ["fan_in_fan_out" ]:
605- print (
606- "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
607- "Setting fan_in_fan_out to False."
608- )
587+ print ("fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
588+ "Setting fan_in_fan_out to False." )
609589 kwargs ["fan_in_fan_out" ] = lora_config .fan_in_fan_out = False
610590 kwargs ["compute_dtype" ] = target .compute_dtype
611591 kwargs ["compress_statistics" ] = target .compress_statistics
612592 kwargs ["weight_dtype" ] = target .weight_dtype
613593 kwargs ["scale_dtype" ] = target .scale_dtype
614594 kwargs ["blocksize" ] = target .blocksize
615595 kwargs ["scheme" ] = target .scheme
616- new_module = QuantizedLoraLinearQBits (
617- adapter_name , in_features , out_features , bias = bias , ** kwargs
618- )
596+ new_module = QuantizedLoraLinearQBits (adapter_name , in_features , out_features , bias = bias , ** kwargs )
619597 else :
620- new_module = QBitsLoraModel ._create_new_module_ (
621- lora_config , adapter_name , target , ** kwargs
622- )
598+ new_module = QBitsLoraModel ._create_new_module_ (lora_config , adapter_name , target , ** kwargs )
623599 return new_module
624600
625601
0 commit comments