1212
1313import bitsandbytes as bnb
1414from bitsandbytes .cextension import ROCM_WARP_SIZE_64
15- from bitsandbytes .functional import QuantState , convert_weight_packed_for_cpu , has_avx512bf16
15+ from bitsandbytes .functional import QuantState , _convert_weight_packed_for_cpu , has_avx512bf16
1616from bitsandbytes .optim import GlobalOptimManager
1717from bitsandbytes .utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING , OutlierTracer
1818
@@ -479,7 +479,7 @@ def __init__(
479479 self .compute_type_is_set = compute_dtype is not None
480480 self .quant_state = None
481481 self .quant_storage = quant_storage
482- self .enable_optimized_cpu = False
482+ self .packing_format_for_cpu = False
483483
484484 def set_compute_type (self , x ):
485485 if x .dtype in [torch .float32 , torch .bfloat16 ]:
@@ -513,19 +513,19 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
513513 destination [prefix + "weight." + k ] = v if keep_vars else v .detach ()
514514
515515 def forward (self , x : torch .Tensor ):
516- quant_state = self .weight .quant_state
517516 fix_4bit_weight_quant_state_from_module (self )
517+ quant_state = self .weight .quant_state
518518
519519 if (
520- not self .enable_optimized_cpu
520+ not self .packing_format_for_cpu
521521 and x .device .type == "cpu"
522522 and has_avx512bf16 ()
523523 and not self .training
524524 and x .requires_grad == False
525525 ):
526- self .weight .data , quant_state = convert_weight_packed_for_cpu (self .weight .data , quant_state )
527- self .enable_optimized_cpu = True
528- quant_state .enable_optimized_cpu = True
526+ self .weight .data , quant_state = _convert_weight_packed_for_cpu (self .weight .data , quant_state )
527+ self .packing_format_for_cpu = True
528+ quant_state .packing_format_for_cpu = True
529529
530530 # weights are cast automatically as Int8Params, but the bias has to be cast manually
531531 if self .bias is not None and self .bias .dtype != x .dtype :
@@ -540,7 +540,7 @@ def forward(self, x: torch.Tensor):
540540 x = x .to (self .compute_dtype )
541541
542542 bias = None if self .bias is None else self .bias .to (self .compute_dtype )
543- weight = self .weight if getattr (quant_state , "enable_optimized_cpu " , False ) else self .weight .t ()
543+ weight = self .weight if getattr (quant_state , "packing_format_for_cpu " , False ) else self .weight .t ()
544544
545545 return bnb .matmul_4bit (x , weight , bias = bias , quant_state = quant_state ).to (inp_dtype )
546546
0 commit comments