Skip to content

Commit f509cd8

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 1df584e + 0bcb40b commit f509cd8

File tree

7 files changed

+294
-76
lines changed

7 files changed

+294
-76
lines changed

convert_hf_to_gguf.py

Lines changed: 177 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,8 @@ class ModelBase:
9090
use_temp_file: bool
9191
lazy: bool
9292
dry_run: bool
93-
part_names: list[str]
94-
is_safetensors: bool
9593
hparams: dict[str, Any]
96-
tensor_names: set[str] | None
94+
model_tensors: dict[str, Callable[[], Tensor]]
9795
gguf_writer: gguf.GGUFWriter
9896
model_name: str | None
9997
metadata_override: Path | None
@@ -137,25 +135,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
137135
self.dry_run = dry_run
138136
self.remote_hf_model_id = remote_hf_model_id
139137
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
140-
if remote_hf_model_id is not None:
141-
self.is_safetensors = True
142-
143-
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
144-
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
145-
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
146-
self.tensor_names = set(name for name in remote_tensors.keys())
147-
for name, remote_tensor in remote_tensors.items():
148-
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
149-
150-
self.get_tensors = get_remote_tensors
151-
else:
152-
prefix = "model" if not self.is_mistral_format else "consolidated"
153-
self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
154-
self.is_safetensors = len(self.part_names) > 0
155-
if not self.is_safetensors:
156-
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
157138
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
158-
self.tensor_names = None
139+
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
159140
self.metadata_override = metadata_override
160141
self.model_name = model_name
161142
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
@@ -171,6 +152,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
171152
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
172153
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
173154

155+
self.dequant_model()
156+
174157
# Configure GGUF Writer
175158
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
176159
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@@ -192,67 +175,215 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
192175
return None
193176
raise KeyError(f"could not find any of: {keys}")
194177

195-
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
196-
tensor_names_from_parts: set[str] = set()
178+
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
179+
tensors: dict[str, Callable[[], Tensor]] = {}
180+
181+
if remote_hf_model_id is not None:
182+
is_safetensors = True
183+
184+
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
185+
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
186+
for name, remote_tensor in remote_tensors.items():
187+
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
188+
189+
return tensors
190+
191+
prefix = "model" if not self.is_mistral_format else "consolidated"
192+
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
193+
is_safetensors: bool = len(part_names) > 0
194+
if not is_safetensors:
195+
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
196+
197+
tensor_names_from_index: set[str] = set()
197198

198199
if not self.is_mistral_format:
199-
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
200+
index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
200201
index_name += ".index.json"
201202
index_file = self.dir_model / index_name
202203

203204
if index_file.is_file():
204-
self.tensor_names = set()
205205
logger.info(f"gguf: loading model weight map from '{index_name}'")
206206
with open(index_file, "r", encoding="utf-8") as f:
207207
index: dict[str, Any] = json.load(f)
208208
weight_map = index.get("weight_map")
209209
if weight_map is None or not isinstance(weight_map, dict):
210210
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
211-
self.tensor_names.update(weight_map.keys())
211+
tensor_names_from_index.update(weight_map.keys())
212212
else:
213-
self.tensor_names = tensor_names_from_parts
214213
weight_map = {}
215214
else:
216-
self.tensor_names = tensor_names_from_parts
217215
weight_map = {}
218216

219-
for part_name in self.part_names:
220-
logger.info(f"gguf: loading model part '{part_name}'")
217+
for part_name in part_names:
218+
logger.info(f"gguf: indexing model part '{part_name}'")
221219
ctx: ContextManager[Any]
222-
if self.is_safetensors:
220+
if is_safetensors:
223221
from safetensors import safe_open
224222
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
225223
else:
226224
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
227225

228226
with ctx as model_part:
229-
tensor_names_from_parts.update(model_part.keys())
227+
assert model_part is not None
230228

231229
for name in model_part.keys():
232-
if self.is_safetensors:
230+
if is_safetensors:
233231
if self.lazy:
234232
data = model_part.get_slice(name)
235-
data = LazyTorchTensor.from_safetensors_slice(data)
233+
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
236234
else:
237235
data = model_part.get_tensor(name)
236+
data_gen = lambda data=data: data # noqa: E731
238237
else:
239238
data = model_part[name]
240239
if self.lazy:
241-
data = LazyTorchTensor.from_eager(data)
242-
yield name, data
240+
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
241+
else:
242+
data_gen = lambda data=data: data # noqa: E731
243+
tensors[name] = data_gen
243244

244245
# verify tensor name presence and identify potentially missing files
245-
if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
246-
missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
247-
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
248-
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
249-
if len(extra) == 0 and len(missing_files) > 0:
250-
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
251-
f"Missing tensors: {missing}")
246+
if len(tensor_names_from_index) > 0:
247+
tensor_names_from_parts = set(tensors.keys())
248+
if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
249+
missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
250+
extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
251+
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
252+
if len(extra) == 0 and len(missing_files) > 0:
253+
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
254+
f"Missing tensors: {missing}")
255+
else:
256+
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
257+
f"Missing tensors: {missing}\n"
258+
f"Extra tensors: {extra}")
259+
260+
return tensors
261+
262+
def dequant_model(self):
263+
tensors_to_remove: list[str] = []
264+
new_tensors: dict[str, Callable[[], Tensor]] = {}
265+
266+
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
267+
quant_method = quant_config.get("quant_method")
268+
269+
def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
270+
weight = weight.view(torch.uint8)
271+
orig_shape = weight.shape
272+
273+
shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
274+
data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift
275+
data = data & 3
276+
data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))
277+
278+
# The scale is inverted
279+
return data / scale.float()
280+
281+
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
282+
scale = scale.float()
283+
284+
if (weight_block_size := quant_config.get("weight_block_size")):
285+
# TODO: make sure it's a list of integers
286+
for i, size in enumerate(weight_block_size):
287+
scale = scale.repeat_interleave(size, i)
288+
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
289+
scale = scale[tuple(slice(0, size) for size in weight.shape)]
290+
291+
return weight.float() * scale
292+
293+
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
294+
def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor:
295+
bits = quant_config["bits"]
296+
assert bits in (2, 3, 4, 8)
297+
assert qweight.dtype == qzeros.dtype
298+
maxq = (2 ** bits) - 1
299+
weight = None
300+
zeros = None
301+
pack_dtype_bits = qweight.dtype.itemsize * 8
302+
303+
if bits in [2, 4, 8]:
304+
pack_factor = pack_dtype_bits // bits
305+
wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0)
306+
if self.lazy:
307+
wf = LazyTorchTensor.from_eager(wf)
308+
309+
zeros = torch.bitwise_right_shift(
310+
qzeros.unsqueeze(2).expand(-1, -1, pack_factor),
311+
wf.unsqueeze(0)
312+
).to(torch.int16 if bits == 8 else torch.int8)
313+
zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape)
314+
315+
weight = torch.bitwise_and(
316+
torch.bitwise_right_shift(
317+
qweight.unsqueeze(1).expand(-1, pack_factor, -1),
318+
wf.unsqueeze(-1)
319+
).to(torch.int16 if bits == 8 else torch.int8),
320+
maxq
321+
)
322+
elif bits == 3:
323+
raise NotImplementedError("3-bit gptq dequantization is not yet implemented")
324+
325+
assert weight is not None
326+
assert zeros is not None
327+
328+
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
329+
330+
# gptq_v2 doesn't need to offset zeros
331+
if quant_config.get("checkpoint_format", "gptq") == "gptq":
332+
zeros += 1
333+
334+
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
335+
336+
if quant_method == "bitnet":
337+
for name in self.model_tensors.keys():
338+
if name.endswith(".weight_scale"):
339+
weight_name = name.removesuffix("_scale")
340+
w = self.model_tensors[weight_name]
341+
s = self.model_tensors[name]
342+
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
343+
tensors_to_remove.append(name)
344+
elif quant_method == "fp8":
345+
for name in self.model_tensors.keys():
346+
if name.endswith(".weight_scale_inv"):
347+
weight_name = name.removesuffix("_scale_inv")
348+
w = self.model_tensors[weight_name]
349+
s = self.model_tensors[name]
350+
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
351+
tensors_to_remove.append(name)
352+
elif quant_method == "gptq":
353+
for name in self.model_tensors.keys():
354+
if name.endswith(".qweight"):
355+
base_name = name.removesuffix(".qweight")
356+
g_idx = self.model_tensors[base_name + ".g_idx"]
357+
qweight = self.model_tensors[base_name + ".qweight"]
358+
qzeros = self.model_tensors[base_name + ".qzeros"]
359+
scales = self.model_tensors[base_name + ".scales"]
360+
new_tensors[base_name + ".weight"] = (
361+
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
362+
g(), w(), z(), s()
363+
)
364+
)
365+
tensors_to_remove += [
366+
base_name + n
367+
for n in (
368+
".g_idx",
369+
".qzeros",
370+
".qweight",
371+
".scales",
372+
)
373+
]
252374
else:
253-
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
254-
f"Missing tensors: {missing}\n"
255-
f"Extra tensors: {extra}")
375+
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
376+
377+
for name in tensors_to_remove:
378+
if name in self.model_tensors:
379+
del self.model_tensors[name]
380+
381+
for name, value in new_tensors.items():
382+
self.model_tensors[name] = value
383+
384+
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
385+
for name, gen in self.model_tensors.items():
386+
yield name, gen()
256387

257388
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
258389
if key not in gguf.MODEL_TENSORS[self.model_arch]:
@@ -4381,27 +4512,6 @@ def set_gguf_parameters(self):
43814512
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
43824513
self.gguf_writer.add_rope_scaling_factor(1.0)
43834514

4384-
_has_tok_embd = False
4385-
4386-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4387-
del bid # unused
4388-
4389-
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
4390-
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
4391-
4392-
new_name = self.map_tensor_name(name)
4393-
4394-
# assuming token_embd.weight is seen before output.weight
4395-
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
4396-
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
4397-
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
4398-
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
4399-
self.tensor_names.remove("transformer.wte.weight")
4400-
elif new_name == tok_embd_name:
4401-
self._has_tok_embd = True
4402-
4403-
return [(new_name, data_torch)]
4404-
44054515

44064516
@ModelBase.register("InternLM2ForCausalLM")
44074517
class InternLM2Model(TextModel):

examples/model-conversion/scripts/causal/run-org-model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def fn(_m, input, output):
138138
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
139139
)
140140

141-
config = AutoConfig.from_pretrained(model_path)
141+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
142142

143143
print("Model type: ", config.model_type)
144144
print("Vocab size: ", config.vocab_size)
@@ -148,8 +148,8 @@ def fn(_m, input, output):
148148
print("EOS token id: ", config.eos_token_id)
149149

150150
print("Loading model and tokenizer using AutoTokenizer:", model_path)
151-
tokenizer = AutoTokenizer.from_pretrained(model_path)
152-
config = AutoConfig.from_pretrained(model_path)
151+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
152+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
153153

154154
if unreleased_model_name:
155155
model_name_lower = unreleased_model_name.lower()
@@ -171,7 +171,7 @@ def fn(_m, input, output):
171171
exit(1)
172172
else:
173173
model = AutoModelForCausalLM.from_pretrained(
174-
model_path, device_map="auto", offload_folder="offload"
174+
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
175175
)
176176

177177
for name, module in model.named_modules():

0 commit comments

Comments
 (0)