Skip to content

Commit f220fb9

Browse files
committed
clean up reshaping
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 116b9f9 commit f220fb9

File tree

4 files changed

+73
-44
lines changed

4 files changed

+73
-44
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def create_transform(self, module: Module, args: TransformArgs):
6060
factory_kwargs = {"construct_device": exec_device}
6161
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6262
perm = self.perms[weight] if self.scheme.randomize else None
63-
return HadamardTransform(weight, perm, args)
63+
return HadamardTransform(weight, perm, args, type(module))
6464

6565
def _create_weight(
6666
self,
@@ -85,11 +85,13 @@ def __init__(
8585
weight: Parameter,
8686
perm: Optional[Parameter],
8787
args: TransformArgs,
88+
module_type: type[torch.nn.Module],
8889
):
8990
super().__init__()
9091
self.weight = weight
9192
self.perm = perm
9293
self.args = args
94+
self.module_type = module_type
9395

9496
def forward(self, value: Tensor) -> Tensor:
9597
weight = self.weight
@@ -100,4 +102,6 @@ def forward(self, value: Tensor) -> Tensor:
100102
if self.args.inverse:
101103
weight = weight.T
102104

103-
return apply_transform_weight(weight, value, self.args.location)
105+
return apply_transform_weight(
106+
weight, value, self.args.location, self.module_type
107+
)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5959
if args.inverse:
6060
weight = self.inverses[weight]
6161

62-
return RandomMatrixTransform(weight, args)
62+
return RandomMatrixTransform(weight, args, type(module))
6363

6464
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6565
# TODO: verify that weight is invertible (has non-zero determinant)
@@ -74,17 +74,27 @@ def _create_inverse(self, weight: Parameter) -> Parameter:
7474

7575

7676
class RandomMatrixTransform(TransformBase):
77-
def __init__(self, weight: Tensor, args: TransformArgs):
77+
def __init__(
78+
self,
79+
weight: Tensor,
80+
args: TransformArgs,
81+
module_type: type[torch.nn.Module],
82+
):
7883
super().__init__()
7984
self.weight = weight # is an inverse if args.inverse
8085
self.args = args
86+
self.module_type = module_type
8187

8288
def forward(self, value: Tensor) -> Parameter:
83-
return apply_transform_weight(self.weight, value, self.args.location)
89+
return apply_transform_weight(
90+
self.weight, value, self.args.location, self.module_type
91+
)
8492

8593
def right_inverse(self, value: Tensor) -> Tensor:
8694
inverse = high_precision_invert(self.weight)
87-
return apply_transform_weight(inverse, value, self.args.location)
95+
return apply_transform_weight(
96+
inverse, value, self.args.location, self.module_type
97+
)
8898

8999

90100
def high_precision_invert(weight: Tensor) -> Tensor:

src/compressed_tensors/transform/utils/matrix.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from typing import Optional, Tuple, Callable
1616

1717
import torch
1818
from compressed_tensors.transform import TransformLocation
@@ -42,7 +42,8 @@ def get_matrix_size(
4242
size = module.out_features
4343

4444
if head_dim is not None:
45-
assert size % head_dim == 0
45+
if size % head_dim != 0:
46+
raise ValueError("Cannot ")
4647
return head_dim
4748

4849
else:
@@ -53,18 +54,35 @@ def apply_transform_weight(
5354
weight: torch.Tensor,
5455
value: torch.Tensor,
5556
location: TransformLocation,
57+
module_type: type[torch.nn.Module],
5658
) -> torch.Tensor:
57-
return apply_transform_weight_linear(weight, value, location)
59+
if module_type == torch.nn.Linear:
60+
fn, axis = get_linear_transform_fn(module_type, location)
5861

62+
else:
63+
raise NotImplementedError(
64+
f"Applying transforms to {module_type} is not supported"
65+
)
66+
67+
assert weight.shape[0] == weight.shape[1]
68+
head_dim = weight.shape[0]
69+
num_heads = value.shape[axis] // head_dim
5970

60-
def apply_transform_weight_linear(
61-
weight: torch.Tensor,
62-
value: torch.Tensor,
71+
value = value.unflatten(axis, (num_heads, head_dim))
72+
value = fn(weight, value)
73+
value = value.flatten(axis - 1, axis)
74+
75+
return value
76+
77+
78+
def get_linear_transform_fn(
79+
module_type: type[torch.nn.Module],
6380
location: TransformLocation,
64-
):
81+
) -> Tuple[Callable[[torch.Tensor, torch.Tensor], torch.Tensor], int]:
6582
"""
6683
Using the transform location, determine how to apply the transform weight to the
67-
given value. For more info on input and output transforms, see `TransformLocation`
84+
given value wrt linear weights. For more info on input and output transforms,
85+
see `TransformLocation`
6886
6987
The following explains how weights should be applied to values according to location
7088
@@ -97,31 +115,28 @@ def apply_transform_weight_linear(
97115
:param location: determines how weight should be applied
98116
:return: value after transform weight has been applied
99117
"""
100-
value_shape = value.shape
101-
weight_size = weight.shape[0]
102-
assert weight.shape[0] == weight.shape[1]
103-
104-
if location == TransformLocation.INPUT:
105-
num_heads = value_shape[1] // weight_size
106-
value = value.reshape(value_shape[0], num_heads, weight_size)
107-
ret = value @ weight
108-
109-
elif location == TransformLocation.WEIGHT_INPUT:
110-
num_heads = value_shape[1] // weight_size
111-
value = value.reshape(value_shape[0], num_heads, weight_size)
112-
ret = value @ weight.T
113-
114-
elif location == TransformLocation.WEIGHT_OUTPUT:
115-
num_heads = value_shape[0] // weight_size
116-
value = value.reshape(num_heads, weight_size, value_shape[1])
117-
ret = weight.T @ value
118-
119-
elif location == TransformLocation.OUTPUT:
120-
num_heads = value_shape[1] // weight_size
121-
value = value.reshape(value_shape[0], num_heads, weight_size)
122-
ret = value @ weight
123-
124-
else:
125-
raise NotImplementedError(f"{location} has not been implemented yet")
126-
127-
return ret.reshape(value_shape)
118+
fn = axis = None
119+
120+
if module_type == torch.nn.Linear:
121+
if location == TransformLocation.INPUT:
122+
fn = lambda weight, value: value @ weight
123+
axis = -1
124+
125+
elif location == TransformLocation.WEIGHT_INPUT:
126+
fn = lambda weight, value: value @ weight.T
127+
axis = -1
128+
129+
elif location == TransformLocation.WEIGHT_OUTPUT:
130+
fn = lambda weight, value: weight.T @ value
131+
axis = -2
132+
133+
elif location == TransformLocation.OUTPUT:
134+
fn = lambda weight, value: value @ weight
135+
axis = -1
136+
137+
if fn is None:
138+
raise NotImplementedError(
139+
f"Applying transforms to {module_type} {location} is not supported"
140+
)
141+
142+
return fn, axis

tests/test_transform/factory/test_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_correctness_linear(type, randomized):
4747
module, TransformArgs(targets="Linear", location="output", inverse=True)
4848
)
4949

50-
input = torch.rand((17, size[0]))
50+
input = torch.rand((17, 5, size[0]))
5151
true_output = input @ module.weight.T
5252
input_transformed = input_tfm(input)
5353
weight_transformed = w_out_tfm(w_in_tfm(module.weight))
@@ -64,7 +64,7 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
6464
model = offloaded_dispatch(model, torch.device("cuda"))
6565

6666
# get output
67-
input = torch.rand((17, model.fcs[0].in_features))
67+
input = torch.rand((17, 5, model.fcs[0].in_features))
6868
if offload:
6969
input = input.to(torch.device("cuda"))
7070
true_output = model(input)

0 commit comments

Comments
 (0)