We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0c88d43 commit feb8ad2Copy full SHA for feb8ad2
bitsandbytes/nn/modules.py
@@ -488,6 +488,7 @@ def __init__(
488
self.compute_type_is_set = compute_dtype is not None
489
self.quant_state = None
490
self.quant_storage = quant_storage
491
+ self.support_avx512bf16_for_cpu = has_avx512bf16()
492
493
def set_compute_type(self, x):
494
if x.dtype in [torch.float32, torch.bfloat16]:
@@ -530,7 +531,7 @@ def forward(self, x: torch.Tensor):
530
531
if (
532
not getattr(quant_state, "packing_format_for_cpu", False)
533
and x.device.type == "cpu"
- and has_avx512bf16()
534
+ and self.support_avx512bf16_for_cpu
535
and not self.training
536
and x.requires_grad == False
537
):
0 commit comments