1212 int8_linear_matmul_impl ,
1313 int8_mm_dequant_impl ,
1414 quantize_4bit_impl ,
15+ _ipex_xpu_version_prereq
1516)
17+ try :
18+ import intel_extension_for_pytorch as ipex
19+ ipex_xpu = ipex if ipex ._C ._has_xpu () else None
20+ except BaseException :
21+ ipex_xpu = None
1622
1723Tensor = torch .Tensor
1824
1925
26+ str2optimizer8bit_blockwise = {}
27+ if ipex_xpu is not None and _ipex_xpu_version_prereq (2 , 7 ):
28+ str2optimizer8bit_blockwise = {
29+ "adam" : (
30+ ipex .xpu .bitsandbytes .cadam_8bit_blockwise_grad_fp32 ,
31+ ipex .xpu .bitsandbytes .cadam_8bit_blockwise_grad_fp16 ,
32+ ipex .xpu .bitsandbytes .cadam_8bit_blockwise_grad_bf16 ,
33+ ),
34+ }
35+
36+
2037def assert_on_xpu (tensors ):
2138 on_xpu = True
2239 for t in tensors :
@@ -35,6 +52,9 @@ class XPUBackend(Backend):
3552 mm_dequant_compute_dtype = torch .bfloat16
3653 mm_dequant_output_dtype = torch .bfloat16
3754
55+ def device_synchronize (self ):
56+ torch .xpu .synchronize ()
57+
3858 def int8_double_quant (
3959 self ,
4060 A : torch .Tensor ,
@@ -185,7 +205,19 @@ def dequantize_blockwise(
185205 blocksize : int = 4096 ,
186206 nested = False ,
187207 ) -> torch .Tensor :
188- raise NotImplementedError
208+ if ipex_xpu is None or not _ipex_xpu_version_prereq (2 , 7 ):
209+ raise RuntimeError ("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device." )
210+
211+ # void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream)
212+ if out .dtype == torch .float16 :
213+ ipex .xpu .bitsandbytes .cdequantize_blockwise_fp16 (code , A , absmax , out , blocksize , A .numel ())
214+ elif out .dtype == torch .bfloat16 :
215+ ipex .xpu .bitsandbytes .cdequantize_blockwise_bf16 (code , A , absmax , out , blocksize , A .numel ())
216+ elif out .dtype == torch .float32 :
217+ ipex .xpu .bitsandbytes .cdequantize_blockwise_fp32 (code , A , absmax , out , blocksize , A .numel ())
218+ else :
219+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { out .dtype } " )
220+
189221
190222 def quantize_blockwise (
191223 self ,
@@ -220,7 +252,48 @@ def optimizer_update_8bit_blockwise(
220252 gnorm_scale : float = 1.0 ,
221253 skip_zeros = False ,
222254 ) -> None :
223- raise NotImplementedError
255+ optim_func = None
256+ if ipex_xpu is None or not _ipex_xpu_version_prereq (2 , 7 ):
257+ raise RuntimeError ("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device." )
258+
259+ assert_on_xpu ([g , p , state1 , state2 , qmap1 , qmap2 , absmax1 , absmax2 ])
260+
261+ if g .dtype == torch .float32 and state1 .dtype == torch .uint8 :
262+ optim_func = str2optimizer8bit_blockwise [optimizer_name ][0 ]
263+ elif g .dtype == torch .float16 and state1 .dtype == torch .uint8 :
264+ optim_func = str2optimizer8bit_blockwise [optimizer_name ][1 ]
265+ elif (
266+ g .dtype == torch .bfloat16
267+ and state1 .dtype == torch .uint8
268+ and len (str2optimizer8bit_blockwise [optimizer_name ]) == 3
269+ ):
270+ optim_func = str2optimizer8bit_blockwise [optimizer_name ][2 ]
271+ else :
272+ raise ValueError (
273+ f"Gradient+optimizer bit data type combination not supported: grad { g .dtype } , optimizer { state1 .dtype } " ,
274+ )
275+ optim_func (
276+ p ,
277+ g ,
278+ state1 ,
279+ state2 ,
280+ beta1 ,
281+ beta2 ,
282+ beta3 ,
283+ alpha ,
284+ eps ,
285+ step ,
286+ lr ,
287+ qmap1 ,
288+ qmap2 ,
289+ absmax1 ,
290+ absmax2 ,
291+ weight_decay ,
292+ gnorm_scale ,
293+ skip_zeros ,
294+ g .numel ()
295+ )
296+
224297
225298 def optimizer_update_32bit (
226299 self ,
0 commit comments