@@ -299,6 +299,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
299299 # Repack and merge qweight, scales, and qzeros into a single tensor
300300 # Currently, this logic is nearly impossible to be implemented in quants.py
301301 def _modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
302+ # Convert unsupported bfloat16 to float32
303+ if data_torch .dtype == torch .bfloat16 :
304+ data_torch = data_torch .to (torch .float32 )
305+
302306 if not self .enable_t_mac or isinstance (self , BitnetModel ):
303307 return self .modify_tensors (data_torch , name , bid )
304308
@@ -316,15 +320,15 @@ def _modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Ite
316320 if len (self ._gptq_quant_dict [base_name ]) < 3 :
317321 return []
318322
323+ # Get weight components: all [out_feature, in_feature]
319324 qweight = LazyTorchTensor .to_eager (self ._gptq_quant_dict [base_name ][".qweight" ]).numpy ()
320325 scales = LazyTorchTensor .to_eager (self ._gptq_quant_dict [base_name ][".scales" ]).numpy ()
321326 qzeros = LazyTorchTensor .to_eager (self ._gptq_quant_dict [base_name ][".qzeros" ]).numpy ()
322327 name = base_name + ".weight"
323328 from gguf .tmac_utils import unpack_gptqv2
324329 w , scales , zeros , bits , group_size = unpack_gptqv2 (qweight , scales , qzeros , "gptqmodel" in self .quantization_config ["quantizer" ])
325330 if bits != self .quantization_config ["bits" ] or group_size != self .quantization_config ["group_size" ]:
326- # logger.error("Error while parsing weights for quantization_config: {}, but got bits={} and group_size={}".format(
327- # self.quantization_config, bits, group_size))
331+ # Currently, we only support models that all weights are corresponding to the quantization config.
328332 raise ValueError ("Error while parsing weights for quantization_config: {}, but got bits={} and group_size={}" .format (
329333 self .quantization_config , bits , group_size ))
330334 self ._t_mac_raw_shape = w .shape
0 commit comments