77from .base import Backend
88from .cpu_xpu_common import (
99 dequantize_4bit_impl ,
10+ double_quant_impl ,
1011 gemm_4bit_impl ,
1112 igemmlt_impl ,
1213 mm_dequant_impl ,
1617Tensor = torch .Tensor
1718
1819
20+ def assert_on_hpu (tensors ):
21+ on_hpu = True
22+ for t in tensors :
23+ if t is None :
24+ continue # NULL pointers are fine
25+ on_hpu &= t .device .type == "hpu"
26+ if not on_hpu :
27+ raise TypeError (
28+ "All input tensors need to be on HPU, but found some tensors to not be on HPU:\n "
29+ f" { [(t .shape , t .device ) if isinstance (t , Tensor ) else None for t in tensors ]} "
30+ )
31+ return on_hpu
32+
33+
1934class HPUBackend (Backend ):
2035 mm_dequant_compute_dtype = torch .bfloat16
2136 mm_dequant_output_dtype = torch .bfloat16
2237
38+ def double_quant (
39+ self ,
40+ A : torch .Tensor ,
41+ col_stats : Optional [torch .Tensor ] = None ,
42+ row_stats : Optional [torch .Tensor ] = None ,
43+ out_col : Optional [torch .Tensor ] = None ,
44+ out_row : Optional [torch .Tensor ] = None ,
45+ threshold = 0.0 ,
46+ ):
47+ raise NotImplementedError ("Not yet implemented for HPU backend" )
48+
2349 def transform (
2450 self ,
2551 A : torch .Tensor ,
@@ -32,20 +58,10 @@ def transform(
3258 ):
3359 """
3460 Transform tensor A to to_order. It is originally designed for CUDA.
35- For HPU , it returns the original tensor if transpose=False.
61+ For CPU , it returns the original tensor if transpose=False.
3662 Otherwise, it returns the transpose of A
3763 """
38- if transpose :
39- if out is not None :
40- out .copy_ (A .T )
41- else :
42- out = A .T
43- else :
44- if out is not None :
45- out .copy_ (A )
46- else :
47- out = A
48- return out , state
64+ raise NotImplementedError ("Not yet implemented for HPU backend" )
4965
5066 def igemmlt (
5167 self ,
@@ -56,9 +72,8 @@ def igemmlt(
5672 out : Optional [torch .Tensor ] = None ,
5773 Sout : Optional [Tuple [torch .Size , str ]] = None ,
5874 dtype = torch .int32 ,
59- ) -> Union [torch .Tensor , Tuple [Optional [Tuple [torch .Tensor , Tuple [torch .Size ,
60- str ]]]]]:
61-
75+ ) -> Union [torch .Tensor , Tuple [Optional [Tuple [torch .Tensor , Tuple [torch .Size , str ]]]]]:
76+ assert_on_hpu ([A , B ])
6277 return igemmlt_impl (A , B , SA , SB , out , Sout , dtype )
6378
6479 def mm_dequant (
@@ -72,7 +87,7 @@ def mm_dequant(
7287 new_col_stats : Optional [torch .Tensor ] = None ,
7388 bias : Optional [torch .Tensor ] = None ,
7489 ) -> torch .Tensor :
75-
90+ assert_on_hpu ([ A , row_stats , col_stats , out , bias ])
7691 return mm_dequant_impl (
7792 A ,
7893 quant_state ,
@@ -95,7 +110,7 @@ def extract_outliers(
95110 """
96111 Extract columns of A by idx
97112 """
98-
113+ assert_on_hpu ([ A ])
99114 return A [:, idx ].contiguous ()
100115
101116 def quantize_4bit (
@@ -108,12 +123,12 @@ def quantize_4bit(
108123 quant_type : Literal ["fp4" , "nf4" ] = "fp4" ,
109124 quant_storage = torch .uint8 ,
110125 ) -> Tuple [torch .Tensor , QuantState ]:
111-
112126 if blocksize is None :
113127 blocksize = 64
114- assert quant_storage == torch .uint8
115- return quantize_4bit_impl (
116- A , absmax , out , blocksize , compress_statistics , quant_type )
128+
129+ assert_on_hpu ([A , absmax , out ])
130+ assert quant_storage == torch .uint8 , "HPU backend only supports uint8 quant_storage"
131+ return quantize_4bit_impl (A , absmax , out , blocksize , compress_statistics , quant_type )
117132
118133 def dequantize_4bit (
119134 self ,
@@ -124,9 +139,10 @@ def dequantize_4bit(
124139 blocksize : int = 64 ,
125140 quant_type : Literal ["fp4" , "nf4" ] = "fp4" ,
126141 ) -> torch .Tensor :
127-
128142 if blocksize is None :
129143 blocksize = 64
144+
145+ assert_on_hpu ([A , absmax , out ])
130146 return dequantize_4bit_impl (A , quant_state , absmax , out , blocksize , quant_type )
131147
132148 def gemv_4bit (
@@ -138,10 +154,73 @@ def gemv_4bit(
138154 transposed_B = False ,
139155 state : QuantState = None ,
140156 ) -> torch .Tensor :
141-
157+ assert_on_hpu ([ A , B , out ])
142158 if state is None :
143- raise ValueError (
144- "state cannot be None. gemv_4bit() requires the state from quantize_4bit()"
145- )
159+ raise ValueError ("state cannot be None. gemv_4bit() requires the state from quantize_4bit()" )
146160
147- return gemm_4bit_impl (A , B , out , transposed_A , transposed_B , state )
161+ return gemm_4bit_impl (A , B , out , transposed_A , transposed_B , state )
162+
163+ def dequantize_blockwise (
164+ self ,
165+ A : torch .Tensor ,
166+ quant_state : Optional [QuantState ] = None ,
167+ absmax : Optional [torch .Tensor ] = None ,
168+ code : Optional [torch .Tensor ] = None ,
169+ out : Optional [torch .Tensor ] = None ,
170+ blocksize : int = 4096 ,
171+ nested = False ,
172+ ) -> torch .Tensor :
173+ raise NotImplementedError ("Not yet implemented for HPU backend" )
174+
175+ def quantize_blockwise (
176+ self ,
177+ A : torch .Tensor ,
178+ code : Optional [torch .Tensor ] = None ,
179+ absmax : Optional [torch .Tensor ] = None ,
180+ out : Optional [torch .Tensor ] = None ,
181+ blocksize = 4096 ,
182+ nested = False ,
183+ ) -> Tuple [torch .Tensor , QuantState ]:
184+ raise NotImplementedError ("Not yet implemented for HPU backend" )
185+
186+ def optimizer_update_8bit_blockwise (
187+ self ,
188+ optimizer_name : str ,
189+ g : torch .Tensor ,
190+ p : torch .Tensor ,
191+ state1 : torch .Tensor ,
192+ state2 : Optional [torch .Tensor ],
193+ beta1 : float ,
194+ beta2 : float ,
195+ eps : float ,
196+ step : int ,
197+ lr : float ,
198+ qmap1 : torch .Tensor ,
199+ qmap2 : Optional [torch .Tensor ],
200+ absmax1 : torch .Tensor ,
201+ absmax2 : Optional [torch .Tensor ],
202+ weight_decay : float = 0.0 ,
203+ gnorm_scale : float = 1.0 ,
204+ skip_zeros = False ,
205+ ) -> None :
206+ raise NotImplementedError ("Not yet implemented for HPU backend" )
207+
208+ def optimizer_update_32bit (
209+ self ,
210+ optimizer_name : str ,
211+ g : torch .Tensor ,
212+ p : torch .Tensor ,
213+ state1 : torch .Tensor ,
214+ beta1 : float ,
215+ eps : float ,
216+ step : int ,
217+ lr : float ,
218+ state2 : Optional [torch .Tensor ] = None ,
219+ beta2 : float = 0.0 ,
220+ weight_decay : float = 0.0 ,
221+ gnorm_scale : float = 1.0 ,
222+ unorm_vec : Optional [torch .Tensor ] = None ,
223+ max_unorm : float = 0.0 ,
224+ skip_zeros = False ,
225+ ) -> None :
226+ raise NotImplementedError ("Not yet implemented for HPU backend" )
0 commit comments