Skip to content

Commit 5c48b33

Browse files
XPU backend support 8bit optimizer (#1565)
* enable xpu 8bit optim * add deqaunt_blockwise * dequantize_blockwise * add bakcend synchronize * refine code * ipex dep * ipex dep too * ipex version check --------- Co-authored-by: jiqing-feng <[email protected]>
1 parent 54a2ad5 commit 5c48b33

File tree

8 files changed

+101
-6
lines changed

8 files changed

+101
-6
lines changed

bitsandbytes/backends/cpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class CPUBackend(Backend):
3535
mm_dequant_compute_dtype = torch.bfloat16
3636
mm_dequant_output_dtype = torch.bfloat16
3737

38+
def device_synchronize(self):
39+
pass
40+
3841
def int8_double_quant(
3942
self,
4043
A: torch.Tensor,

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _ipex_xpu_version_prereq(major, minor):
6060

6161
def _maybe_torch_compile(func):
6262
# torch.compile requires g++ and pytorch >= 2.0
63-
if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu:
63+
if gxx_available and _torch_version_prereq(2, 0) and ipex_cpu_only:
6464
options = {}
6565
# fx_graph_cache requires pytorch >= 2.2
6666
if _torch_version_prereq(2, 2):

bitsandbytes/backends/cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@
9797

9898

9999
class CUDABackend(Backend):
100+
def device_synchronize(self):
101+
torch.cuda.synchronize()
102+
100103
def transform(
101104
self,
102105
A: torch.Tensor,

bitsandbytes/backends/mps.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99

1010
class MPSBackend(Backend):
11+
def device_synchronize(self):
12+
torch.mps.synchronize()
13+
1114
def double_quant(
1215
self,
1316
A: torch.Tensor,

bitsandbytes/backends/npu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def assert_on_npu(tensors):
2929

3030

3131
class NPUBackend(Backend):
32+
def device_synchronize(self):
33+
torch.npu.synchronize()
34+
3235
def int8_double_quant(
3336
self,
3437
A: torch.Tensor,

bitsandbytes/backends/xpu.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,28 @@
1212
int8_linear_matmul_impl,
1313
int8_mm_dequant_impl,
1414
quantize_4bit_impl,
15+
_ipex_xpu_version_prereq
1516
)
17+
try:
18+
import intel_extension_for_pytorch as ipex
19+
ipex_xpu = ipex if ipex._C._has_xpu() else None
20+
except BaseException:
21+
ipex_xpu = None
1622

1723
Tensor = torch.Tensor
1824

1925

26+
str2optimizer8bit_blockwise = {}
27+
if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7):
28+
str2optimizer8bit_blockwise = {
29+
"adam": (
30+
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32,
31+
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16,
32+
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16,
33+
),
34+
}
35+
36+
2037
def assert_on_xpu(tensors):
2138
on_xpu = True
2239
for t in tensors:
@@ -35,6 +52,9 @@ class XPUBackend(Backend):
3552
mm_dequant_compute_dtype = torch.bfloat16
3653
mm_dequant_output_dtype = torch.bfloat16
3754

55+
def device_synchronize(self):
56+
torch.xpu.synchronize()
57+
3858
def int8_double_quant(
3959
self,
4060
A: torch.Tensor,
@@ -185,7 +205,19 @@ def dequantize_blockwise(
185205
blocksize: int = 4096,
186206
nested=False,
187207
) -> torch.Tensor:
188-
raise NotImplementedError
208+
if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7):
209+
raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.")
210+
211+
# void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream)
212+
if out.dtype == torch.float16:
213+
ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel())
214+
elif out.dtype == torch.bfloat16:
215+
ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel())
216+
elif out.dtype == torch.float32:
217+
ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel())
218+
else:
219+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
220+
189221

190222
def quantize_blockwise(
191223
self,
@@ -220,7 +252,48 @@ def optimizer_update_8bit_blockwise(
220252
gnorm_scale: float = 1.0,
221253
skip_zeros=False,
222254
) -> None:
223-
raise NotImplementedError
255+
optim_func = None
256+
if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7):
257+
raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.")
258+
259+
assert_on_xpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
260+
261+
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
262+
optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
263+
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
264+
optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
265+
elif (
266+
g.dtype == torch.bfloat16
267+
and state1.dtype == torch.uint8
268+
and len(str2optimizer8bit_blockwise[optimizer_name]) == 3
269+
):
270+
optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
271+
else:
272+
raise ValueError(
273+
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
274+
)
275+
optim_func(
276+
p,
277+
g,
278+
state1,
279+
state2,
280+
beta1,
281+
beta2,
282+
beta3,
283+
alpha,
284+
eps,
285+
step,
286+
lr,
287+
qmap1,
288+
qmap2,
289+
absmax1,
290+
absmax2,
291+
weight_decay,
292+
gnorm_scale,
293+
skip_zeros,
294+
g.numel()
295+
)
296+
224297

225298
def optimizer_update_32bit(
226299
self,

bitsandbytes/functional.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,16 @@ def dequantize_blockwise(
859859
if out is None:
860860
out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)
861861

862-
if A.device.type != "cpu":
862+
if A.device.type == "xpu":
863+
backends[A.device.type].dequantize_blockwise(
864+
A=A,
865+
quant_state=quant_state,
866+
absmax=absmax,
867+
code=quant_state.code,
868+
out=out,
869+
blocksize=blocksize,
870+
nested=quant_state.nested,)
871+
elif A.device.type != "cpu":
863872
code = quant_state.code.to(A.device)
864873
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
865874
# Some AMD GPUs have warpsize 64

bitsandbytes/optim/optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111

1212
import bitsandbytes.functional as F
13+
from bitsandbytes.backends import backends
1314

1415

1516
class MockArgs:
@@ -289,11 +290,11 @@ def step(self, closure=None):
289290

290291
self.prefetch_state(p)
291292
self.update_step(group, p, gindex, pindex)
292-
torch.cuda.synchronize()
293+
backends[p.device.type].device_synchronize()
293294
if self.is_paged:
294295
# all paged operation are asynchronous, we need
295296
# to sync to make sure all tensors are in the right state
296-
torch.cuda.synchronize()
297+
backends[p.device.type].device_synchronize()
297298

298299
return loss
299300

0 commit comments

Comments
 (0)