Skip to content

Commit de5fb9c

Browse files
committed
fix comments
Signed-off-by: jiqing-feng <[email protected]>
1 parent 302a5fe commit de5fb9c

File tree

4 files changed

+13
-22
lines changed

4 files changed

+13
-22
lines changed

CMakeLists.txt

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -280,24 +280,15 @@ if (BUILD_CPU)
280280
include(CheckCXXCompilerFlag)
281281
check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG)
282282
check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG)
283-
check_cxx_compiler_flag(-mavx512dq HAS_AVX512DQ)
284-
check_cxx_compiler_flag(-mavx512bw HAS_AVX512BW)
285-
check_cxx_compiler_flag(-mavx512vl HAS_AVX512VL)
286283
if (HAS_AVX512F_FLAG)
287284
target_compile_options(bitsandbytes PRIVATE -mavx512f)
288-
endif()
289-
if (HAS_AVX512BF16_FLAG)
290-
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
291-
endif()
292-
if(HAS_AVX512DQ)
293285
target_compile_options(bitsandbytes PRIVATE -mavx512dq)
294-
endif()
295-
if(HAS_AVX512BW)
296286
target_compile_options(bitsandbytes PRIVATE -mavx512bw)
297-
endif()
298-
if(HAS_AVX512VL)
299287
target_compile_options(bitsandbytes PRIVATE -mavx512vl)
300288
endif()
289+
if (HAS_AVX512BF16_FLAG)
290+
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
291+
endif()
301292
target_compile_options(
302293
bitsandbytes PRIVATE
303294
-mprefer-vector-width=256

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def matmul_4bit(
378378
if A.device.type == "cpu":
379379
quant_state.dtype = A.dtype
380380

381-
if getattr(quant_state, "enable_optimized_cpu", False):
381+
if getattr(quant_state, "packing_format_for_cpu", False):
382382
out = F.gemv_4bit(A, B, out, state=quant_state)
383383
if bias is not None:
384384
out += bias

bitsandbytes/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2103,7 +2103,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
21032103
return out
21042104

21052105

2106-
def convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState, block_n: int = 32):
2106+
def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState, block_n: int = 32):
21072107
"""
21082108
qweight: (K * N / 2) uint8
21092109
return: packed_weight

bitsandbytes/nn/modules.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import bitsandbytes as bnb
1414
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
15-
from bitsandbytes.functional import QuantState, convert_weight_packed_for_cpu, has_avx512bf16
15+
from bitsandbytes.functional import QuantState, _convert_weight_packed_for_cpu, has_avx512bf16
1616
from bitsandbytes.optim import GlobalOptimManager
1717
from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
1818

@@ -479,7 +479,7 @@ def __init__(
479479
self.compute_type_is_set = compute_dtype is not None
480480
self.quant_state = None
481481
self.quant_storage = quant_storage
482-
self.enable_optimized_cpu = False
482+
self.packing_format_for_cpu = False
483483

484484
def set_compute_type(self, x):
485485
if x.dtype in [torch.float32, torch.bfloat16]:
@@ -513,19 +513,19 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
513513
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
514514

515515
def forward(self, x: torch.Tensor):
516-
quant_state = self.weight.quant_state
517516
fix_4bit_weight_quant_state_from_module(self)
517+
quant_state = self.weight.quant_state
518518

519519
if (
520-
not self.enable_optimized_cpu
520+
not self.packing_format_for_cpu
521521
and x.device.type == "cpu"
522522
and has_avx512bf16()
523523
and not self.training
524524
and x.requires_grad == False
525525
):
526-
self.weight.data, quant_state = convert_weight_packed_for_cpu(self.weight.data, quant_state)
527-
self.enable_optimized_cpu = True
528-
quant_state.enable_optimized_cpu = True
526+
self.weight.data, quant_state = _convert_weight_packed_for_cpu(self.weight.data, quant_state)
527+
self.packing_format_for_cpu = True
528+
quant_state.packing_format_for_cpu = True
529529

530530
# weights are cast automatically as Int8Params, but the bias has to be cast manually
531531
if self.bias is not None and self.bias.dtype != x.dtype:
@@ -540,7 +540,7 @@ def forward(self, x: torch.Tensor):
540540
x = x.to(self.compute_dtype)
541541

542542
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
543-
weight = self.weight if getattr(quant_state, "enable_optimized_cpu", False) else self.weight.t()
543+
weight = self.weight if getattr(quant_state, "packing_format_for_cpu", False) else self.weight.t()
544544

545545
return bnb.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype)
546546

0 commit comments

Comments
 (0)