@@ -112,31 +112,62 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
112112 dtype = torch .float32 ,
113113 device = "cpu" ,
114114)
115+ _FP4_QUANT_TABLE = torch .tensor (
116+ [
117+ 0.0000 ,
118+ 0.0052 ,
119+ 0.6667 ,
120+ 1.0000 ,
121+ 0.3333 ,
122+ 0.5000 ,
123+ 0.1667 ,
124+ 0.2500 ,
125+ 0.0000 ,
126+ - 0.0052 ,
127+ - 0.6667 ,
128+ - 1.0000 ,
129+ - 0.3333 ,
130+ - 0.5000 ,
131+ - 0.1667 ,
132+ - 0.2500 ,
133+ ],
134+ dtype = torch .float32 ,
135+ device = "cpu" ,
136+ )
137+ CODE = {"nf4" : _NF4_QUANT_TABLE , "fp4" : _FP4_QUANT_TABLE }
115138
116139
117140@register_kernel ("bitsandbytes::quantize_4bit" , "cpu" )
118141def _ (
119142 A : torch .Tensor , blocksize : int , quant_type : str , quant_storage : torch .dtype
120143) -> tuple [torch .Tensor , torch .Tensor ]:
121144 torch ._check_is_size (blocksize )
122- torch ._check (quant_type == "nf4" , lambda : f"quant_type must be nf4 on CPU, got { quant_type } " )
145+ torch ._check (quant_type in ( "nf4" , "fp4" ), lambda : f"quant_type must be nf4 or fp4 on CPU, got { quant_type } " )
123146 torch ._check (
124147 A .dtype in [torch .bfloat16 , torch .float16 , torch .float32 ],
125148 lambda : f"Blockwise 4bit quantization only supports 16/32-bit floats, but got { A .dtype } " ,
126149 )
127150
128151 n = A .numel ()
129-
130- # TODO: Support when weight matrix is not divisible by blocksize
131- torch ._check (n % blocksize == 0 , lambda : f"n must be divisible by blocksize, got { n } and { blocksize } " )
132-
133- # Divide into blocks and normalize
134- blocks = A .reshape (- 1 , blocksize )
135- absmax = blocks .abs ().max (dim = 1 ).values .float ()
136- scaled = blocks / absmax .unsqueeze (- 1 )
137-
152+ blocks = n // blocksize
153+ blocks += 1 if n % blocksize > 0 else 0
154+ rem = n % blocksize
155+ has_rem = rem > 0
156+
157+ # Scale tensor to [-1, 1]
158+ absmax = torch .zeros ((blocks ,), device = A .device , dtype = A .dtype )
159+ A_reshaped = A .reshape (n )
160+ A_com_reshaped = A_reshaped [: n - rem ].reshape (n // blocksize , blocksize )
161+ absmax [: blocks - has_rem ] = torch .abs (A_com_reshaped ).max (dim = - 1 )[0 ]
162+ scaled = torch .clamp (A_com_reshaped * (1 / absmax [: blocks - has_rem ].view (- 1 , 1 )), - 1 , 1 )
163+ scaled = scaled .reshape (- 1 )
164+ if has_rem :
165+ absmax [- 1 ] = torch .abs (A_reshaped [n - rem :]).max ()
166+ scaled_rem = torch .clamp (A_reshaped [n - rem :] * (1 / absmax [- 1 ]), - 1 , 1 )
167+ scaled = torch .cat ([scaled , scaled_rem ], dim = 0 )
138168 # Quantize with the lookup table
139- quantized = torch .argmin (torch .abs (scaled .view (- 1 , 1 ) - _NF4_QUANT_TABLE ), dim = - 1 , keepdim = True ).to (torch .uint8 )
169+ quant_table = CODE [quant_type ]
170+ quantized = torch .argmin (torch .abs (scaled .view (- 1 , 1 ) - quant_table ), dim = - 1 , keepdim = True ).to (torch .uint8 )
140171
141172 # Pack two quantized values per byte
142173 packed = quantized [::2 ] << 4 | quantized [1 ::2 ]
@@ -157,32 +188,47 @@ def _(
157188 dtype : torch .dtype ,
158189) -> torch .Tensor :
159190 torch ._check_is_size (blocksize )
160- torch ._check (quant_type == "nf4" , lambda : f"quant_type must be nf4 on CPU, got { quant_type } " )
191+ torch ._check (quant_type in ( "nf4" , "fp4" ), lambda : f"quant_type must be nf4 or fp4 on CPU, got { quant_type } " )
161192 torch ._check (
162193 dtype in [torch .bfloat16 , torch .float16 , torch .float32 ],
163194 lambda : f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got { dtype } " ,
164195 )
165- torch ._check (
166- A .dtype == torch .uint8 ,
167- lambda : f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got { A .dtype } " ,
168- )
169-
170- A = A .view (- 1 , 1 )
171-
172- # Grab upper and lower nibbles. Using int64 for indexing in the LUT.
173- upper = (A >> 4 ).to (torch .int64 )
174- lower = (A & 0x0F ).to (torch .int64 )
175-
176- # Expand to blocks
177- blocks = torch .cat ((upper , lower ), dim = 1 ).reshape (- 1 , blocksize )
178196
179- # Dequantize
180- blocks = _NF4_QUANT_TABLE [blocks ] * absmax [:, None ]
197+ # Enable non uint8 dtype
198+ device = A .device
199+ if A .dtype != torch .uint8 :
200+ bytes_value = A .cpu ().numpy ().tobytes ()
201+ A = torch .frombuffer (bytes_value , dtype = torch .uint8 ).to (device )
202+
203+ A = A .reshape (- 1 )
204+ # Map nf4 to [-1, 1]
205+ out_dq = torch .empty (A .size (0 ) * 2 , dtype = torch .int32 , device = A .device )
206+ n = out_dq .numel ()
207+ out_dq [1 ::2 ] = A & 0xF
208+ out_dq [::2 ] = A >> 4
209+ # code is fp32, cast to dtype to avoid the mismatch issue
210+ code = CODE [quant_type ].to (dtype )
211+ out_dq = code [out_dq ]
212+
213+ # Apply scales
214+ if out_dq .numel () != n :
215+ assert out_dq .numel () == n + 1
216+ out_dq = torch .narrow (out_dq , 0 , 0 , n )
217+ blocks = n // blocksize
218+ blocks += 1 if n % blocksize > 0 else 0
219+ rem = n % blocksize
220+ has_rem = rem > 0
221+
222+ out = torch .empty (shape , dtype = dtype , device = A .device ).reshape (- 1 )
223+ if has_rem :
224+ out [: n - rem ] = (out_dq [: n - rem ].view (- 1 , blocksize ) * absmax [: blocks - has_rem ].view (- 1 , 1 )).reshape (- 1 )
225+ out [n - rem :] = out_dq [n - rem :] * absmax [- 1 ]
226+ else :
227+ out = out_dq .view (- 1 , blocksize ) * absmax .view (- 1 , 1 )
228+
229+ out = out .reshape (- 1 , * shape [1 :]).to (dtype )
181230
182- # Reshape to original shape
183- blocks = blocks .reshape (- 1 , * shape [1 :])
184-
185- return blocks .to (dtype )
231+ return out
186232
187233
188234@register_kernel ("bitsandbytes::gemv_4bit" , "cpu" )
@@ -194,17 +240,13 @@ def _(
194240 code : torch .Tensor ,
195241 blocksize : int ,
196242) -> torch .Tensor :
197- # TODO: We need to determine whether `code` is NF4, FP4, or other.
198- # Right now we assume NF4, as this is the only one supported on CPU.
199-
200- B_dq = torch .ops .bitsandbytes .dequantize_4bit .default (
201- B ,
202- absmax ,
203- blocksize ,
204- "nf4" ,
205- shape = shapeB ,
206- dtype = A .dtype ,
207- )
243+ # Applied from dequantize_4bit
244+ B = B .view (- 1 , 1 )
245+ upper = (B >> 4 ).to (torch .int64 )
246+ lower = (B & 0x0F ).to (torch .int64 )
247+ blocks = torch .cat ((upper , lower ), dim = 1 ).reshape (- 1 , blocksize )
248+ B_dq = code [blocks ] * absmax [:, None ]
249+ B_dq = B_dq .reshape (- 1 , * shapeB [1 :]).to (A .dtype )
208250
209251 # User called gemv with B.t(), so we need to transpose it back.
210252 # if B.shape[0] == 1:
0 commit comments