Skip to content

Commit f819e9a

Browse files
[FIX] non-peristent buffer was saved incorrectly (#2242)
* The `alias_from_turtle_for_submodule()` function needs to check if the buffer is persistent. Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * get_state_dict_for_save() needs to skip non-persistent buffers. Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * add test_model_save.py Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * format Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * Potential fix for pull request finding 'First parameter of a class method is not named 'cls'' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> * add comment Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> --------- Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
1 parent 5cc0c02 commit f819e9a

File tree

3 files changed

+97
-6
lines changed

3 files changed

+97
-6
lines changed

gptqmodel/utils/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,14 @@ def find_modules(module: nn.Module, layers=None, name: str="") -> Dict[str, nn.M
196196
return res
197197

198198

199+
def get_module_by_name(module, child_name):
200+
# get the child module by its name relative to the module
201+
for name, m in module.named_modules():
202+
if name == child_name:
203+
return m
204+
raise ValueError(f"Cannot find child_name {child_name} in module {module}")
205+
206+
199207
def get_module_by_name_prefix(model, module_name: Union[List[str], str]):
200208
module_name_list = module_name if isinstance(module_name, list) else [module_name]
201209
for name, module in model.named_modules():
@@ -1467,7 +1475,13 @@ def _collect_state_dict_with_offload(model: nn.Module, offload_root: str) -> Dic
14671475
for name, buf in model.named_buffers():
14681476
if name in state_dict:
14691477
continue
1478+
1479+
# If the buffer is non-persistent, it does not need to be written to state_dict.
14701480
module_path, leaf = _split_parameter_path(name)
1481+
module = get_module_by_name(model, module_path)
1482+
if hasattr(module, "_non_persistent_buffers_set") and leaf in module._non_persistent_buffers_set:
1483+
continue
1484+
14711485
if getattr(buf, "is_meta", False) or buf.device.type == "meta":
14721486
source = _resolve_offload_entry(
14731487
offload_root,
@@ -1504,6 +1518,13 @@ def get_state_dict_for_save(model: nn.Module, offload_root: Optional[str] = None
15041518
for name, buf in model.named_buffers():
15051519
if name in state_dict:
15061520
continue
1521+
1522+
# If the buffer is non-persistent, it does not need to be written to state_dict.
1523+
module_path, leaf = _split_parameter_path(name)
1524+
module = get_module_by_name(model, module_path)
1525+
if hasattr(module, "_non_persistent_buffers_set") and leaf in module._non_persistent_buffers_set:
1526+
continue
1527+
15071528
state_dict[name] = TensorSource(name=name, torch_dtype=buf.dtype, shape=tuple(buf.shape), source=buf)
15081529

15091530
ptrs = collections.defaultdict(list)

gptqmodel/utils/structure.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def alias_from_turtle_for_submodule(
529529

530530
# Resolve path & source submodule (on CPU/mmap)
531531
path = _get_qualified_name(target_model, target_submodule)
532-
src_map = dict(turtle_model.named_modules())
532+
src_map: Dict[str, nn.Module] = dict(turtle_model.named_modules())
533533
if path not in src_map:
534534
raise KeyError(f"Path '{path}' not found in turtle_model.")
535535
src_sub = src_map[path]
@@ -544,25 +544,33 @@ def alias_from_turtle_for_submodule(
544544
continue
545545
t_p_new = _ensure_target_storage_on_device_(t_p, device)
546546
if t_p_new is not t_p:
547-
parent, leaf = _get_parent_and_leaf_by_path(target_submodule, name)
548-
setattr(parent, leaf, t_p_new)
547+
t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, name)
548+
setattr(t_parent, leaf, t_p_new)
549549
t_p = t_p_new
550550
t_p.detach().copy_(s_p.detach(), non_blocking=(non_blocking and s_p.is_pinned()))
551551

552552
t_bufs = dict(target_submodule.named_buffers(recurse=True))
553553
s_bufs = dict(src_sub.named_buffers(recurse=True))
554554
for name, s_b in s_bufs.items():
555555
tb = t_bufs.get(name)
556-
parent, leaf = _get_parent_and_leaf_by_path(target_submodule, name)
556+
t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, name)
557+
s_parent, _ = _get_parent_and_leaf_by_path(src_sub, name)
558+
559+
# nn.Module decides buffer persistence using `_non_persistent_buffers_set`:
560+
# the buffer is persistent unless its name is in this set.
561+
persistent = True
562+
if hasattr(s_parent, "_non_persistent_buffers_set"):
563+
persistent = leaf not in s_parent._non_persistent_buffers_set
564+
557565
if tb is None or getattr(tb, "is_meta", False) or tb.device.type == "meta":
558566
new_b = torch.empty_like(s_b, device=device)
559567
new_b.copy_(s_b.detach(), non_blocking=(non_blocking and s_b.is_pinned()))
560-
parent.register_buffer(leaf, new_b, persistent=True)
568+
t_parent.register_buffer(leaf, new_b, persistent=persistent)
561569
else:
562570
if tb.device != device:
563571
new_tb = torch.empty_like(s_b, device=device)
564572
new_tb.copy_(s_b.detach(), non_blocking=(non_blocking and s_b.is_pinned()))
565-
parent.register_buffer(leaf, new_tb, persistent=True)
573+
t_parent.register_buffer(leaf, new_tb, persistent=persistent)
566574
else:
567575
tb.copy_(s_b.detach(), non_blocking=(non_blocking and s_b.is_pinned()))
568576

tests/test_model_save.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
2+
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
3+
# SPDX-License-Identifier: Apache-2.0
4+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
# -- do not touch
6+
import os
7+
import tempfile
8+
9+
from datasets import load_dataset
10+
from transformers import AutoTokenizer
11+
12+
from gptqmodel.utils.torch import torch_empty_cache
13+
14+
15+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
16+
# -- end do not touch
17+
18+
import unittest # noqa: E402
19+
20+
# isort: off
21+
# isort: on
22+
from parameterized import parameterized # noqa: E402
23+
from safetensors import safe_open
24+
25+
from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402
26+
27+
28+
class TestModelSave(unittest.TestCase):
29+
30+
@classmethod
31+
def setUpClass(cls):
32+
cls.pretrained_model_id = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct"
33+
34+
cls.tokenizer = AutoTokenizer.from_pretrained(cls.pretrained_model_id, use_fast=True)
35+
36+
traindata = load_dataset(path="/monster/data/model/dataset/nm-calibration", name="LLM", split="train")
37+
cls.calibration_dataset = traindata.select(range(1))
38+
39+
@parameterized.expand([
40+
True,
41+
False,
42+
])
43+
def test_model_save_with_non_persistent_buffer(self, offload_to_disk):
44+
quantize_config = QuantizeConfig(
45+
bits=4,
46+
offload_to_disk=offload_to_disk,
47+
)
48+
49+
model = GPTQModel.load(
50+
self.pretrained_model_id,
51+
quantize_config=quantize_config,
52+
)
53+
model.quantize(self.calibration_dataset, batch_size=1)
54+
with tempfile.TemporaryDirectory() as tmp_dir_name:
55+
model.save(tmp_dir_name)
56+
57+
del model
58+
torch_empty_cache()
59+
60+
with safe_open(tmp_dir_name+"/model.safetensors", framework="pt") as f:
61+
print("weight_map", f.keys())
62+
self.assertNotIn('model.rotary_emb.inv_freq', f.keys())

0 commit comments

Comments
 (0)