1+ from collections .abc import Sequence
12import ctypes as ct
23import logging
4+ from math import prod
35
46import torch
57
@@ -76,10 +78,8 @@ def _(
7678 torch ._check_is_size (blocksize )
7779 torch ._check (A .dtype == torch .uint8 , lambda : f"A must be uint8, got { A .dtype } " )
7880
79- # Only FP32 has c++ kernrl
81+ out = torch . empty_like ( A , dtype = dtype )
8082 if dtype == torch .float32 :
81- out = torch .empty_like (A , dtype = dtype )
82-
8383 lib .cdequantize_blockwise_cpu_fp32 (
8484 get_ptr (code ),
8585 get_ptr (A ),
@@ -88,6 +88,24 @@ def _(
8888 ct .c_longlong (blocksize ),
8989 ct .c_longlong (A .numel ()),
9090 )
91+ elif dtype == torch .bfloat16 :
92+ lib .cdequantize_blockwise_cpu_bf16 (
93+ get_ptr (code ),
94+ get_ptr (A ),
95+ get_ptr (absmax ),
96+ get_ptr (out ),
97+ ct .c_longlong (blocksize ),
98+ ct .c_longlong (A .numel ()),
99+ )
100+ elif dtype == torch .float16 :
101+ lib .cdequantize_blockwise_cpu_fp16 (
102+ get_ptr (code ),
103+ get_ptr (A ),
104+ get_ptr (absmax ),
105+ get_ptr (out ),
106+ ct .c_longlong (blocksize ),
107+ ct .c_longlong (A .numel ()),
108+ )
91109 else :
92110 out = code [A .reshape (- 1 ).int ()]
93111 blocks = out .shape [- 1 ] // blocksize
@@ -99,3 +117,103 @@ def _(
99117 out = out .reshape (A .shape )
100118
101119 return out
120+
121+ @register_kernel ("bitsandbytes::dequantize_4bit" , "cpu" )
122+ def _ (
123+ A : torch .Tensor ,
124+ absmax : torch .Tensor ,
125+ blocksize : int ,
126+ quant_type : str ,
127+ shape : Sequence [int ],
128+ dtype : torch .dtype ,
129+ ) -> torch .Tensor :
130+ torch ._check_is_size (blocksize )
131+ torch ._check (quant_type in ("nf4" , "fp4" ), lambda : f"quant_type must be nf4 or fp4, got { quant_type } " )
132+ torch ._check (
133+ dtype in [torch .bfloat16 , torch .float16 , torch .float32 ],
134+ lambda : f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got { dtype } " ,
135+ )
136+
137+ # Odd shape is not supported by this kernel; fallback to generic implementation
138+ if shape [- 1 ] % 2 != 0 :
139+ from ..default .ops import _dequantize_4bit_impl
140+
141+ return _dequantize_4bit_impl (A , absmax , blocksize , quant_type , shape , dtype )
142+
143+ # Enable non uint8 dtype
144+ if A .dtype != torch .uint8 :
145+ A = A .view (torch .uint8 )
146+
147+ # TODO: support half precision absmax
148+ if absmax .dtype != torch .float32 :
149+ absmax = absmax .float ()
150+
151+ if len (shape ) == 1 :
152+ shape = (1 , shape [0 ])
153+
154+ m = prod (shape [:- 1 ])
155+ n = shape [- 1 ]
156+
157+ A = A .reshape (m , n // 2 )
158+ out = torch .empty (shape , dtype = dtype , device = A .device )
159+
160+ if quant_type == "fp4" :
161+ if dtype == torch .float32 :
162+ lib .cdequantize_blockwise_cpu_fp4_fp32 (
163+ get_ptr (A ),
164+ get_ptr (absmax ),
165+ get_ptr (out ),
166+ ct .c_longlong (blocksize ),
167+ ct .c_longlong (m ),
168+ ct .c_longlong (n ),
169+ )
170+ elif dtype == torch .bfloat16 :
171+ lib .cdequantize_blockwise_cpu_fp4_bf16 (
172+ get_ptr (A ),
173+ get_ptr (absmax ),
174+ get_ptr (out ),
175+ ct .c_longlong (blocksize ),
176+ ct .c_longlong (m ),
177+ ct .c_longlong (n ),
178+ )
179+ elif dtype == torch .float16 :
180+ lib .cdequantize_blockwise_cpu_fp4_fp16 (
181+ get_ptr (A ),
182+ get_ptr (absmax ),
183+ get_ptr (out ),
184+ ct .c_longlong (blocksize ),
185+ ct .c_longlong (m ),
186+ ct .c_longlong (n ),
187+ )
188+ elif quant_type == "nf4" :
189+ if dtype == torch .float32 :
190+ lib .cdequantize_blockwise_cpu_nf4_fp32 (
191+ get_ptr (A ),
192+ get_ptr (absmax ),
193+ get_ptr (out ),
194+ ct .c_longlong (blocksize ),
195+ ct .c_longlong (m ),
196+ ct .c_longlong (n ),
197+ )
198+ elif dtype == torch .bfloat16 :
199+ lib .cdequantize_blockwise_cpu_nf4_bf16 (
200+ get_ptr (A ),
201+ get_ptr (absmax ),
202+ get_ptr (out ),
203+ ct .c_longlong (blocksize ),
204+ ct .c_longlong (m ),
205+ ct .c_longlong (n ),
206+ )
207+ elif dtype == torch .float16 :
208+ lib .cdequantize_blockwise_cpu_nf4_fp16 (
209+ get_ptr (A ),
210+ get_ptr (absmax ),
211+ get_ptr (out ),
212+ ct .c_longlong (blocksize ),
213+ ct .c_longlong (m ),
214+ ct .c_longlong (n ),
215+ )
216+ else :
217+ raise ValueError
218+
219+ return out
0 commit comments