1- import ctypes as ct
21from math import prod
3- from typing import Optional
2+ from typing import Optional , Tuple
43
54import torch
65
7- from .cextension import lib
8- from .functional import CUBLAS_Context , _cuda_device_of , _get_tensor_stream , get_ptr , is_on_gpu
9-
106_IS_TORCH_GTE_24 = False
117
128if hasattr (torch .library , "register_fake" ):
2723# return () instead of `None` for compatibility, see here: https://github.com/pytorch/pytorch/issues/125044
2824torch .library .define (
2925 "bitsandbytes::int8_linear_matmul" ,
30- "(Tensor A, Tensor B, Tensor(a!) ? out=None, ScalarType dtype=int32) -> Tensor(a!) " ,
26+ "(Tensor A, Tensor B, Tensor? out=None, ScalarType dtype=int32) -> Tensor" ,
3127)
3228
3329
34- # Fake/abstract op
3530@register_fake ("bitsandbytes::int8_linear_matmul" )
3631def _ (A : torch .Tensor , B : torch .Tensor , out : Optional [torch .Tensor ] = None , dtype = torch .int32 ):
3732 shapeC = (* A .shape [:- 1 ], B .shape [0 ])
@@ -40,103 +35,71 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
4035 return out
4136
4237
43- # CPU implementation
44- @register_kernel ("bitsandbytes::int8_linear_matmul" , "cpu" )
45- def _ (A : torch .Tensor , B : torch .Tensor , out : Optional [torch .Tensor ] = None , dtype = torch .int32 ):
46- # Naive implementation: perform matmul in fp32
47- result = torch .matmul (A .float (), B .float ().t ()).to (torch .int32 )
48- if out is not None :
49- result = out .copy_ (result )
50- return result
38+ torch .library .define (
39+ "bitsandbytes::int8_vectorwise_quant" ,
40+ "(Tensor A, Scalar threshold=0.0) -> (Tensor, Tensor, Tensor?)" ,
41+ )
5142
5243
53- # MPS impl
54- @ register_kernel ( "bitsandbytes::int8_linear_matmul" , "mps" )
55- def _ ( A : torch . Tensor , B : torch . Tensor , out : Optional [ torch .Tensor ] = None , dtype = torch .int32 ):
56- pass
44+ @ register_fake ( "bitsandbytes::int8_vectorwise_quant" )
45+ def _ ( A : torch . Tensor , threshold = 0.0 ):
46+ out_row = torch .empty ( A . shape , device = A . device , dtype = torch .int8 )
47+ row_stats = torch . empty ( prod ( A . shape [: - 1 ]), device = A . device , dtype = torch . float32 )
5748
49+ if threshold == 0.0 :
50+ return out_row , row_stats , None
5851
59- # XPU impl
60- @register_kernel ("bitsandbytes::int8_linear_matmul" , "xpu" )
61- def _ (A : torch .Tensor , B : torch .Tensor , out : Optional [torch .Tensor ] = None , dtype = torch .int32 ):
62- pass
52+ outlier_cols = torch .library .get_ctx ().new_dynamic_size ()
6353
54+ return out_row , row_stats , A .new_empty (outlier_cols , dtype = torch .int64 )
6455
65- # Ascend NPU impl
66- @register_kernel ("bitsandbytes::int8_linear_matmul" , "npu" )
67- def _ (A : torch .Tensor , B : torch .Tensor , out : Optional [torch .Tensor ] = None , dtype = torch .int32 ):
68- pass
6956
57+ torch .library .define ("bitsandbytes::int8_vectorwise_dequant" , "(Tensor A, Tensor stats) -> Tensor" )
7058
71- # CUDA/ROCm impl
72- @register_kernel ("bitsandbytes::int8_linear_matmul" , "cuda" )
73- def _ (A : torch .Tensor , B : torch .Tensor , out : Optional [torch .Tensor ] = None , dtype = torch .int32 ):
74- A , B = B , A
75-
76- shapeA = A .shape
77- shapeB = B .shape
78-
79- assert A .dtype == torch .int8
80- assert B .dtype == torch .int8
81- assert A .ndim == 2 , "Only two dimensional matrices are supported for argument B"
82- assert B .ndim in [2 , 3 ], "Only two or three dimensional matrices are supported for argument A"
83- assert prod (shapeB ) > 0 , f"Input tensor dimensions need to be > 0: { shapeB } "
84- assert out is None or out .dtype == dtype
85-
86- shapeC = (* shapeB [:- 1 ], shapeA [0 ])
87-
88- k , m = shapeA
89- n = prod (shapeB [:- 1 ])
90- lda = shapeA [- 1 ] # Weights (outputs, inputs)
91- ldb = shapeB [- 1 ] # Activations (batch, tokens, inputs)
92- ldc = shapeC [- 1 ] # Output (batch, tokens, outputs)
93-
94- assert (
95- lda == ldb
96- ), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = { shapeB } @ { shapeA } "
97-
98- # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
99- # We'll fall back to a slower fp32 calculation in this circumstance.
100- # Fortunately, this should not be very common.
101- if lda % 4 != 0 :
102- result = torch .matmul (B .float (), A .float ().t ()).to (torch .int32 )
103- if out is not None :
104- result = out .copy_ (result )
105- return result
10659
107- if out is None :
108- out = torch .empty (shapeC , device = A .device , dtype = dtype )
109-
110- is_on_gpu ([A , B , out ])
111-
112- with _cuda_device_of (A ):
113- ctx = CUBLAS_Context .get_instance ().get_context (A .device )
114- ptrA = get_ptr (A )
115- ptrB = get_ptr (B )
116- ptrC = get_ptr (out )
117- ptrRowScale = None
118- m = ct .c_int32 (m )
119- n = ct .c_int32 (n )
120- k = ct .c_int32 (k )
121- lda = ct .c_int32 (lda )
122- ldb = ct .c_int32 (ldb )
123- ldc = ct .c_int32 (ldc )
124- stream = _get_tensor_stream (A )
125-
126- if dtype == torch .int32 :
127- has_error = lib .cigemmlt_32 (ctx , m , n , k , ptrA , ptrB , ptrC , ptrRowScale , lda , ldb , ldc , stream )
128- else :
129- has_error = lib .cigemmlt_8 (ctx , m , n , k , ptrA , ptrB , ptrC , ptrRowScale , lda , ldb , ldc , stream )
130-
131- if has_error == 100 : # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
132- raise NotImplementedError ("int8_linear_matmul not implemented!" )
133-
134- if has_error :
135- raise RuntimeError (
136- f"cublasLt ran into an error!\n "
137- f"\t { shapeA = } , { shapeB = } , { shapeC = } \n "
138- f"\t { (lda , ldb , ldc )= } \n "
139- f"\t { (m , n , k )= } "
140- )
60+ @register_fake ("bitsandbytes::int8_vectorwise_dequant" )
61+ def _ (A : torch .Tensor , stats : torch .Tensor ) -> torch .Tensor :
62+ torch ._check (A .dtype == torch .int8 , "A must be int8" )
63+ return torch .empty_like (A , dtype = torch .float32 )
14164
142- return out
65+
66+ torch .library .define (
67+ "bitsandbytes::int8_mm_dequant" ,
68+ "(Tensor A, Tensor row_stats, Tensor col_stats, Tensor? out, Tensor? bias) -> Tensor" ,
69+ )
70+
71+
72+ @register_fake ("bitsandbytes::int8_mm_dequant" )
73+ def _ (
74+ A : torch .Tensor ,
75+ row_stats : torch .Tensor ,
76+ col_stats : torch .Tensor ,
77+ out : Optional [torch .Tensor ] = None ,
78+ bias : Optional [torch .Tensor ] = None ,
79+ ) -> torch .Tensor :
80+ torch ._check (A .dtype == torch .int32 , "A must be int32" )
81+ return torch .empty_like (A , dtype = torch .float16 )
82+
83+
84+ torch .library .define (
85+ "bitsandbytes::int8_double_quant" ,
86+ "(Tensor A, Tensor? col_stats, Tensor? row_stats, Tensor? out_col, Tensor? out_row, Scalar threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)" ,
87+ )
88+
89+
90+ @register_fake ("bitsandbytes::int8_double_quant" )
91+ def _ (
92+ A : torch .Tensor ,
93+ col_stats : Optional [torch .Tensor ] = None ,
94+ row_stats : Optional [torch .Tensor ] = None ,
95+ out_col : Optional [torch .Tensor ] = None ,
96+ out_row : Optional [torch .Tensor ] = None ,
97+ threshold = 0.0 ,
98+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
99+ out_row = torch .empty_like (A , dtype = torch .int8 )
100+ out_col = torch .empty_like (A , dtype = torch .int8 )
101+ row_stats = torch .empty (prod (A .shape [:- 1 ]), device = A .device , dtype = torch .float32 )
102+ col_stats = torch .empty (A .shape [- 1 ], device = A .device , dtype = torch .float32 )
103+ outlier_n = torch .library .get_ctx ().new_dynamic_size ()
104+ outlier_cols = A .new_empty (outlier_n , dtype = torch .int64 )
105+ return out_row , out_col , row_stats , col_stats , outlier_cols
0 commit comments