11import ctypes as ct
22from math import prod
3- from typing import Optional , Tuple
3+ from typing import Optional , Sequence , Tuple
44
55import torch
66
7- from bitsandbytes .functional import CUBLAS_Context , _cuda_device_of , _get_tensor_stream , get_ptr , is_on_gpu
7+ from bitsandbytes .functional import CUBLAS_Context , _cuda_device_of , _get_tensor_stream , get_ptr
88
99from ..._ops import register_kernel
1010from ...cextension import lib
@@ -17,12 +17,12 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
1717 shapeA = A .shape
1818 shapeB = B .shape
1919
20- assert A .dtype == torch .int8
21- assert B .dtype == torch .int8
22- assert A .ndim == 2 , "Only two dimensional matrices are supported for argument B"
23- assert B .ndim in [2 , 3 ], "Only two or three dimensional matrices are supported for argument A"
24- assert prod (shapeB ) > 0 , f"Input tensor dimensions need to be > 0: { shapeB } "
25- assert out is None or out .dtype == dtype
20+ torch . _check ( A .dtype == torch .int8 , "B must be int8" )
21+ torch . _check ( B .dtype == torch .int8 , "A must be int8" )
22+ torch . _check ( A .ndim == 2 , "Only two dimensional matrices are supported for argument B" )
23+ torch . _check ( B .ndim in [2 , 3 ], "Only two or three dimensional matrices are supported for argument A" )
24+ torch . _check ( prod (shapeB ) > 0 , f"Input tensor dimensions need to be > 0: { shapeB } " )
25+ torch . _check ( out is None or out .dtype == dtype )
2626
2727 shapeC = (* shapeB [:- 1 ], shapeA [0 ])
2828
@@ -32,9 +32,10 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
3232 ldb = shapeB [- 1 ] # Activations (batch, tokens, inputs)
3333 ldc = shapeC [- 1 ] # Output (batch, tokens, outputs)
3434
35- assert (
36- lda == ldb
37- ), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = { shapeB } @ { shapeA } "
35+ torch ._check (
36+ lda == ldb ,
37+ f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = { shapeB } @ { shapeA } " ,
38+ )
3839
3940 # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
4041 # We'll fall back to a slower fp32 calculation in this circumstance.
@@ -48,8 +49,6 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
4849 if out is None :
4950 out = torch .empty (shapeC , device = A .device , dtype = dtype )
5051
51- is_on_gpu ([A , B , out ])
52-
5352 with _cuda_device_of (A ):
5453 ctx = CUBLAS_Context .get_instance ().get_context (A .device )
5554 ptrA = get_ptr (A )
@@ -69,16 +68,18 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
6968 else :
7069 has_error = lib .cigemmlt_8 (ctx , m , n , k , ptrA , ptrB , ptrC , ptrRowScale , lda , ldb , ldc , stream )
7170
72- if has_error == 100 : # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
73- raise NotImplementedError ("int8_linear_matmul not implemented!" )
74-
7571 if has_error :
76- raise RuntimeError (
77- f"cublasLt ran into an error!\n "
78- f"\t { shapeA = } , { shapeB = } , { shapeC = } \n "
79- f"\t { (lda , ldb , ldc )= } \n "
80- f"\t { (m , n , k )= } "
81- )
72+ if has_error == 100 :
73+ # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
74+ # TODO: Warn and implement a fallback to fp32 compute?
75+ raise NotImplementedError ("int8_linear_matmul not implemented!" )
76+ else :
77+ raise RuntimeError (
78+ f"cublasLt ran into an error!\n "
79+ f"\t { shapeA = } , { shapeB = } , { shapeC = } \n "
80+ f"\t { (lda , ldb , ldc )= } \n "
81+ f"\t { (m , n , k )= } "
82+ )
8283
8384 return out
8485
@@ -91,10 +92,10 @@ def _(
9192 out : Optional [torch .Tensor ] = None ,
9293 bias : Optional [torch .Tensor ] = None ,
9394) -> torch .Tensor :
94- assert A .dtype == torch .int32
95+ torch . _check ( A .dtype == torch .int32 , "A must be int32" )
9596
9697 if bias is not None :
97- assert bias .dtype == torch .float16
98+ torch . _check ( bias .dtype == torch .float16 )
9899
99100 if out is None :
100101 out = torch .empty_like (A , dtype = torch .float16 )
@@ -107,8 +108,6 @@ def _(
107108 numRows = ct .c_int32 (prod (A .shape [:- 1 ]))
108109 numCols = ct .c_int32 (A .shape [- 1 ])
109110
110- is_on_gpu ([A , row_stats , col_stats , out , bias ])
111-
112111 with _cuda_device_of (A ):
113112 lib .cdequant_mm_int32_fp16 (
114113 ptrA , ptrRowStats , ptrColStats , ptrOut , ptrBias , numRows , numCols , _get_tensor_stream (A )
@@ -119,8 +118,7 @@ def _(
119118
120119@register_kernel ("bitsandbytes::int8_vectorwise_quant" , "cuda" )
121120def _ (A : torch .Tensor , threshold = 0.0 ):
122- assert A .dtype == torch .half
123- is_on_gpu ([A ])
121+ torch ._check (A .dtype == torch .float16 , "A must be float16" )
124122
125123 rows = prod (A .shape [:- 1 ])
126124 cols = A .shape [- 1 ]
@@ -188,7 +186,7 @@ def _get_col_absmax(
188186 A : torch .Tensor ,
189187 threshold = 0.0 ,
190188) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
191- assert A .is_floating_point ()
189+ torch . _check ( A .is_floating_point () )
192190
193191 outlier_mask = None
194192
@@ -203,3 +201,150 @@ def _get_col_absmax(
203201 col_stats = absA .amax (dim = 0 , keepdim = False ).float ()
204202
205203 return col_stats , outlier_mask
204+
205+
206+ @register_kernel ("bitsandbytes::quantize_blockwise" , "cuda" )
207+ def _ (A : torch .Tensor , code : torch .Tensor , blocksize : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
208+ torch ._check (blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
209+
210+ n = A .numel ()
211+ blocks = - (n // - blocksize )
212+ absmax = torch .zeros ((blocks ,), device = A .device , dtype = torch .float32 )
213+ out = torch .zeros_like (A , dtype = torch .uint8 )
214+
215+ with _cuda_device_of (A ):
216+ args = (
217+ get_ptr (code ),
218+ get_ptr (A ),
219+ get_ptr (absmax ),
220+ get_ptr (out ),
221+ ct .c_int32 (blocksize ),
222+ ct .c_int (A .numel ()),
223+ )
224+
225+ if A .dtype == torch .float16 :
226+ lib .cquantize_blockwise_fp16 (* args )
227+ elif A .dtype == torch .bfloat16 :
228+ lib .cquantize_blockwise_bf16 (* args )
229+ elif A .dtype == torch .float32 :
230+ lib .cquantize_blockwise_fp32 (* args )
231+ else :
232+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
233+
234+ return out , absmax
235+
236+
237+ @register_kernel ("bitsandbytes::dequantize_blockwise" , "cuda" )
238+ def _ (A : torch .Tensor , absmax : torch .Tensor , code : torch .Tensor , blocksize : int , dtype : torch .dtype ) -> torch .Tensor :
239+ torch ._check (blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
240+
241+ out = torch .empty_like (A , dtype = dtype )
242+
243+ with _cuda_device_of (A ):
244+ args = (
245+ get_ptr (code ),
246+ get_ptr (A ),
247+ get_ptr (absmax ),
248+ get_ptr (out ),
249+ ct .c_int (blocksize ),
250+ ct .c_int (A .numel ()),
251+ _get_tensor_stream (A ),
252+ )
253+
254+ if dtype == torch .float16 :
255+ lib .cdequantize_blockwise_fp16 (* args )
256+ elif dtype == torch .bfloat16 :
257+ lib .cdequantize_blockwise_bf16 (* args )
258+ elif dtype == torch .float32 :
259+ lib .cdequantize_blockwise_fp32 (* args )
260+ else :
261+ raise ValueError (f"Blockwise dequantization only supports 16/32-bit floats, but got { dtype } " )
262+
263+ return out
264+
265+
266+ @register_kernel ("bitsandbytes::quantize_4bit" , "cuda" )
267+ def _ (
268+ A : torch .Tensor , blocksize : int , quant_type : str , quant_storage : torch .dtype
269+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
270+ torch ._check (blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
271+ torch ._check (quant_type in ["fp4" , "nf4" ])
272+
273+ n = A .numel ()
274+ blocks = - (n // - blocksize )
275+ absmax = torch .empty ((blocks ,), device = A .device , dtype = torch .float32 )
276+ out = torch .empty (((n + 1 ) // (quant_storage .itemsize * 2 ), 1 ), device = A .device , dtype = quant_storage )
277+
278+ with _cuda_device_of (A ):
279+ args = (
280+ None ,
281+ get_ptr (A ),
282+ get_ptr (absmax ),
283+ get_ptr (out ),
284+ ct .c_int32 (blocksize ),
285+ ct .c_int (n ),
286+ )
287+
288+ if A .dtype == torch .bfloat16 :
289+ if quant_type == "fp4" :
290+ lib .cquantize_blockwise_bf16_fp4 (* args )
291+ else :
292+ lib .cquantize_blockwise_bf16_nf4 (* args )
293+ elif A .dtype == torch .float16 :
294+ if quant_type == "fp4" :
295+ lib .cquantize_blockwise_fp16_fp4 (* args )
296+ else :
297+ lib .cquantize_blockwise_fp16_nf4 (* args )
298+ elif A .dtype == torch .float32 :
299+ if quant_type == "fp4" :
300+ lib .cquantize_blockwise_fp32_fp4 (* args )
301+ else :
302+ lib .cquantize_blockwise_fp32_nf4 (* args )
303+ else :
304+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
305+
306+ return out , absmax
307+
308+
309+ @register_kernel ("bitsandbytes::dequantize_4bit" , "cuda" )
310+ def _ (
311+ A : torch .Tensor , absmax : torch .Tensor , blocksize : int , quant_type : str , shape : Sequence [int ], dtype : torch .dtype
312+ ) -> torch .Tensor :
313+ torch ._check (blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
314+ torch ._check (quant_type in ["fp4" , "nf4" ])
315+
316+ out = torch .empty (shape , dtype = dtype , device = A .device )
317+ n = out .numel ()
318+
319+ stream = _get_tensor_stream (A )
320+
321+ with _cuda_device_of (A ):
322+ args = (
323+ None ,
324+ get_ptr (A ),
325+ get_ptr (absmax ),
326+ get_ptr (out ),
327+ ct .c_int (blocksize ),
328+ ct .c_int (n ),
329+ stream ,
330+ )
331+
332+ if out .dtype == torch .bfloat16 :
333+ if quant_type == "fp4" :
334+ lib .cdequantize_blockwise_bf16_fp4 (* args )
335+ else :
336+ lib .cdequantize_blockwise_bf16_nf4 (* args )
337+ elif out .dtype == torch .float16 :
338+ if quant_type == "fp4" :
339+ lib .cdequantize_blockwise_fp16_fp4 (* args )
340+ else :
341+ lib .cdequantize_blockwise_fp16_nf4 (* args )
342+ elif out .dtype == torch .float32 :
343+ if quant_type == "fp4" :
344+ lib .cdequantize_blockwise_fp32_fp4 (* args )
345+ else :
346+ lib .cdequantize_blockwise_fp32_nf4 (* args )
347+ else :
348+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { out .dtype } " )
349+
350+ return out
0 commit comments