Skip to content

Commit b2df366

Browse files
[Transform] Serialize with tied weights (#370)
* add utilities Signed-off-by: Kyle Sayers <[email protected]> * add tests Signed-off-by: Kyle Sayers <[email protected]> * add additional tests Signed-off-by: Kyle Sayers <[email protected]> * add utils and tests Signed-off-by: Kyle Sayers <[email protected]> * Implement transform factories Signed-off-by: Kyle Sayers <[email protected]> * add permutations Signed-off-by: Kyle Sayers <[email protected]> * add delete_offload_module Signed-off-by: Kyle Sayers <[email protected]> * key inverses by weight Signed-off-by: Kyle Sayers <[email protected]> * fix tests Signed-off-by: Kyle Sayers <[email protected]> * standardize random hadamard Signed-off-by: Kyle Sayers <[email protected]> * prepend input hooks Signed-off-by: Kyle Sayers <[email protected]> * apply sqrt division first Signed-off-by: Kyle Sayers <[email protected]> * use divided hadamards Signed-off-by: Kyle Sayers <[email protected]> * fix typo Signed-off-by: Kyle Sayers <[email protected]> * add random option Signed-off-by: Kyle Sayers <[email protected]> * use random seeds, rename matrix multiply Signed-off-by: Kyle Sayers <[email protected]> * add deterministic generation to random matrix Signed-off-by: Kyle Sayers <[email protected]> * fix perm math Signed-off-by: Kyle Sayers <[email protected]> * update docstrings Signed-off-by: Kyle Sayers <[email protected]> * update docstrings Signed-off-by: Kyle Sayers <[email protected]> * cleanup Signed-off-by: Kyle Sayers <[email protected]> * cleanup 2 Signed-off-by: Kyle Sayers <[email protected]> * make seed optional Signed-off-by: Kyle Sayers <[email protected]> * remove iterable check and missing return value Signed-off-by: Kyle Sayers <[email protected]> * Remove unrelated changes * simplify code Signed-off-by: Kyle Sayers <[email protected]> * implement apply, use in tests Signed-off-by: Kyle Sayers <[email protected]> * use hadamards database file Signed-off-by: Kyle Sayers <[email protected]> * try manifest Signed-off-by: Kyle Sayers <[email protected]> * try setup, update hadamards list Signed-off-by: Kyle Sayers <[email protected]> * fix setup Signed-off-by: Kyle Sayers <[email protected]> * add docstrings, cleanup Signed-off-by: Kyle Sayers <[email protected]> * fix setup, thank you @dbarbuzzi Signed-off-by: Kyle Sayers <[email protected]> * remove numpy, add tests Signed-off-by: Kyle Sayers <[email protected]> * solidify dtype, add gpu tests Signed-off-by: Kyle Sayers <[email protected]> * fix docstring Signed-off-by: Kyle Sayers <[email protected]> * add device option Signed-off-by: Kyle Sayers <[email protected]> * construct on execution device, cache on offload device Signed-off-by: Kyle Sayers <[email protected]> * save construction device changes for later Signed-off-by: Kyle Sayers <[email protected]> * construct on execution device, cache on offload device * cite nja sloane Signed-off-by: Kyle Sayers <[email protected]> * remove dreg Signed-off-by: Kyle Sayers <[email protected]> * put on device via safe_open Signed-off-by: Kyle Sayers <[email protected]> * nits and docstrings Signed-off-by: Kyle Sayers <[email protected]> * update docstring Signed-off-by: Kyle Sayers <[email protected]> * Merge * merge with construct: construct in float32 Signed-off-by: Kyle Sayers <[email protected]> * construct with same dtype, constructing on fp32 found no difference Signed-off-by: Kyle Sayers <[email protected]> * remove unnecessary imports Signed-off-by: Kyle Sayers <[email protected]> * bugfixes (#375) Signed-off-by: Brian Dellabetta <[email protected]> * use factory_kwargs Signed-off-by: Kyle Sayers <[email protected]> * add frozen dict to deps Signed-off-by: Kyle Sayers <[email protected]> * fix style Signed-off-by: Kyle Sayers <[email protected]> * merge Signed-off-by: Kyle Sayers <[email protected]> * use delete_offload_module Signed-off-by: Kyle Sayers <[email protected]> * add docstrign Signed-off-by: Kyle Sayers <[email protected]> * use parametrize Signed-off-by: Kyle Sayers <[email protected]> * populate _dynamic_tied_weights_keys Signed-off-by: Kyle Sayers <[email protected]> * ensure serializable Signed-off-by: Kyle Sayers <[email protected]> * remove extra space Signed-off-by: Kyle Sayers <[email protected]> * apply style Signed-off-by: Kyle Sayers <[email protected]> * merge dregs * skip offloading tests until transformers changes land Signed-off-by: Kyle Sayers <[email protected]> * use set Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]> Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]>
1 parent 957a1d1 commit b2df366

File tree

6 files changed

+128
-29
lines changed

6 files changed

+128
-29
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from typing import Optional
16+
from collections import defaultdict
17+
from typing import List, Optional, Tuple, Set
1718

1819
import torch
1920
import torch.nn.utils.parametrize as P
@@ -49,10 +50,13 @@ class TransformFactory(RegistryMixin, ABC):
4950
:param seed: random seed used to transform weight randomization
5051
"""
5152

53+
transforms: List["TransformBase"]
54+
5255
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
5356
self.name = name
5457
self.scheme = scheme
5558
self.generator = torch.Generator()
59+
self.transforms = list()
5660
if seed is not None:
5761
self.generator.manual_seed(seed)
5862

@@ -90,16 +94,26 @@ def apply_to_model(self, model: Module):
9094
for _, module in match_named_modules(model, arg.targets, arg.ignore):
9195
self._apply_to_module(module, arg)
9296

97+
self._update_tied_weights()
98+
9399
def _apply_to_module(self, module: Module, args: TransformArgs):
94100
"""
95101
Create transforms and apply them to the module
96102
97103
:param module: target module to apply transforms to
98104
:param args: defines how the transform will be applied to the target module
99105
"""
106+
if has_offloaded_params(module):
107+
if module._hf_hook.place_submodules:
108+
raise NotImplementedError(
109+
"Applying transforms to offloaded submodules with "
110+
"`place_submodules=True` is not supported"
111+
)
112+
100113
# create transform as submodule
101114
transform_name = f"{self.name}_{args.location}"
102115
transform = self.create_transform(module, args)
116+
self.transforms.append(transform)
103117
register_offload_module(module, transform_name, transform)
104118

105119
# register input transformation hook
@@ -128,8 +142,9 @@ def input_hook(_, args):
128142
raise ValueError("Offloaded training is not supported")
129143
P.register_parametrization(module, "weight", transform)
130144

131-
# transform is no longer needed (unfusing is not supported)
132-
delete_offload_module(module, transform_name)
145+
else:
146+
# transform is no longer needed (unfusing is not supported)
147+
delete_offload_module(module, transform_name)
133148

134149
# register output transformation hook
135150
elif args.location == TransformLocation.OUTPUT:
@@ -143,6 +158,31 @@ def output_hook(_, _input, output):
143158
else:
144159
raise NotImplementedError()
145160

161+
def _update_tied_weights(self):
162+
"""
163+
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
164+
which is used by transformers to detect and remove shared pointers
165+
during saving
166+
"""
167+
# map from data_ptrs to keys
168+
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
169+
for transform in self.transforms:
170+
for name, param in transform.named_parameters(recurse=False):
171+
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
172+
if has_offloaded_params(transform):
173+
param = transform._hf_hook.weights_map[name]
174+
ptr_to_keys[param.data_ptr()].append((transform, name))
175+
176+
# populate `_dynamic_tied_weights_keys` if there is more than one key
177+
# and ensure that they share tensors
178+
for shared_keys in ptr_to_keys.values():
179+
if len(shared_keys) > 1:
180+
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
181+
182+
for transform, name in shared_keys:
183+
transform._dynamic_tied_weights_keys.add(name)
184+
setattr(transform, name, tensor)
185+
146186

147187
class TransformBase(InternalModule, ABC):
148188
"""
@@ -151,6 +191,11 @@ class TransformBase(InternalModule, ABC):
151191

152192
args: TransformArgs
153193
weight: Parameter
194+
_dynamic_tied_weights_keys: Set[str]
195+
196+
def __init__(self):
197+
super().__init__()
198+
self._dynamic_tied_weights_keys = set()
154199

155200
@abstractmethod
156201
def forward(self, value: Tensor) -> Tensor:

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
7070

7171
def _create_inverse(self, weight: Parameter) -> Parameter:
7272
data = high_precision_invert(weight.data)
73+
data = data.contiguous() # ensure proper serialization
7374
return Parameter(data, requires_grad=False)
7475

7576

tests/test_transform/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
import pytest
1616
import torch
17-
from compressed_tensors.transform import TransformArgs
17+
from compressed_tensors.transform import TransformArgs, TransformFactory
18+
from transformers import PretrainedConfig, PreTrainedModel
1819

1920

20-
class TransformableModel(torch.nn.Module):
21+
class TransformableModel(PreTrainedModel):
2122
def __init__(self, *sizes):
22-
super().__init__()
23+
super().__init__(config=PretrainedConfig())
2324
self.fcs = torch.nn.ModuleList(
2425
[
2526
torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)

tests/test_transform/factory/test_correctness.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727

2828

2929
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
30-
@pytest.mark.parametrize("randomized", (True, False))
30+
@pytest.mark.parametrize("randomize", (True, False))
3131
@pytest.mark.parametrize("head_dim", (None, 2, 4))
3232
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
33-
def test_correctness_linear(type, randomized, head_dim, input_batch_size):
33+
def test_correctness_linear(type, randomize, head_dim, input_batch_size):
3434
size = (4, 8)
3535
module = torch.nn.Linear(*size, bias=False)
36-
scheme = TransformScheme(type=type, randomized=randomized, head_dim=head_dim)
36+
scheme = TransformScheme(type=type, randomize=randomize, head_dim=head_dim)
3737
factory = TransformFactory.from_scheme(scheme, name="")
3838

3939
input_tfm = factory.create_transform(
@@ -58,10 +58,10 @@ def test_correctness_linear(type, randomized, head_dim, input_batch_size):
5858

5959

6060
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
61-
@pytest.mark.parametrize("randomized", (True, False))
61+
@pytest.mark.parametrize("randomize", (True, False))
6262
@pytest.mark.parametrize("embed_loc", ("weight_output", "output"))
6363
@pytest.mark.parametrize("linear_loc", ("input", "weight_input"))
64-
def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
64+
def test_correctness_embedding(type, randomize, embed_loc, linear_loc):
6565
model = torch.nn.Sequential(
6666
torch.nn.Embedding(2, 4),
6767
torch.nn.Linear(4, 8, bias=False),
@@ -74,7 +74,7 @@ def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
7474
config_groups={
7575
"": TransformScheme(
7676
type=type,
77-
randomized=randomized,
77+
randomize=randomize,
7878
apply=[
7979
TransformArgs(targets="Embedding", location=embed_loc),
8080
TransformArgs(targets="Linear", location=linear_loc, inverse=True),
@@ -90,10 +90,10 @@ def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
9090

9191

9292
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
93-
@pytest.mark.parametrize("randomized", (True, False))
93+
@pytest.mark.parametrize("randomize", (True, False))
9494
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
9595
def test_correctness_model(
96-
type, randomized, input_batch_size, model_apply, offload=False
96+
type, randomize, input_batch_size, model_apply, offload=False
9797
):
9898
# load model
9999
model = model_apply[0]
@@ -109,7 +109,7 @@ def test_correctness_model(
109109
# apply transforms
110110
config = TransformConfig(
111111
config_groups={
112-
"": TransformScheme(type=type, randomized=randomized, apply=model_apply[1])
112+
"": TransformScheme(type=type, randomize=randomize, apply=model_apply[1])
113113
}
114114
)
115115
apply_transform_config(model, config)
@@ -122,19 +122,17 @@ def test_correctness_model(
122122
@requires_gpu
123123
@requires_accelerate()
124124
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
125-
@pytest.mark.parametrize("randomized", (True, False))
125+
@pytest.mark.parametrize("randomize", (True, False))
126126
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
127-
def test_correctness_model_offload(type, randomized, input_batch_size, model_apply):
128-
test_correctness_model(
129-
type, randomized, input_batch_size, model_apply, offload=True
130-
)
127+
def test_correctness_model_offload(type, randomize, input_batch_size, model_apply):
128+
test_correctness_model(type, randomize, input_batch_size, model_apply, offload=True)
131129

132130

133131
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
134-
@pytest.mark.parametrize("randomized", (True, False))
132+
@pytest.mark.parametrize("randomize", (True, False))
135133
@pytest.mark.parametrize("head_dim", (4, 8))
136134
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
137-
def test_correctness_attention_heads(type, randomized, head_dim, input_batch_size):
135+
def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size):
138136
hidden_size = 64
139137
num_attention_heads = 8
140138

@@ -151,7 +149,7 @@ def test_correctness_attention_heads(type, randomized, head_dim, input_batch_siz
151149
config_groups={
152150
"": TransformScheme(
153151
type=type,
154-
randomized=randomized,
152+
randomize=randomize,
155153
head_dim=head_dim,
156154
apply=[
157155
TransformArgs(targets="v_proj", location="weight_output"),

tests/test_transform/factory/test_memory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030

3131
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
32-
@pytest.mark.parametrize("randomized", (True, False))
32+
@pytest.mark.parametrize("randomize", (True, False))
3333
@pytest.mark.parametrize("requires_grad", (True, False))
34-
def test_memory_sharing(type, randomized, requires_grad, offload=False):
34+
def test_memory_sharing(type, randomize, requires_grad, offload=False):
3535
# load model (maybe with offloading)
3636
model = TransformableModel(2, 2, 4, 4, 8, 8)
3737
if offload:
@@ -42,7 +42,7 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False):
4242
config_groups={
4343
"": TransformScheme(
4444
type=type,
45-
randomzied=randomized,
45+
randomzied=randomize,
4646
requires_grad=requires_grad,
4747
apply=[
4848
TransformArgs(targets="Linear", location="input"),
@@ -84,9 +84,9 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False):
8484
@requires_gpu
8585
@requires_accelerate()
8686
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
87-
@pytest.mark.parametrize("randomized", (True, False))
87+
@pytest.mark.parametrize("randomize", (True, False))
8888
def test_memory_sharing_offload(
8989
type,
90-
randomized,
90+
randomize,
9191
):
92-
test_memory_sharing(type, randomized, requires_grad=False, offload=True)
92+
test_memory_sharing(type, randomize, requires_grad=False, offload=True)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
from compressed_tensors.transform import (
18+
TransformConfig,
19+
TransformScheme,
20+
apply_transform_config,
21+
)
22+
from compressed_tensors.utils import offloaded_dispatch
23+
from tests.testing_utils import requires_accelerate, requires_gpu
24+
25+
26+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
27+
@pytest.mark.parametrize("randomize", (True, False))
28+
def test_serialization(type, randomize, model_apply, tmp_path, offload=False):
29+
# get model, maybe offload
30+
model, apply = model_apply
31+
if offload:
32+
offloaded_dispatch(model, torch.device("cuda"))
33+
34+
# apply transforms to model
35+
config = TransformConfig(
36+
config_groups={"": TransformScheme(type=type, randomize=randomize, apply=apply)}
37+
)
38+
apply_transform_config(model, config)
39+
40+
# save model
41+
model.save_pretrained(tmp_path)
42+
43+
# TODO: reload model
44+
45+
46+
@pytest.mark.skip(reason="Requires changes in upstream transformers")
47+
# https://github.com/huggingface/transformers/pull/39280
48+
# https://github.com/huggingface/transformers/pull/39263
49+
@requires_gpu
50+
@requires_accelerate()
51+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
52+
@pytest.mark.parametrize("randomize", (True, False))
53+
def test_serialization_offload(type, randomize, model_apply, tmp_path):
54+
test_serialization(type, randomize, model_apply, tmp_path, offload=True)

0 commit comments

Comments
 (0)