Skip to content

Commit 84386fa

Browse files
committed
Merge branch 'kylesayrs/transform_save' into kylesayrs/transform-merge
2 parents 2267846 + 4085613 commit 84386fa

File tree

10 files changed

+148
-40
lines changed

10 files changed

+148
-40
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def compress_model(self, model: Module):
393393

394394
if prefix in module_to_scheme or prefix in sparse_compression_targets:
395395
module_device = get_execution_device(module).type
396-
is_meta = (module_device == "meta")
396+
is_meta = module_device == "meta"
397397

398398
exec_device = "meta" if is_meta else "cpu"
399399
onloading_device = "meta" if is_meta else module_device

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,13 @@ def sparse24_bitmask_compress(
178178

179179
if tensor.is_meta:
180180
num_rows, num_cols = tensor.shape
181-
compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta")
181+
compressed_values = torch.empty(
182+
(num_rows, num_cols // 2), dtype=tensor.dtype, device="meta"
183+
)
182184
packed_cols = (num_cols + 7) // 8
183-
bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta")
185+
bitmasks_packed = torch.empty(
186+
(num_rows, packed_cols), dtype=torch.uint8, device="meta"
187+
)
184188
return compressed_values, bitmasks_packed
185189

186190
bytemasks = get_24_bytemasks(tensor=tensor)

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,12 @@ def _initialize_scale_zero_point(
189189
else:
190190
# TODO: consider erroring out in the future as if the dtype if not one of these,
191191
# there is likely bug
192-
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
192+
if scale_dtype not in [
193+
torch.float16,
194+
torch.bfloat16,
195+
torch.float32,
196+
torch.float64,
197+
]:
193198
scale_dtype = torch.float16
194199
zp_dtype = quantization_args.pytorch_dtype()
195200

src/compressed_tensors/transform/factory/base.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from typing import Optional
16+
from collections import defaultdict
17+
from typing import List, Optional, Tuple
1718

1819
import torch
1920
import torch.nn.utils.parametrize as P
@@ -49,10 +50,13 @@ class TransformFactory(RegistryMixin, ABC):
4950
:param seed: random seed used to transform weight randomization
5051
"""
5152

53+
transforms: List["TransformBase"]
54+
5255
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
5356
self.name = name
5457
self.scheme = scheme
5558
self.generator = torch.Generator()
59+
self.transforms = list()
5660
if seed is not None:
5761
self.generator.manual_seed(seed)
5862

@@ -91,16 +95,26 @@ def apply_to_model(self, model: Module):
9195
if is_target(name, module, arg.targets, arg.ignore):
9296
self._apply_to_module(module, arg)
9397

98+
self._update_tied_weights()
99+
94100
def _apply_to_module(self, module: Module, args: TransformArgs):
95101
"""
96102
Create transforms and apply them to the module
97103
98104
:param module: target module to apply transforms to
99105
:param args: defines how the transform will be applied to the target module
100106
"""
107+
if has_offloaded_params(module):
108+
if module._hf_hook.place_submodules:
109+
raise NotImplementedError(
110+
"Applying transforms to offloaded submodules with "
111+
"`place_submodules=True` is not supported"
112+
)
113+
101114
# create transform as submodule
102115
transform_name = f"{self.name}_{args.location.value}"
103116
transform = self.create_transform(module, args)
117+
self.transforms.append(transform)
104118
register_offload_module(module, transform_name, transform)
105119

106120
# register input transformation hook
@@ -129,8 +143,9 @@ def input_hook(_, args):
129143
raise ValueError("Offloaded training is not supported")
130144
P.register_parametrization(module, "weight", transform)
131145

132-
# transform is no longer needed (unfusing is not supported)
133-
delete_offload_module(module, transform_name)
146+
else:
147+
# transform is no longer needed (unfusing is not supported)
148+
delete_offload_module(module, transform_name)
134149

135150
# register output transformation hook
136151
elif args.location == TransformLocation.OUTPUT:
@@ -144,6 +159,35 @@ def output_hook(_, _input, output):
144159
else:
145160
raise NotImplementedError()
146161

162+
def _update_tied_weights(self):
163+
"""
164+
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
165+
which is used by transformers to detect and remove shared pointers
166+
during saving
167+
"""
168+
# avoid issues with this method being called twice
169+
for transform in self.transforms:
170+
transform._dynamic_tied_weights_keys = list()
171+
172+
# map from data_ptrs to keys
173+
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
174+
for transform in self.transforms:
175+
for name, param in transform.named_parameters(recurse=False):
176+
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
177+
if has_offloaded_params(transform):
178+
param = transform._hf_hook.weights_map[name]
179+
ptr_to_keys[param.data_ptr()].append((transform, name))
180+
181+
# populate `_dynamic_tied_weights_keys` if there is more than one key
182+
# and ensure that they share tensors
183+
for shared_keys in ptr_to_keys.values():
184+
if len(shared_keys) > 1:
185+
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
186+
187+
for transform, name in shared_keys:
188+
transform._dynamic_tied_weights_keys.append(name)
189+
setattr(transform, name, tensor)
190+
147191

148192
class TransformBase(InternalModule, ABC):
149193
"""
@@ -152,6 +196,11 @@ class TransformBase(InternalModule, ABC):
152196

153197
args: TransformArgs
154198
weight: Parameter
199+
_dynamic_tied_weights_keys: List[str]
200+
201+
def __init__(self):
202+
super().__init__()
203+
self._dynamic_tied_weights_keys = list()
155204

156205
@abstractmethod
157206
def forward(self, value: Tensor) -> Tensor:

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
7070

7171
def _create_inverse(self, weight: Parameter) -> Parameter:
7272
data = high_precision_invert(weight.data)
73+
data = data.contiguous() # ensure proper serialization
7374
return Parameter(data, requires_grad=False)
7475

7576

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
446446
cpu_model, s_config, q_format
447447
)
448448
# Only stores dtype because meta model does not store values
449-
expected = {
450-
k: v.dtype
451-
for k, v in reference_compressor.compress(cpu_model).items()
452-
}
449+
expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()}
453450

454451
# Load model on meta device
455452
meta_model = AutoModelForCausalLM.from_pretrained(

tests/test_transform/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
import pytest
1616
import torch
17-
from compressed_tensors.transform import TransformArgs
17+
from compressed_tensors.transform import TransformArgs, TransformFactory
18+
from transformers import PretrainedConfig, PreTrainedModel
1819

1920

20-
class TransformableModel(torch.nn.Module):
21+
class TransformableModel(PreTrainedModel):
2122
def __init__(self, *sizes):
22-
super().__init__()
23+
super().__init__(config=PretrainedConfig())
2324
self.fcs = torch.nn.ModuleList(
2425
[
2526
torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)

tests/test_transform/factory/test_correctness.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def test_correctness_linear(type, randomized, head_dim):
5757

5858

5959
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
60-
@pytest.mark.parametrize("randomized", (True, False))
60+
@pytest.mark.parametrize("randomize", (True, False))
6161
@pytest.mark.parametrize("embed_loc", ("weight_output", "output"))
6262
@pytest.mark.parametrize("linear_loc", ("input", "weight_input"))
63-
def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
63+
def test_correctness_embedding(type, randomize, embed_loc, linear_loc):
6464
model = torch.nn.Sequential(
6565
torch.nn.Embedding(2, 4),
6666
torch.nn.Linear(4, 8, bias=False),
@@ -73,7 +73,7 @@ def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
7373
config_groups={
7474
"": TransformScheme(
7575
type=type,
76-
randomized=randomized,
76+
randomize=randomize,
7777
apply=[
7878
TransformArgs(targets="Embedding", location=embed_loc),
7979
TransformArgs(targets="Linear", location=linear_loc, inverse=True),
@@ -89,8 +89,8 @@ def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
8989

9090

9191
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
92-
@pytest.mark.parametrize("randomized", (True, False))
93-
def test_correctness_model(type, randomized, model_apply, offload=False):
92+
@pytest.mark.parametrize("randomize", (True, False))
93+
def test_correctness_model(type, randomize, model_apply, offload=False):
9494
# load model
9595
model = model_apply[0]
9696
if offload:
@@ -105,7 +105,7 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
105105
# apply transforms
106106
config = TransformConfig(
107107
config_groups={
108-
"": TransformScheme(type=type, randomized=randomized, apply=model_apply[1])
108+
"": TransformScheme(type=type, randomize=randomize, apply=model_apply[1])
109109
}
110110
)
111111
apply_transform_config(model, config)
@@ -115,18 +115,10 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
115115
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
116116

117117

118-
@requires_gpu
119-
@requires_accelerate()
120-
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
121-
@pytest.mark.parametrize("randomized", (True, False))
122-
def test_correctness_model_offload(type, randomized, model_apply):
123-
test_correctness_model(type, randomized, model_apply, offload=True)
124-
125-
126118
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
127-
@pytest.mark.parametrize("randomized", (True, False))
119+
@pytest.mark.parametrize("randomize", (True, False))
128120
@pytest.mark.parametrize("head_dim", (4, 8))
129-
def test_correctness_attention_heads(type, randomized, head_dim):
121+
def test_correctness_attention_heads(type, randomize, head_dim):
130122
hidden_size = 64
131123
num_attention_heads = 8
132124

@@ -143,7 +135,7 @@ def test_correctness_attention_heads(type, randomized, head_dim):
143135
config_groups={
144136
"": TransformScheme(
145137
type=type,
146-
randomized=randomized,
138+
randomize=randomize,
147139
head_dim=head_dim,
148140
apply=[
149141
TransformArgs(targets="v_proj", location="weight_output"),
@@ -158,3 +150,11 @@ def test_correctness_attention_heads(type, randomized, head_dim):
158150

159151
output = attention(input)
160152
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
153+
154+
155+
@requires_gpu
156+
@requires_accelerate()
157+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
158+
@pytest.mark.parametrize("randomize", (True, False))
159+
def test_correctness_model_offload(type, randomize, model_apply):
160+
test_correctness_model(type, randomize, model_apply, offload=True)

tests/test_transform/factory/test_memory.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030

3131
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
32-
@pytest.mark.parametrize("randomized", (True, False))
32+
@pytest.mark.parametrize("randomize", (True, False))
3333
@pytest.mark.parametrize("requires_grad", (True, False))
34-
def test_memory_sharing(type, randomized, requires_grad, offload=False):
34+
def test_memory_sharing(type, randomize, requires_grad, offload=False):
3535
# load model (maybe with offloading)
3636
model = TransformableModel(2, 2, 4, 4, 8, 8)
3737
if offload:
@@ -42,7 +42,7 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False):
4242
config_groups={
4343
"": TransformScheme(
4444
type=type,
45-
randomzied=randomized,
45+
randomize=randomize,
4646
requires_grad=requires_grad,
4747
apply=[
4848
TransformArgs(targets="Linear", location="input"),
@@ -84,9 +84,6 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False):
8484
@requires_gpu
8585
@requires_accelerate()
8686
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
87-
@pytest.mark.parametrize("randomized", (True, False))
88-
def test_memory_sharing_offload(
89-
type,
90-
randomized,
91-
):
92-
test_memory_sharing(type, randomized, requires_grad=False, offload=True)
87+
@pytest.mark.parametrize("randomize", (True, False))
88+
def test_memory_sharing_offload(type, randomize):
89+
test_memory_sharing(type, randomize, requires_grad=False, offload=True)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
from compressed_tensors.transform import (
18+
TransformConfig,
19+
TransformScheme,
20+
apply_transform_config,
21+
)
22+
from compressed_tensors.utils import offloaded_dispatch
23+
from tests.testing_utils import requires_accelerate, requires_gpu
24+
25+
26+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
27+
@pytest.mark.parametrize("randomize", (True, False))
28+
def test_serialization(type, randomize, model_apply, tmp_path, offload=False):
29+
# get model, maybe offload
30+
model, apply = model_apply
31+
if offload:
32+
offloaded_dispatch(model, torch.device("cuda"))
33+
34+
# apply transforms to model
35+
config = TransformConfig(
36+
config_groups={"": TransformScheme(type=type, randomize=randomize, apply=apply)}
37+
)
38+
apply_transform_config(model, config)
39+
40+
# save model
41+
model.save_pretrained(tmp_path)
42+
43+
# TODO: reload model
44+
45+
46+
@pytest.mark.skip(reason="Requires changes in upstream transformers")
47+
# https://github.com/huggingface/transformers/pull/39280
48+
# https://github.com/huggingface/transformers/pull/39263
49+
@requires_gpu
50+
@requires_accelerate()
51+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
52+
@pytest.mark.parametrize("randomize", (True, False))
53+
def test_serialization_offload(type, randomize, model_apply, tmp_path):
54+
test_serialization(type, randomize, model_apply, tmp_path, offload=True)

0 commit comments

Comments
 (0)