1515 register_fake = torch .library .impl_abstract
1616 register_kernel = torch .library .impl
1717
18+ # Int8 mixed precision matmul + dequant + bias
19+ torch .library .define (
20+ "bitsandbytes::int8_mixed_scaled_mm" ,
21+ "(Tensor A, Tensor CA, Tensor CB, Tensor SCA, Tensor SCB, Tensor? outlier_cols=None, Tensor? bias=None) -> (Tensor, Tensor?)" ,
22+ )
23+
24+
25+ @register_fake ("bitsandbytes::int8_mixed_scaled_mm" )
26+ def _ (
27+ A : torch .Tensor ,
28+ CA : torch .Tensor ,
29+ CB : torch .Tensor ,
30+ SCA : torch .Tensor ,
31+ SCB : torch .Tensor ,
32+ outlier_cols : Optional [torch .Tensor ] = None ,
33+ bias : Optional [torch .Tensor ] = None ,
34+ ) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
35+ shapeC = (* CA .shape [:- 1 ], CB .shape [0 ])
36+
37+ out = torch .empty (shapeC , device = A .device , dtype = A .dtype )
38+
39+ outlier_cols = torch .library .get_ctx ().new_dynamic_size ()
40+ subA = A .new_empty (outlier_cols , dtype = torch .int64 )
41+
42+ return out , subA
43+
1844
1945# Higher level op: int8 matmul + dequant + bias
2046torch .library .define (
2147 "bitsandbytes::int8_scaled_mm" ,
22- "(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16 ) -> Tensor" ,
48+ "(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType? dtype=None ) -> Tensor" ,
2349)
2450
2551
@@ -30,10 +56,10 @@ def _(
3056 row_stats : torch .Tensor ,
3157 col_stats : torch .Tensor ,
3258 bias : Optional [torch .Tensor ] = None ,
33- dtype = torch .float16 ,
59+ dtype : Optional [ torch .dtype ] = None ,
3460) -> torch .Tensor :
3561 shapeC = (* A .shape [:- 1 ], B .shape [0 ])
36- return torch .empty (shapeC , device = A .device , dtype = dtype )
62+ return torch .empty (shapeC , device = A .device , dtype = dtype or torch . float16 )
3763
3864
3965torch .library .define (
@@ -98,15 +124,15 @@ def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
98124
99125
100126# Default PyTorch-native implementation
101- @register_kernel ("bitsandbytes::int8_vectorwise_dequant" , None )
127+ @register_kernel ("bitsandbytes::int8_vectorwise_dequant" , "default" )
102128def _ (A : torch .Tensor , stats : torch .Tensor ):
103129 # To dequantize we divide by 127, or multiply by the reciprocal.
104130 return A * stats .view (- 1 , 1 ) * 7.874015718698502e-3
105131
106132
107133torch .library .define (
108134 "bitsandbytes::int8_mm_dequant" ,
109- "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16 , Tensor? bias=None) -> Tensor" ,
135+ "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType? dtype=None , Tensor? bias=None) -> Tensor" ,
110136)
111137
112138
@@ -115,11 +141,11 @@ def _(
115141 A : torch .Tensor ,
116142 row_stats : torch .Tensor ,
117143 col_stats : torch .Tensor ,
118- dtype = torch .float16 ,
144+ dtype : Optional [ torch .dtype ] = None ,
119145 bias : Optional [torch .Tensor ] = None ,
120146) -> torch .Tensor :
121147 torch ._check (A .dtype == torch .int32 , lambda : "A must be int32" )
122- return torch .empty_like (A , dtype = dtype )
148+ return torch .empty_like (A , dtype = dtype or torch . float16 )
123149
124150
125151torch .library .define (
0 commit comments