1919# Higher level op: int8 matmul + dequant + bias
2020torch .library .define (
2121 "bitsandbytes::int8_scaled_mm" ,
22- "(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16 ) -> Tensor" ,
22+ "(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType? dtype=None ) -> Tensor" ,
2323)
2424
2525
@@ -30,10 +30,10 @@ def _(
3030 row_stats : torch .Tensor ,
3131 col_stats : torch .Tensor ,
3232 bias : Optional [torch .Tensor ] = None ,
33- dtype = torch .float16 ,
33+ dtype : Optional [ torch .dtype ] = None ,
3434) -> torch .Tensor :
3535 shapeC = (* A .shape [:- 1 ], B .shape [0 ])
36- return torch .empty (shapeC , device = A .device , dtype = dtype )
36+ return torch .empty (shapeC , device = A .device , dtype = dtype or torch . float16 )
3737
3838
3939torch .library .define (
@@ -98,15 +98,15 @@ def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
9898
9999
100100# Default PyTorch-native implementation
101- @register_kernel ("bitsandbytes::int8_vectorwise_dequant" , None )
101+ @register_kernel ("bitsandbytes::int8_vectorwise_dequant" , "default" )
102102def _ (A : torch .Tensor , stats : torch .Tensor ):
103103 # To dequantize we divide by 127, or multiply by the reciprocal.
104104 return A * stats .view (- 1 , 1 ) * 7.874015718698502e-3
105105
106106
107107torch .library .define (
108108 "bitsandbytes::int8_mm_dequant" ,
109- "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16 , Tensor? bias=None) -> Tensor" ,
109+ "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType? dtype=None , Tensor? bias=None) -> Tensor" ,
110110)
111111
112112
@@ -115,11 +115,11 @@ def _(
115115 A : torch .Tensor ,
116116 row_stats : torch .Tensor ,
117117 col_stats : torch .Tensor ,
118- dtype = torch .float16 ,
118+ dtype : Optional [ torch .dtype ] = None ,
119119 bias : Optional [torch .Tensor ] = None ,
120120) -> torch .Tensor :
121121 torch ._check (A .dtype == torch .int32 , lambda : "A must be int32" )
122- return torch .empty_like (A , dtype = dtype )
122+ return torch .empty_like (A , dtype = dtype or torch . float16 )
123123
124124
125125torch .library .define (
0 commit comments