Skip to content

Commit 9bbec5d

Browse files
committed
Fix bf16 in convert script.
1 parent 9531111 commit 9bbec5d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

convert_hf_to_gguf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
299299
# Repack and merge qweight, scales, and qzeros into a single tensor
300300
# Currently, this logic is nearly impossible to be implemented in quants.py
301301
def _modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
302+
# Convert unsupported bfloat16 to float32
303+
if data_torch.dtype == torch.bfloat16:
304+
data_torch = data_torch.to(torch.float32)
305+
302306
if not self.enable_t_mac or isinstance(self, BitnetModel):
303307
return self.modify_tensors(data_torch, name, bid)
304308

@@ -316,15 +320,15 @@ def _modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Ite
316320
if len(self._gptq_quant_dict[base_name]) < 3:
317321
return []
318322

323+
# Get weight components: all [out_feature, in_feature]
319324
qweight = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qweight"]).numpy()
320325
scales = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".scales"]).numpy()
321326
qzeros = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qzeros"]).numpy()
322327
name = base_name + ".weight"
323328
from gguf.tmac_utils import unpack_gptqv2
324329
w, scales, zeros, bits, group_size = unpack_gptqv2(qweight, scales, qzeros, "gptqmodel" in self.quantization_config["quantizer"])
325330
if bits != self.quantization_config["bits"] or group_size != self.quantization_config["group_size"]:
326-
# logger.error("Error while parsing weights for quantization_config: {}, but got bits={} and group_size={}".format(
327-
# self.quantization_config, bits, group_size))
331+
# Currently, we only support models that all weights are corresponding to the quantization config.
328332
raise ValueError("Error while parsing weights for quantization_config: {}, but got bits={} and group_size={}".format(
329333
self.quantization_config, bits, group_size))
330334
self._t_mac_raw_shape = w.shape

0 commit comments

Comments
 (0)