@@ -39,7 +39,7 @@ class GGUFQuantizedTensor(TorchAOBaseTensor):
3939 @staticmethod
4040 def __new__ (
4141 cls ,
42- n_super_blocks ,
42+ n_blocks_per_superblock ,
4343 super_block_scale_scale ,
4444 super_block_min_scale ,
4545 quantized_block_scale ,
@@ -55,7 +55,7 @@ def __new__(
5555
5656 def __init__ (
5757 self ,
58- n_super_blocks ,
58+ n_blocks_per_superblock ,
5959 super_block_scale_scale ,
6060 super_block_min_scale ,
6161 quantized_block_scale ,
@@ -64,7 +64,7 @@ def __init__(
6464 shape ,
6565 ** kwargs ,
6666 ):
67- self .n_super_blocks = n_super_blocks
67+ self .n_blocks_per_superblock = n_blocks_per_superblock
6868 self .super_block_scale_scale = super_block_scale_scale
6969 self .super_block_min_scale = super_block_min_scale
7070 self .quantized_block_scale = quantized_block_scale
@@ -73,7 +73,7 @@ def __init__(
7373
7474 def _apply_fn_to_data (self , fn ):
7575 return self .__class__ (
76- self .n_super_blocks ,
76+ self .n_blocks_per_superblock ,
7777 fn (self .super_block_scale_scale ),
7878 fn (self .super_block_min_sclae ),
7979 fn (self .quantized_block_scale ),
@@ -91,7 +91,7 @@ def __tensor_flatten__(self):
9191 "quantized_block_min" ,
9292 "int_data" ,
9393 ], (
94- self .n_super_blocks ,
94+ self .n_blocks_per_superblock ,
9595 self .dtype ,
9696 self .shape ,
9797 )
@@ -113,9 +113,9 @@ def __tensor_unflatten__(
113113 tensor_data_dict ["quantized_block_min" ],
114114 tensor_data_dict ["int_data" ],
115115 )
116- n_super_blocks , dtype , shape = attributes
116+ n_blocks_per_superblock , dtype , shape = attributes
117117 return cls (
118- n_super_blocks ,
118+ n_blocks_per_superblock ,
119119 super_block_scale_scale ,
120120 super_block_min_scale ,
121121 quantized_block_scale ,
@@ -127,7 +127,7 @@ def __tensor_unflatten__(
127127
128128 def dequantize (self , output_dtype : Optional [torch .dtype ] = None ) -> torch .Tensor :
129129 block_size = tuple (
130- [1 ] * (self .int_data .ndim - 1 ) + [_QK_K // self .n_super_blocks ]
130+ [1 ] * (self .int_data .ndim - 1 ) + [_QK_K // self .n_blocks_per_superblock ]
131131 )
132132 return dequantize_gguf (
133133 self .int_data ,
@@ -144,7 +144,7 @@ def detach(self):
144144 Returns a new `CodebookQuantizedTensor`.
145145 """
146146 return self .__class__ (
147- self .n_super_blocks ,
147+ self .n_blocks_per_superblock ,
148148 self .super_block_scale_scale .detach (),
149149 self .super_block_min_scale .detach (),
150150 self .quantized_block_scale .detach (),
@@ -162,7 +162,7 @@ def requires_grad_(self, requires_grad=False):
162162 return self
163163
164164 @classmethod
165- def from_float (cls , input_float , n_super_blocks , target_dtype ):
165+ def from_float (cls , input_float , n_blocks_per_superblock , target_dtype ):
166166 """
167167 Method used to convert a linear weight tensor to an instance of the
168168 GGMLInt4LinearWeight subclass.
@@ -176,7 +176,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype):
176176 assert (
177177 target_dtype == torch .uint4
178178 ), "only uint4 quantization is supported right now"
179- block_size = (1 , _QK_K // n_super_blocks )
179+ block_size = (1 , _QK_K // n_blocks_per_superblock )
180180 (
181181 super_block_scale_scale ,
182182 super_block_min_scale ,
@@ -194,7 +194,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype):
194194 quantized_block_min ,
195195 )
196196 return cls (
197- n_super_blocks ,
197+ n_blocks_per_superblock ,
198198 super_block_scale_scale ,
199199 super_block_min_scale ,
200200 quantized_block_scale ,
@@ -208,7 +208,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype):
208208@dataclass
209209class GGUFWeightOnlyConfig (AOBaseConfig ):
210210 dtype : torch .dtype = torch .uint4
211- n_super_blocks : int = 8
211+ n_blocks_per_superblock : int = 8
212212
213213
214214@register_quantize_module_handler (GGUFWeightOnlyConfig )
@@ -221,7 +221,7 @@ def _gguf_weight_only_transform(
221221
222222 Args:
223223 dtype: torch.uint1 to torch.uint8, torch.int32 supported.
224- n_super_blocks : the number of super blocks in a 256 element block for gguf, e.g. when it is 8
224+ n_blocks_per_superblock : the number of super blocks in a 256 element block for gguf, e.g. when it is 8
225225 it means we have blocks of 32 and 8 blocks in a superblock of 256 elements.
226226 Returns:
227227 Callable for quantization transformation.
@@ -231,7 +231,9 @@ def _gguf_weight_only_transform(
231231 return module
232232
233233 quantized_weight = GGUFQuantizedTensor .from_float (
234- weight , n_super_blocks = config .n_super_blocks , target_dtype = config .dtype
234+ weight ,
235+ n_blocks_per_superblock = config .n_blocks_per_superblock ,
236+ target_dtype = config .dtype ,
235237 )
236238 module .weight = torch .nn .Parameter (quantized_weight , requires_grad = False )
237239 return module
0 commit comments