Skip to content

Commit 89c0620

Browse files
committed
Final fixes to ensure faithfulness
1 parent b5d3411 commit 89c0620

File tree

12 files changed

+414
-530
lines changed

12 files changed

+414
-530
lines changed

analyze_conversion_faithfulness.py

Lines changed: 0 additions & 443 deletions
This file was deleted.

conversion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@
210210
_loaded_text_modules = set()
211211
_loaded_mmproj_modules = set()
212212

213+
213214
# Function to load all model modules
214215
def _load_all_models():
215216
"""Import all model modules to trigger registration."""
@@ -253,6 +254,7 @@ def get_model_class(name: str, mmproj: bool = False) -> Type[ModelBase]:
253254
module = __import__(f"conversion.{module_name}", fromlist=[class_name])
254255
return getattr(module, class_name)
255256

257+
256258
def print_registered_models():
257259
logger.error("TEXT models:")
258260
for name in sorted(TEXT_MODEL_MAP.keys()):

conversion/afmoe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,4 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
6969
if name.endswith(".expert_bias"):
7070
name = name.replace(".expert_bias", ".expert_bias.bias")
7171

72-
return [(self.map_tensor_name(name), data_torch)]
72+
return [(self.map_tensor_name(name), data_torch)]

conversion/base.py

Lines changed: 102 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -196,26 +196,25 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
196196
logger.info(f"gguf: indexing model part '{part_name}'")
197197
ctx: ContextManager[Any]
198198
if is_safetensors:
199-
from safetensors import safe_open
200-
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
199+
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
201200
else:
202201
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
203202
with ctx as model_part:
204203
assert model_part is not None
205204
for name in model_part.keys():
206205
if is_safetensors:
206+
data: gguf.utility.LocalTensor = model_part[name]
207207
if self.lazy:
208-
data = model_part.get_slice(name)
209-
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
208+
data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
210209
else:
211-
data = model_part.get_tensor(name)
212-
data_gen = lambda data=data: data # noqa: E731
210+
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
211+
data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
213212
else:
214-
data = model_part[name]
213+
data_torch: Tensor = model_part[name]
215214
if self.lazy:
216-
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
215+
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
217216
else:
218-
data_gen = lambda data=data: data # noqa: E731
217+
data_gen = lambda data=data_torch: data # noqa: E731
219218
tensors[name] = data_gen
220219
# verify tensor name presence and identify potentially missing files
221220
if len(tensor_names_from_index) > 0:
@@ -249,14 +248,15 @@ def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
249248
# The scale is inverted
250249
return data / scale.float()
251250

252-
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
251+
def dequant_simple(weight: Tensor, scale: Tensor, block_size: Sequence[int] | None = None) -> Tensor:
253252
scale = scale.float()
254-
if (weight_block_size := quant_config.get("weight_block_size")):
255-
# TODO: make sure it's a list of integers
256-
for i, size in enumerate(weight_block_size):
253+
254+
if block_size is not None:
255+
for i, size in enumerate(block_size):
257256
scale = scale.repeat_interleave(size, i)
258-
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
259-
scale = scale[tuple(slice(0, size) for size in weight.shape)]
257+
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
258+
scale = scale[tuple(slice(0, size) for size in weight.shape)]
259+
260260
return weight.float() * scale
261261

262262
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
@@ -294,6 +294,41 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
294294
if quant_config.get("checkpoint_format", "gptq") == "gptq":
295295
zeros += 1
296296
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
297+
298+
def dequant_packed(w: Tensor, scale: Tensor, shape_tensor: Tensor, zero_point: Tensor | None, num_bits: int, group_size: int):
299+
assert w.dtype == torch.int32
300+
shape = tuple(shape_tensor.tolist())
301+
assert len(shape) == 2
302+
mask = (1 << num_bits) - 1
303+
304+
shifts = torch.arange(0, 32 - (num_bits - 1), num_bits, dtype=torch.int32)
305+
if self.lazy:
306+
shifts = LazyTorchTensor.from_eager(shifts)
307+
308+
if zero_point is None:
309+
offset = 1 << (num_bits - 1)
310+
else:
311+
assert len(zero_point.shape) == 2
312+
offset = (zero_point.unsqueeze(1) >> shifts.reshape(1, -1, 1)) & mask
313+
offset = offset.reshape(-1, zero_point.shape[1])
314+
# trim padding, and prepare for broadcast
315+
# NOTE: the zero-point is packed along dim 0
316+
offset = offset[:shape[0], :].unsqueeze(-1)
317+
318+
# extract values
319+
# NOTE: the weights are packed along dim 1
320+
unpacked = (w.unsqueeze(-1) >> shifts.reshape(1, 1, -1)) & mask
321+
unpacked = unpacked.reshape(shape[0], -1)
322+
323+
# trim padding
324+
unpacked = unpacked[:, :shape[1]]
325+
326+
# prepare for broadcast of the scale
327+
unpacked = unpacked.reshape(shape[0], (unpacked.shape[-1] + group_size - 1) // group_size, group_size)
328+
unpacked = unpacked - offset
329+
330+
return (unpacked * scale.unsqueeze(-1).float()).reshape(shape)
331+
297332
if quant_method == "bitnet":
298333
for name in self.model_tensors.keys():
299334
if name.endswith(".weight_scale"):
@@ -303,12 +338,13 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
303338
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
304339
tensors_to_remove.append(name)
305340
elif quant_method == "fp8":
341+
block_size = quant_config.get("weight_block_size")
306342
for name in self.model_tensors.keys():
307343
if name.endswith(".weight_scale_inv"):
308344
weight_name = name.removesuffix("_scale_inv")
309345
w = self.model_tensors[weight_name]
310346
s = self.model_tensors[name]
311-
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
347+
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
312348
tensors_to_remove.append(name)
313349
elif quant_method == "gptq":
314350
for name in self.model_tensors.keys():
@@ -332,11 +368,56 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
332368
".scales",
333369
)
334370
]
371+
elif quant_method == "compressed-tensors":
372+
quant_format = quant_config["format"]
373+
groups = quant_config["config_groups"]
374+
if len(groups) > 1:
375+
raise NotImplementedError("Can't handle multiple config groups for compressed-tensors yet")
376+
weight_config = tuple(groups.values())[0]["weights"]
377+
378+
if quant_format == "float-quantized" or quant_format == "int-quantized" or quant_format == "naive-quantized":
379+
block_size = weight_config.get("block_structure", None)
380+
strategy = weight_config.get("strategy")
381+
assert strategy == "channel" or strategy == "block"
382+
assert weight_config.get("group_size") is None # didn't find a model using this yet
383+
for name in self.model_tensors.keys():
384+
if name.endswith(".weight_scale"):
385+
weight_name = name.removesuffix("_scale")
386+
w = self.model_tensors[weight_name]
387+
s = self.model_tensors[name]
388+
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), block_size)
389+
tensors_to_remove.append(name)
390+
elif quant_format == "pack-quantized":
391+
assert weight_config.get("strategy") == "group"
392+
assert weight_config.get("type", "int") == "int"
393+
num_bits = weight_config.get("num_bits")
394+
group_size = weight_config.get("group_size")
395+
assert isinstance(num_bits, int)
396+
assert isinstance(group_size, int)
397+
for name in self.model_tensors.keys():
398+
if name.endswith(".weight_packed"):
399+
base_name = name.removesuffix("_packed")
400+
w = self.model_tensors[name]
401+
scale = self.model_tensors[base_name + "_scale"]
402+
shape = self.model_tensors[base_name + "_shape"]
403+
zero_point = self.model_tensors.get(base_name + "_zero_point", lambda: None)
404+
new_tensors[base_name] = (
405+
lambda w=w, scale=scale, shape=shape, zero_point=zero_point: dequant_packed(
406+
w(), scale(), shape(), zero_point(), num_bits, group_size,
407+
)
408+
)
409+
tensors_to_remove += [base_name + n for n in ("_packed", "_shape", "_scale")]
410+
if (base_name + "_zero_point") in self.model_tensors:
411+
tensors_to_remove.append(base_name + "_zero_point")
412+
else:
413+
raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
335414
else:
336415
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
416+
337417
for name in tensors_to_remove:
338418
if name in self.model_tensors:
339419
del self.model_tensors[name]
420+
340421
for name, value in new_tensors.items():
341422
self.model_tensors[name] = value
342423

@@ -940,6 +1021,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
9401021
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
9411022
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
9421023
res = "mellum"
1024+
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
1025+
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
1026+
res = "afmoe"
9431027
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
9441028
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
9451029
res = "bailingmoe2"
@@ -990,10 +1074,11 @@ def _set_vocab_qwen(self):
9901074
vocab_size = hparams["vocab_size"]
9911075
assert max(tokenizer.get_vocab().values()) < vocab_size
9921076
tokpre = self.get_vocab_base_pre(tokenizer)
993-
QwenModel = _get_qwen_model()
1077+
9941078
merges = []
9951079
vocab = {}
9961080
mergeable_ranks = tokenizer.mergeable_ranks
1081+
QwenModel = _get_qwen_model()
9971082
for token, rank in mergeable_ranks.items():
9981083
vocab[QwenModel.token_bytes_to_string(token)] = rank
9991084
if len(token) == 1:

conversion/chameleon.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from torch import Tensor
99

1010

11-
@ModelBase.register("ChameleonForConditionalGeneration", "ChameleonForCausalLM")
11+
@ModelBase.register("ChameleonForConditionalGeneration")
12+
@ModelBase.register("ChameleonForCausalLM")
1213
class ChameleonModel(TextModel):
1314
model_arch = gguf.MODEL_ARCH.CHAMELEON
1415

0 commit comments

Comments
 (0)