Skip to content

Commit de9cfc6

Browse files
authored
[Transform] Support loading random hadamards on meta device (#445)
* meta hadamards Signed-off-by: Kyle Sayers <[email protected]> * fix dynamic weights keys Signed-off-by: Kyle Sayers <[email protected]> * break out _tie_offloaded_tensors, add test Signed-off-by: Kyle Sayers <[email protected]> * better comments Signed-off-by: Kyle Sayers <[email protected]> * better comments Signed-off-by: Kyle Sayers <[email protected]> * simplify function Signed-off-by: Kyle Sayers <[email protected]> * style Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent aa06b72 commit de9cfc6

File tree

6 files changed

+96
-44
lines changed

6 files changed

+96
-44
lines changed

src/compressed_tensors/transform/apply.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Dict
16+
1517
import torch
18+
from accelerate.utils import has_offloaded_params
1619
from compressed_tensors import TRANSFORM_CONFIG_NAME
1720
from compressed_tensors.transform import TransformConfig, TransformFactory
1821

@@ -34,3 +37,35 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
3437

3538
# attach config to model for compression/serialization
3639
setattr(model, TRANSFORM_CONFIG_NAME, config)
40+
41+
# ensure that tied weight transforms can be serialized without aliases
42+
# In the future, this could be done by transformers or model compressor
43+
# which would make this more robust to changing dispatches after transforms
44+
_tie_offloaded_tensors(model)
45+
46+
47+
def _tie_offloaded_tensors(model: torch.nn.Module):
48+
"""
49+
When accelerate replaces tensors with meta tensors during offloading, the meta
50+
tensors may not be identical, even if the offloaded values are identical.
51+
52+
However, transformers can only serialize correctly if meta tensors are identical
53+
(see transformers#39263).
54+
55+
This function collects all meta tensors which have shared offloaded values and sets
56+
those tensors to be identical so that they can be removed during serialization
57+
58+
:param model: model potentially containing offloaded meta tensors to fix
59+
"""
60+
61+
# ensure that if a location shares an offloaded tensor pointers, that the
62+
# meta tensor is also identical (assigned to the first instance of parameter)
63+
ptr_to_meta: Dict[int, torch.nn.Parameter] = dict()
64+
for module in model.modules():
65+
if has_offloaded_params(module):
66+
for key, _ in module.named_parameters(recurse=False):
67+
offloaded_ptr = module._hf_hook.weights_map[key].data_ptr()
68+
69+
if offloaded_ptr not in ptr_to_meta:
70+
ptr_to_meta[offloaded_ptr] = getattr(module, key)
71+
setattr(module, key, ptr_to_meta[offloaded_ptr])

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from collections import defaultdict
17-
from typing import List, Optional, Set, Tuple
16+
from typing import List, Optional
1817

1918
import torch
2019
import torch.nn.utils.parametrize as P
@@ -57,7 +56,6 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
5756
self.name = name
5857
self.scheme = scheme
5958
self.generator = torch.Generator()
60-
self.transforms = list()
6159
if seed is not None:
6260
self.generator.manual_seed(seed)
6361

@@ -101,8 +99,6 @@ def apply_to_model(self, model: Module, use_tqdm=True):
10199
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
102100
self._apply_to_module(module, arg)
103101

104-
self._update_tied_weights()
105-
106102
def _apply_to_module(self, module: Module, args: TransformArgs):
107103
"""
108104
Create transforms and apply them to the module
@@ -120,7 +116,6 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
120116
# create transform as submodule
121117
transform_name = f"{self.name}_{args.location}"
122118
transform = self.create_transform(module, args)
123-
self.transforms.append(transform)
124119
register_offload_module(module, transform_name, transform)
125120

126121
# register input transformation hook
@@ -165,31 +160,6 @@ def output_hook(_, _input, output):
165160
else:
166161
raise NotImplementedError()
167162

168-
def _update_tied_weights(self):
169-
"""
170-
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
171-
which is used by transformers to detect and remove shared pointers
172-
during saving
173-
"""
174-
# map from data_ptrs to keys
175-
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
176-
for transform in self.transforms:
177-
for name, param in transform.named_parameters(recurse=False):
178-
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
179-
if has_offloaded_params(transform):
180-
param = transform._hf_hook.weights_map[name]
181-
ptr_to_keys[param.data_ptr()].append((transform, name))
182-
183-
# populate `_dynamic_tied_weights_keys` if there is more than one key
184-
# and ensure that they share tensors
185-
for shared_keys in ptr_to_keys.values():
186-
if len(shared_keys) > 1:
187-
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
188-
189-
for transform, name in shared_keys:
190-
transform._dynamic_tied_weights_keys.add(name)
191-
setattr(transform, name, tensor)
192-
193163

194164
class TransformBase(InternalModule, ABC):
195165
"""
@@ -198,11 +168,7 @@ class TransformBase(InternalModule, ABC):
198168

199169
args: TransformArgs
200170
weight: Parameter
201-
_dynamic_tied_weights_keys: Set[str]
202-
203-
def __init__(self):
204-
super().__init__()
205-
self._dynamic_tied_weights_keys = set()
171+
_dynamic_tied_weights_keys: List[str] = ["weight"]
206172

207173
@abstractmethod
208174
def forward(self, value: Tensor) -> Tensor:

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from typing import List, Optional
1616

1717
import torch
1818
from compressed_tensors.transform import TransformArgs, TransformScheme
@@ -84,6 +84,8 @@ def _create_permutation(self, weight: Parameter) -> Parameter:
8484

8585

8686
class HadamardTransform(TransformBase):
87+
_dynamic_tied_weights_keys: List[str] = ["weight", "perm"]
88+
8789
def __init__(
8890
self,
8991
weight: Parameter,

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,16 @@ def _fetch_hadamard_divisor(
115115
than forcing callers to manage the file open context
116116
117117
:param n: size of known hadamard matrix
118+
:param dtype: data type to move fetched hadamard to
119+
:param device: device to move fetched hadamard to
118120
:return: a known hadamard matrix of size `n` if one exists, else None
119121
"""
120-
with safe_open(file_path, framework="pt", device=str(device)) as file:
122+
open_device = torch.device("cpu") if device.type == "meta" else device
123+
with safe_open(file_path, framework="pt", device=str(open_device)) as file:
121124
divisors = sorted((int(key) for key in file.keys()), reverse=True)
122125
for divisor in divisors:
123126
if n % divisor == 0 and is_pow2(n // divisor):
124-
return file.get_tensor(str(divisor)).to(dtype=dtype)
127+
return file.get_tensor(str(divisor)).to(dtype=dtype, device=device)
125128

126129
return None
127130

tests/test_transform/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020

2121
class TransformableModel(PreTrainedModel):
22+
config_class = PretrainedConfig
23+
2224
def __init__(self, *sizes):
2325
super().__init__(config=PretrainedConfig())
2426
self.fcs = torch.nn.ModuleList(

tests/test_transform/factory/test_serialization.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
1517
import pytest
1618
import torch
1719
from compressed_tensors.transform import (
@@ -20,7 +22,9 @@
2022
apply_transform_config,
2123
)
2224
from compressed_tensors.utils import offloaded_dispatch
25+
from safetensors import safe_open
2326
from tests.testing_utils import requires_accelerate, requires_gpu
27+
from transformers import AutoModelForCausalLM, AutoTokenizer
2428

2529

2630
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
@@ -38,17 +42,57 @@ def test_serialization(type, randomize, model_apply, tmp_path, offload=False):
3842
apply_transform_config(model, config)
3943

4044
# save model
41-
model.save_pretrained(tmp_path)
45+
model_path = os.path.join(tmp_path, "test_model_path")
46+
model.save_pretrained(model_path)
47+
48+
# check that saved values match model values
49+
# note that shared weights are only serialized once
50+
safetensors_path = os.path.join(model_path, "model.safetensors")
51+
with safe_open(safetensors_path, framework="pt", device="cpu") as file:
52+
saved_keys = set(file.keys())
53+
assert {
54+
"fcs.0.weight",
55+
"fcs.1.weight",
56+
"fcs.2.weight",
57+
"fcs.3.weight",
58+
"fcs.4.weight",
59+
} <= saved_keys
60+
for key in saved_keys:
61+
param = model.get_parameter(key)
62+
saved_param = file.get_tensor(key)
4263

43-
# TODO: reload model
64+
if param.device.type != "meta": # skip testing values in offload case
65+
assert torch.equal(param, saved_param)
4466

4567

46-
@pytest.mark.skip(reason="Requires changes in upstream transformers")
47-
# https://github.com/huggingface/transformers/pull/39280
48-
# https://github.com/huggingface/transformers/pull/39263
4968
@requires_gpu
5069
@requires_accelerate()
5170
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
5271
@pytest.mark.parametrize("randomize", (True, False))
5372
def test_serialization_offload(type, randomize, model_apply, tmp_path):
5473
test_serialization(type, randomize, model_apply, tmp_path, offload=True)
74+
75+
76+
@pytest.mark.skip("Requires transformers#40673")
77+
@requires_gpu
78+
@pytest.mark.parametrize(
79+
"model_stub,exp_perplexity",
80+
[
81+
("nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", 10.0),
82+
("nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", 10.0),
83+
],
84+
)
85+
def test_load_perplexity(model_stub, exp_perplexity):
86+
model = AutoModelForCausalLM.from_pretrained(model_stub, device_map="cuda")
87+
tokenizer = AutoTokenizer.from_pretrained(model_stub)
88+
89+
prompt = "The capital of France is Paris, the capital of Germany is Berlin"
90+
inputs = tokenizer(prompt, return_tensors="pt")
91+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
92+
labels = inputs["input_ids"]
93+
94+
with torch.no_grad():
95+
outputs = model(**inputs, labels=labels)
96+
97+
perplexity = torch.exp(outputs.loss)
98+
assert perplexity <= exp_perplexity

0 commit comments

Comments
 (0)