Skip to content

Commit 7aec12c

Browse files
committed
break out _tie_offloaded_tensors, add test
Signed-off-by: Kyle Sayers <[email protected]>
1 parent dfdbd3f commit 7aec12c

File tree

4 files changed

+88
-6
lines changed

4 files changed

+88
-6
lines changed

src/compressed_tensors/transform/apply.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import defaultdict
16+
from typing import List, Tuple
17+
1518
import torch
1619
from compressed_tensors import TRANSFORM_CONFIG_NAME
1720
from compressed_tensors.transform import TransformConfig, TransformFactory
@@ -34,3 +37,36 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
3437

3538
# attach config to model for compression/serialization
3639
setattr(model, TRANSFORM_CONFIG_NAME, config)
40+
41+
# ensure that tied weight transforms can be serialized without aliases
42+
# In the future, this could be done by transformers or model compressor
43+
# which would make this more robust to changing dispatches after transforms
44+
_tie_offloaded_tensors(model)
45+
46+
47+
def _tie_offloaded_tensors(model: torch.nn.Module):
48+
"""
49+
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
50+
which is used by transformers to detect and remove shared pointers
51+
during saving
52+
"""
53+
from compressed_tensors.utils import has_offloaded_params
54+
55+
# map from to keys
56+
offloaded_ptrs: dict[int, List[Tuple[torch.nn.Module, str]]] = defaultdict(list)
57+
for module in model.modules():
58+
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
59+
if has_offloaded_params(module):
60+
for key, _ in module.named_parameters(recurse=False):
61+
param = module._hf_hook.weights_map[key]
62+
offloaded_ptrs[id(param)].append((module, key))
63+
64+
# populate `_dynamic_tied_weights_keys` if there is more than one key
65+
# and ensure that they share tensors. In the case of offloading, this
66+
for shared_keys in offloaded_ptrs.values():
67+
if len(shared_keys) > 1:
68+
first_tensor = getattr(shared_keys[0][0], shared_keys[0][1])
69+
assert first_tensor.device.type == "meta"
70+
for module, key in shared_keys:
71+
assert getattr(module, key).device.type == "meta"
72+
setattr(module, key, first_tensor)

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from typing import List, Optional, Set
16+
from typing import List, Optional
1717

1818
import torch
1919
import torch.nn.utils.parametrize as P
@@ -56,7 +56,6 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
5656
self.name = name
5757
self.scheme = scheme
5858
self.generator = torch.Generator()
59-
self.transforms = list()
6059
if seed is not None:
6160
self.generator.manual_seed(seed)
6261

@@ -117,7 +116,6 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
117116
# create transform as submodule
118117
transform_name = f"{self.name}_{args.location}"
119118
transform = self.create_transform(module, args)
120-
self.transforms.append(transform)
121119
register_offload_module(module, transform_name, transform)
122120

123121
# register input transformation hook

tests/test_transform/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020

2121
class TransformableModel(PreTrainedModel):
22+
config_class = PretrainedConfig
23+
2224
def __init__(self, *sizes):
2325
super().__init__(config=PretrainedConfig())
2426
self.fcs = torch.nn.ModuleList(

tests/test_transform/factory/test_serialization.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
1517
import pytest
1618
import torch
1719
from compressed_tensors.transform import (
@@ -20,7 +22,9 @@
2022
apply_transform_config,
2123
)
2224
from compressed_tensors.utils import offloaded_dispatch
25+
from safetensors import safe_open
2326
from tests.testing_utils import requires_accelerate, requires_gpu
27+
from transformers import AutoModelForCausalLM, AutoTokenizer
2428

2529

2630
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
@@ -38,15 +42,57 @@ def test_serialization(type, randomize, model_apply, tmp_path, offload=False):
3842
apply_transform_config(model, config)
3943

4044
# save model
41-
model.save_pretrained(tmp_path)
45+
model_path = os.path.join(tmp_path, "test_model_path")
46+
model.save_pretrained(model_path)
47+
48+
# check that saved values match model values
49+
# note that shared weights are only serialized once
50+
safetensors_path = os.path.join(model_path, "model.safetensors")
51+
with safe_open(safetensors_path, framework="pt", device="cpu") as file:
52+
saved_keys = set(file.keys())
53+
assert {
54+
"fcs.0.weight",
55+
"fcs.1.weight",
56+
"fcs.2.weight",
57+
"fcs.3.weight",
58+
"fcs.4.weight",
59+
} <= saved_keys
60+
for key in saved_keys:
61+
param = model.get_parameter(key)
62+
saved_param = file.get_tensor(key)
4263

43-
# TODO: reload model
64+
if param.device.type != "meta": # skip testing values in offload case
65+
assert torch.equal(param, saved_param)
4466

4567

46-
@pytest.mark.skip(reason="Requires changes in upstream transformers")
4768
@requires_gpu
4869
@requires_accelerate()
4970
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
5071
@pytest.mark.parametrize("randomize", (True, False))
5172
def test_serialization_offload(type, randomize, model_apply, tmp_path):
5273
test_serialization(type, randomize, model_apply, tmp_path, offload=True)
74+
75+
76+
@pytest.mark.skip("Requires transformers#40673")
77+
@requires_gpu
78+
@pytest.mark.parametrize(
79+
"model_stub,exp_perplexity",
80+
[
81+
("nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", 10.0),
82+
("nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", 10.0),
83+
],
84+
)
85+
def test_load_perplexity(model_stub, exp_perplexity):
86+
model = AutoModelForCausalLM.from_pretrained(model_stub, device_map="cuda")
87+
tokenizer = AutoTokenizer.from_pretrained(model_stub)
88+
89+
prompt = "The capital of France is Paris, the capital of Germany is Berlin"
90+
inputs = tokenizer(prompt, return_tensors="pt")
91+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
92+
labels = inputs["input_ids"]
93+
94+
with torch.no_grad():
95+
outputs = model(**inputs, labels=labels)
96+
97+
perplexity = torch.exp(outputs.loss)
98+
assert perplexity <= exp_perplexity

0 commit comments

Comments
 (0)