Skip to content

Commit ee40ddf

Browse files
Merge branch 'main' into kylesayrs/transform-merge
2 parents 966b50e + 180226b commit ee40ddf

File tree

4 files changed

+125
-80
lines changed

4 files changed

+125
-80
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ def compress_model(self, model: Module):
392392
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
393393

394394
if prefix in module_to_scheme or prefix in sparse_compression_targets:
395-
module_device = get_execution_device(module).type
396-
is_meta = module_device == "meta"
395+
module_device = get_execution_device(module)
396+
is_meta = (module_device.type == "meta")
397397

398398
exec_device = "meta" if is_meta else "cpu"
399399
onloading_device = "meta" if is_meta else module_device
@@ -747,12 +747,16 @@ def _replace_weights(self, dense_weight_generator, model: Module):
747747

748748
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
749749
"""
750-
Returns a dictionary which maps quantized module names to their quantization schemes
750+
Returns a dictionary which maps quantized module names to their quantization
751+
schemes. Only includes modules with weight quantization
751752
"""
752753
return {
753754
fix_fsdp_module_name(name): module.quantization_scheme
754755
for name, module in model.named_modules()
755-
if is_module_quantized(module)
756+
if (
757+
hasattr(module, "quantization_scheme") and
758+
module.quantization_scheme.weights is not None
759+
)
756760
}
757761

758762

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import math
1616
from typing import Optional
1717

18+
import math
1819
import torch
1920
from compressed_tensors.transform import TransformArgs, TransformScheme
2021
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
@@ -26,7 +27,6 @@
2627
from compressed_tensors.utils import (
2728
get_execution_device,
2829
get_offloaded_device,
29-
match_modules_set,
3030
)
3131
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
3232
from torch import Tensor, device, dtype
@@ -107,8 +107,7 @@ def forward(self, value: Tensor) -> Tensor:
107107

108108
if self.args.inverse:
109109
weight = weight.T
110-
111-
return (
112-
apply_transform_weight(weight, value, self.args.location, self.module_type)
113-
/ self._scale
114-
)
110+
111+
return apply_transform_weight(
112+
weight, value, self.args.location, self.module_type
113+
) / self._scale

src/compressed_tensors/transform/utils/matrix.py

Lines changed: 62 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -59,47 +59,13 @@ def get_transform_size(
5959

6060

6161
def apply_transform_weight(
62-
weight: torch.Tensor,
62+
transform_weight: torch.Tensor,
6363
value: torch.Tensor,
6464
location: TransformLocation,
6565
module_type: type[torch.nn.Module],
6666
) -> torch.Tensor:
6767
"""
68-
:param weight: transform weight to apply
69-
:param value: value to apply weight to
70-
:param location: determines how weight should be applied
71-
:param model_type: result of type(module), passed in to determine application of
72-
weight transform. This is needed because torch uses convention:
73-
- torch.nn.Linear(in_features,out_features) has weight shape
74-
(out_features, in_features)
75-
- torch.nn.Embedding(num_embeddings, embedding_dim) has weight shape
76-
(num_embeddings, embedding_dim)
77-
The transform has to account for Linear's transposed weights
78-
:return: value after weight has been applied
79-
"""
80-
# get function used to apply transform
81-
fn, axis = _get_transform_method(module_type, location)
82-
83-
# reshape for head_dim
84-
head_dim = weight.shape[0]
85-
num_heads = value.shape[axis] // head_dim
86-
value = value.unflatten(axis, (num_heads, head_dim))
87-
88-
# apply transform
89-
value = fn(weight, value)
90-
91-
# [undo] reshape for head_dim
92-
value = value.flatten(axis - 1, axis)
93-
94-
return value
95-
96-
97-
def _get_transform_method(
98-
module_type: type[torch.nn.Module],
99-
location: TransformLocation,
100-
) -> Tuple[Callable[[torch.Tensor, torch.Tensor], torch.Tensor], int]:
101-
"""
102-
Using the transform location, determine how to apply the transform weight to the
68+
Using the transform location, apply the transform_weight to the
10369
given value wrt linear weights. For more info on input and output transforms,
10470
see `TransformLocation`
10571
@@ -129,51 +95,85 @@ def _get_transform_method(
12995
= y U
13096
= yh
13197
132-
:param weight: transform weight to apply
133-
:param value: value to apply weight to
98+
:param transform_weight: transform weight to apply
99+
:param value: value to apply transform_weight to
134100
:param location: determines how weight should be applied
135-
:return: value after transform weight has been applied
101+
:param model_type: result of type(module), passed in to determine application of
102+
weight transform
103+
:return: value after transform_weight has been applied
136104
"""
137-
fn = axis = None
105+
106+
assert transform_weight.shape[0] == transform_weight.shape[1]
138107

139108
if module_type == torch.nn.Linear:
140109
if location == TransformLocation.INPUT:
141-
fn = lambda weight, value: value @ weight
142-
axis = -1
110+
return _multihead_matmul(value, transform_weight)
143111

144112
elif location == TransformLocation.WEIGHT_INPUT:
145-
fn = lambda weight, value: value @ weight.T
146-
axis = -1
113+
# equivalent to (transform_weight @ value.T).T
114+
return _multihead_matmul(value, transform_weight.T)
147115

148116
elif location == TransformLocation.WEIGHT_OUTPUT:
149-
fn = lambda weight, value: weight.T @ value
150-
axis = -2
117+
# equivalent to (value.T @ transform_weight).T
118+
return _multihead_matmul(transform_weight.T, value)
151119

152120
elif location == TransformLocation.OUTPUT:
153-
fn = lambda weight, value: value @ weight
154-
axis = -1
121+
return _multihead_matmul(value, transform_weight)
155122

156123
# similar derivation to torch.nn.Linear, but `y = (x W)`
157-
if module_type == torch.nn.Embedding:
124+
elif module_type == torch.nn.Embedding:
158125
if location == TransformLocation.INPUT:
159-
fn = lambda weight, value: value @ weight
160-
axis = -1
126+
return _multihead_matmul(value, transform_weight)
161127

162128
elif location == TransformLocation.WEIGHT_INPUT:
163-
fn = lambda weight, value: weight @ value
164-
axis = -1
129+
return _multihead_matmul(
130+
transform_weight,
131+
value,
132+
)
165133

166134
elif location == TransformLocation.WEIGHT_OUTPUT:
167-
fn = lambda weight, value: value @ weight
168-
axis = -1
135+
return _multihead_matmul(value, transform_weight)
169136

170137
elif location == TransformLocation.OUTPUT:
171-
fn = lambda weight, value: value @ weight
172-
axis = -1
138+
return _multihead_matmul(value, transform_weight)
173139

174-
if fn is None:
175-
raise NotImplementedError(
176-
f"Applying transforms to {module_type} {location} is not supported"
177-
)
140+
raise NotImplementedError(
141+
f"Applying transforms to {module_type} {location} is not supported"
142+
)
178143

179-
return fn, axis
144+
145+
def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
146+
"""
147+
Performs A @ B for last two dims of two matrices A and B that possibly
148+
have different shapes, as is the case in multi-headed dimension. If
149+
shapes are different, this is equivalent to converting the last two dims
150+
of the smaller matrix into a block-diagonal matrix with the same shape as
151+
the last two dims of the larger matrix.
152+
153+
E.g. if A is half the size of B, this function will perform
154+
[[A ] @ B
155+
[ A]]
156+
157+
If B is a third of the size of A, this function will perform
158+
A @ [[B ]
159+
[ B ]
160+
[ B]]
161+
162+
This function will error out if the shapes are not evenly divisble
163+
164+
:param A: left-hand tensor
165+
:param B: right-hand tensor
166+
:return: result
167+
"""
168+
if A.shape[-1] > B.shape[-2]:
169+
head_dim = B.shape[-2]
170+
num_heads = A.shape[-1] // head_dim
171+
A = A.unflatten(-1, (num_heads, head_dim))
172+
return (A @ B).flatten(-2, -1)
173+
elif A.shape[-1] < B.shape[-2]:
174+
head_dim = A.shape[-1]
175+
num_heads = B.shape[-2] // head_dim
176+
B = B.unflatten(-2, (num_heads, head_dim))
177+
return (A @ B).flatten(-3, -2)
178+
else:
179+
return A @ B

tests/test_transform/factory/test_correctness.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
3030
@pytest.mark.parametrize("randomized", (True, False))
3131
@pytest.mark.parametrize("head_dim", (None, 2, 4))
32-
def test_correctness_linear(type, randomized, head_dim):
32+
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
33+
def test_correctness_linear(type, randomized, head_dim, input_batch_size):
3334
size = (4, 8)
3435
module = torch.nn.Linear(*size, bias=False)
3536
scheme = TransformScheme(type=type, randomized=randomized, head_dim=head_dim)
@@ -48,7 +49,7 @@ def test_correctness_linear(type, randomized, head_dim):
4849
module, TransformArgs(targets="Linear", location="output", inverse=True)
4950
)
5051

51-
input = torch.rand((17, 5, size[0]))
52+
input = torch.rand((input_batch_size, 5, size[0]))
5253
true_output = input @ module.weight.T
5354
input_transformed = input_tfm(input)
5455
weight_transformed = w_out_tfm(w_in_tfm(module.weight))
@@ -57,10 +58,10 @@ def test_correctness_linear(type, randomized, head_dim):
5758

5859

5960
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
60-
@pytest.mark.parametrize("randomize", (True, False))
61+
@pytest.mark.parametrize("randomized", (True, False))
6162
@pytest.mark.parametrize("embed_loc", ("weight_output", "output"))
6263
@pytest.mark.parametrize("linear_loc", ("input", "weight_input"))
63-
def test_correctness_embedding(type, randomize, embed_loc, linear_loc):
64+
def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
6465
model = torch.nn.Sequential(
6566
torch.nn.Embedding(2, 4),
6667
torch.nn.Linear(4, 8, bias=False),
@@ -73,7 +74,7 @@ def test_correctness_embedding(type, randomize, embed_loc, linear_loc):
7374
config_groups={
7475
"": TransformScheme(
7576
type=type,
76-
randomize=randomize,
77+
randomized=randomized,
7778
apply=[
7879
TransformArgs(targets="Embedding", location=embed_loc),
7980
TransformArgs(targets="Linear", location=linear_loc, inverse=True),
@@ -155,6 +156,47 @@ def test_correctness_attention_heads(type, randomize, head_dim):
155156
@requires_gpu
156157
@requires_accelerate()
157158
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
158-
@pytest.mark.parametrize("randomize", (True, False))
159-
def test_correctness_model_offload(type, randomize, model_apply):
160-
test_correctness_model(type, randomize, model_apply, offload=True)
159+
@pytest.mark.parametrize("randomized", (True, False))
160+
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
161+
def test_correctness_model_offload(type, randomized, input_batch_size, model_apply):
162+
test_correctness_model(
163+
type, randomized, input_batch_size, model_apply, offload=True
164+
)
165+
166+
167+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
168+
@pytest.mark.parametrize("randomized", (True, False))
169+
@pytest.mark.parametrize("head_dim", (4, 8))
170+
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
171+
def test_correctness_attention_heads(type, randomized, head_dim, input_batch_size):
172+
hidden_size = 64
173+
num_attention_heads = 8
174+
175+
attention = MockAttention(
176+
hidden_size=hidden_size,
177+
num_attention_heads=num_attention_heads,
178+
num_key_value_heads=head_dim,
179+
)
180+
181+
input = torch.rand(input_batch_size, 5, hidden_size)
182+
true_output = attention(input)
183+
184+
config = TransformConfig(
185+
config_groups={
186+
"": TransformScheme(
187+
type=type,
188+
randomized=randomized,
189+
head_dim=head_dim,
190+
apply=[
191+
TransformArgs(targets="v_proj", location="weight_output"),
192+
TransformArgs(
193+
targets="o_proj", location="weight_input", inverse=True
194+
),
195+
],
196+
)
197+
}
198+
)
199+
apply_transform_config(attention, config)
200+
201+
output = attention(input)
202+
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)

0 commit comments

Comments
 (0)