Skip to content

Commit 9039eb5

Browse files
committed
code cleanup and simplification
Signed-off-by: Kyle Sayers <[email protected]>
1 parent f220fb9 commit 9039eb5

File tree

5 files changed

+42
-34
lines changed

5 files changed

+42
-34
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
2121
from compressed_tensors.transform.utils.matrix import (
2222
apply_transform_weight,
23-
get_matrix_size,
23+
get_transform_size,
2424
)
2525
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2626
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
@@ -52,7 +52,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
:param args: defines how the transform will be applied to the module
5353
"""
5454
assert isinstance(module, Linear)
55-
size = get_matrix_size(module, args.location, self.scheme.head_dim)
55+
size = get_transform_size(module, args.location, self.scheme.head_dim)
5656
dtype = module.weight.dtype
5757
device = get_offloaded_device(module)
5858
exec_device = get_execution_device(module)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
2020
from compressed_tensors.transform.utils.matrix import (
2121
apply_transform_weight,
22-
get_matrix_size,
22+
get_transform_size,
2323
)
2424
from compressed_tensors.utils import get_offloaded_device
2525
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
@@ -51,7 +51,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5151
:param args: defines how the transform will be applied to the module
5252
"""
5353
assert isinstance(module, Linear)
54-
size = get_matrix_size(module, args.location, self.scheme.head_dim)
54+
size = get_transform_size(module, args.location, self.scheme.head_dim)
5555
dtype = module.weight.dtype
5656
device = get_offloaded_device(module)
5757

src/compressed_tensors/transform/utils/matrix.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,45 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Tuple, Callable
15+
from typing import Callable, Optional, Tuple
1616

1717
import torch
1818
from compressed_tensors.transform import TransformLocation
1919

2020

21-
__all__ = ["get_matrix_size", "apply_transform_weight"]
21+
__all__ = ["get_transform_size", "apply_transform_weight"]
2222

2323

24-
def get_matrix_size(
24+
def get_transform_size(
2525
module: torch.nn.Module,
2626
location: TransformLocation,
2727
head_dim: Optional[int] = None,
2828
) -> int:
2929
"""
30-
Determine the size of a matrix given its location on the module
30+
Determine the size of a transform matrix given its location on the module
3131
3232
:param module: module that matrix will be applied to
3333
:param location: location on module
34-
:TODO head_dim:
34+
:param head_dim: size of head when transform is applied to mha
3535
:return: size of matrix
3636
"""
37-
assert isinstance(module, torch.nn.Linear)
38-
39-
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
40-
size = module.in_features
37+
if isinstance(module, torch.nn.Linear):
38+
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
39+
size = module.in_features
40+
else:
41+
size = module.out_features
4142
else:
42-
size = module.out_features
43+
raise NotImplementedError(f"Transforms on {type(module)} are not supported")
4344

4445
if head_dim is not None:
4546
if size % head_dim != 0:
46-
raise ValueError("Cannot ")
47-
return head_dim
47+
raise ValueError(
48+
f"{head_dim} must divide {size} for {type(module)} at {location}"
49+
)
4850

49-
else:
50-
return size
51+
size = head_dim
52+
53+
return size
5154

5255

5356
def apply_transform_weight(
@@ -56,22 +59,22 @@ def apply_transform_weight(
5659
location: TransformLocation,
5760
module_type: type[torch.nn.Module],
5861
) -> torch.Tensor:
59-
if module_type == torch.nn.Linear:
60-
fn, axis = get_linear_transform_fn(module_type, location)
62+
fn, axis = get_linear_transform_fn(module_type, location)
6163

62-
else:
63-
raise NotImplementedError(
64-
f"Applying transforms to {module_type} is not supported"
65-
)
66-
6764
assert weight.shape[0] == weight.shape[1]
6865
head_dim = weight.shape[0]
6966
num_heads = value.shape[axis] // head_dim
7067

68+
value_dtype = value.dtype
69+
value = value.to(torch.float64)
70+
weight = weight.to(torch.float64)
71+
7172
value = value.unflatten(axis, (num_heads, head_dim))
7273
value = fn(weight, value)
7374
value = value.flatten(axis - 1, axis)
7475

76+
value = value.to(value_dtype)
77+
7578
return value
7679

7780

@@ -133,10 +136,10 @@ def get_linear_transform_fn(
133136
elif location == TransformLocation.OUTPUT:
134137
fn = lambda weight, value: value @ weight
135138
axis = -1
136-
139+
137140
if fn is None:
138141
raise NotImplementedError(
139142
f"Applying transforms to {module_type} {location} is not supported"
140143
)
141144

142-
return fn, axis
145+
return fn, axis

tests/test_transform/conftest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,20 @@ def __init__(
4444
self.num_key_value_groups = num_attention_heads // num_key_value_heads
4545
self.head_dim = hidden_size // num_attention_heads
4646
self.scaling = self.head_dim**-0.5
47+
assert hidden_size >= num_attention_heads * self.head_dim
4748

48-
self.q_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
49+
self.q_proj = torch.nn.Linear(
50+
hidden_size, num_attention_heads * self.head_dim, bias=False
51+
)
4952
self.k_proj = torch.nn.Linear(
5053
hidden_size, num_key_value_heads * self.head_dim, bias=False
5154
)
5255
self.v_proj = torch.nn.Linear(
5356
hidden_size, num_key_value_heads * self.head_dim, bias=False
5457
)
55-
self.o_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
58+
self.o_proj = torch.nn.Linear(
59+
num_attention_heads * self.head_dim, hidden_size, bias=False
60+
)
5661

5762
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
5863
batch_size, seq_len, hidden_size = hidden_states.shape

tests/test_transform/factory/test_correctness.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_correctness_model_offload(type, randomized, model_apply):
9393
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
9494
@pytest.mark.parametrize("randomized", (True, False))
9595
@pytest.mark.parametrize("head_dim", (16, 32))
96-
def test_correctness_heads(type, randomized, head_dim, offload=False):
96+
def test_correctness_heads(type, randomized, head_dim):
9797
hidden_size = 64
9898

9999
model = torch.nn.ModuleDict(
@@ -129,10 +129,10 @@ def test_correctness_heads(type, randomized, head_dim, offload=False):
129129

130130
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
131131
@pytest.mark.parametrize("randomized", (True, False))
132-
@pytest.mark.parametrize("head_dim", (8, 16))
133-
def test_correctness_attention_heads(type, randomized, head_dim, offload=False):
134-
hidden_size = 4096
135-
num_attention_heads = 32
132+
@pytest.mark.parametrize("head_dim", (4, 8))
133+
def test_correctness_attention_heads(type, randomized, head_dim):
134+
hidden_size = 64
135+
num_attention_heads = 8
136136

137137
attention = MockAttention(
138138
hidden_size=hidden_size,

0 commit comments

Comments
 (0)