@@ -26,22 +26,42 @@ def _(A: torch.Tensor, B: torch.Tensor):
2626@register_kernel ("bitsandbytes::quantize_blockwise" , "cpu" )
2727def _ (A : torch .Tensor , code : torch .Tensor , blocksize : int ) -> tuple [torch .Tensor , torch .Tensor ]:
2828 torch ._check_is_size (blocksize )
29- torch ._check (A .dtype == torch .float32 , lambda : f"A must be float32 on cpu, got { A .dtype } " )
3029
3130 n = A .numel ()
32- blocks = - (n // - blocksize )
33-
34- absmax = torch .empty ((blocks ,), device = A .device , dtype = torch .float32 )
35- out = torch .empty_like (A , dtype = torch .uint8 )
36-
37- lib .cquantize_blockwise_cpu_fp32 (
38- get_ptr (code ),
39- get_ptr (A ),
40- get_ptr (absmax ),
41- get_ptr (out ),
42- ct .c_longlong (blocksize ),
43- ct .c_longlong (n ),
44- )
31+
32+ # Only FP32 has c++ kernrl
33+ if A .dtype == torch .float32 :
34+ blocks = - (n // - blocksize )
35+
36+ absmax = torch .empty ((blocks ,), device = A .device , dtype = A .dtype )
37+ out = torch .empty_like (A , dtype = torch .uint8 )
38+
39+ lib .cquantize_blockwise_cpu_fp32 (
40+ get_ptr (code ),
41+ get_ptr (A ),
42+ get_ptr (absmax ),
43+ get_ptr (out ),
44+ ct .c_longlong (blocksize ),
45+ ct .c_longlong (n ),
46+ )
47+ else :
48+ rem = n % blocksize
49+ has_rem = rem > 0
50+ blocks = n // blocksize + has_rem
51+ absmax = torch .zeros ((blocks ,), device = A .device , dtype = A .dtype )
52+ A_reshaped = A .reshape (n )
53+ A_com = A_reshaped [: n - rem ]
54+ A_com_reshaped = A_com .reshape (n // blocksize , blocksize )
55+ absmax [: blocks - has_rem ] = torch .abs (A_com_reshaped ).max (dim = - 1 )[0 ]
56+ scaled_A = torch .clamp (A_com_reshaped * (1 / absmax [: blocks - has_rem ].view (- 1 , 1 )), - 1 , 1 )
57+ scaled_A = scaled_A .reshape (- 1 )
58+ if has_rem :
59+ absmax [- 1 ] = torch .abs (A_reshaped [n - rem :]).max ()
60+ scaled_A_rem = torch .clamp (A_reshaped [n - rem :] * (1 / absmax [- 1 ]), - 1 , 1 )
61+ scaled_A = torch .cat ([scaled_A , scaled_A_rem ], dim = 0 )
62+
63+ diff = torch .abs (scaled_A .unsqueeze (- 1 ) - code .to (scaled_A .device ))
64+ out = torch .argmin (diff , dim = - 1 ).to (torch .uint8 ).to (scaled_A .device ).reshape (A .shape )
4565
4666 return out , absmax
4767
@@ -50,18 +70,28 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
5070def _ (A : torch .Tensor , absmax : torch .Tensor , code : torch .Tensor , blocksize : int , dtype : torch .dtype ) -> torch .Tensor :
5171 torch ._check_is_size (blocksize )
5272 torch ._check (A .dtype == torch .uint8 , lambda : f"A must be uint8, got { A .dtype } " )
53- torch ._check (dtype == torch .float32 , lambda : f"dtype must be float32 on cpu, got { dtype } " )
5473
55- out = torch .empty_like (A , dtype = dtype )
56-
57- lib .cdequantize_blockwise_cpu_fp32 (
58- get_ptr (code ),
59- get_ptr (A ),
60- get_ptr (absmax ),
61- get_ptr (out ),
62- ct .c_longlong (blocksize ),
63- ct .c_longlong (A .numel ()),
64- )
74+ # Only FP32 has c++ kernrl
75+ if dtype == torch .float32 :
76+ out = torch .empty_like (A , dtype = dtype )
77+
78+ lib .cdequantize_blockwise_cpu_fp32 (
79+ get_ptr (code ),
80+ get_ptr (A ),
81+ get_ptr (absmax ),
82+ get_ptr (out ),
83+ ct .c_longlong (blocksize ),
84+ ct .c_longlong (A .numel ()),
85+ )
86+ else :
87+ out = code [A .reshape (- 1 ).int ()]
88+ blocks = out .shape [- 1 ] // blocksize
89+ res = out .shape [- 1 ] % blocksize
90+ if res != 0 :
91+ out = torch .nn .functional .pad (out , (0 , blocksize - res ), mode = "constant" , value = 0 )
92+ out = (out .view (- 1 , blocksize ) * absmax .view (- 1 , 1 )).to (dtype ).reshape (- 1 )
93+ out = out [: blocks * blocksize + res ]
94+ out = out .reshape (A .shape )
6595
6696 return out
6797
0 commit comments