11from collections .abc import Sequence
22import ctypes as ct
3- from typing import Optional
43
54import torch
65
@@ -24,29 +23,6 @@ def _(A: torch.Tensor, B: torch.Tensor):
2423 ).reshape (* A .shape [:- 1 ], B .shape [0 ])
2524
2625
27- @register_kernel ("bitsandbytes::int8_mm_dequant" , "cpu" )
28- def _ (
29- A : torch .Tensor ,
30- row_stats : torch .Tensor ,
31- col_stats : torch .Tensor ,
32- dtype : Optional [torch .dtype ] = None ,
33- bias : Optional [torch .Tensor ] = None ,
34- ) -> torch .Tensor :
35- torch ._check (A .dtype == torch .int32 , lambda : f"A must be int32, got { A .dtype } " )
36- torch ._check (row_stats .dtype == torch .float32 , lambda : f"row_stats must be float32, got { row_stats .dtype } " )
37- torch ._check (col_stats .dtype == torch .float32 , lambda : f"col_stats must be float32, got { col_stats .dtype } " )
38-
39- A_calc = A .view (- 1 , A .shape [- 1 ])
40- row_stats = row_stats .reshape (- 1 ).unsqueeze (- 1 )
41- col_stats = col_stats .reshape (- 1 ).unsqueeze (0 )
42-
43- out = A_calc * (row_stats * col_stats ) * 6.200124e-05
44- if bias is not None :
45- out += bias
46-
47- return out .to (dtype or torch .float16 )
48-
49-
5026@register_kernel ("bitsandbytes::quantize_blockwise" , "cpu" )
5127def _ (A : torch .Tensor , code : torch .Tensor , blocksize : int ) -> tuple [torch .Tensor , torch .Tensor ]:
5228 torch ._check_is_size (blocksize )
0 commit comments