Skip to content

Commit 8fe6325

Browse files
authored
Fix XPU 4bit (#1567)
* fix 4bit XPU dequant 4bit Signed-off-by: jiqing-feng <[email protected]> * fix default value Signed-off-by: jiqing-feng <[email protected]> * fix ipex linear set Signed-off-by: jiqing-feng <[email protected]> * fix ipex linear set to false when calling state dict Signed-off-by: jiqing-feng <[email protected]> * fix Int8Param device patch Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent 249a3cd commit 8fe6325

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

bitsandbytes/functional.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,7 @@ def dequantize_fp4(
10671067
quant_state: Optional[QuantState] = None,
10681068
absmax: Optional[torch.Tensor] = None,
10691069
out: Optional[torch.Tensor] = None,
1070-
blocksize: int = 64,
1070+
blocksize: Optional[int] = None,
10711071
) -> torch.Tensor:
10721072
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
10731073

@@ -1077,7 +1077,7 @@ def dequantize_nf4(
10771077
quant_state: Optional[QuantState] = None,
10781078
absmax: Optional[torch.Tensor] = None,
10791079
out: Optional[torch.Tensor] = None,
1080-
blocksize: int = 64,
1080+
blocksize: Optional[int] = None,
10811081
) -> torch.Tensor:
10821082
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
10831083

@@ -1087,8 +1087,8 @@ def dequantize_4bit(
10871087
quant_state: Optional[QuantState] = None,
10881088
absmax: Optional[torch.Tensor] = None,
10891089
out: Optional[torch.Tensor] = None,
1090-
blocksize: int = 64,
1091-
quant_type="fp4",
1090+
blocksize: Optional[int] = None,
1091+
quant_type: Optional[str] = "fp4",
10921092
) -> torch.Tensor:
10931093
"""Dequantizes a packed 4-bit quantized tensor.
10941094
@@ -1106,9 +1106,9 @@ def dequantize_4bit(
11061106
Required if `quant_state` is not provided and ignored otherwise.
11071107
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
11081108
blocksize (`int`, *optional*):
1109-
The size of the blocks. Defaults to 64.
1109+
The size of the blocks. Defaults to 64 if not HIP_ENVIRONMENT else 128.
11101110
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
1111-
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
1111+
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to "fp4".
11121112
11131113
Raises:
11141114
ValueError: Raised when the input data type or blocksize is not supported.
@@ -1118,9 +1118,9 @@ def dequantize_4bit(
11181118
"""
11191119
ensure_backend_is_available(A.device.type)
11201120
if quant_state is not None:
1121-
absmax = absmax or quant_state.absmax
1122-
quant_type = quant_type or quant_state.quant_type
1123-
blocksize = blocksize or quant_state.blocksize
1121+
absmax = quant_state.absmax
1122+
quant_type = quant_state.quant_type
1123+
blocksize = quant_state.blocksize
11241124
if blocksize is None:
11251125
# Some AMD GPUs have warpsize 64
11261126
# Set default blocksize to 128 (~warpsize 64 in kernel) for HIP

bitsandbytes/nn/modules.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
487487
self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1))
488488

489489
self.weight.quant_state.ipex = False
490+
self.ipex_linear_is_set = False
490491

491492
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
492493

@@ -496,15 +497,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
496497

497498
def set_ipex_linear(self, x: torch.Tensor):
498499
if (
499-
(x.device.type in ("cpu", "xpu"))
500-
and not getattr(self.weight.quant_state, "ipex", False)
500+
not getattr(self.weight.quant_state, "ipex", False)
501501
and self.weight.data.dtype == torch.uint8
502502
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
503503
and self.weight.quant_state.quant_type == "nf4"
504-
and not self.training
505-
and x.requires_grad == False
506504
):
507-
enable_ipex_fusion(self, x)
505+
if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False):
506+
enable_ipex_fusion(self, x)
508507

509508
def forward(self, x: torch.Tensor):
510509
# Check if ipex fusion can be used
@@ -695,7 +694,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
695694
def to(self, *args, **kwargs):
696695
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
697696

698-
if device in ("cuda", "xpu", "cpu"):
697+
if device is not None and device.type in ("cuda", "xpu", "cpu"):
699698
if device.type == "cuda" and self.data.device.type == "cpu":
700699
return self.cuda(device)
701700
elif device.type == "cpu":

0 commit comments

Comments
 (0)