Skip to content

Commit 2784653

Browse files
authored
fix nf4 memory issue by init op_context in forward (#1349)
* fix nf4 memory issue by init op_context in forward * disable repack in init * fix code style
1 parent 39097a6 commit 2784653

File tree

3 files changed

+47
-23
lines changed

3 files changed

+47
-23
lines changed

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -370,25 +370,6 @@ def quantize_4bit_impl(
370370
quant_type=quant_type,
371371
)
372372

373-
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4":
374-
# lowp_mode: lowest precision for computation
375-
lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16
376-
state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
377-
out.reshape([input_shape[0], input_shape[1] // 2]),
378-
ipex_cpu.quantization.WoqWeightDtype.NF4,
379-
input_shape, # weight shape
380-
absmax.view(input_shape[0], input_shape[1] // blocksize), # scales
381-
None, # zero_points
382-
None, # bias
383-
None, # g_idx
384-
None, # batch_size
385-
blocksize,
386-
int(lowp_mode),
387-
-1, # act_quant_mode. -1 means don't quant activation
388-
)
389-
state.absmax = torch.Tensor()
390-
return torch.empty([1, 0], dtype=torch.uint8), state
391-
392373
return out.unsqueeze(0), state
393374

394375

bitsandbytes/nn/modules.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
2020
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
2121
OutlierTracer,
22+
enable_ipex_fusion,
2223
)
2324

2425
T = TypeVar("T", bound="torch.nn.Module")
@@ -444,17 +445,35 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
444445
save weight and bias,
445446
then fill state_dict with components of quant_state
446447
"""
448+
if (
449+
getattr(self.weight, "quant_state", None) is not None
450+
and getattr(self.weight.quant_state, "op_context", None) is not None
451+
):
452+
context = self.weight.quant_state.op_context
453+
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])
454+
447455
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
448456

449457
if getattr(self.weight, "quant_state", None) is not None:
458+
if (
459+
self.weight.quant_state.absmax.shape.numel() == 0
460+
and getattr(self.weight.quant_state, "op_context", None) is not None
461+
):
462+
self.weight.quant_state.absmax = context.get_scales().reshape(-1)
463+
delattr(self.weight.quant_state, "op_context")
450464
for k, v in self.weight.quant_state.as_dict(packed=True).items():
451465
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
452-
if getattr(self.weight.quant_state, "op_context", None) is not None:
453-
context = self.weight.quant_state.op_context
454-
destination[prefix + "weight." + "absmax"] = context.get_scales().reshape(-1)
455-
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])
456466

457467
def forward(self, x: torch.Tensor):
468+
# Check if ipex fusion can be used
469+
if (
470+
x.device.type == "cpu"
471+
and not hasattr(self.weight.quant_state, "op_context")
472+
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
473+
and self.weight.quant_state.quant_type == "nf4"
474+
):
475+
enable_ipex_fusion(self.weight, self.weight.quant_state)
476+
458477
# weights are cast automatically as Int8Params, but the bias has to be cast manually
459478
if self.bias is not None and self.bias.dtype != x.dtype:
460479
self.bias.data = self.bias.data.to(x.dtype)

bitsandbytes/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,30 @@ def unpack_tensor_to_dict(tensor_data):
200200
return unpacked_dict
201201

202202

203+
def enable_ipex_fusion(weight, quant_state):
204+
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq
205+
206+
if _ipex_cpu_version_prereq(2, 3):
207+
import intel_extension_for_pytorch as ipex
208+
209+
lowp_mode = ipex.quantization.WoqLowpMode.BF16
210+
quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
211+
weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
212+
ipex.quantization.WoqWeightDtype.NF4,
213+
quant_state.shape, # weight shape
214+
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
215+
None, # zero_points
216+
None, # bias
217+
None, # g_idx
218+
None, # batch_size
219+
quant_state.blocksize,
220+
int(lowp_mode),
221+
-1, # act_quant_mode. -1 means don't quant activation
222+
)
223+
quant_state.absmax = torch.Tensor()
224+
weight.data = torch.empty([1, 0], dtype=torch.uint8)
225+
226+
203227
class QuantState:
204228
"""container for quantization state components to work with Params4bit and similar classes"""
205229

0 commit comments

Comments
 (0)